当前位置: 首页 > news >正文

常用组件详解(十):保存与加载模型、检查点机制的使用

文章目录

  • 1.保存、加载模型
  • 2.torch.nn.Module.state_dict()
    • 2.1基本使用
    • 2.2保存和加载状态字典
  • 3.创建Checkpoint
    • 3.1基本使用
    • 3.2完整案例


1.保存、加载模型

  torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量和字典等(内部使用pickle模块实现对象的序列化)。数据会被保存为.pt.pth格式,可通过torch.load()从磁盘加载被保存的序列化对象,加载时会重新构造出原来的对象。
  torch.save()有两种保存模型的方式:

  • 1.保存整个模型(继承了torch.nn.Module的类),不推荐使用。
    • torch.load():利用pickle将保存的序列化对象反序列化,得到原始数据。可用于加载完整模型或状态字典。
#保存整个模型
torch.save(model, PATH)
#加载模型
model = torch.load(PATH)
  • 2.仅保存模型的参数(状态字典state_dict),推荐使用。
    • torch.nn.Module.load_state_dict():通过反序列化得到模型的state_dict()(状态字典)来加载模型,传入的参数是状态字典,而非.pt.pth文件。
#只保存模型参数
torch.save(model.state_dict(), PATH)
#加载模型
model=Model()
model.load_state_dict(torch.load(PATH))

  在实际使用中推荐第二种方式,第一种方式往往容易产生各种错误:

  • 设备错误。若在cuda:0上训练好一个模型并保存,则读取出来的模型也是默认在cuda:0上,如果训练过程的其他数据被放到了cuda:1上,那么就会发生错误:
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

此时需要将其他其他数据都保存在cuda:0上,或加载模型时指定使用cuda:1

device = torch.device("cuda:1")
model = torch.load(PATH, map_location=device)
  • 版本错误:比如使用pytorch1.0训练并保存CNN模型,再用pytorch1.1读取模型,则会出现错误:
AttributeError: 'Conv2d' object has no attribute 'padding_mode'

此时只能通过获取该模型的参数来加载新的模型:

#加载模型参数
model_state = torch.load(model_path).state_dict()
#初始化新模型并加载参数
model = Model()
model.load_state_dict(model_state)

2.torch.nn.Module.state_dict()

2.1基本使用

  torch.nn.Module.state_dict()用于返回模型的状态字典,其中保存了模型的可学习参数。其中,只有可学习参数的层(卷积层、全连接层等)和注册缓冲区(batchnorm’s running_mean)才会作为模型参数保存(优化器也有状态字典,也可进行保存)。
【例子】

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型
model = TheModelClass()# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

  查看模型与优化器的状态字典:
在这里插入图片描述

2.2保存和加载状态字典

  通过torch.save()来保存模型的状态字典(state_dict),即只保存学习到的模型参数,并通过torch.nn.Module.load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为.pt.pth

#保存模型状态字典
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
#根据状态字典加载模型
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
#打印新模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())

  注意,模型推理之前,需要调用model.eval()函数将dropoutbatch normalization层设置为评估模式,否则会导致模型推理结果不一致。
在这里插入图片描述

3.创建Checkpoint

3.1基本使用

  模型检查点(checkpoint)是指模型训练过程中保存的模型状态,包括模型参数(权重与偏置)、优化器状态等其他相关的训练信息。通过保存检查点,可以实现在训练过程中定期保存模型的当前状态,以便在需要时恢复训练或用于模型评估和推理。模型检查点常见的保存信息如下:

  • 1.模型权重:模型的状态字典。
  • 2.优化器状态:优化器的状态字典。
  • 3.训练状态:当前的训练轮数(epoch)、批次(batch)等。
  • 4.其他数据:如学习率调度器的状态、自定义指标等。

例如:
【保存检查点】

#将模型参数和优化器状态的状态字典保存到检查点中
checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss.item(),'epoch':epoch
}#保存检查点
torch.save(checkpoint, 'checkpoint.pth')

【加载检查点】

# 加载检查点
checkpoint = torch.load('checkpoint.pth')# 恢复模型和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])# 恢复训练状态
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果是恢复训练,可以从保存的epoch继续
for epoch in range(epoch, num_epochs):# 继续训练

3.2完整案例

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()# 训练循环
num_epochs = 100
for epoch in range(num_epochs):# 假设有输入x和目标yx = torch.randn(64, 10)y = torch.randn(64, 1)optimizer.zero_grad()output = model(x)loss = loss_fn(output, y)loss.backward()optimizer.step()# 每10个epoch保存一次检查点if epoch % 10 == 0:checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch,'loss': loss.item()}torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')# 加载检查点并继续训练
checkpoint = torch.load('checkpoint_epoch_10.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']# 从第11个epoch开始继续训练
for epoch in range(start_epoch + 1, num_epochs):# 继续训练pass

相关文章:

常用组件详解(十):保存与加载模型、检查点机制的使用

文章目录 1.保存、加载模型2.torch.nn.Module.state_dict()2.1基本使用2.2保存和加载状态字典 3.创建Checkpoint3.1基本使用3.2完整案例 1.保存、加载模型 torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量…...

基于SpringBoot+Vue+MySQL的在线学习交流平台

系统展示 用户前台界面 管理员后台界面 系统背景 随着互联网技术的飞速发展,在线学习已成为现代教育的重要组成部分。传统的面对面教学方式已无法满足广大学习者的需求,特别是在时间、地点上受限的学习者。因此,构建一个基于SpringBoot、Vue.…...

前端开发在AI时代如何保持核心竞争力

随着人工智能(AI)技术的迅猛发展,前端开发领域正经历着前所未有的变革。AI辅助开发工具、自动化测试框架、智能代码补全等技术的出现,极大地提高了开发效率,同时也对前端开发人员的技能和角色提出了新的要求。在这个背…...

ffmpeg面向对象——拉流协议匹配机制探索

目录 1.URLProtocol类2.协议匹配的核心接口3. URLContext类4. 综合调用流程图5.rtsp拉流协议匹配流程图及对象图5.1 rtsp拉流协议调用流程图5.2 rtsp拉流协议对象图 6.本地文件调用流程图及对象图6.1 本地文件调用流程图6.2 本地文件对象图 7.内存数据调用流程图及对象图7.1 内…...

R语言绘制柱状图

柱状图是一种数据可视化工具。由 x 轴和 y 轴构成,x 轴表示类别,y 轴为数据数值。以矩形柱子展示数据大小,便于直观比较不同类别数据差异及了解分布。广泛应用于销售分析、统计、项目管理、科学研究等领域。可定制颜色、宽度等属性&#xff0…...

GNU/Linux - tarball文件介绍介绍

Linux 中的 tarball 文件是将多个文件和目录归档到一个文件中的常用方法,通常用于备份、分发或打包目的。术语 “tarball ”来源于 “tar”(磁带归档的缩写)命令的使用,该命令最初设计用于将数据写入磁带等顺序存储设备。如今&…...

AppointmentController

目录 1、 AppointmentController 1.1、 删除预约单据信息 1.2、 反审核预约单 1.3、 SelectToMainten AppointmentController using QXQPS.Models; using QXQPS.Vo; using System; using System.Collections; using System.Collections.Generic; using System.L…...

网站建设完成后,切勿让公司官网成为摆设

在当今这个数字化时代,公司官网已经成为企业展示形象、传递信息、吸引客户的重要平台。然而,许多企业在网站建设完成后,往往忽视了对官网的持续运营和维护,导致官网逐渐沦为摆设,无法发挥其应有的作用。为了确保公司官…...

独孤思维:闲得蛋疼才去做副业

独孤现实中玩的要好的朋友。 他们都只在自己的社交圈,工作圈链接。 没有人知道,副业可以这么玩。 所以他们很好奇,问我,独孤,你最开始是怎么知道这些副业的? 其实,独孤最开始接触副业&#…...

vulnhub靶场之hackablell

一.环境搭建 1.靶场描述 difficulty: easy This works better with VirtualBox rather than VMware 2.靶场下载 https://download.vulnhub.com/hackable/hackableII.ova 3.靶场启动 二.信息收集 1.寻找靶场的真实ip nmap -SP 192.168.246.0/24 arp-scan -l 根据上面两个…...

《浔川社团官方通报 —— 为何明确 10 月 2 日上线的浔川 AI 翻译 v3.0 再次被告知延迟上线》

《浔川社团官方通报 —— 为何明确 10 月 2 日上线的浔川 AI 翻译 v3.0 再次被告知延迟上线》 各位关注浔川社团的朋友们: 大家好!首先,我们要向一直期待浔川 AI 翻译 v3.0 上线的朋友们致以最诚挚的歉意。原定于 10 月 2 日上线的浔川 AI 翻…...

加密与安全_HOTP一次性密码生成算法

文章目录 HOTP 的基础原理HOTP 的工作流程HOTP 的应用场景HOTP 的安全性安全性增强措施Code生成HOTP可配置项校验HOTP可拓展功能计数器(counter)计数器在客户端和服务端的作用计数器的同步机制客户端和服务端中的计数器表现服务端如何处理计数器不同步计…...

ResNet18果蔬图像识别分类

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝,拥有2篇国家级人工智能发明专利。 社区特色…...

深度强化学习中收敛图的横坐标是steps还是episode?

在深度强化学习(Deep Reinforcement Learning, DRL)的收敛图中,横坐标选择 steps 或者 episodes 主要取决于算法的设计和实验的需求,两者的差异和使用场景如下: Steps(步数): 定义&a…...

一个真实可用的登录界面!

需要工具: MySQL数据库、vscode上的php插件PHP Server等 项目结构: login | --backend | --database.sql |--login.php |--welcome.php |--index.html |--script.js |--style.css 项目开展 index.html: 首先需要一个静态网页&#x…...

Vue中watch监听属性的一些应用总结

【1】vue2中watch的应用 ① 简单监视 在 Vue 2 中,如果你不需要深度监视,即只需监听顶层属性的变化,可以使用简写形式来定义 watch。这种方式更加简洁,适用于大多数基本场景。 示例代码 假设你有一个 Vue 组件,其中…...

MongoDB-aggregate流式计算:带条件的关联查询使用案例分析

在数据库的查询中,是一定会遇到表关联查询的。当两张大表关联时,时常会遇到性能和资源问题。这篇文章就是用一个例子来分享MongoDB带条件的关联查询发挥的作用。 假设工作环境中有两张MongoDB集合:SC_DATA(学生基本信息集合&…...

Redis数据库与GO(一):安装,string,hash

安装包地址:https://github.com/tporadowski/redis/releases 建议下载zip版本,解压即可使用。解压后,依次打开目录下的redis-server.exe和redis-cli.exe,redis-cli.exe用于输入指令。 一、基本结构 如图,redis对外有个…...

expressjs,实现上传图片,返回图片链接

在 Express.js 中实现图片上传并返回图片链接,你通常需要使用一个中间件来处理文件上传,比如 multer。multer 是一个 node.js 的中间件,用于处理 multipart/form-data 类型的表单数据,主要用于上传文件。 以下是一个简单的示例&a…...

爬虫——XPath基本用法

第一章XML 一、xml简介 1.什么是XML? 1,XML指可扩展标记语言 2,XML是一种标记语言,类似于HTML 3,XML的设计宗旨是传输数据,而非显示数据 4,XML标签需要我们自己自定义 5,XML被…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

Robots.txt 文件

什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...

Python 包管理器 uv 介绍

Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)

题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...

【Post-process】【VBA】ETABS VBA FrameObj.GetNameList and write to EXCEL

ETABS API实战:导出框架元素数据到Excel 在结构工程师的日常工作中,经常需要从ETABS模型中提取框架元素信息进行后续分析。手动复制粘贴不仅耗时,还容易出错。今天我们来用简单的VBA代码实现自动化导出。 🎯 我们要实现什么? 一键点击,就能将ETABS中所有框架元素的基…...

python基础语法Ⅰ

python基础语法Ⅰ 常量和表达式变量是什么变量的语法1.定义变量使用变量 变量的类型1.整数2.浮点数(小数)3.字符串4.布尔5.其他 动态类型特征注释注释是什么注释的语法1.行注释2.文档字符串 注释的规范 常量和表达式 我们可以把python当作一个计算器&#xff0c;来进行一些算术…...