当前位置: 首页 > news >正文

pytorch模型的保存与加载

1 pytorch保存和加载模型的三种方法

PyTorch提供了三种种方式来保存和加载模型,在这三种方式中,加载模型的代码和保存模型的代码必须相匹配,才能保证模型的加载成功。通常情况下,使用第一种方式(保存和加载模型状态字典)更加常见,因为它更轻量且不依赖于特定的模型类。

1.1 仅保存和加载模型参数

1.1.1 保存模型参数

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))# 保存整个模型
torch.save(model.state_dict(), 'sample_model.pt')

1.1.2 加载模型参数

import torch
import torch.nn as nn# 下载模型参数 并放到模型中
loaded_model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
loaded_model.load_state_dict(torch.load('sample_model.pt'))
print(loaded_model)

显示如下:

Sequential((0): Linear(in_features=128, out_features=16, bias=True)(1): ReLU()(2): Linear(in_features=16, out_features=1, bias=True)
)

net.state_dict(),在PyTorch中,Module 的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters() 访问)。state_dict 是一个从参数名称隐射到参数 Tesnor 的有序字典对象。只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict 中的条目。

1.2 保存和加载整个模型

1.2.1 保存整个模型

import torch
import torch.nn as nnnet = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))# 保存整个模型,包含模型结构和参数
torch.save(net, 'sample_model.pt')

1.2.2  加载整个模型

import torch
import torch.nn as nn# 加载整个模型,包含模型结构和参数
loaded_model = torch.load('sample_model.pt')
print(loaded_model)

显示如下:

Sequential((0): Linear(in_features=128, out_features=16, bias=True)(1): ReLU()(2): Linear(in_features=16, out_features=1, bias=True)
)

1.3 导出和加载ONNX格式模型

1.3.1 保存模型

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))input_sample = torch.randn(16, 128)  # 提供一个输入样本作为示例
torch.onnx.export(model, input_sample, 'sample_model.onnx')

1.3.2 加载模型

import torch
import torch.nn as nn
import onnx
import onnxruntimeloaded_model = onnx.load('sample_model.onnx')
session = onnxruntime.InferenceSession('sample_model.onnx')
print(session)

2 模型保存与加载使用的函数

2.1 保存模型函数torch.save

将对象序列化保存到磁盘中,该方法原理是基于python中的pickle来序列化,各种Models,tensors,dictionaries 都可以使用该方法保存。保存的模型文件名可以是.pth, .pt, .pkl

def save(obj: object,f: FILE_LIKE,pickle_module: Any = pickle,pickle_protocol: int = DEFAULT_PROTOCOL,_use_new_zipfile_serialization: bool = True
) -> None:
  • obj:保存的对象
  • f:一个类似文件的对象(必须实现写入和刷新)或字符串或操作系统。包含文件名的类似路径对象
  • pickle_module:用于挑选元数据和对象的模块
  • pickle_protocol:可以指定以覆盖默认协议

2.2 加载模型函数torch.load

def load(f: FILE_LIKE,map_location: MAP_LOCATION = None,pickle_module: Any = None,*,weights_only: bool = False,**pickle_load_args: Any
) -> Any:
  • f:类文件对象 (返回文件描述符)或一个保存文件名的字符串
  • map_location:一个函数或字典规定如何映射存储设备,torch.device对象
  • pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )

2.3 加载模型参数torch.nn.Module.load_state_dict

序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):
  • state_dict:保存 parameters 和 persistent buffers 的字典
  • strict:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

2.4 状态字典state_dict

函数作用是“获取优化器当前状态信息字典”,在神经网络中模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:

  • 通过model.parameters()这个生成器来访问所有参数。
  • 通过model.state_dict()来为每一层和它的参数建立一个映射关系并存储在字典中,其键值由每个网络层和其对应的参数张量构成。
def state_dict(self, destination=None, prefix='', keep_vars=False):

除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较便捷。

3 指定map_location加载模型

采用仅加载模型参数的方式,指定设备类型进行模型加载,代码如下:

model_path = '/opt/sample_model.pth'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
map_location = torch.device(device)model.load_state_dict(torch.load(self.model_path, map_location=self.map_location))

相关文章:

pytorch模型的保存与加载

1 pytorch保存和加载模型的三种方法 PyTorch提供了三种种方式来保存和加载模型,在这三种方式中,加载模型的代码和保存模型的代码必须相匹配,才能保证模型的加载成功。通常情况下,使用第一种方式(保存和加载模型状态字…...

IDE /完整分析C4819编译错误的本质原因

文章目录 概述基本概念代码页标识符字符集和字符编码方案源字符集和执行字符集 编译器使用的字符集VS字符集配置 有何作用编译器 - 源字符集编译器 -执行字符集 Qt Creator下配置MSVC编译器参数动态库DLL字符集配置不同于可执行程序EXE总结 概述 本文将从根本原因上来分析和解…...

前端学习路线(2023)

这个前端学习路线看起来很详细和全面,涵盖了从基础知识到高级框架,从单机开发到全栈项目,从混合应用到原生应用,从性能优化到架构设计的各个方面。如果你能够按照这个路线学习和实践,我相信你一定能够成为一名优秀的前…...

景区如何对旅行社进行分销管理?

旅行社的买票能力强,一般景区会跟多家旅行社合作门票分销。其中卖票下单、价格设定、财务对账结算都出现了很多问题,导致对账困难,查询困难,甚至可能有偷票漏票的情况出现,给景区收入造成损失。那要怎么处理呢&#xf…...

四步从菜鸟到高手,Python编程真的很简单(送书第一期:文末送书2本)

🍁博主简介 🏅云计算领域优质创作者   🏅华为云开发者社区专家博主   🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入! 🐋 希望大家多多支持,我…...

Thread类的常用结构(java))

1 构造器 public Thread() :分配一个新的线程对象。public Thread(String name) :分配一个指定名字的新的线程对象。public Thread(Runnable target) :指定创建线程的目标对象,它实现了Runnable接口中的run方法public Thread(Runnable target,String name) :分配一…...

CSS :nth-child

CSS :nth-child :nth-child 伪类根据元素在同级元素中的位置来匹配元素. CSS :nth-child 语法 值是关键词 odd/evenAnB最新的 [of S] 语法权重 浏览器兼容性 很简单的例子, 来直觉上理解这个伪类的意思 <ul><li class"me">Apple</li><li>B…...

国内好用的企业级在线文档有哪些?

在当今数字化时代&#xff0c;企业级在线文档已经成为了现代办公环境中不可或缺的一部分。它不仅能够提高工作效率&#xff0c;还能够实现多人协同编辑&#xff0c;满足团队协作的需求。那么&#xff0c;在国内市场上&#xff0c;哪些企业级在线文档产品备受企业青睐呢&#xf…...

P1217 [USACO1.5] 回文质数 Prime Palindromes

题目描述 因为 151 151 151 既是一个质数又是一个回文数&#xff08;从左到右和从右到左是看一样的&#xff09;&#xff0c;所以 151 151 151 是回文质数。 写一个程序来找出范围 [ a , b ] ( 5 ≤ a < b ≤ 100 , 000 , 000 ) [a,b] (5 \le a < b \le 100,000,000…...

【STM32MP1系列】DDR内存测试用例

DDRDDR内存测试 一、uboot下测试DDR内存二、Linux内核下测试DDR内存1、使用memtester测试DDR内存2、使用stressapptest测试DDR内存三、Buildroot中构建memtester软件包四、搭建stressapptest软件包五、注意事项一、uboot下测试DDR内存 输入bdinfo查看DDR起始地址以及大小: b…...

【CAS6.6源码解析】调试Rest API接口

CAS的web层默认是基于webflow实现的&#xff0c;ui和后端是耦合在一起的&#xff0c;做前后端分离调用和调试的时候不太方便。但是好在CAS已经添加了支持Rest API的support模块&#xff0c;添加相应模块即可。 文章目录 添加依赖并重新build效果 添加依赖并重新build 具体添加…...

C++设计模式之模板方法、策略模式、观察者模式

面向对象设计模式是”好的面向对象设计“&#xff0c;所谓”好的面向对象设计“指的是可以满足”应对变化&#xff0c;提高复用“的设计。 现代软件设计的特征是”需求的频繁变化“。设计模式的要点是”寻求变化点&#xff0c;然后在变化点处应用设计模式&#xff0c;从而更好地…...

【计算机网络 02】物理层基本概念 传输媒体 传输方式 编码与调制 信道极限容量 章节小结

第二章 -- 物理层 2.1 物理层基本概念2.2 物理层下的传输媒体2.3 传输方式2.4 编码与调制2.5 信道极限容量2.6 章节小结 2.1 物理层基本概念 2.2 物理层下的传输媒体 传输媒体也称为传输介质或传输媒介&#xff0c;他就是数据传输系统中在发送器和接收器之间的物理通路 传输媒…...

无涯教程-jQuery - serialize( )方法函数

serialize()方法将一组输入元素序列化为数据字符串。 serialize( ) - 语法 $.serialize( ) serialize( ) - 示例 假设无涯教程在serialize.php文件中具有以下PHP内容- <?php if( $_REQUEST["name"] ) {$name$_REQUEST[name];echo "Welcome ". $na…...

一套不错的基于uniapp实现的投票类小程序/H5

最近作者心血来潮&#xff0c;想做一个热点话题投票&#xff0c;话题相关的资讯跟踪类的小程序&#xff0c;方便自己发布一些大家比较关心的话题。 基于以上需求&#xff0c;说干就干&#xff0c;首先需要定义一个需求&#xff1a; 1、支持热门话题投票、排行榜&#xff08;日…...

Mac代码编辑器sublime text 4中文注册版下载

Sublime Text 4 for Mac简单实用功能强大&#xff0c;是程序员敲代码必备的代码编辑器&#xff0c;sublime text 4中文注册版支持多种编程语言&#xff0c;包括C、Java、Python、Ruby等&#xff0c;可以帮助程序员快速编写代码。Sublime Text的界面简洁、美观&#xff0c;支持多…...

django------模糊查询

1.常用模糊查询的方法 queryset中支持链式操作 bookBook.objects.all().order_by(-nid).first() 只要返回的是queryset对象就可以调用其他的方法,直到返回的是对象本身 大于、大于等于、小于、小于等于&#xff1a; # __gt 大于> # __gte 大于等于> # __lt 小于< …...

AVFoundation - 音视频组合编辑

文章目录 一、简要说明二、使用1、音频和视频合成2、视频的拼接一、简要说明 相关类 AVMutableCompositionAVMutableCompositionTrack二、使用 1、音频和视频合成 - (void)testCom1{AVMutableComposition *mutableComposition = [AVMutableComposition composition];AVMu...

jpa生成实体类,jpa根据数据库表生成实体类

jpa生成实体类&#xff0c;jpa根据数据库表生成实体类jpa根据数据库表结构生成实体idea下SpringbootJPA从数据库自动生成实体类JPA用数据库表直接生成实体类Spring boot整合jpa(一),根据表生成实体IDEA下SpringData-JPA根据数据库表生成实体类idea怎么根据数据库表自动生成JPA实…...

嵌入式Linux系统组成

嵌入式Linux系统的组成 文章目录 嵌入式Linux系统的组成一、发行版Linux系统VS嵌入式Linux系统二、嵌入式Linux系统架构一、发行版Linux系统VS嵌入式Linux系统 1.产品 发行版Linux系统产品:服务器、消费平板、消费手提电脑 嵌入式Linux系统产品:扫地机器人,小米机顶盒特定场…...

机器学习结合基因无关通路映射:从临床数据挖掘新药靶点

1. 项目概述&#xff1a;当机器学习遇见代谢通路&#xff0c;如何从数据中“挖”出新药靶点&#xff1f;在生物医学研究的前沿&#xff0c;我们正面临一个核心矛盾&#xff1a;一方面&#xff0c;我们拥有海量的临床数据&#xff0c;比如血糖、血压、BMI等指标&#xff1b;另一…...

对称与负电源测试:动态直流电子负载的设计、原理与应用

1. 项目概述&#xff1a;对称与负电源的静态与动态直流负载在电子实验室里&#xff0c;测试一个电源的性能&#xff0c;尤其是它的动态响应能力&#xff0c;是件既基础又关键的事。我们常说的“直流电子负载”就是这个领域的核心工具。我之前设计并分享过一个用于正电源测试的静…...

【UniApp小程序开发】解决无法使用Vue自定义指令的完美替代方案:权限组件封装

在 UniApp 开发中&#xff0c;你是否遇到过这样的困惑&#xff1a;明明在 Vue Web 项目中用得顺手的 v-permission 自定义指令&#xff0c;一到小程序端就完全失效&#xff1f;本文将深入剖析其原因&#xff0c;并提供一套可直接复用的组件化解决方案&#xff0c;让你在小程序中…...

Windows 10/11系统下,SecureCRT 8.7.2保姆级安装与激活图文指南(含Keygen使用避坑点)

Windows平台SecureCRT 8.7.2全流程部署与安全配置指南在当今远程运维与网络管理的日常工作中&#xff0c;一款可靠的终端仿真工具如同工程师的瑞士军刀。作为行业标杆的SecureCRT&#xff0c;其8.7.2版本在Windows 10/11环境下的部署却常让新手陷入各种技术陷阱——从安装路径选…...

深度解析DeTikZify:科研工作者的智能图表生成神器

深度解析DeTikZify&#xff1a;科研工作者的智能图表生成神器 【免费下载链接】DeTikZify Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ. 项目地址: https://gitcode.com/gh_mirrors/de/DeTikZify 在科研工作中&#xff0c;创建高质量…...

Unity事件系统实战:用事件驱动重构你的金币拾取逻辑(告别硬编码)

Unity事件系统实战&#xff1a;用事件驱动重构你的金币拾取逻辑&#xff08;告别硬编码&#xff09;在游戏开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;玩家拾取金币后&#xff0c;需要更新UI、播放音效、解锁成就、保存数据……如果把这些逻辑全部写在金币拾取的代…...

styled-theming 性能优化:如何避免主题切换时的性能瓶颈

styled-theming 性能优化&#xff1a;如何避免主题切换时的性能瓶颈 【免费下载链接】styled-theming Create themes for your app using styled-components 项目地址: https://gitcode.com/gh_mirrors/st/styled-theming styled-theming 是一个专为 styled-components …...

Elden Ring帧率解锁终极指南:从60帧到144+的完整教程

Elden Ring帧率解锁终极指南&#xff1a;从60帧到144的完整教程 【免费下载链接】EldenRingFpsUnlockAndMore A small utility to remove frame rate limit, change FOV, add widescreen support and more for Elden Ring 项目地址: https://gitcode.com/gh_mirrors/el/Elden…...

DeepSeek代码风格检查避坑指南(内部审计报告首次披露:37个被忽略的合规红线)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;DeepSeek代码风格检查的合规性本质与审计背景 DeepSeek代码风格检查并非单纯的技术偏好约束&#xff0c;而是嵌入研发治理链条中的合规性控制节点。其本质是将编程实践与组织级安全策略、行业监管要求&…...

LPCM框架:大模型驱动的计算机架构设计革命

1. LPCM框架&#xff1a;计算机系统架构设计的范式革命计算机系统架构设计正站在历史性的转折点上。过去八十年来&#xff0c;从ENIAC的真空管到现代7纳米制程的异构计算芯片&#xff0c;架构设计始终遵循着"专家经验EDA工具"的传统范式。但随着摩尔定律逼近物理极限…...