当前位置: 首页 > news >正文

rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inferencedef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}outputs = sess.run(output_names, input_data)# TODO: 补充 rknn 转换、加载和获取输出的代码# TODO: 补充评估对比两个输出 onnx 的 target 和 rknn 的 predict# TODO: 补充当评估准确率低于 0.8 时终止循环,此时的模型最后一个算子就是 rknn 不正确的算子if __name__ == "__main__":main()

在这里插入图片描述

解释

在补充 TODO 部分的代码之前,我们需要确保你有正确的 RKNN API 环境,并且你知道如何使用它来转换、加载和获取输出。以下是一个示例代码,展示了如何使用 RKNN API 完成这些任务。

1. 安装 RKNN Toolkit

首先,确保你已经安装了 RKNN Toolkit。你可以通过以下命令安装:

pip install rknn-toolkit

2. 补充代码

以下是补充了 TODO 部分的代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference
from rknn.api import RKNNdef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def convert_onnx_to_rknn(onnx_model_path, rknn_model_path):"""将 ONNX 模型转换为 RKNN 模型。参数:onnx_model_path (str): ONNX 模型路径。rknn_model_path (str): 转换后的 RKNN 模型路径。"""rknn = RKNN()# 加载 ONNX 模型print('--> Loading model')ret = rknn.load_onnx(model=onnx_model_path)if ret != 0:print('Load ONNX model failed!')returnprint('done')# 配置模型print('--> Building model')ret = rknn.build(do_quantization=False)if ret != 0:print('Build RKNN model failed!')returnprint('done')# 导出 RKNN 模型print('--> Export RKNN model')ret = rknn.export_rknn(rknn_model_path)if ret != 0:print('Export RKNN model failed!')returnprint('done')def load_and_run_rknn_model(rknn_model_path, input_data):"""加载 RKNN 模型并运行推理。参数:rknn_model_path (str): RKNN 模型路径。input_data (np.ndarray): 输入数据。返回:list: RKNN 模型的输出结果。"""rknn = RKNN()# 加载 RKNN 模型print('--> Loading RKNN model')ret = rknn.load_rknn(rknn_model_path)if ret != 0:print('Load RKNN model failed!')return []print('done')# 初始化 RKNN 模型print('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')return []print('done')# 运行推理print('--> Running model')outputs = rknn.inference(inputs=[input_data])print('done')rknn.release()return outputsdef compare_outputs(onnx_outputs, rknn_outputs, threshold=0.8):"""比较 ONNX 和 RKNN 模型的输出结果。参数:onnx_outputs (list): ONNX 模型的输出结果。rknn_outputs (list): RKNN 模型的输出结果。threshold (float): 准确率阈值。返回:bool: 如果准确率低于阈值,则返回 False,否则返回 True。"""# 计算准确率 (这里假设是简单的相对误差)accuracy = np.mean([np.allclose(onnx_out, rknn_out, rtol=threshold) for onnx_out, rknn_out in zip(onnx_outputs, rknn_outputs)])return accuracy >= thresholddef main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}onnx_outputs = sess.run(output_names, input_data)# 转换 ONNX 模型为 RKNN 模型rknn_model_path = 'test.rknn'convert_onnx_to_rknn(new_model_path, rknn_model_path)# 加载并运行 RKNN 模型rknn_outputs = load_and_run_rknn_model(rknn_model_path, im)# 比较 ONNX 和 RKNN 模型的输出结果if not compare_outputs(onnx_outputs, rknn_outputs):print(f'Node {node_name} is the incorrect operator in RKNN model.')breakif __name__ == "__main__":main()

相关文章:

rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码: import cv2 import numpy as np import onnx import onnxruntime as rt from onnx import helper, shape_inferencedef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): O…...

【C语言】解决C语言报错:Stack Overflow

文章目录 简介什么是Stack OverflowStack Overflow的常见原因如何检测和调试Stack Overflow解决Stack Overflow的最佳实践详细实例解析示例1:递归调用过深示例2:分配过大的局部变量示例3:嵌套函数调用过多 进一步阅读和参考资料总结 简介 St…...

【滚动哈希 二分查找】1044. 最长重复子串

本文涉及知识点 滚动哈希 二分查找算法合集 LeetCode 1044. 最长重复子串 给你一个字符串 s ,考虑其所有 重复子串 :即 s 的(连续)子串,在 s 中出现 2 次或更多次。这些出现之间可能存在重叠。 返回 任意一个 可能具…...

webid、sec_poison_id、a1、web_session参数分析与算法实现

文章目录 1. 写在前面2. 参数分析3. 核心算法【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Python与爬虫领域研究与开发工作! 【🌟作者推荐】:对爬…...

Qt|QWebSocket与Web进行通讯,实时接收语音流

实现功能主要思路:在网页端进行语音输入,PC机可以实时接收并播放语音流。 此时,Qt程序做客户端,Web端做服务器,使用QWebSocket进行通讯,实时播放接收的语音流。 功能实现 想要实现该功能,需要…...

「51媒体」电视台媒体邀约采访报道怎么做?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 电视台作为地方主流媒体,对于新闻报道有着严格的选题标准和报道流程。如果您希望电视台对某个会议或活动进行报道,可以按这样的方法来做: 1.明确活动信…...

Python提取PDF文本和图片,以及提前PDF页面中指定矩形区域的文本

前言 从PDF中提取内容能帮助我们获取文件中的信息,以便进行进一步的分析和处理。此外,在遇到类似项目时,提取出来的文本或图片也能再次利用。要在Python中通过代码提取PDF文件中的文本和图片,可以使用 Spire.PDF for Python 这个…...

C#实现边缘锐化(图像处理)

在 C# 中进行图像的边缘锐化,可以通过卷积滤波器实现。边缘锐化的基本思想是通过卷积核(也称为滤波器或掩模)来增强图像中的边缘。我们可以使用一个简单的锐化核,例如: [ 0, -1, 0][-1, 5, -1][ 0, -1, 0]这个卷积核…...

ffmpeg windows系统详细教程

视频做预览时黑屏,但有声音问题解决方案。 需要将 .mp4编成H.264格式的.mp4 一般上传视频的站点,如YouTube、Vimeo 等,通常会在用户上传视频时自动对视频进行转码,以确保视频能够在各种设备和网络条件下流畅播放。这些网站通常…...

【单片机】MSP430G2553单片机 Could not find MSP-FET430UIF on specified COM port 解决方案

文章目录 MSP430G2553开发板基础知识解决办法如何实施解决办法4步骤一步骤二步骤三 MSP430G2553开发板基础知识 MSP430G2553开发板如下图,上半部分就是UIF程序下载调试区域的硬件。个人觉得MSP430G2553开发板的这个部分没有做好硬件设计,导致很多系统兼…...

每日一题——力扣104. 二叉树的最大深度(举一反三+思想解读+逐步优化)四千字好文

一个认为一切根源都是“自己不够强”的INTJ 个人主页:用哲学编程-CSDN博客专栏:每日一题——举一反三Python编程学习Python内置函数 目录 我的写法 代码功能 代码结构 时间复杂度分析 空间复杂度分析 总结 我要更强 优化方法:迭代&…...

wpf textbox 有焦点 导致后台更新 前台不跟着改变

这个问题可能是由于 WPF 的数据绑定机制导致的。当 TextBox 有焦点时,它会独立于数据绑定进行更新,这可能会导致前台界面不能及时反映后台数据的变化。 1.使用 UpdateSourceTrigger 属性: 在数据绑定时,将 UpdateSourceTrigger 属性设置为 PropertyChanged。这样当 TextBox 的…...

数字化物资管理系统的未来:RFID技术的创新应用

在信息化和智能化不断发展的背景下,物资管理系统的数字化转型已成为各行各业关注的焦点。RFID技术作为一种先进的物联网技术,通过全面数字化实现物资信息的实时追踪和高效管理,为企业的物资管理提供了强有力的支持。 首先,RFID技…...

【docker】常用指令-表格整理

以下列出的指令是Docker中常用的命令,但并不是全部。Docker的指令非常丰富,可以根据具体的需求和场景选择合适的指令。同时,每个指令都有很多选项和参数可以使用,可以通过 docker COMMAND --help 来获取更详细的信息。 一、容器命…...

洛谷——P2824 排序

题目来源:[HEOI2016/TJOI2016] 排序 - 洛谷https://www.luogu.com.cn/problem/P2824 问题思路 本文介绍一种二分答案的做法,时间复杂度为:(nm)*log(n)*log(n).本题存在nlog(n)的做法,然而其做法没有二分答案的做法通俗易懂. 默认读…...

echart在线图表demo下载直接运行

echart 全面的数据可视化图表解决方案 | 折线图、柱状图、饼图、散点图、水球图等各类图表展示 持续更新中 三色带下表题速度仪表盘 地图自定义图标 动态环形图饼状图 动态水波动圆形 多标题指针仪表盘 温度仪表盘带下标题 横向柱状图排名 环形饼状图 双折线趋势变化...

MLX5_SET_TO_ONES宏解析

看代码时,遇到一个非常复杂的宏MLX5_SET_TO_ONES,这个宏的主要作用是对特定的数据结构置位,宏的上下文如下: #define __mlx5_nullp(typ) ((struct mlx5_ifc_##typ##_bits *)0) #define __mlx5_bit_off(typ, fld) (offsetof(struc…...

SQL Server入门-SSMS简单使用(2008R2版)-1

环境: win10,SQL Server 2008 R2 参考: SQL Server 新建数据库 - 菜鸟教程 https://www.cainiaoya.com/sqlserver/sql-server-create-db.html 第 2 课:编写 Transact-SQL | Microsoft Learn https://learn.microsoft.com/zh-cn/…...

高考专业抉择探索计算机专业的未来展望及适合人群

身份:一位正在面临人生重要抉择的高考生,一位计算机行业从业者  正文:  随着2024年高考落幕,我与数百万高三学生一样,又将面临人生中的重要抉择:选择大学专业。对于许多学生来说,计算机科学…...

windows安装spark

在 Windows 上安装 Spark 并进行配置需要一些步骤,包括安装必要的软件和配置环境变量。以下是详细的步骤指南: 步骤一:安装 Java 下载和安装 Java Development Kit (JDK) 到 Oracle JDK 下载页面 或 OpenJDK 下载页面 下载适合你系统的 JDK。…...

多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度​

一、引言:多云环境的技术复杂性本质​​ 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,​​基础设施的技术债呈现指数级积累​​。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎:品融电商,一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中,品牌如何破浪前行?自建团队成本高、效果难控;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展,消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁,不仅优化了客户体验,还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用,并…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...