YOLOv8由pt文件中读取模型信息
Pytorch的pt模型文件中保存了许多模型信息,如模型结构、模型参数、任务类型、批次、数据集等
在先前的YOLOv8实验中,博主发现YOLOv8在预测时并不需要指定任务类型,因为这些信息便保存在pt模型中,那么,今天我们便来看看,其到底是如何加载这些参数的。
我们首先对pt文件进行一个简单介绍:
pt文格式
pt格式文件是PyTorch中用于保存张量数据的文件格式。与pth文件类似,pt文件也常用于模型的保存和加载,但更侧重于保存单个张量或一组张量数据。通过pt文件,我们可以方便地将张量数据持久化,并在需要时重新加载使用。
张量(Tensor)是PyTorch中的核心数据结构,用于表示多维数组。在深度学习中,张量常用于存储模型的参数、输入数据、中间结果等。因此,掌握pt文件的保存和加载方法对于PyTorch的使用者来说至关重要。
pt文件与pth的区别
pt和.pth都是PyTorch模型文件的扩展名,但是它们的区别在于.pt文件是保存整个PyTorch模型的,而.pth文件只保存模型的参数。(其实现在似乎并没有区别了)
因此,如果要加载一个,pth文件,需要先定义模型的结构,然后再加载参数;而如果要加载一个,pt文件,则可以直接加载整个模型。
如何保存pt格式文件
在PyTorch中,我们可以使用torch.save()函数将张量数据保存到pt文件中。
下面是一个简单的示例:
import torch
# 创建一个张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将张量保存到pt文件中
torch.save(tensor, 'tensor.pt')
在上面的代码中,我们首先创建了一个二维张量tensor,然后使用torch.save()函数将其保存到名为tensor.pt的文件中。保存的文件将包含张量的数据和元数据,以便在加载时能够准确地恢复张量的结构和内容。
除了保存单个张量外,我们还可以保存多个张量到一个pt文件中。这可以通过将多个张量放入一个字典或列表中,然后将整个字典或列表保存到文件中实现。
例如:
# 创建多个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([[4, 5], [6, 7]])# 将张量放入字典中
tensors_dict = {'tensor1': tensor1, 'tensor2': tensor2}# 将字典保存到pt文件中
torch.save(tensors_dict, 'tensors_dict.pt')
如何加载pt格式文件
加载pt文件同样使用torch.load()函数。
下面是一个加载pt文件的示例:
# 加载单个张量的pt文件
loaded_tensor = torch.load('tensor.pt')
print(loaded_tensor)# 加载包含多个张量的字典的pt文件
loaded_dict = torch.load('tensors_dict.pt')
print(loaded_dict['tensor1'])
print(loaded_dict['tensor2'])
在加载单个张量的pt文件时,我们直接调用torch.load()函数并传入文件名即可。加载得到的loaded_tensor将是一个与原始张量结构和内容相同的张量对象。
当加载包含多个张量的字典的pt文件时,我们同样使用torch.load()函数。加载得到的loaded_dict将是一个字典对象,其中包含了我们在保存时放入的所有张量。我们可以通过字典的键来访问这些张量。
强烈建议只保存模型参数,而非保存整个网络。PyTorch 官方也是这么建议的。
torch.save(net.state_dict(),path2)#只保留模型参数
(只保存模型参数)是官方推荐的方法,运行速度快,且占空间较小。需要注意的是 net.state_dict() 是将网络参数保存为字典形式(OrderedDict),load_state_dict() 加载的并不是网络参数的pth文件,而是字典。
pt文件保存神经网络
在评估时,记住一定要使用model.eval()来固定dropout和归一化层,否则每次推理会生成不同的结果。
import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module): # 神经网络部分用你自己的def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 2, 1) # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)self.conv3 = nn.Conv2d(64, 128, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(6272, 128) # 6272=128*7*7self.fc2 = nn.Linear(128, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = self.conv3(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)self.output = F.log_softmax(x, dim=1)out1 = xreturn self.output,out1def predict_mine():model=Net()model.load_state_dict(torch.load("model.pt"))print(model)images=torch.rand((1,1,64,64))x=model(images)print(x)
def torch_script_save():model=Net()
if __name__ == '__main__':save_model()predict_mine()

可以看到,我们可以通过pt文件读取出来下面的信息:

同时,我们也看到,我们虽然可以使用pt文件保存模型结构,但我们在推理时,依旧需要我们能够生成Net对象才能加载其数据,这其实很不方便,那么,有什么办法可以真正的将模型结构保存进去,让我们在推理过程中不需要再定义相关的类与对象呢,先前博主所使用的ONNX便是其中的一种,但它其实是另一种文件结构了,pt文件真的就不能摆脱环境吗,答案是否定的,TorchScript模型便解决了这个问题。
TorchScript模型
事实上,PyTorch提供了两种主要的模型保存和加载机制,一种是基于Python的序列化,另一种是TorchScript。
普通的PyTorch模型(基于Python的序列化):
- 保存: 使用
torch.save(model.state_dict(), 'model_path.pth'),它保存了模型的权重和参数,但不保存模型的结构。(当然也是可以保存的,但我们需要处理一下才能用,比如定义好Net类)- 加载: 首先,您需要有模型的类定义。 创建该类的一个实例。 使用
model.load_state_dict(torch.load('model_path.pth'))来加载权重。
特点:- 需要
Python环境和模型的原始代码来加载和运行模型。 保存的文件是Python特定的,并且依赖于特定的类结构。 主要用于继续训练或在Python环境中进行推断。
TorchScript模型:
TorchScript是PyTorch的一个子集,它创建了一个可以独立于Python运行的序列化模型。 生成方法:
Tracing: 使用torch.jit.trace方法。这涉及到通过模型运行一个输入示例,从而跟踪模型的执行路径。Scripting: 使用torch.jit.script方法。这转化Python代码到TorchScript,允许更复杂的模型和控制流。- 保存: 使用
torch.jit.save(traced_model, 'model_path.pt')。- 加载: 使用torch.jit.load(‘model_path.pt’)。注意,加载不需要原始的模型类定义。
特点:
- 可以在没有
Python运行时的环境中运行,如C++。- 提供了一种方法,将模型从
Python转移到其他平台或部署环境。- 包含模型的完整定义,包括结构、权重和参数。
Tracing方法:
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_model.pt')
Scripting方法:
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')
加载模型:
loaded_model = torch.jit.load('model_path.pt')
例程:
import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()# Tracing
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_simple_model.pt')# Scripting
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_simple_model.pt')# 加载模型
loaded_model = torch.jit.load('traced_simple_model.pt')
我们采用TorchScript结构去执行先前的Net
def torch_script_save():model=Net()example_input =torch.rand((1,1,64,64))traced_model = torch.jit.trace(model, example_input)torch.jit.save(traced_model, 'traced_simple_model.pt')# Scriptingscripted_model = torch.jit.script(model)torch.jit.save(scripted_model, 'scripted_simple_model.pt')
def predict_script():model1=torch.jit.load("traced_simple_model.pt")image =torch.rand((1,1,64,64))print(model1)model1.eval()x=model1(image)print(x)model2=torch.jit.load("scripted_simple_model.pt")image =torch.rand((1,1,64,64))print(model2)model2.eval()x=model2(image)print(x)

yaml文件内容如下:
{'nc': 1000,
'scales': {'n': [0.33, 0.25, 1024], 's': [0.33, 0.5, 1024], 'm': [0.67, 0.75, 1024], 'l': [1.0, 1.0, 1024], 'x': [1.0, 1.25, 1024]},
'backbone': [[-1, 1, 'Conv', [64, 3, 2]], [-1, 1, 'Conv', [128, 3, 2]], [-1, 3, 'C2f', [128, True]], [-1, 1, 'Conv', [256, 3, 2]], [-1, 6, 'C2f', [256, True]], [-1, 1, 'Conv', [512, 3, 2]], [-1, 6, 'C2f', [512, True]], [-1, 1, 'Conv', [1024, 3, 2]], [-1, 3, 'C2f', [1024, True]]],'head': [[-1, 1, 'Classify', ['nc']]], 'scale': 'n','yaml_file': 'yolov8n-cls.yaml', 'ch': 3}
相关文章:
YOLOv8由pt文件中读取模型信息
Pytorch的pt模型文件中保存了许多模型信息,如模型结构、模型参数、任务类型、批次、数据集等 在先前的YOLOv8实验中,博主发现YOLOv8在预测时并不需要指定任务类型,因为这些信息便保存在pt模型中,那么,今天我们便来看看…...
js遍历效率
1w条数据,遍历效率 1、for 15s let t(new Date()).getTime()let a[]for(var i 0; i < 100000; i){a.push({id:i,val:i})}let ts[]for(var i 0; i < a.length; i){if(a[i].val!2 && a[i].val!4 && a[i].val!8){ts.push(a[i])}}let c(new D…...
QModbus例程分析
由于有一个Modebus上位机的需要,分析一下QModbus Slave的源代码,方便后面的开发。 什么是Modbus Modbus是一种常用的串行通信协议,被广泛应用于工业自动化领域。它最初由Modicon(目前属于施耐德电气公司)于1979年开发…...
Vue万字学习笔记(入门1)
目录 简介 Vue是什么 渐进式框架 单文件组件 API 风格 选项式 API (Options API) 组合式 API (Composition API) 创建一个 Vue 应用 挂载应用 DOM 中的根组件模板 应用配置 多个应用实例 模板语法 文本插值 原始 HTML Attribute 绑定 简写…...
Cesium手动建模模型用Cesiumlab转3D Tiles模型位置不对,调整模型位置至指定经纬度
Cesium加载3Dtiles模型的平移和旋转_3dtiles先旋转再平移示例-CSDN博客 Cesium 平移cesiumlab生产的3Dtiles切片模型到目标经纬度-CSDN博客 【ArcGISCityEngine】自行制作Lod1城市大尺度白膜数据_cityengine 生成指定坐标集指定区域的白模-CSDN博客 以上次ArcGISCityEngine制…...
学习C语言第23天(程序环境和预处理)
1. 程序的翻译环境和执行环境 在ANSIC的任何一种实现中,存在两个不同的环境 第1种是翻译环境,在这个环境中源代码被转换为可执行的机器指令。 第2种是执行环境,它用于实际执行代码。 2. 详解编译链接 2.1 翻译环境 每个源文件单独经过编…...
Ubuntu22.04安装
使用Vmware安装好后 首先执行下面命令,不然每次打开终端会出现To run a command as administrator (user root)… touch ~/.sudo_as_admin_successful换源 参考 sudo cp /etc/apt/sources.list /etc/apt/sources.list.baksudo gedit /etc/apt/sources.list清空…...
从入门到自动化:一篇文章掌握Python的80%
Python作为一种高级编程语言,以其简洁明了的语法和强大的功能性,在全球编程社区内享有极高的声誉。本文将带领你从Python的基础语法入手,介绍其常用库的应用,以及如何将Python用于数据分析、网络爬虫和简单的自动化任务࿰…...
开源的主流机器学习框架
主流的开源机器学习框架包括: 1. TensorFlow:由Google开发和维护的深度学习框架,广泛用于生产环境和研究。支持多种平台,并具有丰富的工具和库支持。 2. PyTorch:由Facebook开发的深度学习框架,以其动态计…...
RabbitMQ:发送者的可靠性之配置发送者重试机制
文章目录 为什么需要重试机制?如何配置重试机制?测试重试机制使用重试机制的注意事项 在使用消息队列(MQ)系统时,网络故障是不可避免的问题,尤其是在与RabbitMQ等服务交互时。如果生产者在发送消息时遇到网…...
基于深度学习的大规模MIMO信道状态信息反馈
MIMO系统 MIMO系统利用多个天线在发送端和接收端之间建立多条独立的信道,从而使得同一时间可以传输多个数据流,从而使得同一之间可以传输多个数据流,提高数据传输速率。 优势 增加传输速率和容量,提高信号覆盖范围和抗干扰能力…...
在Docker中部署Rasa NLU服务
最近因为项目需要将rasa nlu配置到docker容器中供系统调用,本篇主要整理该服务的docker配置过程。 本篇的重点在于docker的使用,不在Rasa NLU。 系统环境:Ubuntu 18.04.6 1. Rasa介绍 Rasa是一个开源的机器学习框架,专为构建基于文…...
SQL语句创建数据库(增删查改)
SQL语句 一.数据库的基础1.1 什么是数据库1.2 基本使用1.2.1 连接服务器1.2.2 使用案例 1.2 SQL分类 二.库的操作2.1 创建数据库2.2 创建数据库示例2.3 字符集和校验规则2.3.1 查看系统默认字符集以及校验规则2.3.2查看数据库支持的字符集2.3.3查看数据库支持的字符集校验规则2…...
微信小程序-Vant组件库的使用
一. 在app.json里面删除style:v2 为了避免使用Vant组件库和微信小程序组件样式的相互影响 二.在app.json里面usingComponents注册Vant组件库的自定义组件 "usingComponents": {"van-icon": "./miniprogram_npm/vant-weapp/icon/index&qu…...
为什么企业需要进行能源体系认证?
通过能源体系认证,企业可以向公众和利益相关方展示其在节能减排方面的承诺和成就。这不仅提升了企业的社会责任形象,还增强了品牌的信誉度。在当今消费者更加关注环境问题的背景下,绿色企业形象有助于赢得市场和客户的认可与信任。 能源体系认…...
【日常记录-MySQL】EVENT
Author:赵志乾 Date:2024-08-07 Declaration:All Right Reserved!!! 1. 简介 在MySQL中,EVENT是一种数据库对象,其用于设定数据库任务自动执行。这些任务可以是任意有效的SQL语句&a…...
嵌入式学习day12(LinuxC高级)
由于C高级部分比较零碎,各部分之间没有联系,所以学起来比较累,多练习就好了 一丶Linux起源 寻科普|第二期:聊聊Linux的前世今生 UNIX和linux的区别: (1)linux是开发源代码的自由软件.而unix是…...
pytorch中的hook机制register_forward_hook
上篇文章主要介绍了hook钩子函数的大致使用流程,本篇文章主要介绍pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行。 1、hook背景 Hook被成为钩子机制,pytorch中包含forwa…...
使用Gin框架返回JSON、XML和HTML数据
简介 Gin是一个高性能的Go语言Web框架,它不仅提供了简洁的API,还支持快速的路由和中间件处理。在Web开发中,返回JSON、XML和HTML数据是非常常见的需求。本文将介绍如何使用Gin框架来返回这三种类型的数据。 环境准备 在开始之前࿰…...
网工内推 | 国企运维工程师,华为认证优先,最高年薪20w
01 上海陆家嘴物业管理有限公司 🔷招聘岗位:IT运维工程师 🔷岗位职责: 1、负责对公司软、硬件系统、周边设备、桌面系统、服务器、网络基础环境运行维护、故障排除。 2、负责对各部门软件操作、网络安全进行检查、指导。 3、负责…...
基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
AtCoder 第409场初级竞赛 A~E题解
A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...
【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案
目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后,迭代器会失效,因为顺序迭代器在内存中是连续存储的,元素删除后,后续元素会前移。 但一些场景中,我们又需要在执行删除操作…...
Leetcode33( 搜索旋转排序数组)
题目表述 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...
土建施工员考试:建筑施工技术重点知识有哪些?
《管理实务》是土建施工员考试中侧重实操应用与管理能力的科目,核心考查施工组织、质量安全、进度成本等现场管理要点。以下是结合考试大纲与高频考点整理的重点内容,附学习方向和应试技巧: 一、施工组织与进度管理 核心目标: 规…...
