pytorch软件封装
封装代码,通过传入文件名,即可输出类别信息
上一章节,我们做了关于动物图像的分类,接下来我们把程序封装,然后进行预测。
单张图片的predict文件
predict.py
'''按着路径,导入单张图片做预测
'''
from torchvision.models import resnet18
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import cv2 as cv
import os
import numpy as np'''加载图片与格式转化
'''# 图片标准化
transform_BZ = transforms.Normalize(mean=[0.5062653, 0.46558657, 0.37899864], # 取决于数据集std=[0.22566116, 0.20558165, 0.21950442]
)img_size = 224
val_tf = transforms.Compose([ ##简单把图片压缩了变成Tensor模式transforms.ToPILImage(), # 将numpy数组转换为PIL图像transforms.Resize((img_size, img_size)),transforms.ToTensor(),transform_BZ # 标准化操作
])def cv_imread(file_path):cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)return cv_imgdef predict(img_path):'''获取标签名字'''# # 增加类别标签# dir_names = []# for root, dirs, files in os.walk("dataset"):# if dirs:# dir_names = dirs# 将输出保存到exel中,方便后续分析label_names = ['cat', 'chicken', 'cow', 'dog', 'duck','goldfish', 'lion', 'pig', 'sheep','snake']# 指定设备device = "cuda" if torch.cuda.is_available() else "cpu"print(f"Using {device} device")"""加载模型"""model = resnet18(weights=None)num_ftrs = model.fc.in_features # 获取全连接层的输入model.fc = nn.Linear(num_ftrs, 10) # 全连接层改为不同的输出torch_data = torch.load('./logs_resnet18_adam/best.pth',map_location=torch.device(device))model.load_state_dict(torch_data)model.to(device)'''读取图片'''img = cv_imread(img_path)img = cv.cvtColor(img, cv.COLOR_BGR2RGB)img_tensor = val_tf(img)# 增加batch_size维度img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(),requires_grad=False).to(device)'''数据输入与模型输出转换'''model.eval()with torch.no_grad():output_tensor = model(img_tensor)# 将输出通过softmax变为概率值output = torch.softmax(output_tensor, dim=1)# 输出可能性最大的那位pred_value, pred_index = torch.max(output, 1)# 将数据从cuda转回cpuif torch.cuda.is_available() == False:pred_value = pred_value.detach().cpu().numpy()pred_index = pred_index.detach().cpu().numpy()result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value[0].item() * 100)[:5] + "%"return resultif __name__ == "__main__":img_path = r'dataset/cat/10.jpg'result = predict(img_path)print(result)

这里可以看出,我们用的cat数据集中的图片,预测出来的结果却是是cat,虽然可能性不是很高。
torch_data=torch.load('./logs_resnet18_adam/best.pth',map_location=torch.device(device))
使用 PyTorch 加载一个保存的模型权重文件(best.pth),并将其映射到指定的设备(device)
img_tensor=Variable(torch.unsqueeze(img_tensor,dim=0).float(),requires_grad=False).to(device)
将一个图像张量(img_tensor)进行处理,使其成为适合输入到神经网络模型中的格式,并将其移动到指定的设备(CPU 或 GPU)上
1. torch.unsqueeze(img_tensor, dim=0)
-
作用:在张量的第 0 维(即最外层)添加一个维度。
-
背景:神经网络模型通常期望输入数据是一个四维张量,形状为
[batch_size, channels, height, width]。如果img_tensor是一个三维张量(例如[channels, height, width]),则需要在第 0 维添加一个维度,使其形状变为[1, channels, height, width],其中1表示批量大小(batch_size)为 1。
2. .float()
-
作用:将张量的数据类型转换为
float32。 -
背景:许多神经网络模型在训练和推理时使用
float32数据类型。如果img_tensor的数据类型不是float32,则需要显式转换。
3. Variable(..., requires_grad=False)
-
作用:将张量封装为
Variable对象,并设置requires_grad属性。 -
背景:
-
Variable是 PyTorch 中的一个旧类,用于封装张量并支持自动求导。在较新的 PyTorch 版本中,Variable已经与Tensor合并,因此这一步在现代代码中通常是多余的。 -
requires_grad=False表示这个张量不需要计算梯度。这在推理阶段非常常见,因为输入数据不需要参与梯度计算。
-
转成ONNX,兼容各种设备

ONNX是什么?
ONNX(Open Neural Network Exchange)是一种开放的神经网络交换格式,它为深度学习模型提供了一种标准化的表示方式,使得模型可以在不同的深度学习框架之间进行转换和共享。
ONNX的作用是什么?
-
模型转换:开发者可以将训练好的模型从一个框架(如PyTorch)转换为ONNX格式,然后在另一个框架(如TensorFlow)中加载和使用。这使得开发者可以在不同的框架之间灵活切换,利用不同框架的优势。
-
模型部署:ONNX模型可以被导出到多种推理引擎,如ONNX Runtime。ONNX Runtime是一个高性能的推理引擎,支持多种硬件平台(如CPU、GPU、FPGA等),可以用于将模型部署到生产环境中。
-
模型优化:通过ONNX,开发者可以对模型进行优化和量化等操作。例如,可以将模型从浮点数量化为整数,以提高模型的推理速度和降低存储需求。
import torch
from torch import nn
from torchvision.models import resnet18
# pip install onnx
# pip install onnxruntimeif __name__ == '__main__':# 指定设备device = "cuda" if torch.cuda.is_available() else "cpu"print(f"Using {device} device")# 指定模型model = resnet18(pretrained=False)num_ftrs = model.fc.in_features # 获取全连接层的输入model.fc = nn.Linear(num_ftrs, 10) # 全连接层改为不同的输出# 模型加载权重torch_data = torch.load('logs_resnet18_pretrain/best.pth',map_location=torch.device(device))model.load_state_dict(torch_data)model.to(device)# 创建一个示例输入dummy_input = torch.randn(1,3,224,224, device=device)# 指定输出文件路径onnx_file_path = "logs_resnet18_pretrain/model.onnx"# 导出onnxtorch.onnx.export(model, dummy_input, onnx_file_path,verbose=True, # 屏幕中打印日志信息input_names=['input'],output_names=['output'])print("Model Exported Success")
Netron模型可视化
NETRON查看网络结构

如何下载可以看这篇文章网络可视化工具netron安装流程-CSDN博客
下载过后打开文件

ONNX单张图片预测
# -*- coding: utf-8 -*-
'''按着路径,导入单张图片做预测
'''
import onnxruntime as ort # pip install onnxruntime onnx
import numpy as np
import torchvision.transforms as transforms
import cv2 as cv
import osdef softmax(x):e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=1, keepdims=True)def cv_imread(file_path):cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)return cv_imgdef predict(img_path):'''获取标签名字'''# dir_names = []# for root, dirs, files in os.walk("dataset"):# if dirs:# dir_names = dirs# label_names = dir_nameslabel_names = ['cat', 'chicken', 'cow', 'dog', 'duck','goldfish', 'lion', 'pig', 'sheep','snake']'''加载图片与格式转化'''# 图片标准化transform_BZ = transforms.Normalize(mean=[0.5062653, 0.46558657, 0.37899864], # 取决于数据集std=[0.22566116, 0.20558165, 0.21950442])img_size = 224val_tf = transforms.Compose([ # 简单把图片压缩了变成Tensor模式transforms.ToPILImage(), # 将numpy数组转换为PIL图像transforms.Resize((img_size, img_size)),transforms.ToTensor(),transform_BZ # 标准化操作])# 读取图片img = cv_imread(img_path)img = cv.cvtColor(img, cv.COLOR_BGR2RGB)img_tensor = val_tf(img)# 将图片转换为ONNX运行时所需的格式img_numpy = img_tensor.numpy()img_numpy = np.expand_dims(img_numpy, axis=0) # 增加batch_size维度# 加载ONNX模型onnx_model_path = r'logs_resnet18_pretrain/model.onnx' # 替换为ONNX模型的路径ort_session = ort.InferenceSession(onnx_model_path)# 运行ONNX模型outputs = ort_session.run(None, {'input': img_numpy})output = outputs[0]# 应用softmaxprobabilities = softmax(output)# 获得预测结果pred_index = np.argmax(probabilities, axis=1)pred_value = probabilities[0][pred_index[0]]result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value * 100)[:5] + "%"return resultif __name__ == "__main__":img_path = r'dataset/cat/10.jpg'result = predict(img_path)print(result)
这个没什么好讲的,就是可以直接封装成了一个onnx,可以不用安装pytorch库
PyQt5做预测模型
接下来先请大家准备一些库,看一看下面这篇文章PyCharm配置外部工具PyQtDesigner、PyUIC、Pyrcc_pycharm外部工具-CSDN博客
我把所有的文件封装了一下,大家要记得改一改路径

main_one_thread,py
# -*- coding: utf-8 -*-
from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict封装 import predictclass UiMain(QMainWindow, Ui_MainWindow):def __init__(self, parent=None):super(UiMain, self).__init__(parent)self.setupUi(self)self.fileBtn.clicked.connect(self.loadImage)# 打开文件功能def loadImage(self):self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')if self.fname:print(self.fname)self.Infolabel.setText("文件打开成功\n"+self.fname)jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),self.Imglabel.height())self.Imglabel.setPixmap(jpg)result = predict(self.fname)self.Infolabel.setText(result)else:# print("打开文件失败")self.Infolabel.setText("打开文件失败")if __name__ == '__main__':app = QApplication(sys.argv)ui = UiMain()ui.show()sys.exit(app.exec_())
运行结果

但是这个文件如果打包,别人不一定能用
main_one_thread_onnx.py
from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict_onnx import predictclass UiMain(QMainWindow, Ui_MainWindow):def __init__(self, parent=None):super(UiMain, self).__init__(parent)self.setupUi(self)self.fileBtn.clicked.connect(self.loadImage)# 打开文件功能def loadImage(self):self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')if self.fname:print(self.fname)self.Infolabel.setText("文件打开成功\n"+self.fname)jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),self.Imglabel.height())self.Imglabel.setPixmap(jpg)result = predict(self.fname)self.Infolabel.setText(result)else:# print("打开文件失败")self.Infolabel.setText("打开文件失败")if __name__ == '__main__':app = QApplication(sys.argv)ui = UiMain()ui.show()sys.exit(app.exec_())
相关文章:
pytorch软件封装
封装代码,通过传入文件名,即可输出类别信息 上一章节,我们做了关于动物图像的分类,接下来我们把程序封装,然后进行预测。 单张图片的predict文件 predict.py 按着路径,导入单张图片做预测from torchvis…...
【多线程-第四天-自己模拟SDWebImage的下载图片功能-看SDWebImage的Demo Objective-C语言】
一、我们打开之前我们写的异步下载网络图片的项目,把刚刚我们写好的分类拖进来 1.我们这个分类包含哪些文件: 1)HMDownloaderOperation类, 2)HMDownloaderOperationManager类, 3)NSString+Sandbox分类, 4)UIImageView+WebCache分类, 这四个文件吧,把它们拖过来…...
电脑提示“找不到mfc140u.dll“的完整解决方案:从原因分析到彻底修复
当你启动某个软件或游戏时,突然遭遇"无法启动程序,因为计算机中丢失mfc140u.dll"的错误提示,这确实令人沮丧。mfc140u.dll是Microsoft Foundation Classes(MFC)库的重要组成部分,属于Visual C Re…...
图像变换方式区别对比(Opencv)
1. 变换示例 import cv2 import matplotlib.pyplot as plotimg cv2.imread(url) img_cut img[100:200, 200:300] img_rsize cv2.resize(img, (50, 50)) (hight,width) img.shape[:2] rotate_matrix cv2.getRotationMatrix2D((hight//2, width//2), 50, 1) img_wa cv2.wa…...
图像颜色空间对比(Opencv)
1. 颜色转换 import cv2 import matplotlib.pyplot as plotimg cv2.imread("tmp.jpg") img_r cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_g cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_h cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_l cv2.cvtColor(img, cv2.C…...
【NLP】24. spaCy 教程:自然语言处理核心操作指南(进阶)
spaCy 中文教程:自然语言处理核心操作指南(进阶) 1. 识别文本中带有“百分号”的数字 import spacy# 创建一个空的英文语言模型 nlp spacy.blank("en")# 处理输入文本 doc nlp("In 1990, more than 60% of people in East…...
每天学一个 Linux 命令(15):man
可访问网站查看,视觉品味拉满:http://www.616vip.cn/15/index.html 每天学一个 Linux 命令(15):man 命令简介 man(Manual)是 Linux 中最核心的命令之一,用于查看命令、系统调用、库函数等的手册文档。它是用户和开发者获取帮助的核心工具,几乎覆盖了系统中的所有功…...
必刷算法100题之计算右侧小于当前元素的个数
题目链接 315. 计算右侧小于当前元素的个数 - 力扣(LeetCode) 题目解析 计算数组里面所有元素右侧比它小的数的个数, 并且组成一个数组,进行返回 算法原理 归并解法(分治) 当前元素的后面, 有多少个比我小(降序) 我们要找到第一比左边小的元素, 这样…...
Python依赖注入完全指南:高效解耦、技术深析与实践落地
Python依赖注入完全指南:高效解耦、技术深析与实践落地 摘要 依赖注入(DI)不仅是一种设计技术,更是一种解耦的艺术。它通过削减模块间的强耦合性,为系统提供了更高的灵活性和可测试性,特别是在 FastAPI 等…...
android弱网环境数据丢失解决方案(3万字长文)
在移动互联网时代,Android 应用已经成为人们日常生活中不可或缺的一部分。从社交媒体到在线购物,从移动办公到娱乐游戏,用户对应用的依赖程度与日俱增。然而,尽管网络基础设施在全球范围内得到了显著改善,弱网环境依然是一个普遍存在且难以完全避免的现实。特别是在一些发…...
答案之书和源代码
答案之书是一个神秘而神奇的工具,它可以帮助你在遇到问题或犹豫不决的时候找到答案或暗示。这个程序模拟了答案之书的功能,让你随机生成一个简短而有启发性的答案,让你在困境中找到一丝希望。 在这个程序中,你会看到一个画布上显…...
Spring Cloud主要组件介绍
一、Spring Cloud 1、Spring Cloud技术概览 分为:服务治理,链路追踪,消息组件,配置中心,安全控制,分布式任务管理、调度,Cluster工具,Spring Cloud CLI,测试 2、注册中心:常用注册中心(Euerka[AP]、Zookeeper[CP]) 1)Euerka Client(服务提供者)=》注册=》Eue…...
深度学习ResNet模型提取影响特征
大家好,我是带我去滑雪! 影像组学作为近年来医学影像分析领域的重要研究方向,致力于通过从医学图像中高通量提取大量定量特征,以辅助疾病诊断、分型、预后评估及治疗反应预测。这些影像特征涵盖了形状、纹理、灰度统计及波形变换等…...
【Qt】Qt Creator开发基础:项目创建、界面解析与核心概念入门
🍑个人主页:Jupiter. 🚀 所属专栏:QT 欢迎大家点赞收藏评论😊 目录 Qt Creator 新建项⽬认识 Qt Creator 界⾯项⽬⽂件解析Qt 编程注意事项认识对象模型(对象树)Qt 窗⼝坐标体系 Qt Creator 新…...
SimpleITK (sitk) 中查看 DICOM 文件的像素位深(8位或16位)
在 SimpleITK (sitk) 中查看 DICOM 文件的像素位深(8位或16位),可以通过以下方法实现: 方法一:通过 图像像素数组的数据类型 判断 读取 DICOM 文件: 使用 sitk.ReadImage() 加载文件,生成图像对…...
Unity IL2CPP内存泄漏追踪方案(基于Memory Profiler)技术详解
一、IL2CPP内存管理特性与泄漏根源 1. IL2CPP内存架构特点 内存区域管理方式常见泄漏类型托管堆(Managed)GC自动回收静态引用/事件订阅未取消原生堆(Native)手动管理非托管资源未释放桥接层GCHandle/PInvoke跨语言引用未正确释放 对惹,这里有一个游戏开发交流小组…...
制造业项目管理如何做才能更高效?制造企业如何选择适配的数字化项目管理系统工具?
一、制造企业项目管理过程中面临的痛点有哪些? 制造企业在项目管理过程中面临的痛点通常涉及跨部门协作、资源调配、数据整合、风险控制等多个维度,且与行业特性(如离散制造vs流程制造)紧密相关。 进度失控多项目资源冲突信息孤…...
Python批量处理PDF图片详解(插入、压缩、提取、替换、分页、旋转、删除)
目录 一、概述 二、 使用工具 三、Python 在 PDF 中插入图片 3.1 插入图片到现有PDF 3.2 插入图片到新建PDF 3.3 批量插入多张图片到PDF 四、Python 提取 PDF 图片及其元数据 五、Python 替换 PDF 图片 5.1 使用图片替换图片 5.2 使用文字替换图片 六、Python 实现 …...
让 Python 脚本在后台持续运行:架构级解决方案与工业级实践指南
让 Python 脚本在后台持续运行:架构级解决方案与工业级实践指南 一、生产环境需求全景分析 1.1 后台进程的工业级要求矩阵 维度开发环境要求生产环境要求容灾要求可靠性单点运行集群部署跨机房容灾可观测性控制台输出集中式日志分布式追踪资源管理无限制CPU/Memo…...
【后端开发】Spring配置文件
文章目录 配置文件properties配置文件基本语法读取配置文件 yml配置文件基本语法读取配置文件配置空字符串及null单双引号配置对象配置集合配置Map 优缺点优点缺点 配置文件 硬编码是将数据直接嵌入到程序或其他可执行对象的源代码中,也就是常说的"代码写死&q…...
七种驱动器综合对比——《器件手册--驱动器》
九、驱动器 名称 功能与作用 工作原理 优势 应用 隔离式栅极驱动器 隔离式栅极驱动器用于控制功率晶体管(如MOSFET、IGBT、SiC或GaN等)的开关,其核心功能是将控制信号从低压侧传输到高压侧的功率器件栅极,同时在输入和输出之…...
996引擎-源码学习:PureMVC Lua 中的系统启动,初始化并注册 Mediator
996引擎-源码学习:PureMVC Lua 中的系统启动,初始化并注册 Mediator 一、PureMVC 核心架构二、系统启动流程系统启动注册 StartUp 通知发送 StartUp 通知,开始初始化三、Mediator 初始化1. gameStateInit.lua2. LoadingBeginCommand.lua3. RegisterWorldMediatorCommand.lua…...
redis系列--1.redis是什么
国际惯例,想了解一个东西,首先就要看看官方提供了什么。redis的官网是https://redis.io 。以下这段话就是redis的简介了: Redis is an open source (BSD licensed), in-memory data structure store, used as a database, cache, and message…...
CSS 过渡与变形:让交互更丝滑
在网页设计中,动效能让用户交互更自然、流畅,提升使用体验。本文将通过 CSS 的 transition(过渡)和 transform(变形)属性,带你入门基础动效设计,结合案例演示如何实现颜色渐变、元素…...
linuxbash原理
3417 1647 0 04:17 ? 00:00:21 /usr/libexec/gnome-terminal-server yangang 3425 3417 0 04:17 pts/0 00:00:00 bash yangang 4524 3417 0 04:26 pts/1 00:00:00 bash 控制台创建是通过/usr/libexec/gnome-terminal-server 进行创建 rea…...
MecAgent Copilot:机械设计师的AI助手,开启“氛围建模”新时代
MecAgent Copilot作为机械设计师的AI助手,正通过多项核心技术推动机械设计进入“氛围建模”新时代。以下从功能特性、技术支撑和应用场景三方面解析其创新价值: 一、核心功能特性 智能草图生成与参数化建模 支持自然语言输入生成设计草图和3D模型,如输入“剖面透视…...
[Python基础速成]2-模块与包与OOP
上篇➡️[Python基础速成]1-Python规范与核心语法 目录 Python模块创建模块与导入属性__name__dir()函数标准模块 Python包类类的专有方法 对象继承多态拷贝 Python模块 Python 中的模块(Module)是一个包含 Python 定义和语句的文件,文件名就…...
【prometheus+Grafana篇】Prometheus与Grafana:深入了解监控架构与数据可视化分析平台
💫《博主主页》:奈斯DB-CSDN博客 🔥《擅长领域》:擅长阿里云AnalyticDB for MySQL(分布式数据仓库)、Oracle、MySQL、Linux、prometheus监控;并对SQLserver、NoSQL(MongoDB)有了解 💖如果觉得文章对你有所帮…...
Web前端开发——超链接与浮动框架(下)
本节说明: 上一节,我们了解了超链接概述与超链接的语法、路径及分类两大部分内容,本节我们将了解超链接的应用与浮动框架。 三、超链接的应用 在网络上能够通过链接访问不同的资源或网页。链接对象多种多样,可分为文件、FTP站点…...
【后端开发】初识Spring IoC与SpringDI、图书管理系统
文章目录 图书管理系统用户登录需求分析接口定义前端页面代码服务器代码 图书列表展示需求分析接口定义前端页面部分代码服务器代码Controller层service层Dao层modle层 Spring IoC定义传统程序开发解决方案IoC优势 Spring DIIoC &DI使用主要注解 Spring IoC详解bean的存储五…...
