当前位置: 首页 > article >正文

2025-05-31 Python深度学习9——网络模型的加载与保存

文章目录

  • 1 使用现有网络
  • 2 修改网络结构
    • 2.1 添加新层
    • 2.2 替换现有层
  • 3 保存网络模型
    • 3.1 完整保存
    • 3.2 参数保存(推荐)
  • 4 加载网络模型
    • 4.1 加载完整模型文件
    • 4.2 加载参数文件
  • 5 Checkpoint
    • 5.1 保存 Checkpoint
    • 5.2 加载 Checkpoint

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ PyTorch 通过torchvision.models提供预训练模型(如 VGG16)。

​ 网址链接:https://docs.pytorch.org/vision/stable/models.html。

1 使用现有网络

​ 以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

image-20250531103635500

方法一:使用随机初始化权重

​ 将 weights 设置为 None,从 0 开始训练自己的网络。

vgg16_false = torchvision.models.vgg16(weights=None)  # 权重随机初始化

方法二:加载预训练权重

​ 也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。

vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

​ 可打印网络查看其模型结构:

print(vgg16_true)
image-20250531104433678
...
image-20250531104447912

2 修改网络结构

2.1 添加新层

​ 使用add_module在分类器(classifier)后追加全连接层:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
image-20250531104536554

2.2 替换现有层

​ 直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):

vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 替换第6层
image-20250531104551228

3 保存网络模型

​ 使用torch.save()方法保存网络模型。文件扩展名推荐使用.pt.pth

3.1 完整保存

​ 将模型类和参数一并保存到文件中。

torch.save(vgg16, 'vgg16_method1.pth')  # 包含模型类和参数
  • 优点:加载时无需重新定义模型结构。
  • 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。

3.2 参数保存(推荐)

​ 仅保存参数字典到文件中。

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')  # 仅保存参数字典
  • 优点:文件小,灵活性强,适合生产部署。

示例

import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

4 加载网络模型

​ 使用torch.load()方法加载网络模型。

4.1 加载完整模型文件

​ 加载完整模型时,需将 weights_only 参数设置为 False。

model = torch.load('vgg16_method1.pth', weights_only=False)  # 需确保模型类已定义

​ 模型打印结果如下:

print(model)
image-20250531111142517

注意

​ 若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。

  • model_save.py

    # model_save.pyimport torch
    from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel()
    torch.save(model, 'my_model_method1.pth')
    
  • model_load.py

    import torchmodel = torch.load('my_model_method1.pth', weights_only=False)  # 报错,找不到 MyModel 的定义
    

    先运行 model_save.py,再运行 model_load.py,则会出现以下报错:

image-20250531110244566

4.2 加载参数文件

​ 首先,使用torch.load()方法加载网络模型。

​ 使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()加载参数数据。

vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict)  # 需结构匹配

​ 模型打印结果是参数字典:

print(model_dict)
image-20250531111411199

注意

​ 模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。

torch.load('model.pth', map_location=torch.device('cpu'))

​ 将参数加载到模型后,手动迁移到 GPU:

model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda:0')

5 Checkpoint

​ 使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar

5.1 保存 Checkpoint

​ 要保存一个模型的 Checkpoint,通常需要保存以下数据:

  • 模型的 state_dict(状态字典);
  • 优化器的状态;
  • 额外的信息,如 epoch 等。
import torch# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')

5.2 加载 Checkpoint

​ 加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。

# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以继续训练
model.train()  # 确保模型处于训练模式

相关文章:

2025-05-31 Python深度学习9——网络模型的加载与保存

文章目录 1 使用现有网络2 修改网络结构2.1 添加新层2.2 替换现有层 3 保存网络模型3.1 完整保存3.2 参数保存(推荐) 4 加载网络模型4.1 加载完整模型文件4.2 加载参数文件 5 Checkpoint5.1 保存 Checkpoint5.2 加载 Checkpoint 本文环境: Py…...

长安链起链调用合约时docker ps没有容器的原因

在调用这个命令的时候,发现并没有出现官方预期的合约容器,这是因为我们在起链的时候没有选择用docker的虚拟环境,实际上这不影响后续的调用,如果想要达到官方的效果那么你只需要在起链的时候输入yes即可,如图三所示...

Appium+python自动化(七)- 认识Appium- 上

简介 经过前边的各项准备工作,终于才把appium搞定。 一、appium自我介绍 appium是一款开源的自动化测试工具,可以支持iOS和安卓平台上的原生的,基于移动浏览器的,混合的应用(APP)。 1、 使用appium进…...

数据中心双活架构解决方案

数据中心双活架构解决方案 数据中心双活架构(Active-Active Data Center)旨在实现业务高可用、负载均衡和灾难自动切换。以下是完整的解决方案,涵盖架构设计、关键技术、实施步骤及最佳实践。 1. 双活架构设计 1.1 基本架构模型 同城双活(Metro Active-Active) 两个数据…...

YOLOv5 详解:从原理到实战的全方位解析

在计算机视觉领域,目标检测作为核心任务之一,始终吸引着众多研究者和开发者的目光。YOLO(You Only Look Once)系列算法凭借其高效、准确的特点,在目标检测领域占据重要地位。而 YOLOv5 作为 YOLO 系列算法的重要成员&a…...

模块联邦:更快的微前端方式!

什么是模块联邦 在前端项目中,不同团队之间的业务模块可能有耦合,比如A团队的页面里有一个富文本模块(组件),而B团队 的页面恰好也需要使用这个富文本模块。 传统模式下,B团队只能去抄A团队的代码&#x…...

前端基础学习html+css+js

HTML 区块 div标签,块级标签 span包装小部分文本,行内元素 表单 CSS css选择器 css属性 特性blockinlineinline-block是否换行✅ 换行❌ 不换行❌ 不换行可设置宽高✅ 支持❌ 不支持✅ 支持常见元素div容器 p段落 h标题span文本容器 a超链接img图片…...

手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席

手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席 --局域网SIP坐席呼叫 上一篇:手机打电话时由对方DTMF响应切换多级IVR语音菜单(完结) 下一篇:安卓App识别手机系统弹授权框包含某段文字-并自动点击确定按钮 一、前言 …...

TCP三次握手/四次握手-TCP/IP四层模型-SSL/TLS-HTTP-HTTPS

重要概念 seq ( Squence Number ) 序列号,用于数据排序、去重,防止数据包乱序 ack ( Acknowledgement Number ) 确认好,表示期望接受的下一个字节序号,用于确认数据包被对方接受 TCP三次握手是建立可靠连接的过程,确…...

SAP Business One:无锡哲讯科技助力中小企业数字化转型的智慧之选

数字化转型,中小企业的必经之路 在当今竞争激烈的商业环境中,数字化转型已不再是大型企业的专利,越来越多的中小企业开始寻求高效、灵活的管理系统来优化业务流程、提升运营效率。作为全球领先的企业管理软件,SAP Business One…...

【Ubuntu远程桌面】

Ubuntu-远程桌面 ubuntu环境rustdesk-1.4.0-aarch64.deb安装rustdesk注意事项:报错:可能会在远程连接时候显示‘No displays’解决方法1. 安装 CUDA(如果需要)2. 解决 XDG 桌面门户问题3. 检查 RustDesk 客户端日志 总结 kill --t…...

⚡ Linux 系统安装与配置 Vim 编辑器(包括 Vim 插件管理器)

⚡ Linux 系统安装与配置 Vim 编辑器(包括 Vim 插件管理器) 📌 1. Vim 简介 Vim(Vi IMproved)是一款高度可定制的文本编辑器,基于早期的 vi 编辑器扩展而来。 它支持语法高亮、插件扩展、多种编程语言&am…...

小型语言模型:为何“小”才是“大”?

当说到人工智能(AI)的时候,大家通常会想到那些拥有数十亿参数的超大型语言模型,它们能做出一些令人惊叹的事情。 厉害不厉害?绝对厉害! 但对于大多数企业和开发者来说,实用吗?可能…...

雪花算法:分布式ID生成的优雅解决方案

一、雪花算法的核心机制与设计思想 雪花算法(Snowflake)是由Twitter开源的分布式ID生成算法,它通过巧妙的位运算设计,能够在分布式系统中快速生成全局唯一且趋势递增的ID。 1. 基本结构 雪花算法生成的是一个64位(lo…...

针对PostgreSQL中pg_wal目录占用过大的系统性解决方案

​一、问题现象与根本原因​ 当pg_wal目录占用超过预期(如数十GB甚至占满磁盘),通常由以下原因导致 ​长事务未提交​:未完成的事务会阻塞WAL日志清理。​复制槽未释放​:逻辑复制或流复制槽未及时清理,导…...

git push Git远端意外挂断

git push Git远端意外挂断 枚举对象中: 99, 完成. 对象计数中: 100% (99/99), 完成. 使用 8 个线程进行压缩 压缩对象中: 100% (78/78), 完成. send-pack: unexpected disconnect while reading sideband packet 写入对象中: 100% (82/82), 2.78 MiB | 5.56 MiB/s, 完成. 总共…...

python学习day34

GPU训练及类的call方法 知识点回归: CPU性能的查看:看架构代际、核心数、线程数GPU性能的查看:看显存、看级别、看架构代际GPU训练的方法:数据和模型移动到GPU device上类的call方法:为什么定义前向传播时可以直接写作…...

秋招Day12 - 计算机网络 - 网络综合

从浏览器地址栏输入URL到显示网页的过程了解吗? 从在浏览器地址栏输入 URL 到显示网页的完整过程,并不是一个单一的数据包从头到尾、一次性地完成七层封装再七层解析的过程。 而是涉及到多次、针对不同目的、与不同服务器进行的、独立的网络通信交互&a…...

QT-JSON

#include <QJsonDocument>#include <QJsonObject>#include <QJsonArray>#include <QFile>#include <QDebug>void createJsonFile() {// 创建一个JSON对象 键值对QJsonObject jsonObj;jsonObj["name"] "John Doe";jsonObj[…...

IP 风险画像技术略解

IP 风险画像的技术定义与价值 IP 风险画像通过整合 IP 查询数据与 IP 离线库信息&#xff0c;结合机器学习算法&#xff0c;为每个 IP 地址生成多维度风险评估模型。其核心价值在于将传统的静态 IP 黑名单升级为动态风险评估体系&#xff0c;可实时识别新型网络威胁&#xff0…...

秋招Day12 - 计算机网络 - 基础

说一下计算机网络体系结构 OSI七层模型&#xff0c;TCP/IP四层模型和五层体系结构 说说OSI七层模型&#xff1f; 应用层&#xff1a;最靠近用户的层&#xff0c;用于处理特定应用程序的细节&#xff0c;提供了应用程序和网络服务之间的接口。表示层&#xff1a;确保从一个系…...

【网络安全】——Modbus协议详解:工业通信的“通用语言”

目录 一、初识Modbus&#xff1a;工业通信的基石 1.1 协议全称 1.2 协议简史 二、核心特性解析 2.1 架构设计 2.2 典型应用场景 三、协议族全景图 3.1 协议栈分类 3.2 版本演进对比 四、协议报文深度解析 4.1 Modbus RTU帧结构 4.2 Modbus TCP报文 五、通信机制实…...

MySQL 数据库备份与恢复利器:Percona XtraBackup 详解

一、XtraBackup 简介 1. 什么是 XtraBackup&#xff1f; XtraBackup 是 Percona 公司推出的免费开源工具&#xff0c;专为 InnoDB/XtraDB 引擎设计&#xff0c;支持 在线物理热备&#xff0c;具备以下核心特性&#xff1a; 非阻塞备份&#xff1a;备份过程中数据库仍可读写。…...

【GlobalMapper精品教程】095:如何获取无人机照片的拍摄方位角

文章目录 一、加载无人机照片二、计算方位角三、Globalmapper符号化显示方向四、arcgis符号化显示方向一、加载无人机照片 打开软件,加载无人机照片,在GLobalmapperV26中文版中,默认显示如下的航线信息。 关于航线的起止问题,可以直接从照片名称来确定。 二、计算方位角 …...

小提琴图绘制-Graph prism

在 GraphPad Prism 中为小提琴图添加显著性标记(如*P<0.05)的步骤如下: 步骤1:完成统计检验 选择数据表:确保数据已按分组排列(如A列=Group1,B列=Group2)。执行统计检验: 点击工具栏 Analyze → Column analyses → Mann-Whitney test(非参数检验,适用于非正态数…...

写作即是生活

一个问题 “我是什么时候开始写作的呢&#xff1f;”请你先暂停一下&#xff0c;别往下读&#xff0c;先想想这个问题。 什么才是写作&#xff1f; 或许在想上个问题之后&#xff0c;你就会开始想问另外一个问题&#xff0c;什么才算是写作呢&#xff1f; 我的回答是&#x…...

进阶知识:Selenium底层原理深度解析

Selenium底层原理深度解析&#xff1a;网络IO密集型系统揭秘 一、Selenium核心组件解析 1.1 三大核心角色 客户端&#xff08;Client&#xff09; 扮演"指挥官"角色&#xff0c;负责&#xff1a; 编写测试脚本&#xff08;模拟用户点击、输入等操作&#xff09;发送…...

基于 Flickr30k-Entities 数据集 的 Phrase Localization

以下示例基于 Flickr30k-Entities 数据集中的标注&#xff0c;以及近期&#xff08;以 TransVG &#xff08;Li et al. 2021&#xff09;为例&#xff09;在短语定位&#xff08;Phrase Grounding&#xff09;任务上的评测结果&#xff0c;展示了单张图片中若干名词短语的定位情…...

[GHCTF 2025]SQL???

打开题目在线环境&#xff1a; 先尝试注入&#xff1a; id1;show databases; 发现报错&#xff0c;后来看了wp才知道这个题目是SQLite注入。 我看的是这个师傅的wp: https://blog.csdn.net/2401_86190146/article/details/146164505?ops_request_misc%257B%2522request%255Fid…...

【科研绘图系列】R语言绘制GO term 富集分析图(enrichment barplot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理画图code 2code 3系统信息介绍 本文介绍了使用R语言绘制GO富集分析条形图的方法。通过加载ggplot2等R包,对GO term数据进行预处理,包括p值转换…...