当前位置: 首页 > news >正文

【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:os.system(f"wget {url} -O {filename}")
except Exception as e:print(f"Error downloading image: {e}")# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")# Read the categories
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")# Load the ONNX model
import onnx
import onnxruntimeonnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

相关文章:

【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序: import torch import torch.onnx from torchvision import transforms, models from PIL import Image import os# Load MobileNetV2 model model models.mobilenet_v2(pretrainedTrue) model.eval()# Download an example image from the P…...

高防CDN安全防护系统在业务方面的应用

在当今数字化的时代,网络安全问题日益严峻,保护网站和数据免受攻击变得至关重要。CDN安全防护系统作为一种有效的解决方案,受到了广泛关注。小德将向您介绍CDN安全防护系统的原理、应用场景以及使用方法,助您更好地保障网络安全。…...

opencv(3):控制鼠标,创建 tackbar控件

文章目录 控制鼠标相关APIsetMouseCallbackcallback TrackBar 控件cv2.createTrackbarcv2.getTrackbarPos: 控制鼠标相关API setMouseCallback(winname, callback, userdata)callback(event, x, y, flags, userdata) setMouseCallback 在 OpenCV 中,s…...

UE4动作游戏实例RPG Action解析二:GAS系统播放武器绑定的技能,以及GE效果

一、GAS系统播放武器技能 官方实例激活技能通过装备系统数据激活,我先用武器数据资产直接激活 官方实例蒙太奇播放是自定义的AbilityTask,我先用更简单的方法实现效果 1.1、技能系统必要步骤: 1.1.1 插件启用AbilitySystem 1.1.2 PlayerCharacter绑定技能组件AbilitySy…...

做完这些_成为机器学习方面的专家

简单记个帖子, 用来记录学习机器学习的路线图 1. 数学分析, 高等代数, 概率论这三大件不多说, 基础中的基础. 2. 对于编程工具, b站上500集的python教程---python面向对象编程五部曲(从零到就业). 3. 对于机器学习的理论板块, 推荐b站up主---啥都会一点的研究生, 里面有一个吴恩…...

kubernetes|云原生| 如何优雅的重启和更新pod---pod生命周期管理实务

前言: kubernetes的管理维护的复杂性体现在了方方面面,例如,pod的管理,服务的管理,用户的管理(RBAC)&#xf…...

【总结】坐标变换和过渡矩阵(易忘记)

xCy,此为x到y的坐标变换。 [β1,β2,…,βn] [α1,α2,…αn]C,此为基α到基β的过渡矩阵。 这个概念经常忘记。。。alpha到beta看来就是alpha后面加一个过渡矩阵了,很直观。坐标变换就是根据过渡矩阵和基本形式推一推得到吧,记…...

第十一周任务总结

本周任务总结 本周物联网方面主要继续进行网关的二次开发与规则引擎实现设备联动的实现 非物联网方面主要复习了docker的使用与算法的学习 1.网关的二次开发,本周将实现debug调试输出的文件下载到了网关,但网关出了问题无法连接,最终跟客服…...

Java Web——JavaScript基础

1. 引入方式 JavaScript程序不能独立运行,它需要被嵌入HTML中,然后浏览器才能执行 JavaScript 代码。 通过 script 标签将 JavaScript 代码引入到 HTML 中,有3种方式: 1.1. 内嵌式(嵌入式) 直接写在html文件里,用s…...

Vue3 toRaw 和 markRaw

一、toRaw 我们可以使用ref 和 reactive 将普通对象类型的数据变为响应式的数据。 我们可以使用toRaw 将reactive 对象的数据变为一般对象类型的数据。 使用toRaw 需要先进行引入: import { toRaw } from vue; 语法格式: const xxx toRaw(数据) set…...

麒麟信安助力长沙市就业与社保数据服务中心政务系统向自主创新演进

应用场景 长沙市就业与社保数据服务中心依托长沙市“政务云”的公共基础资源和相应的支撑能力,围绕社保、就业、人事人才、劳动关系等人社全量业务服务,力求建立以“智慧服务、智慧监管、智慧决策”为核心的“智慧人社”综合服务平台,实现人…...

【LeetCode刷题-双指针】--16.最接近的三数之和

16.最接近的三数之和 方法&#xff1a;排序双指针 class Solution {public int threeSumClosest(int[] nums, int target) {Arrays.sort(nums);int ans nums[0] nums[1] nums[2];for(int i 0;i<nums.length;i){int start i1,end nums.length - 1;while(start < en…...

Mac 安装 protobuf 和Android Studio 使用

1. 安装,执行命令 brew install protoc 2. Mac 错误提示&#xff1a;zsh: command not found: brew解决方法 解决方法&#xff1a;mac 安装homebrew&#xff0c; 用以下命令安装&#xff0c;序列号选择中科大&#xff08;1&#xff09;或 阿里云 /bin/zsh -c "$(curl…...

MongoDB入门级别教程全(Windows版,保姆级教程)

下载mongodb 进入官网&#xff1a; Download MongoDB Community Server | MongoDB 选择msi&#xff0c;Windows版本 下载完后直接双击&#xff1a; 选择complete 这里建议改地方&#xff1a; 我这里直接改成d盘&#xff1a;work目录下面&#xff1a; 点击next&#xff1a; 因…...

基于机器学习的居民消费影响因子分析预测

项目视频讲解: 基于机器学习的居民消费影响因子分析预测_哔哩哔哩_bilibili 主要工作内容: 完整代码: import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import missingno as msno import warnings warnings.filterwarnin…...

Qt HTTP 摘要认证(海康球机摄像机ISAPI开发)

接到一个需求是开发下海康的球机,控制云台,给到我的是一个开发手册,当然了是海康的私有协议 ISAPI开发手册https://download.csdn.net/download/qq_37059136/88547425关于开发这块读文档就可以理解了,海康使用的是摘要认证,当然了海康已经给出使用范例 通过libcurl就可以直接连…...

srs webrtc推拉流环境搭建(公网)

本地环境搭建 官方代码https://github.com/ossrs/srs 拉取代码&#xff1a; git clone https://github.com/ossrs/srs.gitcd ./configure make ./objs/srs -c conf/https.rtc.confsrs在公网上&#xff0c;由于srs是lite-ice端&#xff0c;导致他不会主动到srs获取自己的公网i…...

【Flutter】设计原则(2)深入解析 SOLID 原则的应用

【Flutter】设计原则(2)深入解析 SOLID 原则的应用 文章目录 一、前言二、SOLID原则三、在 Flutter 中应用单一职责原则1. 专注单一功能的 Widget2. 提高代码可维护性四、在 Flutter 中应用开闭原则1. 利用多态和基类实现可扩展的 Widget2. 增强应用的可扩展性和灵活性五、在…...

python爬虫概述及简单实践:获取豆瓣电影排行榜

目录 前言 Python爬虫概述 简单实践 - 获取豆瓣电影排行榜 1. 分析目标网页 2. 获取页面内容 3. 解析页面 4. 数据存储 5. 使用代理IP 总结 前言 Python爬虫是指通过程序自动化地对互联网上的信息进行抓取和分析的一种技术。Python作为一门易于学习且强大的编程语言&…...

ts视频文件转为mp4(FFmpeg)

有些视频资源下载下来之后发现是.ts的文件&#xff0c;除了用下载它时用的工具或是浏览器才能看&#xff0c;那有没有将ts文件转换成更加通用视频格式的方法。 几乎万能的音视频工具--ffmpeg登场 安装和环境配置可看这篇博客&#xff1a;FFmpeg指令行打开usb摄像头&#xff0…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中&#xff0c;iftop是网络管理的得力助手&#xff0c;能实时监控网络流量、连接情况等&#xff0c;帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好&#xff0c;欢迎来到《云原生核心技术》系列的第七篇&#xff01; 在上一篇&#xff0c;我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在&#xff0c;我们就像一个拥有了一块崭新数字土地的农场主&#xff0c;是时…...

ES6从入门到精通:前言

ES6简介 ES6&#xff08;ECMAScript 2015&#xff09;是JavaScript语言的重大更新&#xff0c;引入了许多新特性&#xff0c;包括语法糖、新数据类型、模块化支持等&#xff0c;显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var&#xf…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

Qt Widget类解析与代码注释

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码&#xff0c;写上注释 当然可以&#xff01;这段代码是 Qt …...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么&#xff1f;1.1.2 感知机的工作原理 1.2 感知机的简单应用&#xff1a;基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...

STM32HAL库USART源代码解析及应用

STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...