当前位置: 首页 > news >正文

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):"""Input:x: all points, [B, C, N]k: k nearest points of each pointReturn:idx: grouped points index, [B, N, k]"""inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):"""Input:nsample: max sample number in local regionxyz: all points, [B, N, C]new_xyz: query points, [B, S, C]Return:group_idx: grouped points index, [B, S, nsample]"""sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch# Make sure your CUDA is available.
assert torch.cuda.is_available()from knn_cuda import KNN
"""
if transpose_mode is True, ref   is Tensor [bs x nr x dim]query is Tensor [bs x nq x dim]return dist is Tensor [bs x nq x k]indx is Tensor [bs x nq x k]
elseref   is Tensor [bs x dim x nr]query is Tensor [bs x dim x nq]return dist is Tensor [bs x k x nq]indx is Tensor [bs x k x nq]
"""knn = KNN(k=10, transpose_mode=True)ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNNdef knn(x, k):inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idxdef square_distance(src, dst):B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx# Custom knn implementation
def test_knn(query, k, times):query = query.permute(0,2,1)start_time = time.time()  # Start timerfor i in range(times):indx = knn(query, k = k)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Custom knn_point implementation
def test_knn_point(ref, query, k, times):start_time = time.time()  # Start timerfor i in range(times):indx = knn_point(k, ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):knn = KNN(k=k, transpose_mode=True)start_time = time.time()  # Start timerfor i in range(times):dist, indx = knn(ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Main testing function
def test_knn_methods(ref, query, k, times):print("Test times: %d" % times)# Test custom knntime_knn = test_knn(query, k, times)print(f"knn      : {time_knn:.6f} seconds")# Test custom knn_pointtime_point = test_knn_point(ref, query, k, times)print(f"knn_point: {time_point:.6f} seconds")# Test knn_cudatime_cuda = test_knn_cuda(ref, query, k, times)print(f"knn_cuda : {time_cuda:.6f} seconds")if __name__ == '__main__':# Sample inputB, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinatesk = 24                            # Number of nearest neighborsref = torch.randn(B, N, C).cuda() # Reference points# Test above methodstimes_list = [1,2,3,10,50,100]for times in times_list:test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

相关文章:

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言 K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程&#xff…...

【OJ】常用技巧

1. 模版 #include<bits/stdc.h> using namespace std;int main(){ios::sync_with_stdio(false);cin.tie(0);// write herereturn 0; }2. 填充数组 memset是一个字节一个字节填充&#xff0c;如果是使int类型填充非0或者-1就会报错&#xff0c;如 int a[100]; memset(a…...

Redis:Redis性能变慢的原因

一、淘汰策略性能问题 当使用Redis当作缓存使用时&#xff0c;通常会给这个实例设置内存上限maxmemory&#xff0c;然后设置一个数据淘汰策略&#xff1b;如果Redis实例设置了内存上限maxmemory&#xff0c;那么也有可能导致Redis变慢。 原因在于&#xff0c;当Redis内存达到…...

Linux多线程——利用C++模板对pthread线程库封装

文章目录 线程封装主要框架线程启动线程等待其他信息 测试函数 线程封装 我们之前介绍过pthread的线程库&#xff0c;这个线程库主要是基于C语言的void*指针来进行传参和返回 我们使用C的模板对其封装可以让他的使用更加方便&#xff0c;并且经过测试可以让我们更加直观的了解…...

SpringBoot教程(十五) | SpringBoot集成RabbitMq(消息丢失、消息重复、消息顺序、消息顺序)

SpringBoot教程&#xff08;十五&#xff09; | SpringBoot集成RabbitMq&#xff08;消息丢失、消息重复、消息顺序、消息顺序&#xff09; RabbitMQ常见问题解决方案问题一&#xff1a;消息丢失的解决方案&#xff08;1&#xff09;生成者丢失消息丢失的情景解决方案1&#xf…...

TensorRT-LLM高级用法

--multi_block_mode decoding phase, 推理1个新token&#xff0c; 平时&#xff1a;按照batch样本&#xff0c;按照head&#xff0c;将计算平均分给所有SM&#xff1b; batch_size*num_heads和SM数目相比较小时&#xff1a;有些SM会空闲&#xff1b;加了--multi_block_mode&…...

文心一言功能新升级:读文档、懂翻译、能识图

9月4日&#xff0c;百度文心一言官网显示&#xff0c;在向全社会开放一周年之际&#xff0c;文心一言进行了功能最新全面升级&#xff0c;同时在周年期间为新老会员增加1个月专业版免费使用体验。 据了解&#xff0c;针对网页版用户需求&#xff0c;文心一言实现了创作内容更加…...

C++机试——走方格的方案

题目 请计算n*m的棋盘格子&#xff08;n为横向的格子数&#xff0c;m为竖向的格子数&#xff09;从棋盘左上角出发沿着边缘线从左上角走到右下角&#xff0c;总共有多少种走法&#xff0c;要求不能走回头路&#xff0c;即&#xff1a;只能往右和往下走&#xff0c;不能往左和往…...

Bootstrap 字体图标无法显示问题,<i>标签字体图标无法显示问题

bootstrap fileInput 以及 Bootstrap 字体图标无法显示问题。 今天在用 bootstrap fileInput 插件的时候发现图标无法显示&#xff0c;如下&#xff1a; 查看DOM&#xff0c;发现那些图标是<i>标签做的&#xff1a; 网上的方案 方案1 网上很多人说是我们打乱了boots…...

docker registry 仓库加密

docker registry 仓库加密 1、背景 ​ 公司一直用的镜像仓库是docker registry&#xff0c;但是有个安全问题&#xff0c;就是仓库从web ui的浏览到镜像的拉取都是可以直接使用的&#xff0c;还是放到了公网上&#xff0c;只需要知道你的域名那就是畅通无阻了&#xff0c;可以…...

利用高德+ArcGIS优雅获取任何感兴趣的矢量边界

荷花十里&#xff0c;清风鉴水&#xff0c;明月天衣。 四时之景不同&#xff0c;乐亦无穷尽也。今天呢&#xff0c;梧桐君给大家讲解一下&#xff0c;如何利用高德地图&#xff0c;随机所欲的获取shp边界数据。 文章主要分成以下几个步骤&#xff1a; 首先搜索你想获取的矢量…...

炮弹【USACO】

题目背景 时/空限制&#xff1a;1s / 64MB 题目描述 贝茜已经精通了变成炮弹并沿着长度为 N 的数轴弹跳的艺术&#xff0c;数轴上的位置从左到右编号为 1,2,…,N 。 她从某个整数位置 S 开始&#xff0c;以 1 的起始能量向右弹跳。 如果贝茜的能量为 k &#xff0c;则她将…...

python如何读取excel文件内的数据

目录 前言一、安装openpyxl二、读取Excel数据总结前言 在Python中读取Excel数据,最常用的库之一是openpyxl(用于.xlsx格式)和xlrd(尽管xlrd从版本2.0开始不再支持.xlsx,仅支持旧的.xls格式)。然而,对于大多数现代应用来说,openpyxl是一个更好的选择,因为它支持.xlsx格…...

Java项目: 基于SpringBoot+mybatis+maven+mysql教师工作量管理系统(含源码+数据库+毕业论文)

一、项目简介 本项目是一套基于SpringBootmybatismavenmysql教师工作量管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观…...

项目开发--数据库--postgresql数据库操作

背景 1、安装postgresql的基础方法 2、基本操作命令 解决方案 安装命令 在ubuntu环境当中进行安装。 sudo apt install postgresql安装完毕之后直接进行测试&#xff0c;如果看到如下内容则安装成功。 sudo systemctl status postgresql使用DBeaver进行连接报错&#xff…...

c语言——用一维数组输出杨辉三角形

一.代码 #include <stdio.h> int Num[100]; int Hang; int Lie; int a; int Flag; int main() {Lie 1;Hang 1;a 0;while (1) {//列1为1if (Lie 1) {Num[1] 1;Lie;}//数据存到数组里面while (Hang > Lie && Hang ! 2) { if (Hang!Lie) {Flag Num[Lie] …...

Codeforces Round 971 (Div. 4) (A~G1)

A、B题太简单&#xff0c;不做解释 C 对于 x y 两个方向&#xff0c;每一个方向至少需要 x / k 向上取整的步数&#xff0c;取最大值。 由于 x 方向先移动&#xff0c;假如 x 方向需要的步数多于 y 方向的步数&#xff0c;那么最后 y 方向的那一步就不需要了&#xff0c;答案…...

为什么构造函数不能为虚函数?为什么析构函数可以为虚函数,如果不设为虚函数可能会存在什么问题?

目录 一、为什么构造函数不能为虚函数&#xff1f; 二、为什么析构函数可以是虚函数&#xff1f;如果不设为虚函数可能会存在什么问题&#xff1f; 构造函数不能为虚函数&#xff0c;因为在构造过程中&#xff0c;虚函数机制尚未生效&#xff0c;对象还未完成构造&#xff0c…...

【数据结构】单链表功能的实现

目录 1.链表的概念及结构 2.单链表功能的实现 2.1打印单链表 2.2创建节点 2.3单链表尾插 2.3单链表头插 2.5单链表尾删 2.6单链表头删 2.7单链表的查找 2.8在指定位置之前插入数据 2.9在指定位置之后插入数据 2.10删除pos节点 2.11删除pos之后的节点 2.12销毁链表…...

最新车型库大全|阿里云实现调用API接口

整体请求流程&#xff1a; 介绍&#xff1a; 本次解析通过阿里云云市场的云服务来实现查询车型库大全查询&#xff0c;首先需要选择一家可以提供查询的商品。 [探数API]车型库查询_API专区_云市场-阿里云 步骤1: 选择商品 如图点击免费试用&#xff0c;即可免费申请该接口数…...

70. 爬楼梯

70. 爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1.1 阶 1 阶 2.2 阶 示例…...

pytorch正向传播没问题,loss.backward()使定义的神经网络中权重参数变为nan

记录一个非常坑爹的bug:loss回传导致神经网络中一个linear层的权重参数变为nan 1.首先loss值是正常数值&#xff1b; 2.查了好多网上的解决办法&#xff1a;检查原始输入神经网络数据有没有nan值&#xff0c;初始化权重参数&#xff0c;使用relu激活函数&#xff0c;梯度裁剪&a…...

❤《实战纪录片 1 》原生开发小程序中遇到的问题和解决方案

《实战纪录片 1 》原生开发小程序中遇到的问题和解决方案 文章目录 《实战纪录片 1 》原生开发小程序中遇到的问题和解决方案1、问题一&#xff1a;原生开发中 request请求中返回 的数据无法 使用this传递给 data{}中怎么办&#xff1f;2、刚登录后如何将token信息保存&#xf…...

2024.9.6 作业

手写unique_ptr指针指针 代码&#xff1a; #include <iostream> #include <stdexcept>template <typename T> class unique_ptr { public:// 构造函数explicit unique_ptr(T* ptr nullptr) : m_ptr(ptr) {}// 析构函数~unique_ptr() {delete m_ptr;}// 禁…...

2024年架构设计师论文-“模型驱动架构设计方法及其应用”

论模型驱动架构设计方法及其应用 模型驱动架构设计是一种用于应用系统开发的软件设计方法&#xff0c;以模型构造、模型转换和精化为核心&#xff0c;提供了一套软件设计的指导规范。在模型驱动架构环境下&#xff0c;通过创建出机器可读和高度抽象的模型实现对不同问题域的描述…...

Tapd敏捷开发平台的使用心得

Tapd敏捷开发平台的使用心得 一、Tapd 简介 TAPD(Tencent Agile Product Development),腾讯敏捷产品研发平台行业领先的敏捷协作方案,贯穿敏捷产品研发生命周期的一站式服务,了解敏捷如下图 二、几个核心模块概念 需求迭代缺陷故事墙前期项目需求的管理,可以按类别建…...

远程桌面 Rust Desk 自建服务器

因为某些原因(诈骗)&#xff0c;Rush Desk 服务已暂停国内访问&#xff0c;今天我们介绍如何利用自己的服务器搭建 Rust Desk 远程桌面&#xff0c;低延迟电脑远程手机&#xff0c;手机远程电脑等 一、准备工作 准备一台服务器&#xff0c;我用的腾讯云服务器&#xff0c;一年…...

开源网安引领AIGC+开发安全,智能防护铸就软件安全新高度

近日&#xff0c;国内网络安全领域知名媒体数说安全正式发布了《2024年中国网络安全市场100强》和《2024年中国网络安全十大创新方向》。开源网安凭借在市场表现力、资源支持力以及产品在AI方向的创新力上的优秀表现成功入选百强榜单&#xff0c;并被评为“AIGC开发安全”典型厂…...

树和二叉树

树 节点&#xff08;Node&#xff1a;&#xff09; 树由一系列的节点组成&#xff0c;每个节点可以包含数据和指向其他节点的链接。 节点通常包含一个数据元素和若干指向其他节点的指针 根节点&#xff08;Root&#xff09;&#xff1a; 树的顶部节点称为根节点&#xff0c…...

一篇带你速通差分算法(C/C++)

个人主页&#xff1a;摆烂小白敲代码 创作领域&#xff1a;算法、C/C 持续更新算法领域的文章&#xff0c;让博主在您的算法之路上祝您一臂之力 欢迎各位大佬莅临我的博客&#xff0c;您的关注、点赞、收藏、评论是我持续创作最大的动力 差分算法是一种在计算机科学中常用的算法…...