【RKNN】YOLO V5中pytorch2onnx,pytorch和onnx模型输出不一致,精度降低
在yolo v5训练的模型,转onnx,再转rknn后,测试发现:
rknn模型,量化与非量化,相较于pytorch模型,测试精度都有降低onnx模型,相较于pytorch模型,测试精度也有降低,且与rknn模型的精度更接近
于是,根据这种测试情况,rknn模型的上游,就是onnx。onnx这里发现不对劲,肯定是这步就出现了问题。于是就查pytorch转onnx阶段,就存在转化的精度降低了。
本篇就是记录这样一个过程,也请各位针对本文的问题,给一些建议,毕竟目前是发现了问题,同时还存在一些问题在。
一、pytorch转onnx:torch.onnx.export
yolo v5 export.py: def export_onnx()中,添加下面代码,检查转储的onnx模型,与pytorch模型的输出结果是否一致。代码如下:
torch.onnx.export(model.cpu() if dynamic else model, # --dynamic only compatible with cpuim.cpu() if dynamic else im,f,verbose=False,opset_version=opset,export_params=True, # 将训练好的权重保存到模型文件中do_constant_folding=True, # 执行常数折叠进行优化input_names=['images'],output_names=output_names,dynamic_axes={"image": {0: "batch_size"}, # variable length axes"output": {0: "batch_size"},}
)# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx modelimport onnxruntime
import numpy as np
print('onnxruntime run start', f)
sess = onnxruntime.InferenceSession('best.onnx')
print('sess run start')
output = sess.run(['output0'], {'images': im.detach().numpy()})[0]
print('pytorch model inference start')pytorch_result = model(im)[0].detach().numpy()
print(' allclose start')
print('output:', output)
print('pytorch_result:', pytorch_result)
assert np.allclose(output, pytorch_result), 'the output is different between pytorch and onnx !!!'
对其中的输出结果进行了打印,将差异性比较明显的地方进行了标记,如下所示:

也可以直接使用我下面这个版本,在转完onnx后,进行评测,转好的onnx和pt文件之间的差异性。如下:
参考pytorch官方:(OPTIONAL) EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME
import os
import platform
import sys
import warnings
from pathlib import Path
import torchFILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT)) # add ROOT to PATH
if platform.system() != 'Windows':ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relativefrom models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_modeimport numpy as np
def cosine_distance(arr1, arr2):# flatten the arrays to shape (16128, 7)arr1_flat = arr1.reshape(-1, 7)arr2_flat = arr2.reshape(-1, 7)# calculate the cosine distancecosine_distance = np.dot(arr1_flat.T, arr2_flat) / (np.linalg.norm(arr1_flat) * np.linalg.norm(arr2_flat))return cosine_distance.mean()def check_onnx(model, im):import onnxruntimeimport numpy as npprint('onnxruntime run start')sess = onnxruntime.InferenceSession('best.onnx')print('sess run start')output = sess.run(['output0'], {'images': im.detach().numpy()})[0]print('pytorch model inference start')with torch.no_grad():pytorch_result = model(im)[0].detach().numpy()print(' allclose start')print('output:', output, output.shape)print('pytorch_result:', pytorch_result, pytorch_result.shape)cosine_dis = cosine_distance(output, pytorch_result)print('cosine_dis:', cosine_dis)# 判断小数点后几位(4),是否相等,不相等就报错# np.testing.assert_almost_equal(pytorch_result, output, decimal=4)# compare ONNX Runtime and PyTorch resultsnp.testing.assert_allclose(pytorch_result, output, rtol=1e-03, atol=1e-05)# assert np.allclose(output, pytorch_result), 'the output is different between pytorch and onnx !!!'import cv2
from utils.augmentations import letterbox
def preprocess(img, device):img = cv2.resize(img, (512, 512))img = img.transpose((2, 0, 1))[::-1]img = np.ascontiguousarray(img)img = torch.from_numpy(img).to(device)img = img.float()img /= 255if len(img.shape) == 3:img = img[None]return img
def main(weights=ROOT / 'weights/best.pt', # weights pathimgsz=(512, 512), # image (height, width)batch_size=1, # batch sizedevice='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpuinplace=False, # set YOLOv5 Detect() inplace=Truedynamic=False, # ONNX/TF/TensorRT: dynamic axes):# Load PyTorch modeldevice = select_device(device)model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model# Checksimgsz *= 2 if len(imgsz) == 1 else 1 # expand# Inputgs = int(max(model.stride)) # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection# im = cv2.imread(r'F:\tmp\yolov5_multiDR\data\0000005_20200929_M_063Y16640.jpeg')# im = preprocess(im, device)print(im.shape)# Update modelmodel.eval()for k, m in model.named_modules():if isinstance(m, Detect):m.inplace = inplacem.dynamic = dynamicm.export = Truewarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarningcheck_onnx(model, im)if __name__ == "__main__":main()
测试1:图像是一个全0的数组,一致性检查如下:
Mismatched elements: 76 / 112896 (0.0673%)
Max absolute difference: 0.00053406
Max relative difference: 2.2101output: [[[ 3.1054 3.965 8.9553 ... 6.8545e-07 0.36458 0.53113][ 9.0205 2.5498 13.39 ... 6.2585e-07 0.18449 0.70698][ 20.786 2.2233 13.489 ... 2.3842e-06 0.033101 0.95657]...[ 419.42 493.04 106.14 ... 8.4937e-06 0.24135 0.60916][ 485.68 500.22 46.923 ... 1.1176e-05 0.33573 0.48875][ 488.37 503.87 68.881 ... 5.9605e-08 0.00030029 0.99639]]] (1, 16128, 7)
pytorch_result: [[[ 3.1054 3.965 8.9553 ... 7.0523e-07 0.36458 0.53113][ 9.0205 2.5498 13.39 ... 6.0181e-07 0.18449 0.70698][ 20.786 2.2233 13.489 ... 2.4172e-06 0.033101 0.95657]...[ 419.42 493.04 106.14 ... 8.5151e-06 0.24135 0.60916][ 485.68 500.22 46.923 ... 1.1174e-05 0.33573 0.48875][ 488.37 503.87 68.881 ... 9.3094e-08 0.0003003 0.99639]]] (1, 16128, 7)
cosine_dis: 0.04229331
测试2:图像是加载的本地图像,一致性检查如下:
Mismatched elements: 158 / 112896 (0.14%)
Max absolute difference: 0.0016251
Max relative difference: 1.2584output: [[[ 3.0569 2.4338 10.758 ... 2.0862e-07 0.16333 0.78551][ 11.028 2.0251 13.407 ... 3.5763e-07 0.090503 0.88087][ 19.447 1.8957 13.431 ... 6.8545e-07 0.047358 0.95029]...[ 418.66 487.8 80.157 ... 1.4573e-05 0.65453 0.23448][ 472.99 491.78 79.313 ... 1.3232e-05 0.79356 0.15061][ 496.41 488.49 44.447 ... 2.6256e-05 0.89966 0.08772]]] (1, 16128, 7)
pytorch_result: [[[ 3.0569 2.4338 10.758 ... 2.5371e-07 0.16333 0.78551][ 11.028 2.0251 13.407 ... 3.3069e-07 0.090503 0.88087][ 19.447 1.8957 13.431 ... 6.6051e-07 0.047358 0.95029]...[ 418.66 487.8 80.157 ... 1.4618e-05 0.65453 0.23448][ 472.99 491.78 79.313 ... 1.3215e-05 0.79356 0.15061][ 496.41 488.49 44.447 ... 2.6262e-05 0.89966 0.08772]]] (1, 16128, 7)
cosine_dis: 0.04071107
发现,输出结果中,差异的数据点还是挺多的,那么就说明在模型中,有些部分的参数是有差异的,这才导致相同的输入,在最后的输出结果中存在差异。
但是在一定的误差内,结果是一致的。比如我验证了小数点后3位,都是一样的,但是到第4位的时候,就开始出现了差异性。
那么,如何降低,甚至没有这种差异,该怎么办呢?不知道你们有没有这方面的知识储备或经验,欢迎评论区给出指导,感谢。
二、新的pytorch转onnx:torch.onnx.dynamo_export
在参考pytorch官方,关于torch.onnx.export的模型转换,相关文档中:(OPTIONAL) EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME

上述案例,是pytorch官方给出评测pytorch和onnx转出模型,在相同输入的情况下,输出结果一致性对比的评测代码。对比这里:
testing.assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg='', verbose=True)
其中:
- rtol:相对tolerance(容忍度,公差,容许偏差)
- atol:绝对tolerance
- 要求
actual的desired值的差别不超过atol + rtol * abs(desired),否则弹出错误提示
可以看出,这是在误差允许的范围内,进行的评测。只要满足一定的误差要求,还是满足的。并且在本测试案例中,也确实通过了上述设定值的误差要求。
但是,峰回路转,有个提示,如下:

于是,就转到torch.onnx.dynamo_export链接,点击这里直达:EXPORT A PYTORCH MODEL TO ONNX
同样的流程,导出模型,然后进行一致性评价,发现官方竟然没有采用允许误差的评测,而是下面这样:
输出完全一致,这是一个大好消息。至此,开始验证
2.1、验证结果
与此同时,发现yolo v5更新到了v7.0.0的版本,于是就想着把yolo 进行升级,同时将pytorch版本也更新到最新的2.1.0,这样就可以采用torch.onnx.dynamo_export 进行转onnx模型的操作尝试了。
当一起就绪后,采用下面的代码转出onnx模型的时候,却出现了错误提示。
export_output = torch.onnx.dynamo_export(model.cpu() if dynamic else model,im.cpu() if dynamic else im)
export_output.save("my_image_classifier.onnx")
2.2、转出失败

给出失败的的提示:torch.onnx.OnnxExporterError,转出onnx模型失败,产生了一个SARIF的文件。然后介绍了什么是SARIF文件,可以通过VS Code SARIF,也可以 SARIF web查看。最后说吧这个错误,报告给pytorch的GitHub的issue地方。
产生了一个名为:report_dynamo_export.sarif是文件,打开文件,记录的信息如下:
{"runs":[{"tool":{"driver":{"name":"torch.onnx.dynamo_export","contents":["localizedData","nonLocalizedData"],"language":"en-US","rules":[],"version":"2.1.0+cu118"}},"language":"en-US","newlineSequences":["\r\n","\n"],"results":[]}],"version":"2.1.0","schemaUri":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
}
这更像是一个运行环境收集的一个记录文件。在我对全网进行搜索时候,发现了类似的报错提示,但并没有解决办法。不知道是不是因为这个函数还在内测阶段,并没有很好的适配。
如果你也遇到了同样的问题,欢迎给评论,指导问题出在了哪里?如何解决这个问题。感谢
三、总结
原本想着验证最终转rknn的模型,与原始pytorch模型是否一致的问题,最后发现在转onnx阶段,这种差异性就已经存在了。并且发现rknn的测试结果,与onnx模型的测试结果更加的贴近。无论是量化后的rknn,还是未量化的,均存在这个问题。
同时发现,量化后的rknn模型,在config阶段改变量化的方式,确实会提升模型的性能,且几乎接近于未量化的模型版本。
原本以为采用pytorch新的转出onnx的模型函数,可以解决这个问题。但是,发现还是内测版本,不知道问题是出在了哪里,还需要大神帮助,暂时未跑通。
最后,如果你也遇到了同样的问题,欢迎给评论,指导问题出在了哪里?如何解决这个问题。感谢
相关文章:
【RKNN】YOLO V5中pytorch2onnx,pytorch和onnx模型输出不一致,精度降低
在yolo v5训练的模型,转onnx,再转rknn后,测试发现: rknn模型,量化与非量化,相较于pytorch模型,测试精度都有降低onnx模型,相较于pytorch模型,测试精度也有降低ÿ…...
六分科技CEO李阳:精准定位助力汽车智能化普及
10月10日,2023四维图新用户大会在上海成功举办。大会现场,六分科技展示了基于PPP-RTK技术的“星璨”产品和软硬件一体化解决方案。同时在智能驾驶主题论坛上,六分科技CEO李阳受邀发表了以《精准定位助力汽车智能化普及》为主题的演讲。 高精度…...
信号完整性分析基础知识之有损传输线、上升时间衰减和材料特性(六):衰减和dB
线路中的损耗对信号的主要影响是当信号沿线路长度传播时幅度减小。如果将幅度为 V 的正弦波电压信号引入传输线,则其幅度将随着传输线向下移动而下降。图 9-16 显示了如果我们可以冻结时间并观察直线上存在的正弦波,则正弦波在不同位置可能会是什么样子。…...
吃鸡达人必备:分享顶级干货+作图工具推荐+账号安全查询!
吃鸡达人们,你们好!今天我来给大家介绍一些炙手可热的吃鸡话题,以及一些让你实力飙升的独家干货! 首先,让我们说一下如何提高自己的游戏战斗力。作为一名专业吃鸡行家,我将与你们分享一些顶级游戏作战干货&…...
帆软报表解决单元格不显示问题
前言 使用帆软报表设计器制作普通报表时、设计器界面经常有一根垂直的 “虚线”。一旦单元格超过这条 “虚线” ,那么真正打开报表就看不到这些列了。以下提供了简单的修正方法、欢迎大家讨论交流。 操作环境 设计器是帆软报表 9.0,操作系统是 Window…...
LeetCode讲解篇之138. 随机链表的复制
LeetCode讲解篇之138. 随机链表的复制 文章目录 LeetCode讲解篇之138. 随机链表的复制题目描述题解思路题解代码 题目描述 题解思路 先遍历一遍链表,用哈希表保存原始节点和克隆节点的映射关系,先只克隆节点的Val,然后再次遍历链表ÿ…...
主定理(简化版)
主定理(Master Theorem)是用于分析递归算法时间复杂度的一个重要工具。它适用于形式化定义的一类递归关系,通常采用分治策略解决问题的情况。 假设我们有一个递归算法,它将问题分解成 a a a 个子问题,每个子问题的规模…...
HTTP1.0和HTTP2.0的区别
相同点:所有的HTTP请求都要基于TCP连接。 HTTP1.0:每次发送请求时建立一个TCP连接,得到响应后,释放TCP连接。 HTP1.1:**相比于1.0,引入了Keep live,客户端得到响应后,不会立刻释放T…...
ARM资源记录《AI嵌入式系统:算法优化与实现》第八章(暂时用不到)
1.CMSIS的代码 书里给的5,https://github.com/ARM-software/CMSIS_5 现在有6了,https://github.com/ARM-software/CMSIS_6 这是官网的书,介绍cmsis函数的https://arm-software.github.io/CMSIS_5/Core/html/index.html 2.CMSIS介绍 Cort…...
微信小程序2
一,视图层 1.什么视图层 框架的视图层由 WXML 与 WXSS 编写,由组件来进行展示。 将逻辑层的数据反映成视图,同时将视图层的事件发送给逻辑层。 WXML(WeiXin Markup language) 用于描述页面的结构。 WXS(WeiXin Script) 是小程序的一套脚本语…...
G.711语音编解码器详解
语音编解码利用人听觉上的冗余对语音信息进行压缩从而达到节省带宽的目的。值得注意的是,本文说的是语音编解码器,也就Speech codec,而常用的还有另一种编解码器称作音频编解码器,英文是Audio codec,它们的区别如下。 以前在学校的时候研究了很多VoIP的编解码器从G.723到A…...
蓝桥杯每日一题2023.10.17
迷宫 - 蓝桥云课 (lanqiao.cn) 题目描述 样例: 01010101001011001001010110010110100100001000101010 00001000100000101010010000100000001001100110100101 01111011010010001000001101001011100011000000010000 0100000000101010001101000010100000101010101100…...
16.SpringBoot前后端分离项目之简要配置一
SpringBoot前后端分离项目之简要配置一 前面对后端所需操作及前端页面进行了了解及操作,这一节开始前后端分离之简要配置 为什么要前后端分离 为了更低成本、更高效率的开发模式。 前端有一个独立的服务器。 后端有一个独立的服务器。两个服务器之间实时数据交换…...
Probability Calibration概率校准大比拼:性能、应用场景和可视化对比总结
在机器学习中,概率校准(Probability Calibration)是一个重要的概念。简单来说,概率校准就是将分类器输出的原始预测概率转换为更准确、更可靠的概率值。这样做的目的是为了让模型的预测结果更接近实际情况,从而提高模型在特定应用场景中的可用性。 在Python的Scikit-Lear…...
PHP 球鞋在线商城系统mysql数据库web结构apache计算机软件工程网页wamp计算机毕业设计
一、源码特点 PHP球鞋在线商城系统是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 php球鞋在线商城系统 代码 https://download.csdn.net/download/qq_41221322/8843725…...
使用Apache和内网穿透实现私有服务公网远程访问——“cpolar内网穿透”
文章目录 前言1.Apache服务安装配置1.1 进入官网下载安装包1.2 Apache服务配置 2.安装cpolar内网穿透2.1 注册cpolar账号2.2 下载cpolar客户端 3. 获取远程桌面公网地址3.1 登录cpolar web ui管理界面3.2 创建公网地址 4. 固定公网地址 前言 Apache作为全球使用较高的Web服务器…...
PreparedStatement
使用参数化查询:使用预编译的语句和参数化查询来执行SQL语句,而不是将用户输入直接嵌入到SQL语句中。这将帮助防止恶意输入注入SQL语句。...
CSS3 新增属性-边框圆角-文字阴影-盒子阴影
边框圆角 CSS 边框圆角可以通过 border-radius 属性来实现。该属性用于设置元素的圆角大小,支持四个值分别表示上左、上右、下右和下左四个角的圆角半径大小,也可以使用两个值分别表示上下和左右两个方向的圆角大小,甚至可以只使用一个值来…...
制作.a静态库 (封盒)
//云库房间 1.GitHub上创建开源框架项目须包含文件: LICENSE:开源许可证;README.md:仓库说明文件;开源项目;(登录GitHub官网) 2. 云仓储库构建成功(此时云库中没有内容三方框架)!!! 3. 4.5. //…...
一台服务器,一个新世界
我如何看待服务器 当我拥有一台服务器,我看到的不仅仅是一块硬件,而是一扇打开未来的大门,一个我可以将自己的愿景和创意投射到其中的平台。这台服务器是我的工具,我的画布,我将在其中铸造我的数字梦想。 第一步我要…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...
Leetcode 3576. Transform Array to All Equal Elements
Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到…...
【OSG学习笔记】Day 18: 碰撞检测与物理交互
物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...
让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
在机器学习的回归分析中,损失函数的选择对模型性能具有决定性影响。均方误差(MSE)作为经典的损失函数,在处理干净数据时表现优异,但在面对包含异常值的噪声数据时,其对大误差的二次惩罚机制往往导致模型参数…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
动态 Web 开发技术入门篇
一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...
Mysql8 忘记密码重置,以及问题解决
1.使用免密登录 找到配置MySQL文件,我的文件路径是/etc/mysql/my.cnf,有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...
逻辑回归暴力训练预测金融欺诈
简述 「使用逻辑回归暴力预测金融欺诈,并不断增加特征维度持续测试」的做法,体现了一种逐步建模与迭代验证的实验思路,在金融欺诈检测中非常有价值,本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...
