PyTorch系列教程:编写高效模型训练流程
当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后传递以及参数更新。
模型训练流程
PyTorch训练循环流程通常包括:
- 加载数据
- 批量处理
- 执行正向传播
- 计算损失
- 反向传播
- 更新权重
一个典型的训练流程将这些步骤合并到一个迭代过程中,在数据集上迭代多次,或者在训练的上下文中迭代多个epoch。

1. 搭建环境
在编写代码之前,请确保在本地环境中设置了PyTorch。这通常需要安装PyTorch和其他依赖项:
pip install torch torchvision
下面演示为建立一个有效的训练循环奠定了基本路径的示例。
2. 数据加载
数据加载是使用DataLoader完成的,它有助于数据的批量处理:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
data_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
DataLoader在这里被设计为以64个为单位的批量获取数据,在数据传递中进行随机混淆。
3. 模型初始化
一个使用PyTorch的简单神经网络定义如下:
import torch.nn as nn
import torch.nn.functional as Fclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return F.log_softmax(x, dim=1)
这里,784指的是输入维度(28x28个图像),并创建一个输出大小为10个类别的顺序前馈网络。
4. 建立训练循环
定义损失函数和优化器:为了改进模型的预测,必须定义损失和优化器:
import torch.optim as optimmodel = SimpleNN()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 实现训练循环
有效的训练循环的本质在于正确的步骤顺序:
epochs = 5
for epoch in range(epochs):running_loss = 0for images, labels in train_loader:optimizer.zero_grad() # Zero the parameter gradientsoutput = model(images) # Forward passloss = criterion(output, labels) # Calculate lossloss.backward() # Backward passoptimizer.step() # Optimize weightsrunning_loss += loss.item()print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader)}")
注意,每次迭代都需要重置梯度、通过网络处理输入、计算误差以及调整权重以减少该误差。
性能优化
使用以下策略提高循环效率:
-
使用GPU:将计算转移到GPU上,以获得更快的处理速度。如果GPU可用,使用to(‘cuda’)转换模型和输入。
-
数据并行:利用多gpu设置与dataparlele模块来分发批处理。
-
FP16训练:使用自动混合精度(AMP)来加速训练并减少内存使用,而不会造成明显的精度损失。
在 PyTorch 中使用 FP16(半精度浮点数)训练 可以显著减少显存占用、加速计算,同时保持模型精度接近 FP32。以下是详细指南:
1. FP16 的优势
- 显存节省:FP16 占用显存是 FP32 的一半(例如,1024MB 显存在 FP32 下可容纳约 2000 万参数,在 FP16 下可容纳约 4000 万)。
- 计算加速:NVIDIA 的 Tensor Core 支持 FP16 矩阵运算,速度比 FP32 快数倍至数十倍。
- 适合大规模模型:如 Transformer、Vision Transformer(ViT)等参数量大的模型。
2. 实现 FP16 训练的两种方式
(1) 自动混合精度(Automatic Mixed Precision, AMP)
PyTorch 的 torch.cuda.amp 自动管理 FP16 和 FP32,减少手动转换的复杂性。
python
import torch
from torch.cuda.amp import autocast, GradScalermodel = model.to("cuda") # 确保模型在 GPU 上
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = GradScaler() # 梯度缩放器for data, target in dataloader:data = data.to("cuda").half() # 输入转为 FP16target = target.to("cuda")with autocast(): # 自动切换 FP16/FP32 计算output = model(data)loss = criterion(output, target)scaler.scale(loss).backward() # 梯度缩放scaler.step(optimizer) # 更新参数scaler.update() # 重置缩放器
关键点:
autocast()内部自动将计算转换为 FP16(若 GPU 支持),梯度累积在 FP32。GradScaler()解决 FP16 下梯度下溢问题。
(2) 手动转换(低级用法)
直接将模型参数、输入和输出转为 FP16,但需手动管理精度和稳定性。
python
model = model.half() # 模型参数转为 FP16
for data, target in dataloader:data = data.to("cuda").half() # 输入转为 FP16target = target.to("cuda")output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()
缺点:
- 可能因数值不稳定导致训练失败(如梯度消失)。
- 不支持动态精度切换(如部分层用 FP32)。
3. FP16 训练的注意事项
(1) 设备支持
- NVIDIA GPU:需支持 Tensor Core(如 Volta 架构以上的 GPU,包括 Tesla V100、A100、RTX 3090 等)。
- AMD GPU:部分型号支持 FP16 计算,但 AMP 功能受限(需使用
torch.backends.cudnn.enabled = False)。
(2) 学习率调整
- FP16 的初始学习率通常设为 FP32 的 2~4 倍(因梯度放大),需配合学习率调度器(如
CosineAnnealingLR)。
(3) 损失缩放(Loss Scaling)
-
FP16 的梯度可能过小,导致
update()时下溢。解决方案:- 自动缩放:使用
GradScaler()(推荐)。 - 手动缩放:将损失乘以一个固定因子(如
1e4),反向传播后再除以该因子。
- 自动缩放:使用
(4) 模型初始化
- FP16 参数初始化值不宜过大,否则可能导致
nan。建议初始化时用 FP32,再转为 FP16。
(5) 检查数值稳定性
- 训练过程中监控损失是否为
nan或无穷大。 - 可通过
torch.set_printoptions(precision=10)打印中间结果。
4. FP16 vs FP32 精度对比
| 模型 | FP32 精度损失 | FP16 精度损失 |
|---|---|---|
| ResNet-18 | 微小 | 可忽略 |
| BERT-base | 微小 | ~1-2% |
| GPT-2 | 微小 | ~3-5% |
结论:多数任务中 FP16 的精度损失可接受,但需通过实验验证。
5. 常见错误及解决
| 错误现象 | 解决方案 |
|---|---|
RuntimeError: CUDA error: out of memory | 减少 batch size 或清理缓存 (torch.cuda.empty_cache()) |
nan 或 inf | 调整学习率、检查数据预处理、启用梯度缩放 |
InvalidArgumentError | 确保输入数据已正确转换为 FP16 |
- 推荐使用
autocast+GradScaler:平衡易用性和性能。 - 优先在 NVIDIA GPU 上使用:AMD GPU 的 FP16 支持较弱。
- 从小批量开始测试:避免显存不足或数值不稳定。
通过合理配置,FP16 可以在几乎不损失精度的情况下显著提升训练速度和显存利用率。
最后总结
高效的训练循环为优化PyTorch模型奠定了坚实的基础。通过遵循适当的数据加载过程,模型初始化过程和系统的训练步骤,你的训练设置将有效地利用GPU资源,并通过数据集快速迭代,以构建健壮的模型。
相关文章:
PyTorch系列教程:编写高效模型训练流程
当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后…...
10 【HarmonyOS NEXT】 仿uv-ui组件开发之Avatar头像组件开发教程(一)
温馨提示:本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦! 目录 第一篇:Avatar 组件基础概念与设计1. 组件概述2. 接口设计2.1 形状类型定义2.2 尺寸类型定义2.3 组件属性接口 3. 设计原则4. 使用…...
C++编程指南24 - 避免线程频繁的创建和销毁
一:概述 线程的创建和销毁是昂贵的操作,尤其在多线程程序中频繁创建和销毁线程时,可能会导致性能问题。 二:示例 这段代码中,dispatcher 每收到一个 Message 就创建一个新的线程来处理这个消息。这种方式虽然简单&…...
C语言——【全局变量和局部变量】
🚀个人主页:fasdfdaslsfadasdadf 📖收入专栏:C语言 🌍文章目入 1.🚀 全局变量2.🚀 局部变量3.🚀 局部和全局变量,名字相同呢? 1.🚀 全局变量 全局变量&…...
浅谈 DeepSeek 对 DBA 的影响
引言: 在人工智能技术飞速发展的背景下,DeepSeek 作为一款基于混合专家模型(MoE)和强化学习技术的大语言模型,正在重塑传统数据库管理(DBA)的工作模式。通过结合其强大的自然语言处理能力、推理…...
Web服务器配置
配置虚拟主机:启动XAMPP的Apache,在htdocs目录中创建www.php.test目录 创建index.html,内容为“Welcome www.php.test”,访问两个虚拟主机 访问权限控制 在Apache的主配置文件httpd.conf中,默认已经添加了一些目录的…...
DeepSeek-R1本地化部署(Mac)
一、下载 Ollama 本地化部署需要用到 Ollama,它能支持很多大模型。官方网站:https://ollama.com/ 点击 Download 即可,支持macOS,Linux 和 Windows;我下载的是 mac 版本,要求macOS 11 Big Sur or later,Ol…...
Java面试第九山!《SpringBoot框架》
引言 你是否经历过这样的场景?想快速开发一个Java Web应用,却被XML配置、依赖冲突、服务器部署搞得焦头烂额。Spring Boot的诞生,正是为了解决这些"配置地狱"问题。 对比项Spring Boot传统 Spring配置复杂度自动配置,…...
Java 中数据脱敏的实现
数据脱敏 首先,要思考一个问题,SpringBoot 查询到的一条数据是一个 Java 对象,为什么返回给前端时候,前端拿到的却是 JSON 格式的数据呢? 是因为 SpringBoot 默认采用了 Jackson 作为序列化器,而 Jackson…...
PyQt组件间的通信方式
PyQt组件间的通信方式 PyQt组件间的通信方式 1. 组件介绍 1.1 组件的定义1.2 组件的分类 2. 组件的通信方式 2.1 信号与槽(Signal & Slot) 1. 组件介绍 在 Qt 框架中,组件(Component)是构建图形用户界面&am…...
视频理解开山之作 “双流网络”
1 论文核心信息 1.1核心问题 任务:如何利用深度学习方法进行视频中的动作识别(Action Recognition)。挑战: 视频包含时空信息,既需要捕捉静态外观特征(Spatial Information),也需要…...
基于Matlab的人脸识别的二维PCA
一、基本原理 传统 PCA 在处理图像数据时,需将二维图像矩阵拉伸为一维向量,这使得数据维度剧增,引发高计算成本与存储压力。与之不同,2DPCA 直接基于二维图像矩阵展开运算。 它着眼于图像矩阵的列向量,构建协方差矩阵…...
Java直通车系列15【Spring MVC】(ModelAndView 使用)
目录 1. ModelAndView 概述 2. ModelAndView 的主要属性和方法 主要属性 主要方法 3. 场景示例 示例 1:简单的 ModelAndView 使用 示例 2:使用 ModelAndView 处理列表数据 示例 3:使用 ModelAndView 处理异常情况 1. ModelAndView 概…...
考研数一非数竞赛复习之Stolz定理求解数列极限
在非数类大学生数学竞赛中,Stolz定理作为一种强大的工具,经常被用来解决和式数列极限的问题,也被誉为离散版的’洛必达’方法,它提供了一种简洁而有效的方法,使得原本复杂繁琐的极限计算过程变得直观明了。本文&#x…...
Java在小米SU7 Ultra汽车中的技术赋能
目录 一、智能驾驶“大脑”与实时数据 场景一:海量数据的分布式计算 场景二:实时决策的毫秒级响应 场景三:弹性扩展与容错机制 技术隐喻: 二、车载信息系统(IVI)的交互 场景一:Android Automo…...
【简单的C++围棋游戏开发示例】
C围棋游戏开发简单示例(控制台版) 核心代码实现 #include <iostream> #include <vector> #include <queue> using namespace std;const int SIZE 9; // 简化棋盘为9x9:ml-citation{ref"1" data"citationList&…...
DeepSeek R1-7B 医疗大模型微调实战全流程分析(全码版)
DeepSeek R1-7B 医疗大模型微调实战全流程指南 目录 环境配置与硬件优化医疗数据工程微调策略详解训练监控与评估模型部署与安全持续优化与迭代多模态扩展伦理与合规体系故障排除与调试行业应用案例进阶调优技巧版本管理与迭代法律风险规避成本控制方案文档与知识传承1. 环境配…...
HCIA-DHCP
1、定义:DHCP即动态主机配置协议,通过C/S模型架构,无需主机配置IP地址,自动分配网络配置参数的网络协议。 2、作用 对比项目无 DHCP有 DHCP配置难度配置多,容易出错自动为客户端分配 IP 地址及其他网络配置参数&…...
前端面试题 口语化复述解答(从2025.3.8 开始频繁更新中)
背景 看了很多面试题及其答案。但是过于标准化,一般不能直接用于回复面试官,这里我将总结一系列面试题,用于自我复习也乐于分享给大家,欢迎大家提供建议,我必不断完善之。 Javascript ES6 1. var let const 的区别…...
网络安全技术和协议(高软43)
系列文章目录 网络安全技术和协议 文章目录 系列文章目录前言一、网络安全技术1.防火墙2.入侵检测系统IDS3.入侵防御系统IPS 二、网络攻击和威胁三、网络安全协议四、真题在这里插入图片描述 总结 前言 本节讲明网络安全技术和协议方面的相关知识。 一、网络安全技术 1.防火…...
K8S学习之基础十七:k8s的蓝绿部署
蓝绿部署概述 蓝绿部署中,一共有两套系统,一套是正在提供服务的系统,一套是准备发布的系统。两套系统都是功能完善、正在运行的系统,只是版本和对外服务情况不同。 开发新版本,要用新版本替换线上的旧版本&…...
【统计至简】【古典概率模型】联合概率、边缘概率、条件概率、全概率
联合概率、边缘概率、条件概率 联合概率边缘概率条件概率全概率 一副标准扑克牌有 54 张,包括 52 张常规牌(13 个点数,每个点数有 4 种花色)和 2 张王(大、小王)。我们从中随机抽取一张牌,定义以…...
批量插入对比-mysql-oracle-sqlserver
单个插入mysql //单个 根据有值就插入,无值不改动 <insert id"insertOne" keyColumn"id" keyProperty"id"parameterType"com.test.log" useGeneratedKeys"true">insert into test_mysql_tab<trim p…...
saltstack通过master下发脚本批量修改minion_id,修改为IP
通过master下发脚本批量修改minion_id,以修改为IP为例 通过cmd.script远程执行shell脚本修改minion_id,步骤如下: # 下发脚本并执行 >> salt old_minion_id cmd.script salt://modify_minion_id.sh saltenvdev #输出结果 old_minion_id:Minion di…...
【网络】TCP常考知识点详解
TCP报文结构 TCP报文由**首部(Header)和数据(Data)**两部分组成。首部包括固定部分(20字节)和可选选项(最多40字节),总长度最大为60字节。 1. 首部固定部分 源端口&…...
LeetCode1137 第N个泰波那契数
泰波那契数列求解:从递归到迭代的优化之路 在算法的世界里,数列问题常常是我们锻炼思维、提升编程能力的重要途径。今天,让我们一同深入探讨泰波那契数列这一有趣的话题。 泰波那契数列的定义 泰波那契序列 Tn 有着独特的定义方式…...
六十天前端强化训练之第十四天之深入理解JavaScript异步编程
欢迎来到编程星辰海的博客讲解 目录 一、异步编程的本质与必要性 1.1 单线程的JavaScript运行时 1.2 阻塞与非阻塞的微观区别 1.3 异步操作的性能代价 二、事件循环机制深度解析 2.1 浏览器环境的事件循环架构 核心组件详解: 2.2 执行顺序实战分析 2.3 Nod…...
利用EasyCVR平台打造化工园区视频+AI智能化监控管理系统
化工园区作为化工产业的重要聚集地,其安全问题一直是社会关注的焦点。传统的人工监控方式效率低下且容易出现疏漏,已经难以满足日益增长的安全管理需求。 基于EasyCVR视频汇聚平台构建的化工园区视频AI智能化应用方案,能够有效解决这些问题&…...
【VUE2】第三期——样式冲突、组件通信、异步更新
目录 1 scoped解决样式冲突 2 data写法 3 组件通信 3.1 父子关系 3.1.1 父向子传值 props 3.1.2 子向父传值 $emit 3.2 非父子关系 3.2.1 event bus 事件总线 3.2.2 跨层级共享数据 provide&inject 4 props 4.1 介绍 4.2 props校验完整写法 5 v-model原理 …...
深度学习分类回归(衣帽数据集)
一、步骤 1 加载数据集fashion_minst 2 搭建class NeuralNetwork模型 3 设置损失函数,优化器 4 编写评估函数 5 编写训练函数 6 开始训练 7 绘制损失,准确率曲线 二、代码 导包,打印版本号: import matplotlib as mpl im…...
