当前位置: 首页 > news >正文

rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inferencedef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}outputs = sess.run(output_names, input_data)# TODO: 补充 rknn 转换、加载和获取输出的代码# TODO: 补充评估对比两个输出 onnx 的 target 和 rknn 的 predict# TODO: 补充当评估准确率低于 0.8 时终止循环,此时的模型最后一个算子就是 rknn 不正确的算子if __name__ == "__main__":main()

在这里插入图片描述

解释

在补充 TODO 部分的代码之前,我们需要确保你有正确的 RKNN API 环境,并且你知道如何使用它来转换、加载和获取输出。以下是一个示例代码,展示了如何使用 RKNN API 完成这些任务。

1. 安装 RKNN Toolkit

首先,确保你已经安装了 RKNN Toolkit。你可以通过以下命令安装:

pip install rknn-toolkit

2. 补充代码

以下是补充了 TODO 部分的代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference
from rknn.api import RKNNdef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def convert_onnx_to_rknn(onnx_model_path, rknn_model_path):"""将 ONNX 模型转换为 RKNN 模型。参数:onnx_model_path (str): ONNX 模型路径。rknn_model_path (str): 转换后的 RKNN 模型路径。"""rknn = RKNN()# 加载 ONNX 模型print('--> Loading model')ret = rknn.load_onnx(model=onnx_model_path)if ret != 0:print('Load ONNX model failed!')returnprint('done')# 配置模型print('--> Building model')ret = rknn.build(do_quantization=False)if ret != 0:print('Build RKNN model failed!')returnprint('done')# 导出 RKNN 模型print('--> Export RKNN model')ret = rknn.export_rknn(rknn_model_path)if ret != 0:print('Export RKNN model failed!')returnprint('done')def load_and_run_rknn_model(rknn_model_path, input_data):"""加载 RKNN 模型并运行推理。参数:rknn_model_path (str): RKNN 模型路径。input_data (np.ndarray): 输入数据。返回:list: RKNN 模型的输出结果。"""rknn = RKNN()# 加载 RKNN 模型print('--> Loading RKNN model')ret = rknn.load_rknn(rknn_model_path)if ret != 0:print('Load RKNN model failed!')return []print('done')# 初始化 RKNN 模型print('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')return []print('done')# 运行推理print('--> Running model')outputs = rknn.inference(inputs=[input_data])print('done')rknn.release()return outputsdef compare_outputs(onnx_outputs, rknn_outputs, threshold=0.8):"""比较 ONNX 和 RKNN 模型的输出结果。参数:onnx_outputs (list): ONNX 模型的输出结果。rknn_outputs (list): RKNN 模型的输出结果。threshold (float): 准确率阈值。返回:bool: 如果准确率低于阈值,则返回 False,否则返回 True。"""# 计算准确率 (这里假设是简单的相对误差)accuracy = np.mean([np.allclose(onnx_out, rknn_out, rtol=threshold) for onnx_out, rknn_out in zip(onnx_outputs, rknn_outputs)])return accuracy >= thresholddef main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}onnx_outputs = sess.run(output_names, input_data)# 转换 ONNX 模型为 RKNN 模型rknn_model_path = 'test.rknn'convert_onnx_to_rknn(new_model_path, rknn_model_path)# 加载并运行 RKNN 模型rknn_outputs = load_and_run_rknn_model(rknn_model_path, im)# 比较 ONNX 和 RKNN 模型的输出结果if not compare_outputs(onnx_outputs, rknn_outputs):print(f'Node {node_name} is the incorrect operator in RKNN model.')breakif __name__ == "__main__":main()

相关文章:

rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码: import cv2 import numpy as np import onnx import onnxruntime as rt from onnx import helper, shape_inferencedef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): O…...

【C语言】解决C语言报错:Stack Overflow

文章目录 简介什么是Stack OverflowStack Overflow的常见原因如何检测和调试Stack Overflow解决Stack Overflow的最佳实践详细实例解析示例1:递归调用过深示例2:分配过大的局部变量示例3:嵌套函数调用过多 进一步阅读和参考资料总结 简介 St…...

【滚动哈希 二分查找】1044. 最长重复子串

本文涉及知识点 滚动哈希 二分查找算法合集 LeetCode 1044. 最长重复子串 给你一个字符串 s ,考虑其所有 重复子串 :即 s 的(连续)子串,在 s 中出现 2 次或更多次。这些出现之间可能存在重叠。 返回 任意一个 可能具…...

webid、sec_poison_id、a1、web_session参数分析与算法实现

文章目录 1. 写在前面2. 参数分析3. 核心算法【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Python与爬虫领域研究与开发工作! 【🌟作者推荐】:对爬…...

Qt|QWebSocket与Web进行通讯,实时接收语音流

实现功能主要思路:在网页端进行语音输入,PC机可以实时接收并播放语音流。 此时,Qt程序做客户端,Web端做服务器,使用QWebSocket进行通讯,实时播放接收的语音流。 功能实现 想要实现该功能,需要…...

「51媒体」电视台媒体邀约采访报道怎么做?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 电视台作为地方主流媒体,对于新闻报道有着严格的选题标准和报道流程。如果您希望电视台对某个会议或活动进行报道,可以按这样的方法来做: 1.明确活动信…...

Python提取PDF文本和图片,以及提前PDF页面中指定矩形区域的文本

前言 从PDF中提取内容能帮助我们获取文件中的信息,以便进行进一步的分析和处理。此外,在遇到类似项目时,提取出来的文本或图片也能再次利用。要在Python中通过代码提取PDF文件中的文本和图片,可以使用 Spire.PDF for Python 这个…...

C#实现边缘锐化(图像处理)

在 C# 中进行图像的边缘锐化,可以通过卷积滤波器实现。边缘锐化的基本思想是通过卷积核(也称为滤波器或掩模)来增强图像中的边缘。我们可以使用一个简单的锐化核,例如: [ 0, -1, 0][-1, 5, -1][ 0, -1, 0]这个卷积核…...

ffmpeg windows系统详细教程

视频做预览时黑屏,但有声音问题解决方案。 需要将 .mp4编成H.264格式的.mp4 一般上传视频的站点,如YouTube、Vimeo 等,通常会在用户上传视频时自动对视频进行转码,以确保视频能够在各种设备和网络条件下流畅播放。这些网站通常…...

【单片机】MSP430G2553单片机 Could not find MSP-FET430UIF on specified COM port 解决方案

文章目录 MSP430G2553开发板基础知识解决办法如何实施解决办法4步骤一步骤二步骤三 MSP430G2553开发板基础知识 MSP430G2553开发板如下图,上半部分就是UIF程序下载调试区域的硬件。个人觉得MSP430G2553开发板的这个部分没有做好硬件设计,导致很多系统兼…...

每日一题——力扣104. 二叉树的最大深度(举一反三+思想解读+逐步优化)四千字好文

一个认为一切根源都是“自己不够强”的INTJ 个人主页:用哲学编程-CSDN博客专栏:每日一题——举一反三Python编程学习Python内置函数 目录 我的写法 代码功能 代码结构 时间复杂度分析 空间复杂度分析 总结 我要更强 优化方法:迭代&…...

wpf textbox 有焦点 导致后台更新 前台不跟着改变

这个问题可能是由于 WPF 的数据绑定机制导致的。当 TextBox 有焦点时,它会独立于数据绑定进行更新,这可能会导致前台界面不能及时反映后台数据的变化。 1.使用 UpdateSourceTrigger 属性: 在数据绑定时,将 UpdateSourceTrigger 属性设置为 PropertyChanged。这样当 TextBox 的…...

数字化物资管理系统的未来:RFID技术的创新应用

在信息化和智能化不断发展的背景下,物资管理系统的数字化转型已成为各行各业关注的焦点。RFID技术作为一种先进的物联网技术,通过全面数字化实现物资信息的实时追踪和高效管理,为企业的物资管理提供了强有力的支持。 首先,RFID技…...

【docker】常用指令-表格整理

以下列出的指令是Docker中常用的命令,但并不是全部。Docker的指令非常丰富,可以根据具体的需求和场景选择合适的指令。同时,每个指令都有很多选项和参数可以使用,可以通过 docker COMMAND --help 来获取更详细的信息。 一、容器命…...

洛谷——P2824 排序

题目来源:[HEOI2016/TJOI2016] 排序 - 洛谷https://www.luogu.com.cn/problem/P2824 问题思路 本文介绍一种二分答案的做法,时间复杂度为:(nm)*log(n)*log(n).本题存在nlog(n)的做法,然而其做法没有二分答案的做法通俗易懂. 默认读…...

echart在线图表demo下载直接运行

echart 全面的数据可视化图表解决方案 | 折线图、柱状图、饼图、散点图、水球图等各类图表展示 持续更新中 三色带下表题速度仪表盘 地图自定义图标 动态环形图饼状图 动态水波动圆形 多标题指针仪表盘 温度仪表盘带下标题 横向柱状图排名 环形饼状图 双折线趋势变化...

MLX5_SET_TO_ONES宏解析

看代码时,遇到一个非常复杂的宏MLX5_SET_TO_ONES,这个宏的主要作用是对特定的数据结构置位,宏的上下文如下: #define __mlx5_nullp(typ) ((struct mlx5_ifc_##typ##_bits *)0) #define __mlx5_bit_off(typ, fld) (offsetof(struc…...

SQL Server入门-SSMS简单使用(2008R2版)-1

环境: win10,SQL Server 2008 R2 参考: SQL Server 新建数据库 - 菜鸟教程 https://www.cainiaoya.com/sqlserver/sql-server-create-db.html 第 2 课:编写 Transact-SQL | Microsoft Learn https://learn.microsoft.com/zh-cn/…...

高考专业抉择探索计算机专业的未来展望及适合人群

身份:一位正在面临人生重要抉择的高考生,一位计算机行业从业者  正文:  随着2024年高考落幕,我与数百万高三学生一样,又将面临人生中的重要抉择:选择大学专业。对于许多学生来说,计算机科学…...

windows安装spark

在 Windows 上安装 Spark 并进行配置需要一些步骤,包括安装必要的软件和配置环境变量。以下是详细的步骤指南: 步骤一:安装 Java 下载和安装 Java Development Kit (JDK) 到 Oracle JDK 下载页面 或 OpenJDK 下载页面 下载适合你系统的 JDK。…...

保姆级教程:用迪文屏官方工具生成30x30点阵汉字库,搞定界面文本显示

嵌入式UI开发实战:迪文屏3030点阵汉字库生成全流程指南 在嵌入式设备的人机交互界面开发中,文本显示是最基础却最容易出问题的环节之一。许多开发者第一次使用迪文屏时,往往会被字库生成工具的参数设置难住——为什么明明生成了字库&#xf…...

企业数字化转型基石:全面认识4A企业架构数据架构方案

数据架构是企业架构中连接业务、应用与技术的桥梁,通过数据资产目录厘清家底,数据标准统一语言,数据模型指导开发,数据分布拉通业务流,从而提升数据质量与运作效率,支撑业务决策与系统建设。 统一语言&…...

UE4蓝图插件推荐:这5款免费工具让你的开发效率翻倍(附详细使用技巧)

UE4蓝图插件推荐:5款免费工具解锁高效开发新姿势 第一次在虚幻引擎中搭建复杂交互逻辑时,我盯着满屏纠缠的连线发呆了半小时——这简直比解毛线团还令人崩溃。直到发现那些藏在社区角落的蓝图效率神器,才意识到原来80%的重复劳动都可以交给插…...

告别PDF转换烦恼:Marker让学术文档秒变Markdown的完整指南

告别PDF转换烦恼:Marker让学术文档秒变Markdown的完整指南 【免费下载链接】marker 一个高效、准确的工具,能够将 PDF 和图像快速转换为 Markdown、JSON 和 HTML 格式,支持多语言和复杂布局处理,可选集成 LLM 提升精度&#xff0c…...

StructBERT在代码仓库管理中的重复代码检测应用

StructBERT在代码仓库管理中的重复代码检测应用 你有没有遇到过这种情况?在代码审查时,总觉得某段代码似曾相识,但又说不清在哪见过。或者,团队里不同成员为了解决类似问题,各自写了一套逻辑相近但细节不同的代码&…...

Sketchfab 3D模型本地化工具:Firefox浏览器专业解决方案

Sketchfab 3D模型本地化工具:Firefox浏览器专业解决方案 【免费下载链接】sketchfab sketchfab download userscipt for Tampermonkey by firefox only 项目地址: https://gitcode.com/gh_mirrors/sk/sketchfab 在数字创作领域,3D资源的离线获取与…...

智能家居选遥控器?RF 2.4G vs 蓝牙 vs IR 保姆级对比指南

智能家居遥控技术终极对决:RF 2.4G vs 蓝牙 vs IR 深度解析 当你深夜躺在沙发上想调暗灯光,却发现必须起身对准空调才能操作——这种尴尬正是选错遥控技术的代价。智能家居的"最后一米"控制体验,往往取决于那只看不见的传输协议。本…...

模型加载与初始化(3)

前言 在 llama.cpp 中,模型推理主要基于 GGUF 格式展开。GGUF 是一种专为存储基于 GGML 及其相关执行器进行推理的模型文件而设计的格式。作为一种二进制格式,其设计初衷在于实现模型的高效加载与保存,并确保良好的易读性。本章将深入探讨大语…...

CentOS8网络管理大变革:从network.service到NetworkManager的全面解析

CentOS8网络管理架构深度解析:从传统命令到NetworkManager的进化之路 如果你是一位长期使用CentOS的系统管理员,最近升级到CentOS8后可能会遇到一个令人困惑的问题:当你习惯性地输入systemctl restart network命令时,系统却无情地…...

5步攻克MZmine 3质谱数据分析:从问题解决到专业应用的实战指南

5步攻克MZmine 3质谱数据分析:从问题解决到专业应用的实战指南 【免费下载链接】mzmine3 MZmine 3 source code repository 项目地址: https://gitcode.com/gh_mirrors/mz/mzmine3 MZmine 3作为开源质谱数据分析领域的核心工具,在代谢组学、蛋白质…...