2025-05-31 Python深度学习9——网络模型的加载与保存
文章目录
- 1 使用现有网络
- 2 修改网络结构
- 2.1 添加新层
- 2.2 替换现有层
- 3 保存网络模型
- 3.1 完整保存
- 3.2 参数保存(推荐)
- 4 加载网络模型
- 4.1 加载完整模型文件
- 4.2 加载参数文件
- 5 Checkpoint
- 5.1 保存 Checkpoint
- 5.2 加载 Checkpoint
本文环境:
- Pycharm 2025.1
- Python 3.12.9
- Pytorch 2.6.0+cu124
PyTorch 通过torchvision.models
提供预训练模型(如 VGG16)。
网址链接:https://docs.pytorch.org/vision/stable/models.html。
1 使用现有网络
以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

方法一:使用随机初始化权重
将 weights 设置为 None,从 0 开始训练自己的网络。
vgg16_false = torchvision.models.vgg16(weights=None) # 权重随机初始化
方法二:加载预训练权重
也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
可打印网络查看其模型结构:
print(vgg16_true)


2 修改网络结构
2.1 添加新层
使用add_module
在分类器(classifier
)后追加全连接层:
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

2.2 替换现有层
直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):
vgg16_false.classifier[6] = nn.Linear(4096, 10) # 替换第6层

3 保存网络模型
使用torch.save()
方法保存网络模型。文件扩展名推荐使用.pt
或.pth
。
3.1 完整保存
将模型类和参数一并保存到文件中。
torch.save(vgg16, 'vgg16_method1.pth') # 包含模型类和参数
- 优点:加载时无需重新定义模型结构。
- 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。
3.2 参数保存(推荐)
仅保存参数字典到文件中。
torch.save(vgg16.state_dict(), 'vgg16_method2.pth') # 仅保存参数字典
- 优点:文件小,灵活性强,适合生产部署。
示例
import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')
4 加载网络模型
使用torch.load()
方法加载网络模型。
4.1 加载完整模型文件
加载完整模型时,需将 weights_only 参数设置为 False。
model = torch.load('vgg16_method1.pth', weights_only=False) # 需确保模型类已定义
模型打印结果如下:
print(model)

注意
若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。
model_save.py
# model_save.pyimport torch from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel() torch.save(model, 'my_model_method1.pth')
model_load.py
import torchmodel = torch.load('my_model_method1.pth', weights_only=False) # 报错,找不到 MyModel 的定义
先运行 model_save.py,再运行 model_load.py,则会出现以下报错:
![]()
4.2 加载参数文件
首先,使用torch.load()
方法加载网络模型。
使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()
加载参数数据。
vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict) # 需结构匹配
模型打印结果是参数字典:
print(model_dict)

注意
模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。
torch.load('model.pth', map_location=torch.device('cpu'))
将参数加载到模型后,手动迁移到 GPU:
model = MyModel() model.load_state_dict(model_dict) model.to('cuda:0')
5 Checkpoint
使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar
。
5.1 保存 Checkpoint
要保存一个模型的 Checkpoint,通常需要保存以下数据:
- 模型的 state_dict(状态字典);
- 优化器的状态;
- 额外的信息,如 epoch 等。
import torch# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')
5.2 加载 Checkpoint
加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。
# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以继续训练
model.train() # 确保模型处于训练模式
相关文章:

2025-05-31 Python深度学习9——网络模型的加载与保存
文章目录 1 使用现有网络2 修改网络结构2.1 添加新层2.2 替换现有层 3 保存网络模型3.1 完整保存3.2 参数保存(推荐) 4 加载网络模型4.1 加载完整模型文件4.2 加载参数文件 5 Checkpoint5.1 保存 Checkpoint5.2 加载 Checkpoint 本文环境: Py…...

长安链起链调用合约时docker ps没有容器的原因
在调用这个命令的时候,发现并没有出现官方预期的合约容器,这是因为我们在起链的时候没有选择用docker的虚拟环境,实际上这不影响后续的调用,如果想要达到官方的效果那么你只需要在起链的时候输入yes即可,如图三所示...

Appium+python自动化(七)- 认识Appium- 上
简介 经过前边的各项准备工作,终于才把appium搞定。 一、appium自我介绍 appium是一款开源的自动化测试工具,可以支持iOS和安卓平台上的原生的,基于移动浏览器的,混合的应用(APP)。 1、 使用appium进…...
数据中心双活架构解决方案
数据中心双活架构解决方案 数据中心双活架构(Active-Active Data Center)旨在实现业务高可用、负载均衡和灾难自动切换。以下是完整的解决方案,涵盖架构设计、关键技术、实施步骤及最佳实践。 1. 双活架构设计 1.1 基本架构模型 同城双活(Metro Active-Active) 两个数据…...
YOLOv5 详解:从原理到实战的全方位解析
在计算机视觉领域,目标检测作为核心任务之一,始终吸引着众多研究者和开发者的目光。YOLO(You Only Look Once)系列算法凭借其高效、准确的特点,在目标检测领域占据重要地位。而 YOLOv5 作为 YOLO 系列算法的重要成员&a…...

模块联邦:更快的微前端方式!
什么是模块联邦 在前端项目中,不同团队之间的业务模块可能有耦合,比如A团队的页面里有一个富文本模块(组件),而B团队 的页面恰好也需要使用这个富文本模块。 传统模式下,B团队只能去抄A团队的代码&#x…...

前端基础学习html+css+js
HTML 区块 div标签,块级标签 span包装小部分文本,行内元素 表单 CSS css选择器 css属性 特性blockinlineinline-block是否换行✅ 换行❌ 不换行❌ 不换行可设置宽高✅ 支持❌ 不支持✅ 支持常见元素div容器 p段落 h标题span文本容器 a超链接img图片…...

手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席
手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席 --局域网SIP坐席呼叫 上一篇:手机打电话时由对方DTMF响应切换多级IVR语音菜单(完结) 下一篇:安卓App识别手机系统弹授权框包含某段文字-并自动点击确定按钮 一、前言 …...
TCP三次握手/四次握手-TCP/IP四层模型-SSL/TLS-HTTP-HTTPS
重要概念 seq ( Squence Number ) 序列号,用于数据排序、去重,防止数据包乱序 ack ( Acknowledgement Number ) 确认好,表示期望接受的下一个字节序号,用于确认数据包被对方接受 TCP三次握手是建立可靠连接的过程,确…...

SAP Business One:无锡哲讯科技助力中小企业数字化转型的智慧之选
数字化转型,中小企业的必经之路 在当今竞争激烈的商业环境中,数字化转型已不再是大型企业的专利,越来越多的中小企业开始寻求高效、灵活的管理系统来优化业务流程、提升运营效率。作为全球领先的企业管理软件,SAP Business One…...
【Ubuntu远程桌面】
Ubuntu-远程桌面 ubuntu环境rustdesk-1.4.0-aarch64.deb安装rustdesk注意事项:报错:可能会在远程连接时候显示‘No displays’解决方法1. 安装 CUDA(如果需要)2. 解决 XDG 桌面门户问题3. 检查 RustDesk 客户端日志 总结 kill --t…...
⚡ Linux 系统安装与配置 Vim 编辑器(包括 Vim 插件管理器)
⚡ Linux 系统安装与配置 Vim 编辑器(包括 Vim 插件管理器) 📌 1. Vim 简介 Vim(Vi IMproved)是一款高度可定制的文本编辑器,基于早期的 vi 编辑器扩展而来。 它支持语法高亮、插件扩展、多种编程语言&am…...

小型语言模型:为何“小”才是“大”?
当说到人工智能(AI)的时候,大家通常会想到那些拥有数十亿参数的超大型语言模型,它们能做出一些令人惊叹的事情。 厉害不厉害?绝对厉害! 但对于大多数企业和开发者来说,实用吗?可能…...
雪花算法:分布式ID生成的优雅解决方案
一、雪花算法的核心机制与设计思想 雪花算法(Snowflake)是由Twitter开源的分布式ID生成算法,它通过巧妙的位运算设计,能够在分布式系统中快速生成全局唯一且趋势递增的ID。 1. 基本结构 雪花算法生成的是一个64位(lo…...
针对PostgreSQL中pg_wal目录占用过大的系统性解决方案
一、问题现象与根本原因 当pg_wal目录占用超过预期(如数十GB甚至占满磁盘),通常由以下原因导致 长事务未提交:未完成的事务会阻塞WAL日志清理。复制槽未释放:逻辑复制或流复制槽未及时清理,导…...
git push Git远端意外挂断
git push Git远端意外挂断 枚举对象中: 99, 完成. 对象计数中: 100% (99/99), 完成. 使用 8 个线程进行压缩 压缩对象中: 100% (78/78), 完成. send-pack: unexpected disconnect while reading sideband packet 写入对象中: 100% (82/82), 2.78 MiB | 5.56 MiB/s, 完成. 总共…...
python学习day34
GPU训练及类的call方法 知识点回归: CPU性能的查看:看架构代际、核心数、线程数GPU性能的查看:看显存、看级别、看架构代际GPU训练的方法:数据和模型移动到GPU device上类的call方法:为什么定义前向传播时可以直接写作…...

秋招Day12 - 计算机网络 - 网络综合
从浏览器地址栏输入URL到显示网页的过程了解吗? 从在浏览器地址栏输入 URL 到显示网页的完整过程,并不是一个单一的数据包从头到尾、一次性地完成七层封装再七层解析的过程。 而是涉及到多次、针对不同目的、与不同服务器进行的、独立的网络通信交互&a…...

QT-JSON
#include <QJsonDocument>#include <QJsonObject>#include <QJsonArray>#include <QFile>#include <QDebug>void createJsonFile() {// 创建一个JSON对象 键值对QJsonObject jsonObj;jsonObj["name"] "John Doe";jsonObj[…...

IP 风险画像技术略解
IP 风险画像的技术定义与价值 IP 风险画像通过整合 IP 查询数据与 IP 离线库信息,结合机器学习算法,为每个 IP 地址生成多维度风险评估模型。其核心价值在于将传统的静态 IP 黑名单升级为动态风险评估体系,可实时识别新型网络威胁࿰…...

秋招Day12 - 计算机网络 - 基础
说一下计算机网络体系结构 OSI七层模型,TCP/IP四层模型和五层体系结构 说说OSI七层模型? 应用层:最靠近用户的层,用于处理特定应用程序的细节,提供了应用程序和网络服务之间的接口。表示层:确保从一个系…...

【网络安全】——Modbus协议详解:工业通信的“通用语言”
目录 一、初识Modbus:工业通信的基石 1.1 协议全称 1.2 协议简史 二、核心特性解析 2.1 架构设计 2.2 典型应用场景 三、协议族全景图 3.1 协议栈分类 3.2 版本演进对比 四、协议报文深度解析 4.1 Modbus RTU帧结构 4.2 Modbus TCP报文 五、通信机制实…...
MySQL 数据库备份与恢复利器:Percona XtraBackup 详解
一、XtraBackup 简介 1. 什么是 XtraBackup? XtraBackup 是 Percona 公司推出的免费开源工具,专为 InnoDB/XtraDB 引擎设计,支持 在线物理热备,具备以下核心特性: 非阻塞备份:备份过程中数据库仍可读写。…...

【GlobalMapper精品教程】095:如何获取无人机照片的拍摄方位角
文章目录 一、加载无人机照片二、计算方位角三、Globalmapper符号化显示方向四、arcgis符号化显示方向一、加载无人机照片 打开软件,加载无人机照片,在GLobalmapperV26中文版中,默认显示如下的航线信息。 关于航线的起止问题,可以直接从照片名称来确定。 二、计算方位角 …...

小提琴图绘制-Graph prism
在 GraphPad Prism 中为小提琴图添加显著性标记(如*P<0.05)的步骤如下: 步骤1:完成统计检验 选择数据表:确保数据已按分组排列(如A列=Group1,B列=Group2)。执行统计检验: 点击工具栏 Analyze → Column analyses → Mann-Whitney test(非参数检验,适用于非正态数…...
写作即是生活
一个问题 “我是什么时候开始写作的呢?”请你先暂停一下,别往下读,先想想这个问题。 什么才是写作? 或许在想上个问题之后,你就会开始想问另外一个问题,什么才算是写作呢? 我的回答是&#x…...
进阶知识:Selenium底层原理深度解析
Selenium底层原理深度解析:网络IO密集型系统揭秘 一、Selenium核心组件解析 1.1 三大核心角色 客户端(Client) 扮演"指挥官"角色,负责: 编写测试脚本(模拟用户点击、输入等操作)发送…...
基于 Flickr30k-Entities 数据集 的 Phrase Localization
以下示例基于 Flickr30k-Entities 数据集中的标注,以及近期(以 TransVG (Li et al. 2021)为例)在短语定位(Phrase Grounding)任务上的评测结果,展示了单张图片中若干名词短语的定位情…...

[GHCTF 2025]SQL???
打开题目在线环境: 先尝试注入: id1;show databases; 发现报错,后来看了wp才知道这个题目是SQLite注入。 我看的是这个师傅的wp: https://blog.csdn.net/2401_86190146/article/details/146164505?ops_request_misc%257B%2522request%255Fid…...

【科研绘图系列】R语言绘制GO term 富集分析图(enrichment barplot)
禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理画图code 2code 3系统信息介绍 本文介绍了使用R语言绘制GO富集分析条形图的方法。通过加载ggplot2等R包,对GO term数据进行预处理,包括p值转换…...