当前位置: 首页 > news >正文

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):"""Input:x: all points, [B, C, N]k: k nearest points of each pointReturn:idx: grouped points index, [B, N, k]"""inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):"""Input:nsample: max sample number in local regionxyz: all points, [B, N, C]new_xyz: query points, [B, S, C]Return:group_idx: grouped points index, [B, S, nsample]"""sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch# Make sure your CUDA is available.
assert torch.cuda.is_available()from knn_cuda import KNN
"""
if transpose_mode is True, ref   is Tensor [bs x nr x dim]query is Tensor [bs x nq x dim]return dist is Tensor [bs x nq x k]indx is Tensor [bs x nq x k]
elseref   is Tensor [bs x dim x nr]query is Tensor [bs x dim x nq]return dist is Tensor [bs x k x nq]indx is Tensor [bs x k x nq]
"""knn = KNN(k=10, transpose_mode=True)ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNNdef knn(x, k):inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idxdef square_distance(src, dst):B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx# Custom knn implementation
def test_knn(query, k, times):query = query.permute(0,2,1)start_time = time.time()  # Start timerfor i in range(times):indx = knn(query, k = k)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Custom knn_point implementation
def test_knn_point(ref, query, k, times):start_time = time.time()  # Start timerfor i in range(times):indx = knn_point(k, ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):knn = KNN(k=k, transpose_mode=True)start_time = time.time()  # Start timerfor i in range(times):dist, indx = knn(ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Main testing function
def test_knn_methods(ref, query, k, times):print("Test times: %d" % times)# Test custom knntime_knn = test_knn(query, k, times)print(f"knn      : {time_knn:.6f} seconds")# Test custom knn_pointtime_point = test_knn_point(ref, query, k, times)print(f"knn_point: {time_point:.6f} seconds")# Test knn_cudatime_cuda = test_knn_cuda(ref, query, k, times)print(f"knn_cuda : {time_cuda:.6f} seconds")if __name__ == '__main__':# Sample inputB, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinatesk = 24                            # Number of nearest neighborsref = torch.randn(B, N, C).cuda() # Reference points# Test above methodstimes_list = [1,2,3,10,50,100]for times in times_list:test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

相关文章:

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言 K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程&#xff…...

【OJ】常用技巧

1. 模版 #include<bits/stdc.h> using namespace std;int main(){ios::sync_with_stdio(false);cin.tie(0);// write herereturn 0; }2. 填充数组 memset是一个字节一个字节填充&#xff0c;如果是使int类型填充非0或者-1就会报错&#xff0c;如 int a[100]; memset(a…...

Redis:Redis性能变慢的原因

一、淘汰策略性能问题 当使用Redis当作缓存使用时&#xff0c;通常会给这个实例设置内存上限maxmemory&#xff0c;然后设置一个数据淘汰策略&#xff1b;如果Redis实例设置了内存上限maxmemory&#xff0c;那么也有可能导致Redis变慢。 原因在于&#xff0c;当Redis内存达到…...

Linux多线程——利用C++模板对pthread线程库封装

文章目录 线程封装主要框架线程启动线程等待其他信息 测试函数 线程封装 我们之前介绍过pthread的线程库&#xff0c;这个线程库主要是基于C语言的void*指针来进行传参和返回 我们使用C的模板对其封装可以让他的使用更加方便&#xff0c;并且经过测试可以让我们更加直观的了解…...

SpringBoot教程(十五) | SpringBoot集成RabbitMq(消息丢失、消息重复、消息顺序、消息顺序)

SpringBoot教程&#xff08;十五&#xff09; | SpringBoot集成RabbitMq&#xff08;消息丢失、消息重复、消息顺序、消息顺序&#xff09; RabbitMQ常见问题解决方案问题一&#xff1a;消息丢失的解决方案&#xff08;1&#xff09;生成者丢失消息丢失的情景解决方案1&#xf…...

TensorRT-LLM高级用法

--multi_block_mode decoding phase, 推理1个新token&#xff0c; 平时&#xff1a;按照batch样本&#xff0c;按照head&#xff0c;将计算平均分给所有SM&#xff1b; batch_size*num_heads和SM数目相比较小时&#xff1a;有些SM会空闲&#xff1b;加了--multi_block_mode&…...

文心一言功能新升级:读文档、懂翻译、能识图

9月4日&#xff0c;百度文心一言官网显示&#xff0c;在向全社会开放一周年之际&#xff0c;文心一言进行了功能最新全面升级&#xff0c;同时在周年期间为新老会员增加1个月专业版免费使用体验。 据了解&#xff0c;针对网页版用户需求&#xff0c;文心一言实现了创作内容更加…...

C++机试——走方格的方案

题目 请计算n*m的棋盘格子&#xff08;n为横向的格子数&#xff0c;m为竖向的格子数&#xff09;从棋盘左上角出发沿着边缘线从左上角走到右下角&#xff0c;总共有多少种走法&#xff0c;要求不能走回头路&#xff0c;即&#xff1a;只能往右和往下走&#xff0c;不能往左和往…...

Bootstrap 字体图标无法显示问题,<i>标签字体图标无法显示问题

bootstrap fileInput 以及 Bootstrap 字体图标无法显示问题。 今天在用 bootstrap fileInput 插件的时候发现图标无法显示&#xff0c;如下&#xff1a; 查看DOM&#xff0c;发现那些图标是<i>标签做的&#xff1a; 网上的方案 方案1 网上很多人说是我们打乱了boots…...

docker registry 仓库加密

docker registry 仓库加密 1、背景 ​ 公司一直用的镜像仓库是docker registry&#xff0c;但是有个安全问题&#xff0c;就是仓库从web ui的浏览到镜像的拉取都是可以直接使用的&#xff0c;还是放到了公网上&#xff0c;只需要知道你的域名那就是畅通无阻了&#xff0c;可以…...

利用高德+ArcGIS优雅获取任何感兴趣的矢量边界

荷花十里&#xff0c;清风鉴水&#xff0c;明月天衣。 四时之景不同&#xff0c;乐亦无穷尽也。今天呢&#xff0c;梧桐君给大家讲解一下&#xff0c;如何利用高德地图&#xff0c;随机所欲的获取shp边界数据。 文章主要分成以下几个步骤&#xff1a; 首先搜索你想获取的矢量…...

炮弹【USACO】

题目背景 时/空限制&#xff1a;1s / 64MB 题目描述 贝茜已经精通了变成炮弹并沿着长度为 N 的数轴弹跳的艺术&#xff0c;数轴上的位置从左到右编号为 1,2,…,N 。 她从某个整数位置 S 开始&#xff0c;以 1 的起始能量向右弹跳。 如果贝茜的能量为 k &#xff0c;则她将…...

python如何读取excel文件内的数据

目录 前言一、安装openpyxl二、读取Excel数据总结前言 在Python中读取Excel数据,最常用的库之一是openpyxl(用于.xlsx格式)和xlrd(尽管xlrd从版本2.0开始不再支持.xlsx,仅支持旧的.xls格式)。然而,对于大多数现代应用来说,openpyxl是一个更好的选择,因为它支持.xlsx格…...

Java项目: 基于SpringBoot+mybatis+maven+mysql教师工作量管理系统(含源码+数据库+毕业论文)

一、项目简介 本项目是一套基于SpringBootmybatismavenmysql教师工作量管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观…...

项目开发--数据库--postgresql数据库操作

背景 1、安装postgresql的基础方法 2、基本操作命令 解决方案 安装命令 在ubuntu环境当中进行安装。 sudo apt install postgresql安装完毕之后直接进行测试&#xff0c;如果看到如下内容则安装成功。 sudo systemctl status postgresql使用DBeaver进行连接报错&#xff…...

c语言——用一维数组输出杨辉三角形

一.代码 #include <stdio.h> int Num[100]; int Hang; int Lie; int a; int Flag; int main() {Lie 1;Hang 1;a 0;while (1) {//列1为1if (Lie 1) {Num[1] 1;Lie;}//数据存到数组里面while (Hang > Lie && Hang ! 2) { if (Hang!Lie) {Flag Num[Lie] …...

Codeforces Round 971 (Div. 4) (A~G1)

A、B题太简单&#xff0c;不做解释 C 对于 x y 两个方向&#xff0c;每一个方向至少需要 x / k 向上取整的步数&#xff0c;取最大值。 由于 x 方向先移动&#xff0c;假如 x 方向需要的步数多于 y 方向的步数&#xff0c;那么最后 y 方向的那一步就不需要了&#xff0c;答案…...

为什么构造函数不能为虚函数?为什么析构函数可以为虚函数,如果不设为虚函数可能会存在什么问题?

目录 一、为什么构造函数不能为虚函数&#xff1f; 二、为什么析构函数可以是虚函数&#xff1f;如果不设为虚函数可能会存在什么问题&#xff1f; 构造函数不能为虚函数&#xff0c;因为在构造过程中&#xff0c;虚函数机制尚未生效&#xff0c;对象还未完成构造&#xff0c…...

【数据结构】单链表功能的实现

目录 1.链表的概念及结构 2.单链表功能的实现 2.1打印单链表 2.2创建节点 2.3单链表尾插 2.3单链表头插 2.5单链表尾删 2.6单链表头删 2.7单链表的查找 2.8在指定位置之前插入数据 2.9在指定位置之后插入数据 2.10删除pos节点 2.11删除pos之后的节点 2.12销毁链表…...

最新车型库大全|阿里云实现调用API接口

整体请求流程&#xff1a; 介绍&#xff1a; 本次解析通过阿里云云市场的云服务来实现查询车型库大全查询&#xff0c;首先需要选择一家可以提供查询的商品。 [探数API]车型库查询_API专区_云市场-阿里云 步骤1: 选择商品 如图点击免费试用&#xff0c;即可免费申请该接口数…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习&#xff08;Reinforcement Learning, RL&#xff09;是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程&#xff0c;然后使用强化学习的Actor-Critic机制&#xff08;中文译作“知行互动”机制&#xff09;&#xff0c;逐步迭代求解…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

什么是EULA和DPA

文章目录 EULA&#xff08;End User License Agreement&#xff09;DPA&#xff08;Data Protection Agreement&#xff09;一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA&#xff08;End User License Agreement&#xff09; 定义&#xff1a; EULA即…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

用机器学习破解新能源领域的“弃风”难题

音乐发烧友深有体会&#xff0c;玩音乐的本质就是玩电网。火电声音偏暖&#xff0c;水电偏冷&#xff0c;风电偏空旷。至于太阳能发的电&#xff0c;则略显朦胧和单薄。 不知你是否有感觉&#xff0c;近两年家里的音响声音越来越冷&#xff0c;听起来越来越单薄&#xff1f; —…...

省略号和可变参数模板

本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...

深入理解Optional:处理空指针异常

1. 使用Optional处理可能为空的集合 在Java开发中&#xff0c;集合判空是一个常见但容易出错的场景。传统方式虽然可行&#xff0c;但存在一些潜在问题&#xff1a; // 传统判空方式 if (!CollectionUtils.isEmpty(userInfoList)) {for (UserInfo userInfo : userInfoList) {…...

libfmt: 现代C++的格式化工具库介绍与酷炫功能

libfmt: 现代C的格式化工具库介绍与酷炫功能 libfmt 是一个开源的C格式化库&#xff0c;提供了高效、安全的文本格式化功能&#xff0c;是C20中引入的std::format的基础实现。它比传统的printf和iostream更安全、更灵活、性能更好。 基本介绍 主要特点 类型安全&#xff1a…...

链式法则中 复合函数的推导路径 多变量“信息传递路径”

非常好&#xff0c;我们将之前关于偏导数链式法则中不能“约掉”偏导符号的问题&#xff0c;统一使用 二重复合函数&#xff1a; z f ( u ( x , y ) , v ( x , y ) ) \boxed{z f(u(x,y),\ v(x,y))} zf(u(x,y), v(x,y))​ 来全面说明。我们会展示其全微分形式&#xff08;偏导…...