当前位置: 首页 > news >正文

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):"""Input:x: all points, [B, C, N]k: k nearest points of each pointReturn:idx: grouped points index, [B, N, k]"""inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):"""Input:nsample: max sample number in local regionxyz: all points, [B, N, C]new_xyz: query points, [B, S, C]Return:group_idx: grouped points index, [B, S, nsample]"""sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch# Make sure your CUDA is available.
assert torch.cuda.is_available()from knn_cuda import KNN
"""
if transpose_mode is True, ref   is Tensor [bs x nr x dim]query is Tensor [bs x nq x dim]return dist is Tensor [bs x nq x k]indx is Tensor [bs x nq x k]
elseref   is Tensor [bs x dim x nr]query is Tensor [bs x dim x nq]return dist is Tensor [bs x k x nq]indx is Tensor [bs x k x nq]
"""knn = KNN(k=10, transpose_mode=True)ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNNdef knn(x, k):inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idxdef square_distance(src, dst):B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx# Custom knn implementation
def test_knn(query, k, times):query = query.permute(0,2,1)start_time = time.time()  # Start timerfor i in range(times):indx = knn(query, k = k)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Custom knn_point implementation
def test_knn_point(ref, query, k, times):start_time = time.time()  # Start timerfor i in range(times):indx = knn_point(k, ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):knn = KNN(k=k, transpose_mode=True)start_time = time.time()  # Start timerfor i in range(times):dist, indx = knn(ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Main testing function
def test_knn_methods(ref, query, k, times):print("Test times: %d" % times)# Test custom knntime_knn = test_knn(query, k, times)print(f"knn      : {time_knn:.6f} seconds")# Test custom knn_pointtime_point = test_knn_point(ref, query, k, times)print(f"knn_point: {time_point:.6f} seconds")# Test knn_cudatime_cuda = test_knn_cuda(ref, query, k, times)print(f"knn_cuda : {time_cuda:.6f} seconds")if __name__ == '__main__':# Sample inputB, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinatesk = 24                            # Number of nearest neighborsref = torch.randn(B, N, C).cuda() # Reference points# Test above methodstimes_list = [1,2,3,10,50,100]for times in times_list:test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

相关文章:

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言 K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程&#xff…...

【OJ】常用技巧

1. 模版 #include<bits/stdc.h> using namespace std;int main(){ios::sync_with_stdio(false);cin.tie(0);// write herereturn 0; }2. 填充数组 memset是一个字节一个字节填充&#xff0c;如果是使int类型填充非0或者-1就会报错&#xff0c;如 int a[100]; memset(a…...

Redis:Redis性能变慢的原因

一、淘汰策略性能问题 当使用Redis当作缓存使用时&#xff0c;通常会给这个实例设置内存上限maxmemory&#xff0c;然后设置一个数据淘汰策略&#xff1b;如果Redis实例设置了内存上限maxmemory&#xff0c;那么也有可能导致Redis变慢。 原因在于&#xff0c;当Redis内存达到…...

Linux多线程——利用C++模板对pthread线程库封装

文章目录 线程封装主要框架线程启动线程等待其他信息 测试函数 线程封装 我们之前介绍过pthread的线程库&#xff0c;这个线程库主要是基于C语言的void*指针来进行传参和返回 我们使用C的模板对其封装可以让他的使用更加方便&#xff0c;并且经过测试可以让我们更加直观的了解…...

SpringBoot教程(十五) | SpringBoot集成RabbitMq(消息丢失、消息重复、消息顺序、消息顺序)

SpringBoot教程&#xff08;十五&#xff09; | SpringBoot集成RabbitMq&#xff08;消息丢失、消息重复、消息顺序、消息顺序&#xff09; RabbitMQ常见问题解决方案问题一&#xff1a;消息丢失的解决方案&#xff08;1&#xff09;生成者丢失消息丢失的情景解决方案1&#xf…...

TensorRT-LLM高级用法

--multi_block_mode decoding phase, 推理1个新token&#xff0c; 平时&#xff1a;按照batch样本&#xff0c;按照head&#xff0c;将计算平均分给所有SM&#xff1b; batch_size*num_heads和SM数目相比较小时&#xff1a;有些SM会空闲&#xff1b;加了--multi_block_mode&…...

文心一言功能新升级:读文档、懂翻译、能识图

9月4日&#xff0c;百度文心一言官网显示&#xff0c;在向全社会开放一周年之际&#xff0c;文心一言进行了功能最新全面升级&#xff0c;同时在周年期间为新老会员增加1个月专业版免费使用体验。 据了解&#xff0c;针对网页版用户需求&#xff0c;文心一言实现了创作内容更加…...

C++机试——走方格的方案

题目 请计算n*m的棋盘格子&#xff08;n为横向的格子数&#xff0c;m为竖向的格子数&#xff09;从棋盘左上角出发沿着边缘线从左上角走到右下角&#xff0c;总共有多少种走法&#xff0c;要求不能走回头路&#xff0c;即&#xff1a;只能往右和往下走&#xff0c;不能往左和往…...

Bootstrap 字体图标无法显示问题,<i>标签字体图标无法显示问题

bootstrap fileInput 以及 Bootstrap 字体图标无法显示问题。 今天在用 bootstrap fileInput 插件的时候发现图标无法显示&#xff0c;如下&#xff1a; 查看DOM&#xff0c;发现那些图标是<i>标签做的&#xff1a; 网上的方案 方案1 网上很多人说是我们打乱了boots…...

docker registry 仓库加密

docker registry 仓库加密 1、背景 ​ 公司一直用的镜像仓库是docker registry&#xff0c;但是有个安全问题&#xff0c;就是仓库从web ui的浏览到镜像的拉取都是可以直接使用的&#xff0c;还是放到了公网上&#xff0c;只需要知道你的域名那就是畅通无阻了&#xff0c;可以…...

利用高德+ArcGIS优雅获取任何感兴趣的矢量边界

荷花十里&#xff0c;清风鉴水&#xff0c;明月天衣。 四时之景不同&#xff0c;乐亦无穷尽也。今天呢&#xff0c;梧桐君给大家讲解一下&#xff0c;如何利用高德地图&#xff0c;随机所欲的获取shp边界数据。 文章主要分成以下几个步骤&#xff1a; 首先搜索你想获取的矢量…...

炮弹【USACO】

题目背景 时/空限制&#xff1a;1s / 64MB 题目描述 贝茜已经精通了变成炮弹并沿着长度为 N 的数轴弹跳的艺术&#xff0c;数轴上的位置从左到右编号为 1,2,…,N 。 她从某个整数位置 S 开始&#xff0c;以 1 的起始能量向右弹跳。 如果贝茜的能量为 k &#xff0c;则她将…...

python如何读取excel文件内的数据

目录 前言一、安装openpyxl二、读取Excel数据总结前言 在Python中读取Excel数据,最常用的库之一是openpyxl(用于.xlsx格式)和xlrd(尽管xlrd从版本2.0开始不再支持.xlsx,仅支持旧的.xls格式)。然而,对于大多数现代应用来说,openpyxl是一个更好的选择,因为它支持.xlsx格…...

Java项目: 基于SpringBoot+mybatis+maven+mysql教师工作量管理系统(含源码+数据库+毕业论文)

一、项目简介 本项目是一套基于SpringBootmybatismavenmysql教师工作量管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观…...

项目开发--数据库--postgresql数据库操作

背景 1、安装postgresql的基础方法 2、基本操作命令 解决方案 安装命令 在ubuntu环境当中进行安装。 sudo apt install postgresql安装完毕之后直接进行测试&#xff0c;如果看到如下内容则安装成功。 sudo systemctl status postgresql使用DBeaver进行连接报错&#xff…...

c语言——用一维数组输出杨辉三角形

一.代码 #include <stdio.h> int Num[100]; int Hang; int Lie; int a; int Flag; int main() {Lie 1;Hang 1;a 0;while (1) {//列1为1if (Lie 1) {Num[1] 1;Lie;}//数据存到数组里面while (Hang > Lie && Hang ! 2) { if (Hang!Lie) {Flag Num[Lie] …...

Codeforces Round 971 (Div. 4) (A~G1)

A、B题太简单&#xff0c;不做解释 C 对于 x y 两个方向&#xff0c;每一个方向至少需要 x / k 向上取整的步数&#xff0c;取最大值。 由于 x 方向先移动&#xff0c;假如 x 方向需要的步数多于 y 方向的步数&#xff0c;那么最后 y 方向的那一步就不需要了&#xff0c;答案…...

为什么构造函数不能为虚函数?为什么析构函数可以为虚函数,如果不设为虚函数可能会存在什么问题?

目录 一、为什么构造函数不能为虚函数&#xff1f; 二、为什么析构函数可以是虚函数&#xff1f;如果不设为虚函数可能会存在什么问题&#xff1f; 构造函数不能为虚函数&#xff0c;因为在构造过程中&#xff0c;虚函数机制尚未生效&#xff0c;对象还未完成构造&#xff0c…...

【数据结构】单链表功能的实现

目录 1.链表的概念及结构 2.单链表功能的实现 2.1打印单链表 2.2创建节点 2.3单链表尾插 2.3单链表头插 2.5单链表尾删 2.6单链表头删 2.7单链表的查找 2.8在指定位置之前插入数据 2.9在指定位置之后插入数据 2.10删除pos节点 2.11删除pos之后的节点 2.12销毁链表…...

最新车型库大全|阿里云实现调用API接口

整体请求流程&#xff1a; 介绍&#xff1a; 本次解析通过阿里云云市场的云服务来实现查询车型库大全查询&#xff0c;首先需要选择一家可以提供查询的商品。 [探数API]车型库查询_API专区_云市场-阿里云 步骤1: 选择商品 如图点击免费试用&#xff0c;即可免费申请该接口数…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件

在选煤厂、化工厂、钢铁厂等过程生产型企业&#xff0c;其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进&#xff0c;需提前预防假检、错检、漏检&#xff0c;推动智慧生产运维系统数据的流动和现场赋能应用。同时&#xff0c;…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

Nuxt.js 中的路由配置详解

Nuxt.js 通过其内置的路由系统简化了应用的路由配置&#xff0c;使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

【单片机期末】单片机系统设计

主要内容&#xff1a;系统状态机&#xff0c;系统时基&#xff0c;系统需求分析&#xff0c;系统构建&#xff0c;系统状态流图 一、题目要求 二、绘制系统状态流图 题目&#xff1a;根据上述描述绘制系统状态流图&#xff0c;注明状态转移条件及方向。 三、利用定时器产生时…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

push [特殊字符] present

push &#x1f19a; present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中&#xff0c;push 和 present 是两种不同的视图控制器切换方式&#xff0c;它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...