PyTorch系列教程:编写高效模型训练流程
当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后传递以及参数更新。
模型训练流程
PyTorch训练循环流程通常包括:
- 加载数据
- 批量处理
- 执行正向传播
- 计算损失
- 反向传播
- 更新权重
一个典型的训练流程将这些步骤合并到一个迭代过程中,在数据集上迭代多次,或者在训练的上下文中迭代多个epoch。

1. 搭建环境
在编写代码之前,请确保在本地环境中设置了PyTorch。这通常需要安装PyTorch和其他依赖项:
pip install torch torchvision
下面演示为建立一个有效的训练循环奠定了基本路径的示例。
2. 数据加载
数据加载是使用DataLoader完成的,它有助于数据的批量处理:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
data_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
DataLoader在这里被设计为以64个为单位的批量获取数据,在数据传递中进行随机混淆。
3. 模型初始化
一个使用PyTorch的简单神经网络定义如下:
import torch.nn as nn
import torch.nn.functional as Fclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return F.log_softmax(x, dim=1)
这里,784指的是输入维度(28x28个图像),并创建一个输出大小为10个类别的顺序前馈网络。
4. 建立训练循环
定义损失函数和优化器:为了改进模型的预测,必须定义损失和优化器:
import torch.optim as optimmodel = SimpleNN()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 实现训练循环
有效的训练循环的本质在于正确的步骤顺序:
epochs = 5
for epoch in range(epochs):running_loss = 0for images, labels in train_loader:optimizer.zero_grad() # Zero the parameter gradientsoutput = model(images) # Forward passloss = criterion(output, labels) # Calculate lossloss.backward() # Backward passoptimizer.step() # Optimize weightsrunning_loss += loss.item()print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader)}")
注意,每次迭代都需要重置梯度、通过网络处理输入、计算误差以及调整权重以减少该误差。
性能优化
使用以下策略提高循环效率:
-
使用GPU:将计算转移到GPU上,以获得更快的处理速度。如果GPU可用,使用to(‘cuda’)转换模型和输入。
-
数据并行:利用多gpu设置与dataparlele模块来分发批处理。
-
FP16训练:使用自动混合精度(AMP)来加速训练并减少内存使用,而不会造成明显的精度损失。
在 PyTorch 中使用 FP16(半精度浮点数)训练 可以显著减少显存占用、加速计算,同时保持模型精度接近 FP32。以下是详细指南:
1. FP16 的优势
- 显存节省:FP16 占用显存是 FP32 的一半(例如,1024MB 显存在 FP32 下可容纳约 2000 万参数,在 FP16 下可容纳约 4000 万)。
- 计算加速:NVIDIA 的 Tensor Core 支持 FP16 矩阵运算,速度比 FP32 快数倍至数十倍。
- 适合大规模模型:如 Transformer、Vision Transformer(ViT)等参数量大的模型。
2. 实现 FP16 训练的两种方式
(1) 自动混合精度(Automatic Mixed Precision, AMP)
PyTorch 的 torch.cuda.amp 自动管理 FP16 和 FP32,减少手动转换的复杂性。
python
import torch
from torch.cuda.amp import autocast, GradScalermodel = model.to("cuda") # 确保模型在 GPU 上
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = GradScaler() # 梯度缩放器for data, target in dataloader:data = data.to("cuda").half() # 输入转为 FP16target = target.to("cuda")with autocast(): # 自动切换 FP16/FP32 计算output = model(data)loss = criterion(output, target)scaler.scale(loss).backward() # 梯度缩放scaler.step(optimizer) # 更新参数scaler.update() # 重置缩放器
关键点:
autocast()内部自动将计算转换为 FP16(若 GPU 支持),梯度累积在 FP32。GradScaler()解决 FP16 下梯度下溢问题。
(2) 手动转换(低级用法)
直接将模型参数、输入和输出转为 FP16,但需手动管理精度和稳定性。
python
model = model.half() # 模型参数转为 FP16
for data, target in dataloader:data = data.to("cuda").half() # 输入转为 FP16target = target.to("cuda")output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()
缺点:
- 可能因数值不稳定导致训练失败(如梯度消失)。
- 不支持动态精度切换(如部分层用 FP32)。
3. FP16 训练的注意事项
(1) 设备支持
- NVIDIA GPU:需支持 Tensor Core(如 Volta 架构以上的 GPU,包括 Tesla V100、A100、RTX 3090 等)。
- AMD GPU:部分型号支持 FP16 计算,但 AMP 功能受限(需使用
torch.backends.cudnn.enabled = False)。
(2) 学习率调整
- FP16 的初始学习率通常设为 FP32 的 2~4 倍(因梯度放大),需配合学习率调度器(如
CosineAnnealingLR)。
(3) 损失缩放(Loss Scaling)
-
FP16 的梯度可能过小,导致
update()时下溢。解决方案:- 自动缩放:使用
GradScaler()(推荐)。 - 手动缩放:将损失乘以一个固定因子(如
1e4),反向传播后再除以该因子。
- 自动缩放:使用
(4) 模型初始化
- FP16 参数初始化值不宜过大,否则可能导致
nan。建议初始化时用 FP32,再转为 FP16。
(5) 检查数值稳定性
- 训练过程中监控损失是否为
nan或无穷大。 - 可通过
torch.set_printoptions(precision=10)打印中间结果。
4. FP16 vs FP32 精度对比
| 模型 | FP32 精度损失 | FP16 精度损失 |
|---|---|---|
| ResNet-18 | 微小 | 可忽略 |
| BERT-base | 微小 | ~1-2% |
| GPT-2 | 微小 | ~3-5% |
结论:多数任务中 FP16 的精度损失可接受,但需通过实验验证。
5. 常见错误及解决
| 错误现象 | 解决方案 |
|---|---|
RuntimeError: CUDA error: out of memory | 减少 batch size 或清理缓存 (torch.cuda.empty_cache()) |
nan 或 inf | 调整学习率、检查数据预处理、启用梯度缩放 |
InvalidArgumentError | 确保输入数据已正确转换为 FP16 |
- 推荐使用
autocast+GradScaler:平衡易用性和性能。 - 优先在 NVIDIA GPU 上使用:AMD GPU 的 FP16 支持较弱。
- 从小批量开始测试:避免显存不足或数值不稳定。
通过合理配置,FP16 可以在几乎不损失精度的情况下显著提升训练速度和显存利用率。
最后总结
高效的训练循环为优化PyTorch模型奠定了坚实的基础。通过遵循适当的数据加载过程,模型初始化过程和系统的训练步骤,你的训练设置将有效地利用GPU资源,并通过数据集快速迭代,以构建健壮的模型。
相关文章:
PyTorch系列教程:编写高效模型训练流程
当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后…...
【面试】Zookeeper
Zookeeper 1、ZooKeeper 介绍2、znode 节点里面的存储3、znode 节点上监听机制4、ZooKeeper 集群部署5、ZooKeeper 选举机制6、何为集群脑裂7、如何保证数据一致性8、讲一下 zk 分布式锁实现原理吧9、Eureka 与 Zk 有什么区别 1、ZooKeeper 介绍 ZooKeeper 的核心特性 高可用…...
电力系统中各参数的详细解释【智能电表】
一、核心电力参数 电压 (Voltage) 单位:伏特(V) 含义:电势差,推动电流流动的动力 类型:线电压(三相系统)、相电压,如220V(家用)或380Vÿ…...
前端系统测试(单元、集成、数据|性能|回归)
有关前端测试的面试题 系统测试 首先,功能测试部分。根据资料,单元测试是验证最小可测试单元的正确性,比如函数或组件。都提到了单元测试的重要性,强调其在开发早期发现问题,并通过自动化提高效率。需要整合我搜索到的资料中的观点,比如单元测试的方法(接口测试、路径覆…...
软件开发过程总揽
开发模型 传统开发模型 瀑布模型 #mermaid-svg-yDNBSwh3gDYETWou {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yDNBSwh3gDYETWou .error-icon{fill:#552222;}#mermaid-svg-yDNBSwh3gDYETWou .error-text{fill:#…...
VBA第二十期 VBA最简单复制整张表格Cells的用法
前面讲过复制整张表格的方法,使用语句Workbooks("实例.xlsm").Sheets("表格1").Copy Workbooks(wjm).Sheets(1)实现,这里用我们熟悉的Cells属性也可以实现整表复制。实例如下: Sheets("全部").Activate Cells…...
Redis为什么要自定义序列化?如何实现自定义序列化器?
在 Redis中,通常会使用自定义序列化器,那么,Redis为什么需要自定义序列化器,该如何实现它? 1、为什么需要自定义序列化器? 整体来说,Redis需要自定义序列化器,主要有以下几个原因&…...
Matlab:矩阵运算篇——矩阵数学运算
目录 1.矩阵的加法运算 实例——验证加法法则 实例——矩阵求和 实例——矩阵求差 2.矩阵的乘法运算 1.数乘运算 2.乘运算 3.点乘运算 实例——矩阵乘法运算 3.矩阵的除法运算 1.左除运算 实例——验证矩阵的除法 2.右除运算 实例——矩阵的除法 ヾ( ̄…...
手写一个Tomcat
Tomcat 是一个广泛使用的开源 Java Servlet 容器,用于运行 Java Web 应用程序。虽然 Tomcat 本身功能强大且复杂,但通过手写一个简易版的 Tomcat,我们可以更好地理解其核心工作原理。本文将带你一步步实现一个简易版的 Tomcat,并深…...
开发ai模型最佳的系统是Ubuntu还是linux?
在 AI/ML 开发中,Ubuntu 是更优选的 Linux 发行版,原因如下: 1. 开箱即用的 AI 工具链支持 Ubuntu 预装了主流的 AI 框架(如 TensorFlow、PyTorch)和依赖库,且通过 apt 包管理器可快速部署开发环境。 提…...
Scala 中生成一个RDD的方法
在 Scala 中,生成 RDD(弹性分布式数据集)的主要方法是通过 SparkContext(或 SparkSession)提供的 API。以下是生成 RDD 的常见方法: 1. 从本地集合创建 RDD 使用 parallelize 方法将本地集合(如…...
【redis】慢查询分析与优化
慢查询指在Redis中执行时间超过预设阈值的命令,其日志记录是排查性能瓶颈的核心工具。Redis采用单线程模型,任何耗时操作都可能阻塞后续请求,导致整体性能下降。 命令的执行流程 根据Redis的核心机制,命令执行流程可分为以下步骤…...
P8925 「GMOI R1-T2」Light 题解
P8925 「GMOI R1-T2」Light 让我们好好观察样例解释的这一张图: 左边第 1 1 1 个像到 O O O 点的距离 : L 2 2 L L\times22L L22L 右边第 1 1 1 个像到 O O O 点的距离 : R 2 2 R R\times22R R22R 左边第 2 2 2 个像到 O O O 点…...
Spring Boot + MyBatis + MySQL:快速搭建CRUD应用
一、引言 1. 项目背景与目标 在现代Web开发中,CRUD(创建、读取、更新、删除)操作是几乎所有应用程序的核心功能。本项目旨在通过Spring Boot、MyBatis和MySQL技术栈,快速搭建一个高效、简洁的CRUD应用。我们将从零开始ÿ…...
python中os库的常用举例
os 库是Python中用于与操作系统进行交互的标准库,以下是一些 os 库的常用示例: 获取当前工作目录 python import os current_dir os.getcwd() print(current_dir) os.getcwd() 函数用于获取当前工作目录的路径。 列出目录内容 python import os …...
Unity 通用UI界面逻辑总结
概述 在游戏开发中,常常会遇到一些通用的界面逻辑,它不论在什么类型的游戏中都会出现。为了避免重复造轮子,本文总结并提供了一些常用UI界面的实现逻辑。希望可以帮助大家快速开发通用界面模块,也可以在次基础上进行扩展修改&…...
Python3 与 VSCode:深度对比分析
Python3 与 VSCode:深度对比分析 引言 Python3 和 Visual Studio Code(VSCode)在软件开发领域扮演着举足轻重的角色。Python3 作为一门强大的编程语言,拥有丰富的库和框架,广泛应用于数据科学、人工智能、网络开发等多个领域。而 VSCode 作为一款轻量级且功能强大的代码…...
第五课:Express框架与RESTful API设计:技术实践与探索
在使用Node.js进行企业应用开发,常用的开发框架Express,其中的中间件、路由配置与参数解析、RESTful API核心技术尤为重要,本文将深入探讨它们在应用开发中的具体使用方法,最后通过Postman来对开发的接口进行测试。 一、Express中…...
Linux 内核自定义协议族开发:从 “No buffer space available“ 错误到解决方案
引言 在 Linux 内核网络协议栈开发中,自定义协议族(Address Family, AF)是实现新型通信协议或扩展内核功能的关键步骤。然而,开发者常因对内核地址族管理机制理解不足,遇到如 insmod: No buffer space available 的错误。本文将以实际案例为基础,深入分析错误根源,并提…...
html-列表标签和表单标签
一、列表标签 表格是用来显示数据的,那么列表就是用来布局的 列表最大的特点就是整齐、整洁、有序,它作为布局会更加自由和方便。 根据使用情景不同,列表可以分为三大类:无序列表、有序列表和自定义列表。 1.无序列表(重…...
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造,完美适配AGV和无人叉车。同时,集成以太网与语音合成技术,为各类高级系统(如MES、调度系统、库位管理、立库等)提供高效便捷的语音交互体验。 L…...
观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...
<6>-MySQL表的增删查改
目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表…...
中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试
作者:Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位:中南大学地球科学与信息物理学院论文标题:BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接:https://arxiv.…...
什么是EULA和DPA
文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...
JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作
一、上下文切换 即使单核CPU也可以进行多线程执行代码,CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短,所以CPU会不断地切换线程执行,从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...
#Uniapp篇:chrome调试unapp适配
chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器:Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...
【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...
Vue 模板语句的数据来源
🧩 Vue 模板语句的数据来源:全方位解析 Vue 模板(<template> 部分)中的表达式、指令绑定(如 v-bind, v-on)和插值({{ }})都在一个特定的作用域内求值。这个作用域由当前 组件…...
