【gridsample】地平线如何支持gridsample算子
文章目录
- 1. grid_sample算子功能解析
- 1.1 理论介绍
- 1.2 代码分析
- 1.2.1 x,y取值范围[-1,1]
- 1.2.2 x,y取值范围超出[-1,1]
- 2. 使用grid_sample算子构建一个网络
- 3. 走PTQ进行模型转换与编译
实操以J5 OE1.1.60对应的docker为例
1. grid_sample算子功能解析
该段主要参考:https://blog.csdn.net/jameschen9051/article/details/124714759,不想看理论可直接跳至第2节
1.1 理论介绍
在图像处理领域,grid_sample 是一个常用的操作,通常用于对图像进行仿射变换或透视变换。它可以在给定输入图像和一个变换矩阵的情况下,对输入图像进行采样,生成一个新的输出图像。
pytorch中调用接口:
torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)
- input:输入特征图,可以是四维或者五维张量,本文主要以四维为例进行介绍,表示为 (N,C,Hin,Win) 。
- grid:采样网格,包含输出特征图的shape大小(Hout、Wout),每个 网格值 通过变换对应到输入特征图的采样点位,当对应四维input时,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入input为五维张量,那么最后一维大小必须为3。
为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(x,y,z),对应到输入特征图的二维或三维特征图的坐标维度,x,y取值范围一般为[-1,1],该范围映射到输入特征图的全图,一通操作变换后对应于输出图像上的一个像素点。
- mode:采样模式,可以是 ‘bilinear’(双线性插值)、 ‘nearest’(最近邻插值)、‘bicubic’ 双三次插值。。
- padding_mode:填充模式,用于处理采样时超出输入图像边界的情况,可以是 ‘zeros’ 、 ‘border’、 ‘reflection’。
- align_corners:一个布尔值,用于指定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。
总的说来,grid_sample 算子会根据给定的网格(grid)在输入图像上进行采样,然后根据选择的插值方法在采样点周围的像素上进行插值,最终生成输出图像。
画一个在BEV方案中grid_sample原理图来帮助理解grid_sample怎么回事:

1.2 代码分析
对照代码进行下一步解读。
假设输入shape为(N,C,H_in,W_in),grid的shape设定为(N,H_out,W_out,2),使用双线性差值,填充模式为zeros,align_corners需要设置为True。
首先根据input和grid设定,输出特征图tensor的shape为(N,C,H_out,W_out),输出特征图上每一个cell上的值与grid最后一维(x,y)息息相关,那么如何计算输出tensor上每一个点的值?
首先,通过(x,y)找到输入特征图上的采样位置:由于x,y取值范围为[-1,1],为了便于计算,先将x,y取值范围调整为[0,1],方法是(x+1)/2,(y+1)/2。因此,将x,y映射为输入特征图的具体坐标位置:(w-1)(x+1)/2、(h-1)(y+1)/2。
将x,y映射到输入特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。
注意:x,y映射后的坐标可能是输入特征图上任意位置。
基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,之后取grid中的第一个位置中的x,y,根据x,y从input中通过双线性插值计算出output第一个位置的值。类比使用pytorch中的grid_sample算子生成output。
其它的看代码注释即可。
1.2.1 x,y取值范围[-1,1]
import torch
import numpy as npdef grid_sample(input, grid):N, C, H_in, W_in = input.shapeN, H_out, W_out, _ = grid.shapeoutput = np.random.random((N,C,H_out,W_out))for i in range(N):for j in range(C):for k in range(H_out):for l in range(W_out):param = [0.0, 0.0]# 通过(w-1)*(x+1)/2、(h-1)*(y+1)/2将x,y映射为输入特征图的具体坐标位置。param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2x0 = int(param[0]) # int取整规则:将小数部分截断去掉。x1 = x0 + 1y0 = int(param[1])y1 = y0 + 1param[0] -= x0 # 此时param里装的是小数部分param[1] -= y0# 双线性插值left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1])left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1]right_top = input[i][j][y0][x1] * param[0] * (1 - param[1])right_bottom = input[i][j][y1][x1] * param[0] * param[1]result = left_bottom + left_top + right_bottom + right_topoutput[i][j][k][l] = resultreturn outputif __name__=='__main__':N, C, H_in, W_in, H_out, W_out = 1, 1, 4, 4, 2, 2input = np.random.random((N,C,H_in,W_in))# np.random.random()范围是[0,1),想要[a,b)的数据,需要(b-a)*np.random.random() + agrid = -1 + 2*np.random.random((N,H_out,W_out,2)) # 最后一维2,生成了坐标out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)# 注意:这儿align_corners=Trueoutput = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')
输出:

从输出结果上看,与pytorch基本一致。
注意:这里没有对超出[-1,1]范围的x,y值做处理,只能处理四维input,五维input的实现思路与这里基本一致:再加一层循环,内插算法改为3维。。
1.2.2 x,y取值范围超出[-1,1]
考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置’zero’,也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:
import torch
import numpy as npdef grid_sample(input, grid):N, C, H_in, W_in = input.shapeN, H_out, W_out, _ = grid.shapeoutput = np.random.random((N,C,H_out,W_out))for i in range(N):for j in range(C):for k in range(H_out):for l in range(W_out):param = [0.0, 0.0]# 通过(w-1)*(x+1)/2、(h-1)*(y+1)/2将x,y映射为输入特征图的具体坐标位置。param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2x1 = int(param[0] + 1) # int取整规则:将小数部分截断去掉。x0 = x1 - 1 y1 = int(param[1] + 1)y0 = y1 - 1param[0] = abs(param[0] - x0) # 此时param里装的是离x0,y0的距离param[1] = abs(param[1] - y0)# 填充left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0if 0 <= x0 < W_in and 0 <= y0 < H_in:left_top_value = input[i][j][y0][x0]if 0 <= x1 < W_in and 0 <= y0 < H_in:right_top_value = input[i][j][y0][x1]if 0 <= x0 < W_in and 0 <= y1 < H_in:left_bottom_value = input[i][j][y1][x0]if 0 <= x1 < W_in and 0 <= y1 < H_in:right_bottom_value = input[i][j][y1][x1]# 双线性插值left_top = left_top_value * (1 - param[0]) * (1 - param[1])left_bottom = left_bottom_value * (1 - param[0]) * param[1]right_top = right_top_value * param[0] * (1 - param[1])right_bottom = right_bottom_value * param[0] * param[1]result = left_bottom + left_top + right_bottom + right_topoutput[i][j][k][l] = resultreturn outputif __name__=='__main__':N, C, H_in, W_in, H_out, W_out = 1, 1, 4, 4, 2, 2input = np.random.random((N,C,H_in,W_in))# np.random.random()范围是[0,1),想要[a,b)的数据,需要(b-a)*np.random.random() + agrid = -1 + 2*np.random.random((N,H_out,W_out,2)) # 最后一维2,生成了坐标grid[0][0][0] = [-1.2, 1.3] # 超出[-1,1]的范围out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)# 注意:这儿align_corners=Trueoutput = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')
输出:

2. 使用grid_sample算子构建一个网络
先看一下地平线提供的算子支持与约束列表:

据此,构建一个简单的网络,test.py代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom horizon_nn.torch import export_onnxclass GridSampleModel(nn.Module):def __init__(self):super(GridSampleModel, self).__init__()self.unitconv = nn.Conv2d(24, 24, (1, 1), groups=3)nn.init.constant_(self.unitconv.weight, 1)nn.init.constant_(self.unitconv.bias, 0)def forward(self, x1, x2):x1 = self.unitconv(x1)x = F.grid_sample(x1,grid=x2,mode='bilinear',padding_mode='zeros',align_corners=True)x = self.unitconv(x)return xif __name__ == "__main__":model = GridSampleModel()model.eval()input_names = ['x1', 'x2']output_names = ['output']x1 = torch.randn((1, 24, 600, 800))x2 = torch.randn((1, 48, 64, 2))export_onnx(model, (x1, x2), 'gridsample.onnx', verbose=True, opset_version=11,input_names=input_names, output_names=output_names)print('convert to gridsampe onnx finish!!!')
运行test.py,生成onnx模型,可视化结构如下图:

3. 走PTQ进行模型转换与编译
对应config.yaml文件:
# 模型转化相关的参数
model_parameters:onnx_model: './gridsample.onnx'march: "bayes"working_dir: 'model_output'output_model_file_prefix: 'gridsample'# 模型输入相关参数, 若输入多个节点, 则应使用';'进行分隔, 使用默认缺省设置则写None
input_parameters:input_name: "x1;x2"input_type_rt: 'featuremap;featuremap'input_layout_rt: 'NCHW;NCHW'input_type_train: 'featuremap;featuremap'input_layout_train: 'NCHW;NCHW'input_shape: '1x24x600x800;1x48x64x2'norm_type: 'no_preprocess;no_preprocess'# 模型量化相关参数
calibration_parameters:calibration_type: 'skip'# 编译器相关参数
compiler_parameters:compile_mode: 'latency'optimize_level: 'O3'
使用的是OE1.1.60对应的docker
hb_mapper makertbin --config config.yaml --model-type onnx

全一段,且都在BPU上
相关文章:
【gridsample】地平线如何支持gridsample算子
文章目录 1. grid_sample算子功能解析1.1 理论介绍1.2 代码分析1.2.1 x,y取值范围[-1,1]1.2.2 x,y取值范围超出[-1,1] 2. 使用grid_sample算子构建一个网络3. 走PTQ进行模型转换与编译 实操以J5 OE1.1.60对应的docker为例 1. grid_sample算子功能解析 该段主要参考:…...
JPA实现存储实体类型信息
本文已收录于专栏 《Java》 目录 背景介绍概念说明DiscriminatorValue 注解:DiscriminatorColumn 注解:Inheritance(strategy InheritanceType.SINGLE_TABLE) 注解: 实现方式父类子类执行效果 总结提升 背景介绍 在我们项目开发的过程中经常…...
阿里云快速部署开发环境 (Apache + Mysql8.0+Redis7.0.x)
本文章的内容截取于云服务器管理控制台提供的安装步骤,再整合前人思路而成,文章末端会提供原文连接 ApacheMysql 8.0部署MySQL数据库(Linux)步骤一:安装MySQL步骤二:配置MySQL步骤三:远程访问My…...
语音秘书:让录音转文字识别软件成为你的智能工作助手
每当在需要写文章的深夜,我的思绪经常跟不上我的笔,即便是说出来用录音机录下,再书写出来,也需要耗费大量时间。这个困扰了我很久的问题终于有了解决的办法,那就是录音转文字软件。它像个语言魔术师,将我所…...
【腾讯云 Cloud Studio 实战训练营】用于编写、运行和调试代码的云 IDE泰裤辣
文章目录 一、引言✉️二、什么是腾讯云 Cloud Studio🔍三、Cloud Studio优点和功能🌈四、Cloud Studio初体验(注册篇)🎆五、Cloud Studio实战演练(实战篇)🔬1. 初始化工作空间2. 安…...
[C#] 简单的俄罗斯方块实现
一个控制台俄罗斯方块游戏的简单实现. 已在 github.com/SlimeNull/Tetris 开源. 思路 很简单, 一个二维数组存储当前游戏的方块地图, 用 bool 即可, true 表示当前块被填充, false 表示没有. 然后, 抽一个 “形状” 类, 形状表示当前玩家正在操作的一个形状, 例如方块, 直线…...
postman官网下载安装登录详细教程
目录 一、介绍 二、官网下载 三、安装 四、注册登录postman账号(不注册也可以) postman注册登录和不注册登录的使用区别 五、关于汉化的说明 一、介绍 简单来说:是一款前后端都用来测试接口的工具。 展开来说:Postman 是一个…...
(贪心) 剑指 Offer 14- I. 剪绳子 ——【Leetcode每日一题】
❓剑指 Offer 14- I. 剪绳子 难度:中等 给你一根长度为 n 的绳子,请把绳子剪成整数长度的 m 段(m、n都是整数,n > 1 并且 m > 1),每段绳子的长度记为 k[0],k[1]...k[m-1] 。请问 k[0]*k[1]*...*k[m…...
如何将Linux上的cpolar内网穿透设置成 - > 开机自启动
如何将Linux上的cpolar内网穿透设置成 - > 开机自启动 文章目录 如何将Linux上的cpolar内网穿透设置成 - > 开机自启动前言一、进入命令行模式二、输入token码三、输入内网穿透命令 前言 我们将cpolar安装到了Ubuntu系统上,并通过web-UI界面对cpolar的功能有…...
50.两数之和(力扣)
目录 问题描述 核心代码解决 代码思想 时间复杂度和空间复杂度 问题描述 给定一个整数数组 和一个整数目标值 ,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。numstarget 你可以假设每种输入只会对应一个答案。但是&am…...
k8s基础
k8s基础 文章目录 k8s基础一、k8s组件二、k8s组件作用1.master节点2.worker node节点 三、K8S创建Pod的工作流程?四、K8S资源对象1.Pod2.Pod控制器3.service && ingress 五、K8S资源配置信息六、K8s部署1.K8S二进制部署2.K8S kubeadm搭建 七、K8s网络八、K8…...
【自然语言处理】大模型高效微调:PEFT 使用案例
文章目录 一、PEFT介绍二、PEFT 使用2.1 PeftConfig2.2 PeftModel2.3 保存和加载模型 三、PEFT支持任务3.1 Models support matrix3.1.1 Causal Language Modeling3.1.2 Conditional Generation3.1.3 Sequence Classification3.1.4 Token Classification3.1.5 Text-to-Image Ge…...
FFmpeg将编码后数据保存成mp4
以下测试代码实现的功能是:持续从内存块中获取原始数据,然后依次进行解码、编码、最后保存成mp4视频文件。 可保存成单个视频文件,也可指定每个视频文件的总帧数,保存多个视频文件。 为了便于查看和修改,这里将可独立的…...
设置VsCode 将打开的多个文件分行(栏)排列,实现全部显示
目录 1. 前言 2. 设置VsCode 多文件分行(栏)排列显示 1. 前言 主流编程IDE几乎都有排列切换选择所要查看的文件功能,如下为Visual Studio 2022的该功能界面: 图 1 图 2 当在Visual Studio 2022打开很多文件时,可以按照图1、图2所示找到自…...
Vue.js2+Cesium1.103.0 六、标绘与测量
Vue.js2Cesium1.103.0 六、标绘与测量 点,线,面的绘制,可实时编辑图形,点击折线或多边形边的中心点,可进行添加线段移动顶点位置等操作,并同时计算出点的经纬度,折线的距离和多边形的面积。 De…...
【redis 延时队列】使用go-redis的list做异步,生产消费者模式
分享一个用到的,使用go-redis的list做异步,生产消费者模式,接着再用 go 协程去检测队列里是否有东西去消费 如果队列为空,就会一直pop,空轮询导致 cpu 资源浪费和redis qps无效升高,所以可以通过 time.Sec…...
激光焊接塑料多点测试全画面穿透率测试仪
工程塑料由于其具有高比强度、电绝缘性、耐磨性、耐腐蚀性等优点,已广泛应用于各个重要领域。另一方面,工程塑料还具有良好的焊接性,是制成复合材料的基体材料的优良选择,因此目前已成为国内外新型复合材料的研究热点。 工程塑料…...
用 Uno 当烧录器给 atmega328 烧录 bootloader
用 Uno 当烧录器给 atmega328 烧录 bootloader date: 2023-8-10 https://backmountaindevil.github.io/#/hackaday/arduino/isp 引脚接线 把两个板子的 11(MOSI)、12(MISO)、13(SCK)、5V、GND 两两相连,还要把 Uno(烧录器)的 10 接到atmeg…...
spring boot策略模式实用: 告警模块为例
spring boot策略模式实用: 告警模块 0 涉及知识点 策略模式, 模板方法, 代理, 多态, 反射 1 需求概括 场景: 每隔一段时间, 会获取设备运行数据, 如通过温湿度计获取到当前环境温湿度;需求: 对获取回来的进行分析, 超过配置的阈值需要产生对应的告警 2 方案设计 告警的类…...
Camunda 7.x 系列【10】使用 Rest API 运行流程实例
有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 2.7.9 本系列Camunda 版本 7.19.0 源码地址:https://gitee.com/pearl-organization/camunda-study-demo 文章目录 1. 前言2. 官方接口文档3. 本地接口文档3.1 Postman3.2 Camunda Platform Run Swagger3.3 S…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...
python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...
sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
R 语言科研绘图第 55 期 --- 网络图-聚类
在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...
关于uniapp展示PDF的解决方案
在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项: 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库: npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...
Vue 模板语句的数据来源
🧩 Vue 模板语句的数据来源:全方位解析 Vue 模板(<template> 部分)中的表达式、指令绑定(如 v-bind, v-on)和插值({{ }})都在一个特定的作用域内求值。这个作用域由当前 组件…...
MySQL的pymysql操作
本章是MySQL的最后一章,MySQL到此完结,下一站Hadoop!!! 这章很简单,完整代码在最后,详细讲解之前python课程里面也有,感兴趣的可以往前找一下 一、查询操作 我们需要打开pycharm …...
