PyTorch 基础学习(13)- 混合精度训练
系列文章:
《PyTorch 基础学习》文章索引
基本概念
混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32
)和低精度(如 torch.float16
或 torch.bfloat16
)数据类型的优势,提高计算效率和内存利用率。
-
高精度(
torch.float32
):适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。 -
低精度(
torch.float16
或torch.bfloat16
):适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。
混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp
模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。
重要方法及其作用
torch.autocast
torch.autocast
是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast
将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16
,而损失计算则保持为 float32
。
主要参数:
device_type
:指定设备类型,如cuda
、cpu
或xpu
。dtype
:指定在autocast
区域内使用的低精度数据类型。对于 CUDA 设备,默认是torch.float16
;对于 CPU 设备,默认是torch.bfloat16
。enabled
:是否启用混合精度。默认为True
。cache_enabled
:是否启用权重缓存。默认是True
,可以在某些场景下提高性能。
torch.amp.GradScaler
在低精度(如 float16
)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler
,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。
主要参数:
init_scale
:初始的缩放因子,默认是65536.0
。growth_factor
:在没有发生下溢的情况下,缩放因子增长的倍数,默认是2.0
。backoff_factor
:发生下溢时,缩放因子减少的倍数,默认是0.5
。growth_interval
:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是2000
。enabled
:是否启用梯度缩放,默认为True
。
适用的场景
GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda")
能够有效地减少 GPU 的计算负载,并提高吞吐量。
CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16)
可以在推理过程中降低计算复杂度,同时保持较高的精度。
3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwd
和 torch.amp.custom_bwd
,用户可以定义在特定设备(如 cuda
)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。
应用实例
以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocast
和 torch.amp.GradScaler
。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10): # 假设有10个epochfor input, target in data_loader: # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")
代码说明:
- 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
- 在每次训练循环中,我们使用
torch.autocast(device_type="cuda")
上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。 - 使用
GradScaler
对损失进行缩放,并在缩放后的损失上调用backward()
进行反向传播。这一步骤有助于防止梯度下溢。 scaler.step(optimizer)
用于更新模型参数,scaler.update()
则是调整缩放因子。
这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。
注意事项
-
弃用警告:从 PyTorch 1.10 开始,原有的
torch.cuda.amp.autocast
和torch.cpu.amp.autocast
方法被弃用,推荐使用通用的torch.autocast
代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。 -
数据类型匹配:在使用
autocast
时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为float32
或其他合适的精度。 -
GradScaler
的适用性:虽然GradScaler
对大多数模型都有效,但在某些情况下(例如使用bf16
预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。
通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。
相关文章:
PyTorch 基础学习(13)- 混合精度训练
系列文章: 《PyTorch 基础学习》文章索引 基本概念 混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16 或 torch.bfloat16)数据类型的优势,…...

Mycat分片-垂直拆分
目录 场景 配置 测试 全局表配置 续接上篇:MySQ分库分表与MyCat安装配置-CSDN博客 续接下篇:Mycat分片-水平拆分-CSDN博客 场景 在业务系统中, 涉及以下表结构 ,但是由于用户与订单每天都会产生大量的数据, 单台服务器的数据 存储及处理能力是有限…...

一元四次方程求解-【附MATLAB代码】
目录 前言 求解方法 编辑 MATLAB验证 附:一元四次方程的故事 前言 最近在研究机器人的干涉(碰撞)检测,遇到了一个问题,就是在求椭圆到原点的最短距离时,构建的方程是一个一元四次方程。无论是高中的…...

【极限性能,尽在掌控】ROG NUC:游戏与创作的微型巨擘
初见ROG NUC,你或许会为它的小巧体型惊讶。然而,这看似不起眼的机身内,蕴藏着游戏、创意的强大能量。 掌中风暴,性能无界 ROG NUC搭载英特尔高性能处理器,配合高速NVMe SSD固态硬盘以及可选的高端独立显卡(…...

Ecosmos开启公测,将深度赋能CIOE中国光博会元宇宙参会新体验
如今,生成式AI技术的发展,极大地降低了3D数字资产的制作成本,元宇宙作为一种可以无缝将物理和数字资产进行融合的技术,在推动电子产业数字化进程、助力产业高质量发展的方面展现出了巨大的潜力。 当前,发展新质生产力是…...

【Kubernetes】k8s集群之包管理器Helm
目录 一.Helm概述 1.Helm的简介 2.Helm的三个重要概念 3.Helm2与Helm3的的区别 二.Helm 部署 1.安装 helm 2.使用 helm 安装 Chart 3.Helm 自定义模板 4.Helm 仓库 每个成功的软件平台都有一个优秀的打包系统,比如Debian、Ubuntu 的 apt,RedH…...
嵌入式linux系统镜像制作day3(构建镜像)
点击上方"蓝字"关注我们 01、上节回顾 嵌入式linux系统镜像制作day1嵌入式linux系统镜像制作day2提前下载好准备工具,不然失败了大眼瞪小眼。 02、构建 Poky 的 Sato 镜像1 环境: ubuntu18.04poky版本:Dizzy 工具git 在开始之前,针对不同的发行版,需要先执行…...

【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】
教师节中秋节国庆节车模特美女举牌生日视频制作教程AE模板改文字软件生成器素材 怎么如何做的【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件下载AE模板把AE模板导入AE软件修改图…...

RongCallKit iOS 端本地私有 pod 方案
RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…...

C++11:可变参数模板
目录 一、概述 二、场景 1.深拷贝的类 2.浅拷贝的类 C使用指南 一、概述 // Args是一个模板参数包,args是一个函数形参参数包 // 声明一个参数包Args...args,这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(…...

C++ 与 QML 之间进行数据交互的几种方法
https://www.cnblogs.com/jzcn/p/17774676.html 一、属性绑定 这是最简单的方式,可以在QML中直接绑定C 对象的属性。通过在C 对象中使用Q_PROPERTY宏定义属性,然后在QML中使用绑定语法将属性与QML元素关联起来。 1. person.h #include <QObject&g…...

Javaweb学习之Vue项目的创建(二)
学习资料 Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org) 准备工作都做完了,接下来开始Vue的正式学习。 第一步,打开VS Code 在VS Code里,我们也需要使用到终端,如果不是以管理员身份打开,在新建Vue项目的时候…...

『深度长文』4种有效提高LLM输出质量的方法!
LLM,全称Large Language Model,意为大型语言模型,是一种基于深度学习的AI技术,能够生成、理解和处理自然语言文本,也因此成为当前大多数AI工具的核心引擎。LLM通过学习海量的文本数据,掌握了词汇、语法、语…...

【工业机器人】工业异常检测大模型AnomalyGPT
AnomalyGPT 工业异常检测视觉大模型AnomalyGPT AnomalyGPT: Detecting Industrial Anomalies using Large Vision-Language Models AnomalyGPT是一种基于大视觉语言模型(LVLM)的新型工业异常检测(IAD)方法。它利用LVLM的能力来理…...
【PGCCC】PostgreSQL案例:planning time超长问题分析#PG初级
在使用 PostgreSQL 时,查询的执行计划(planning time)有时会出现异常长的情况,这可能会影响数据库的整体性能。分析和解决这种问题可以从多个角度入手,以下是常见原因和相应的解决思路: 1. 统计信息不准确…...

【图文并茂】ant design pro 如何给后端发送 json web token - 请求拦截器的使用
上一节有讲过 【图文并茂】ant design pro 如何对接后端个人信息接口 还差一个东西,去获取个人信息的时候,是要发送 token 的,不然会报 403. 就是说在你登录之后才去获得个人信息。这样后端才能知道是谁的信息。 token 就代码了某个人。 …...

【微信小程序】自定义组件 - behaviors
1. 什么是 behaviors 2. behaviors 的工作方式 3. 创建 behavior 调用 Behavior(Object object) 方法即可创建一个共享的 behavior 实例对象,供所有的组件使用: 4. 导入并使用 behavior 5. behavior 中所有可用的节点 6. 同名字段的覆盖和组合规则* 关…...

Linux ubuntu 24.04 安装运行《帝国时代3》免安装绿色版游戏,解决 “Could not load DATAP.BAR”等问题
Linux ubuntu 24.04 安装运行《帝国时代3》游戏,解决 “Could not load DATAP.BAR" 等问题 《帝国时代 3》是一款比较经典的即时战斗游戏,伴随了我半个高中时代,周末有时间就去泡网吧,可惜玩的都是简单人机,高难…...
Springboot 图片
Springboot 图片 因为 server.servlet.context-path: /api 所以 url是这个的时候 http://127.0.0.1:9100/api/staticfiles/image/dd56a59d-da84-441a-8dac-1d97f9e42090.jpeg 配置代码的前面的 /api 是不要写的 package com.gk.study.config;import org.springframework.conte…...
LIMS实验室管理系统如何实现数据自动采集
随着科研技术的不断发展,LIMS实验室管理系统的应用也愈来愈广,已经成为现代化实验室管理不可或缺的工具。LIMS实验室管理系统未与仪器设备对接前,仪器设备产生的数据都是通过人工录入到系统中,再经过人工审核形成最终的数据报告。…...

JavaSec-RCE
简介 RCE(Remote Code Execution),可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景:Groovy代码注入 Groovy是一种基于JVM的动态语言,语法简洁,支持闭包、动态类型和Java互操作性,…...

C++初阶-list的底层
目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...

练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...

遍历 Map 类型集合的方法汇总
1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI
前一阵子在百度 AI 开发者大会上,看到基于小智 AI DIY 玩具的演示,感觉有点意思,想着自己也来试试。 如果只是想烧录现成的固件,乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外,还提供了基于网页版的 ESP LA…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...

深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
GitHub 趋势日报 (2025年06月06日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...