PyTorch模型转ONNX例子
参考:(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.6.0+cu124 documentation
import numpy as np
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init
import onnx
import onnxruntime
import time
import os
from PIL import Image
import torchvision.transforms as transformsclass SuperResolutionNet(nn.Module):def __init__(self, upscale_factor, inplace=False):super(SuperResolutionNet, self).__init__()self.relu = nn.ReLU(inplace=inplace)self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))self.pixel_shuffle = nn.PixelShuffle(upscale_factor)self._initialize_weights()def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))return xdef _initialize_weights(self):init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv4.weight)def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()def evaluation_accuracy(x, torch_model, ort_session):torch_out = torch_model(x)ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}ort_outs = ort_session.run(None, ort_inputs)np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")def evaluation_speed(x, torch_model, ort_session):start = time.time()torch_out = torch_model(x)end = time.time()print(f"Inference of Pytorch model used {end - start} seconds")ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}start = time.time()ort_outs = ort_session.run(None, ort_inputs)end = time.time()print(f"Inference of ONNX model used {end - start} seconds")def evaluation_result(ort_session):img = Image.open("cat.jpg")resize = transforms.Resize([224, 224])img = resize(img)img_ycbcr = img.convert('YCbCr')img_y, img_cb, img_cr = img_ycbcr.split()to_tensor = transforms.ToTensor()img_y = to_tensor(img_y)img_y.unsqueeze_(0)ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}ort_outs = ort_session.run(None, ort_inputs)img_out_y = ort_outs[0]img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')final_img = Image.merge("YCbCr", [img_out_y,img_cb.resize(img_out_y.size, Image.BICUBIC),img_cr.resize(img_out_y.size, Image.BICUBIC),]).convert("RGB")final_img.save("cat_superres_with_ort.jpg")img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)img.save("cat_resized.jpg")if __name__ == '__main__':torch_model = SuperResolutionNet(upscale_factor=3)model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'batch_size = 64map_location = lambda storage, loc: storageif torch.cuda.is_available():map_location = Nonetorch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))torch_model.eval()x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)if not os.path.exists( "super_resolution.onnx"):torch.onnx.export(torch_model, # model being runx, # model input (or a tuple for multiple inputs)"super_resolution.onnx", # where to save the model (can be a file or file-like object)export_params=True, # store the trained parameter weights inside the model fileopset_version=10, # the ONNX version to export the model todo_constant_folding=True, # whether to execute constant folding for optimizationinput_names = ['input'], # the model's input namesoutput_names = ['output'], # the model's output namesdynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes'output' : {0 : 'batch_size'}})onnx_model = onnx.load("super_resolution.onnx")onnx.checker.check_model(onnx_model)ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])evaluation_accuracy(x, torch_model, ort_session)evaluation_speed(x, torch_model, ort_session)evaluation_result(ort_session)
相关文章:
PyTorch模型转ONNX例子
参考:(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.6.0cu124 documentation import numpy as np import torch.utils.model_zoo as model_zoo import torch.onnx import torch.nn as nn import t…...
科技云报到:AI Agent打了个响指,商业齿轮加速转动
科技云报到原创。 3月16日,百度旗下文心大模型4.5和文心大模型X1正式发布。目前,两款模型已在文心一言官网上线,免费向用户开放。 同时,文心大模型4.5已上线百度智能云千帆大模型平台,企业用户和开发者登录即可调用AP…...
【蓝桥杯python研究生组备赛】005 数学与简单DP
题目1 01背包 有 N 件物品和一个容量是 V 的背包。每件物品只能使用一次。 第 i 件物品的体积是 vi,价值是 wi。 求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大。 输出最大价值。 输入格式 第一行两个整数&a…...
Chapter 4-16. Troubleshooting Congestion in Fibre Channel Fabrics
Show FCS Ie Example 4-17 shows the NX-OS command show fcs ie on Cisco MDS switches. 例 4-17 显示了 Cisco MDS 交换机上的 NX-OS 命令 show fcs ie。 Example 4-17 NX-OS command show fcs ie on Cisco MDS switches MDS9706-C# show fcs ie IE List for VSAN: 20 --…...
抖音视频数据获取实战:从API调用到热门内容挖掘
在短视频流量为王的时代,掌握抖音热门视频数据已成为内容运营、竞品分析及营销决策的关键。本文将手把手教你通过抖音开放平台API获取视频详情数据,并提供完整的代码实现及商业化应用思路。 一、抖音API权限申请与核心接口 抖音API需企业资质认证&…...
大白话读懂java对象创建的过程
1. java对象创建流程(大白话版) 咱们java对象被创建的过程大致如下,即: 在 JVM 中对象的创建,从⼀个 new 指令开始: 首先检查这个指令的参数是否能在常量池中定位到⼀个类的符号引用检查这个符号引用代表…...
Ubutu20.04安装docker与docker-compose
系统:20.04.6 LTS (Focal Fossa)" 1.配置apt源(在/etc/apt/sources.list中输入以下内容) # deb cdrom:[Ubuntu 20.04.6 LTS _Focal Fossa_ - Release amd64 (20230316)]/ focal main restricted deb http://mirrors.aliyun.com/ubuntu/ focal main restricted …...
AI图像理解技术的演进
在CLIP等现代多模态模型出现之前,早期的图生文技术主要依赖人工标注的ImageNet等数据集,但其技术路线与当前方法存在本质差异。 一、传统图生文技术的标注依赖 ImageNet的核心地位 在2012-2020年间,ImageNet的1,400万张人工标注图像ÿ…...
STM32 —— MCU、MPU、ARM、FPGA、DSP
在嵌入式系统中,MCU、MPU、ARM、FPGA和DSP是核心组件,各自在架构、功能和应用场景上有显著差异。以下从专业角度详细解析这些概念: 一、 MCU(Microcontroller Unit,微控制器单元) 核心定义 集成系统芯片&a…...
aiosignal
文章目录 安装 一、关于 aiosignal Github : https://github.com/aio-libs/aiosignal官方文档:https://aiosignal.aio-libs.org/gitter聊天:https://gitter.im/aio-libs/Lobby许可证 : Apache 2 aiosignal 管理 asyncio 项目中回调的项目。 Signal是已…...
在 VSCode 远程开发环境下使用 Git 常用命令
在日常开发过程中,无论是单人项目还是团队协作,Git 都是版本管理的利器。尤其是在使用 VSCode 连接远程服务器进行代码开发时,Git 不仅能帮助你管理代码版本,还能让多人协作变得更加高效。本文将介绍一些常用的 Git 命令ÿ…...
电脑节电模式怎么退出 分享5种解决方法
在使用电脑的过程中,许多用户为了节省电力,通常会选择开启电脑的节能模式。然而,在需要更高性能或进行图形密集型任务时,节能模式可能会限制系统的性能表现。这时,了解如何正确地关闭或调整节能设置就显得尤为重要了。…...
kubernetes高级实战
一、模拟企业环境进行一个实战部署 [rootmaster node]# kubectl apply -f pod-tomcat.yaml pod/tomcat-test created [rootmaster node]# kubectl get pods NAME READY STATUS RESTARTS AGE tomcat-test 2/2 Running 0 2s [rootmaster node]…...
【Java】——程序逻辑控制(构建稳健代码的基石)
🎁个人主页:User_芊芊君子 🎉欢迎大家点赞👍评论📝收藏⭐文章 🔍系列专栏:【Java】内容概括 文章目录: 一.顺序结构二.分支结构1.if 语句1.1 语法格式11.2 语法格式21.3 语法格式3 …...
QT编程之PCM音频处理
一、高级播放接口(未压缩编码的音频文件) QMediaPlayer 支持MP3/WMA等压缩格式及网络流媒体播放,集成媒体控制(播放/暂停/进度调节)需设置QAudioOutput指定输出设备,支持播放速度调节(setPl…...
卫星互联网智慧杆:开启智能城市新时代
哇哦!在当下这个数字化浪潮正以雷霆万钧之势席卷全球的超酷时代,智慧城市建设已然成为世界各国你追我赶、竞相发力的核心重点领域啦!而咱们的卫星互联网智慧杆,作为一项完美融合了卫星通信与物联网顶尖技术的创新结晶,…...
Numpy broadcasting规则
Numpy的broadcast操作是为了将两个不同形状的数组,通过一系列规则,变换成形状相同的数组,从而使得它们之间可以进行按元素进行的计算。 Broadcasting的机制并不复杂,只要记住以下几条规则就可以了: 1. 顺序。首先&am…...
掌握 Shopee 商品数据:用爬虫解锁无限商机
在电商的浩瀚宇宙中,Shopee 宛如一颗璀璨星辰,吸引着无数卖家与买家在此汇聚。对于电商从业者、市场调研人员或是数据分析师而言,获取 Shopee 店铺的商品信息就如同掌握了开启财富之门的钥匙。而爬虫技术,正是帮助我们高效获取这些…...
Qt-QChart实现折线图
一、介绍场景 动态查看数据变化,或者了解数据发展趋势,让数据可以形象直观展现出来,这里推荐使用折线图的方式展现,本文抛砖引玉,简单实现一个实例,效果图如下: 二、实现步骤 1、charts组件 …...
取消Win10锁屏界面上显示的天气、市场和广告的操作
要取消Win10锁屏界面上显示的天气、市场和广告,您可以按照以下步骤操作: 方法一:更改锁屏界面设置 打开“设置”: 点击“开始”菜单,然后点击齿轮状的“设置”图标。 进入“个性化”: 在“设置”窗口中&a…...
IoT设备测试:从协议到硬件的全栈验证体系与实践指南
一、引言:IoT技术浪潮下的质量挑战 根据IDC预测,到2027年全球IoT设备数量将突破290亿台,涵盖智能家居、工业物联网(IIoT)、智慧城市、车联网等场景。然而,IoT系统的复杂性远超传统嵌入式设备——硬件异构性…...
大白话详细解读React框架的diffing算法
1. Diffing 算法是什么? Diffing 算法是 React 用来比较虚拟 DOM(Virtual DOM)树的一种算法。它的作用是找出前后两次渲染之间的差异(diff),然后只更新这些差异部分,而不是重新渲染整个页面。 …...
自然语言处理入门
第一章 自然语言处理入门 1 什么是自然语言处理 【什么是人工智能,分别对应哪几个领域】 AI是模仿甚至超越人的某项机能,NLP、CV、ASR NLP是机器理解并生成人类语言2 自然语言处理的发展简史 1950 -- 图灵提出“机器能思考吗”,划时代性的…...
Arduino示例代码讲解:Pitch follower 跟随
Arduino示例代码讲解:Pitch follower 跟随 Pitch follower代码功能代码逐行解释1. 注释部分功能:硬件连接:2. `setup()` 函数3. `loop()` 函数硬件连接**扬声器连接**:**光敏电阻连接**:**Arduino板**:运行结果修改建议视频讲解Pitch follower 这段代码是一个Arduino示例…...
从TouchDriver Pro到Touchdriver G1,Weart触觉手套全系解析:XR交互的“真实触感”如何实现?
Weart旗下的Touchdriver Pro触觉手套和Touchdriver G1触觉手套,凭借其技术创新,为用户带来了全新的触觉体验。Touchdriver Pro触觉手套通过多模态触觉反馈技术,提供力反馈、纹理渲染和温度提示,让用户在虚拟环境中感受到真实的触觉…...
华为OD机试-阿里巴巴找黄金宝箱(I)-双指针(Java 2023 B卷 100分)
题目描述 阿里巴巴在去砍柴的路上发现了强盗集团的藏宝地,藏宝地有编号从 0 到 N 的箱子,每个箱子上贴有一个数字。黄金宝箱满足排在它之前的所有箱子数字和等于排在它之后的所有箱子数字和。第一个箱子左边部分的数字和定义为 0;最后一个宝箱右边部分的数字和定义为 0。请…...
ubuntu20如何升级nginx到最新版本(其它版本大概率也可以)
前言: Nginx非常常用,所以在网络安全方面备受“关注”。其漏洞非常多,要经常保持软件更新版本才能更好的保证安全。但是Ubuntu官网适配nginx非常慢,所以nginx官方也会推出针对主流Linux操作系统的包管理工具安装方式。 步骤&…...
排序算法实现:插入排序与希尔排序
目录 一、引言 二、代码整体结构 三、宏定义与头文件 四、插入排序函数(Insertsort) 函数作用 代码要点分析 五、希尔排序函数(ShellSort) 函数作用 代码要点分析 六、打印数组函数(PrintSort&#x…...
UDP协议原理
UDP协议原理 本篇介绍 在前面使用UDP编程时已经基本了解了UDP的工作模式,也知道了UDP有三个特点: 无连接不可靠面向数据报 但是当时并没有具体谈论为什么UDP有以上三个特点,基于这个原因,本篇就会针对这三个原因进行介绍 UDP…...
EtherCAT转Modbus网关如何在倍福plc组态快速配置
EtherCAT转Modbus网关如何在倍福plc组态快速配置 在工业控制领域,EtherCAT和Modbus是两种常见的总线通信协议。EtherCAT以其高速的数据传输和灵活的网络配置被广泛应用于高性能自动化控制系统中,而Modbus则因其简单、稳定且兼容性强而被许多设备所支持。…...
