Pytorch框架06-现有网络模型(修改/使用/保存/加载)
3.9 PyTorch网络模型的修改及使用
PyTorch 提供了多个预训练的网络模型,涵盖了广泛的计算机视觉任务,如图像分类、目标检测、语义分割等。这些预训练模型在 ImageNet 等大型数据集上进行了训练,并可以直接用于迁移学习或微调。
3.9.1 常见的现有Pytorch网络模型
- 分类模型
这些模型在图像分类任务中表现非常好,可以在 ImageNet 数据集上进行微调。
- ResNet 系列:
- ResNet18、ResNet34、ResNet50、ResNet101、ResNet110、ResNet152
- ResNet 引入了残差连接,解决了深度网络训练中的退化问题,广泛应用于各种视觉任务。
- VGG 系列:
- VGG11、VGG13、VGG16、VGG19
- VGG 模型有着简单而深度的结构,通常用于图像分类任务。
- DenseNet 系列:
- DenseNet121、DenseNet169、DenseNet201、DenseNet161
- DenseNet 使用密集连接,每一层与前面的所有层相连,能够更有效地利用特征。
- Inception 系列:
- Inception v3
- Inception 模型通过并行的卷积核不同尺寸的卷积来捕捉图像的多尺度特征。
- MobileNet 系列:
- MobileNetV2、MobileNetV3
- 适用于轻量级应用,能够在移动设备或嵌入式设备上进行高效推理。
- AlexNet:
- 经典的卷积神经网络,曾经在 ImageNet 上取得突破性进展,适用于图像分类。
- EfficientNet 系列:
- EfficientNet B0 到 B7
- 高效的神经网络架构,通过自动化搜索来优化网络的深度、宽度和分辨率。
- SqueezeNet:
- 一个轻量级的网络模型,通过减少模型的参数量来降低计算需求。
- 目标检测模型
PyTorch 提供了目标检测的预训练模型,适用于检测图像中的物体。
- Faster R-CNN:
- 一种基于区域卷积神经网络 (R-CNN) 的目标检测模型,广泛应用于检测任务。
- Mask R-CNN:
- 扩展了 Faster R-CNN,在目标检测的基础上加入了实例分割功能。
- RetinaNet:
- 采用焦点损失函数,改进了处理前景和背景样本的不均衡问题。
- YOLO(You Only Look Once):
- 通过一个网络直接输出目标的位置和类别,实现快速且准确的目标检测。
- 语义分割模型
语义分割任务是将图像中的每个像素分配一个标签。
- FCN (Fully Convolutional Network):
- 用于像素级的图像分割。
- U-Net:
- U-Net 架构常用于医学图像的语义分割,具有编码-解码的结构。
- DeepLabV3:
- 用于语义分割的深度学习模型,特别是在捕捉图像中的边缘和细节上表现出色。
- 图像生成模型
这些模型常用于图像生成、修复、增强等任务。
- Generative Adversarial Networks (GAN):
- 用于生成与训练数据分布相似的图像。PyTorch 提供了许多不同种类的 GAN 模型,例如 DCGAN、WGAN 等。
- Pix2Pix:
- 用于图像到图像的转换任务,如图像修复或风格迁移。
- CycleGAN:
- 用于无监督的图像到图像的转换,如将马的照片转换成斑马的照片等。
- Transformer 模型
随着 Transformer 在自然语言处理(NLP)中的成功,越来越多的 Transformer 网络架构也被应用于计算机视觉任务。
- Vision Transformer (ViT):
- 使用 Transformer 架构来处理图像,近年来在多个任务中表现出色。
- DeiT (Data-efficient Image Transformer):
- 通过优化训练过程使得 Vision Transformer 可以在较少数据下进行有效训练。
3.9.2 现有网络模型的修改与使用
为了方便讲解,我们选取分类模型Vgg16来进行讲解:
VGG16 是一种卷积神经网络(CNN)架构,由牛津大学视觉几何组(Visual Geometry Group,简称 VGG)提出。它在 2014 年的 ImageNet 图像分类挑战中表现非常出色,因此被广泛使用。VGG16 网络模型的名字中的 “16” 表示该网络包含 16 层权重(包括卷积层和全连接层)。
- 方法:
torchvision.models.vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any)
- 参数:
1.weights (VGG16_Weights, optional):指定是否使用预训练的权重。可以选择使用 ImageNet 预训练的权重(VGG16_Weights.DEFAULT)或不使用(None)。
2.progress (bool, optional):控制是否显示下载进度条。默认为 True,表示显示进度条,False 表示不显示。
3.kwargs:传递给 VGG 基类的额外参数,常见的包括:
num_classes: 分类数目(默认为 1000)。
batch_norm: 是否使用批归一化(默认为 False)。
- 使用和修改步骤
1. 选择是否训练好的模型
# 加载预训练的 VGG16 模型
model = model.vgg16(weights=model.VGG16_Weights.DEFAULT)
# 加载没有预训练权重的 VGG16 模型(从头开始训练)
model_no_weights = model.vgg16(weights=None)
Vgg(...
(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
2. 选择要训练的数据集:CIFAR10
data = torchvision.datasets.CIFAR10('data-train', train=False, transform=torchvision.transforms.ToTensor(), download=True)
data_loader = DataLoader(data, batch_size=1)
3. 修改模型:对于要训练的CIFAR10只有10个特征,因此需要修改原有模型
方法一:使用 add_module 添加新的全连接层
model.classifier.add_module('add_Linear', nn.Linear(4096, 10))
"""add_module 允许你动态地向模型的某个部分(这里是 classifier)添加新层,通常是为了扩展或修改模型。
这种方式是将新层添加到模型结构的末尾。"""方法二:直接修改模型的输出层
model.classifier[6] = nn.Linear(4096, 10)
在这里,model.classifier[6] 是指 VGG16 模型中的第七个层(注意,索引从0开始)。这个位置是VGG16模型的最后一个全连接层。我们直接将该层替换为一个新的全连接层,输出大小为 10,即CIFAR-10的类别数。
3.9.3 网络模型的保存与读取
在 PyTorch 中,保存模型的方式有两种:保存整个模型(包括结构和参数)和仅保存模型的参数(state_dict)。以下是对你提供的两种方式的详细解释:
class Net(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 16, 3)def forward(self, input):return self.conv(input)
net = Net()
1. 方式一:保存整个模型(结构和参数)
torch.save(net, 'model/save_model_method1.pth')
解释:
- 这种方式会保存整个模型(包括模型的架构、参数以及优化器状态等)。这意味着你可以直接加载这个文件,并恢复整个模型。
- 保存后,加载时不需要重新定义模型架构,PyTorch 会自动恢复整个模型。
- 但缺点是它依赖于特定的模型类定义,意味着如果你迁移到不同的环境或修改了模型的定义,可能会出现问题。
加载模型:
# 加载整个模型
model = torch.load('model/saveModel1.pth')
2. 方式二:仅保存模型的参数(state_dict)
torch.save(net.state_dict(), 'model/save_model_method2.pth')
解释:
- 这种方式只保存模型的参数(即
state_dict),不保存模型的架构。state_dict包含了模型的所有可训练参数,如权重、偏置等。 - 保存
state_dict后,加载时你需要先定义模型的结构,然后加载参数到模型中。
加载模型:
# 重新初始化模型
model = MyModel() # 确保你定义了模型的结构# 加载保存的 state_dict
model.load_state_dict(torch.load('model/saveModel2.pth'))
- 方式一(保存整个模型):
- 适用于需要方便地保存和恢复整个模型的情况。
- 依赖于模型类的定义,跨环境迁移时可能会出现问题。
- 方式二(保存
state_dict):- 推荐的做法,通常用于保存训练好的模型的参数。
- 需要在加载时手动定义模型结构。
- 更加灵活,能够避免依赖于模型定义的版本问题。
相关文章:
Pytorch框架06-现有网络模型(修改/使用/保存/加载)
3.9 PyTorch网络模型的修改及使用 PyTorch 提供了多个预训练的网络模型,涵盖了广泛的计算机视觉任务,如图像分类、目标检测、语义分割等。这些预训练模型在 ImageNet 等大型数据集上进行了训练,并可以直接用于迁移学习或微调。 3.9.1 常见的…...
【亲测有效】百度Ueditor富文本编辑器添加插入视频、视频不显示、和插入视频后二次编辑视频标签不显示,显示成img标签,二次保存视频被替换问题,解决方案
【亲测有效】项目使用百度Ueditor富文本编辑器上传视频相关操作问题 1.百度Ueditor富文本编辑器添加插入视频、视频不显示 2.百度Ueditor富文本编辑器插入视频后二次编辑视频标签不显示,在编辑器内显示成img标签,二次保存视频被替换问题 问题1࿱…...
MySQL 使用 `WHERE` 子句时 `COUNT(*)`、`COUNT(1)` 和 `COUNT(column)` 的区别解析
文章目录 1. COUNT() 函数的基本作用2. COUNT(*)、COUNT(1) 和 COUNT(column) 的详细对比2.1 COUNT(*) —— 统计所有符合条件的行2.2 COUNT(1) —— 统计所有符合条件的行2.3 COUNT(column) —— 统计某一列非 NULL 的记录数 3. 性能对比3.1 EXPLAIN 分析 4. 哪种方式更好&…...
laravel11设置中文语言包
安装中文语言包 Laravel 11 默认没有内置完整中文语言包,推荐使用第三方维护的完整翻译: # 通过 Composer 安装语言包 composer require laravel-lang/common --dev# 发布中文语言文件到项目 php artisan lang:add zh_CN这会自动将中文语言文件生成到 l…...
二、IDE集成DeepSeek保姆级教学(使用篇)
各位看官老爷好,如果还没有安装DeepSeek请查阅前一篇 一、IDE集成DeepSeek保姆级教学(安装篇) 一、DeepSeek在CodeGPT中使用教学 1.1、Edit Code 编辑代码 选中代码片段 —> 右键 —> CodeGPT —> Edit Code, 输入自然语言可编辑代码,点击S…...
网络七层模型—OSI参考模型详解
网络七层模型:OSI参考模型详解 引言 在网络通信的世界中,OSI(Open Systems Interconnection)参考模型是一个基础且核心的概念。它由国际标准化组织(ISO)于1984年提出,旨在为不同厂商的设备和应…...
四、Redis主从复制与读写分离
一、环境搭建 准备环境 IP角色192.168.10.101Master192.168.10.102Slave192.168.10.103Slave 创建配置/数据/日志目录 # 创建配置目录 mkdir -p /usr/local/redis/conf # 创建数据目录 mkdir -p /usr/local/redis/data # 创建日志目录 mkdir -p /usr/local/redis/log修改配置…...
爬虫框架与库
爬虫框架与库是用于网络数据抓取的核心工具,帮助开发者高效地从网页中提取结构化数据。 Requests:用于发送HTTP请求。 BeautifulSoup:用于解析HTML和XML。 Scrapy:强大的爬虫框架,适合大规模爬取。 Selenium&#…...
【保姆级视频教程(二)】YOLOv12训练数据集构建:标签格式转换-划分-YAML 配置 避坑指南 | 小白也能轻松玩转目标检测!
【2025全站首发】YOLOv12训练数据集构建:标签格式转换-划分-YAML 配置 避坑指南 | 小白也能轻松玩转目标检测! 文章目录 1. 数据集准备1.1 标签格式转换1.2 数据集划分1.3 yaml配置文件创建 2. 训练验证 1. 数据集准备 示例数据集下载链接:P…...
数据如何安全“过桥”?分类分级与风险评估,守护数据流通安全
信息化高速发展,数据已成为企业的核心资产,驱动着业务决策、创新与市场竞争力。随着数据开发利用不断深入,常态化的数据流通不仅促进了信息的快速传递与共享,还能帮助企业快速响应市场变化,把握商业机遇,实…...
本地大模型编程实战(24)用智能体(Agent)实现智能纠错的SQL数据库问答系统(3)
本文将实现这样一个 智能体(Agent) : 可以使用自然语言对 SQLite 数据库进行查询。即:用户用自然语言提出问题,智能体也用自然语言根据数据库的查询结果回答问题。增加一个自动对查询中的专有名词进行纠错的工具,这将明显提升查询…...
Apache DolphinScheduler系列1-单节点部署及测试报告
文章目录 整体说明一、部署环境二、版本号三、部署方案四、部署步骤4.1、上传部署包4.2、创建外部数据库4.3、修改元数据库配置4.4、上传MySQLl驱动程序4.5、初始化外部数据库4.6、启停服务4.7、访问页面五、常见问题及解决方式5.1、时间不一致5.2、异常终止5.3、大量日志5.4、…...
Java+SpringBoot+Vue+数据可视化的音乐推荐与可视化平台(程序+论文+讲解+安装+调试+售后)
感兴趣的可以先收藏起来,还有大家在毕设选题,项目以及论文编写等相关问题都可以给我留言咨询,我会一一回复,希望帮助更多的人。 系统介绍 在互联网技术以日新月异之势迅猛发展的浪潮下,5G 通信技术的普及、云计算能力…...
LVS+Keepalived 高可用集群搭建
一、高可用集群: 1.什么是高可用集群: 高可用集群(High Availability Cluster)是以减少服务中断时间为目地的服务器集群技术它通过保护用户的业务程序对外不间断提供的服务,把因软件、硬件、人为造成的故障对业务的影响…...
跟着AI学vue第十二章
第十二章:技术引领与社区共建 在熟练掌握Vue开发技能,并将其与前沿技术融合应用后, 第十二章是一个更具使命感与影响力的阶段,着重于引领技术发展方向和为社区贡献力量。 1. 推动Vue技术创新与实践 探索前沿技术融合࿱…...
PydanticToolsParser 工具(tool call)把 LLM 生成的文本转成结构化的数据(Pydantic 模型)过程中遇到的坑
PydanticToolsParser 的作用 PydanticToolsParser 是一个工具,主要作用是 把 LLM 生成的文本转成结构化的数据(Pydantic 模型),让代码更容易使用这些数据进行自动化处理。 换句话说,AI 生成的文本通常是自然语言&…...
python-leetcode-乘积最大子数组
152. 乘积最大子数组 - 力扣(LeetCode) class Solution:def maxProduct(self, nums: List[int]) -> int:if not nums:return 0max_prod nums[0]min_prod nums[0]result nums[0]for i in range(1, len(nums)):if nums[i] < 0:max_prod, min_prod…...
江协科技/江科大-51单片机入门教程——P[1-1] 课程简介P[1-2] 开发工具介绍及软件安装
本教程也力求在玩好单片机的同时了解一些计算机的基本概念,了解电脑的一些基本操作,了解电路及其元器件的基本理论,为我们学习更高级的单片机,入门IT和信息技术行业,打下一定的基础。 目录 1.课程简介 2.开发工具及…...
简单介绍JVM
1.什么是JVM? JVM就是Java虚拟机【Java Virtual Machine】,简称JVM。主要部分包括类加载子系统,运行时数据区,执行引擎,本地方法库等,接下来我们一一介绍 2.类加载子系统 JVM中运行的就是我们日常写的JA…...
【对话推荐系统】Towards Topic-Guided Conversational Recommender System 论文阅读
Towards Topic-Guided Conversational Recommender System 论文阅读 Abstract1 Introduction2 Related Work2.1 Conversation System2.2 Conversational Recommender System2.3 Dataset for Conversational Recommendation 3 Dataset Construction3.1 Collecting Movies for Re…...
当下弹幕互动游戏源码开发教程及功能逻辑分析
当下很多游戏开发者或者想学习游戏开发的人,想要了解如何制作弹幕互动游戏,比如直播平台上常见的那种,观众通过发送弹幕来影响游戏进程。需要涵盖教程的步骤和功能逻辑的分析。 首先,弹幕互动游戏源码开发教程部分应该分步骤&…...
STM32——HAL库开发笔记21(定时器2—输出比较)(参考来源:b站铁头山羊)
本文主要讲述输出比较及PWM信号相关知识。 一、概念 所谓输出比较,就是通过单片机的定时器向外输出精确定时的方波信号。 1.1 PWM信号 PWM信号即脉冲宽度调制信号。PWM信号的占空比 (高电压 所占周期 / 整个周期) * 100% 。所以PWM信号…...
YOLOv12 ——基于卷积神经网络的快速推理速度与注意力机制带来的增强性能结合
概述 实时目标检测对于许多实际应用来说已经变得至关重要,而Ultralytics公司开发的YOLO(You Only Look Once,只看一次)系列一直是最先进的模型系列,在速度和准确性之间提供了稳健的平衡。注意力机制的低效阻碍了它们在…...
动态内容加载的解决方案:Selenium与Playwright对比故障排查实录
方案进程 2024-09-01 09:00 | 接到亚航航班数据采集需求 2024-09-01 11:30 | 首次尝试使用Selenium遭遇Cloudflare验证 2024-09-01 14:00 | 切换Playwright方案仍触发反爬机制 2024-09-01 16:30 | 引入爬虫代理IPUA轮换策略 2024-09-02 10:00 | 双方案完整实现并通过压力测试故…...
NLP学习记录十:多头注意力
一、单头注意力 单头注意力的大致流程如下: ① 查询编码向量、键编码向量和值编码向量分别经过自己的全连接层(Wq、Wk、Wv)后得到查询Q、键K和值V; ② 查询Q和键K经过注意力评分函数(如:缩放点积运算&am…...
Spring基础01
Spring基础01 软件开发原则 OCP开闭原则:七大开发原则当中最基本的原则,其他的六个原则是为这个原则服务的。 对扩展开放,对修改关闭。在扩展系统功能的时候,没有修改之前写好的代码,就符合OCP原则,反之&a…...
Gurobi 并行计算的一些问题
最近尝试用 gurobi 进行并行计算,即同时用多个 cpu 核计算 gurobi 的 model,但是发现了不少问题。总体来看,gurobi 对并行计算的支持并不是那么好。 gurobi 官方对于并行计算的使用在这个网址,并有下面的大致代码: i…...
2025年2月,TVBOX接口最新汇总版
这里写自定义目录标题 1、离线版很必要2、关于在线版好还是离线版更实在,作个总结:★ 离线版的优点:★ 离线版的缺点: 3.1、 针对FM内置的写法;3.2、 如果是用在YSC,那么格式也要有些小小的改变3.2.1、 YSC…...
Dubbo RPC 原理
一、Dubbo 简介 Apache Dubbo 是一款高性能、轻量级的开源 RPC 框架,支持服务治理、协议扩展、负载均衡、容错机制等核心功能,广泛应用于微服务架构。其核心目标是解决分布式服务之间的高效通信与服务治理问题。 二、Dubbo 架构设计 1. 核心组件 Prov…...
qt5的中文乱码问题,QString、QStringLiteral 为 UTF-16 编码
qt5的中文乱码问题一直没有很明确的处理方案。 今天处理进程间通信时,也遇到了qt5乱码问题,一边是设置的GBK,一边设置的是UTF8,单向通信约定采用UTF8。 发送端保证发的是UTF8字符串,因为UTF8在网络数据包中没有字节序…...
