【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录
- 0. 前言
- 1. 什么是自动混合精度?
- 2. PyTorch AMP 模块
- 3. 如何使用 PyTorch AMP
- 3.1 环境准备
- 3.2 代码实例
- 3.3 代码解析
- 4. 结论
0. 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
在深度学习领域,训练大型神经网络往往需要大量的计算资源。为了提高训练效率和减少内存占用,研究人员和工程师们不断探索新的技术手段。其中,自动混合精度(Automatic Mixed Precision, AMP)是一种非常有效的技术,它能够在保证模型准确性的同时显著提高训练速度和降低内存使用。
PyTorch 1.6 版本引入了对自动混合精度的支持,通过 torch.cuda.amp
模块来实现。本文将详细介绍 PyTorch 中的 AMP 模块,并提供一个示例来演示如何使用它。
1. 什么是自动混合精度?
自动混合精度是一种训练技巧,它允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。这种方法的主要好处包括:
- 加速训练:在现代GPU上,对于16位浮点数的算术运算比32位浮点数更快。因此,使用混合精度训练可以显著提高训练速度;
- 减少内存使用:16位浮点数占用的空间是32位浮点数的一半,这意味着模型可以在有限的GPU内存中处理更大的批次大小,或者可以将更多的数据缓存到内存中,从而进一步加速训练。
- 提高计算效率:通过减少数据类型转换的需求,可以减少计算开销。在某些情况下,使用16位浮点数的运算可以利用特定硬件(如NVIDIA Tensor Cores)的优势,这些硬件专门为低精度运算进行了优化。
- 数值稳定性:虽然16位浮点数的动态范围较小,但通过适当的缩放策略(例如使用GradScaler)可以维持数值稳定性,从而避免梯度消失或爆炸的问题。
- 易于集成:PyTorch等框架提供的自动混合精度(Automatic Mixed Precision, AMP)工具使得混合精度训练变得非常简单,通常只需要添加几行代码即可实现。
2. PyTorch AMP 模块
PyTorch 的 AMP 模块主要包含两个核心组件:autocast
和 GradScaler
。
-
autocast
:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。 -
GradScaler
:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler
可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。
3. 如何使用 PyTorch AMP
接下来,将通过一个简单的示例来演示如何使用 PyTorch 的 AMP 模块来训练一个神经网络。
3.1 环境准备
确保安装了 PyTorch 1.6 或更高版本。可以使用以下命令安装:
pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio===0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3.2 代码实例
下面的示例代码演示了如何使用 PyTorch 的 AMP 模块来训练一个简单的多层感知器(MLP)。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast# 设置随机种子以保证结果的一致性
torch.manual_seed(0)# 创建一个简单的多层感知器模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# 初始化模型、损失函数和优化器
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建 GradScaler
scaler = GradScaler()# 生成一些随机数据
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# 训练循环
for epoch in range(1):print(f"inputs dtype:{inputs.dtype}")# 使用 autocast 上下文管理器with autocast(): #尝试去掉这行再看下# 前向传播outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")# 清除梯度optimizer.zero_grad(set_to_none=True)# 使用 GradScaler 缩放损失scaler.scale(loss).backward()# 更新权重scaler.step(optimizer)# 更新 GradScalerscaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
3.3 代码解析
上面实例输出为:
inputs dtype:torch.float32
outputs dtype:torch.float16
loss dtype:torch.float32
Epoch 1, Loss: 2.2972
这里可以注意到outputs
的类型自动变成了float16
。
- 模型定义:我们定义了一个简单的多层感知器模型,包含两个线性层。
- 初始化:初始化模型、损失函数和优化器,并创建
GradScaler
对象。 - 数据准备:生成一些随机输入数据和目标标签。
- 训练循环:
- 使用
with autocast()
上下文管理器自动转换张量精度。 - 前向传播计算输出和损失。
- 使用
scaler.scale(loss)
放大损失以确保数值稳定性。 - 反向传播和梯度更新。
- 更新
GradScaler
状态。
- 使用
4. 结论
通过使用 PyTorch 的自动混合精度模块,我们可以显著提高模型的训练速度并减少内存使用,尤其是在 GPU 上训练大型神经网络时。上述示例展示了如何轻松地将 AMP 集成到现有训练流程中,只需几行代码即可启用这一功能。
相关文章:
【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录 0. 前言1. 什么是自动混合精度?2. PyTorch AMP 模块3. 如何使用 PyTorch AMP3.1 环境准备3.2 代码实例3.3 代码解析 4. 结论 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果&a…...
数据结构 --- 哈希表
哈希表(Hash Table),也叫散列表,是一种根据关键码值(Key value)而直接进行访问的数据结构。 一、基本原理 哈希函数 哈希表通过一个特定的哈希函数,将关键码映射到表中的一个位置。这个位置通常…...

Linux相关:在阿里云下载centos系统镜像
文章目录 1、镜像站2、下载方式一2.1、第一步打开镜像站地址2.2 下载地址: https://mirrors.aliyun.com/centos/2.3、选择7版本2.4、镜像文件在isos文件夹中2.5、选择合适的版本 3、下载镜像快捷方式 1、镜像站 阿里云镜像站地址 2、下载方式一 2.1、第一步打开镜像站地址 2…...
24. 线模型对象
线模型Line渲染顶点数据 下面代码是把几何体作为线模型Line (opens new window)的参数,你会发现渲染效果是从第一个点开始到最后一个点,依次连成线。 // 线材质对象 const material new THREE.LineBasicMaterial({color: 0xff0000 //线条颜色 }); //…...

EasyExcel 快速入门
目录 一、 EasyExcel简介 官网链接: 代码链接: 二、 EasyExcel快速上手 引入依赖: 设置Excel相关注解 编写对应的监听类: 简单写入数据: 简单读取数据: 不需要使用监听器: 需要使…...

Sparse4D v1
Sparse4D: Multi-view 3D Object Detection with Sparse Spatial-Temporal Fusion Abstract 基于鸟瞰图 (BEV) 的方法最近在多视图 3D 检测任务方面取得了重大进展。与基于 BEV 的方法相比,基于稀疏的方法在性能上落后,但仍然有很多不可忽略的优点。为了…...
速盾:你知道高防 IP 和高防 CDN 的区别吗?
在当今网络安全形势日益严峻的情况下,网站的安全防护成为了企业和个人关注的焦点。高防 IP 和高防 CDN 作为两种常见的网络安全防护手段,被广泛应用于网站的安全防护中。那么,高防 IP 和高防 CDN 有什么区别呢?防护网站哪个更好呢…...
HTML和CSS网页制作成品
HTML和CSS网页制作成品 一、引言 1. 背景介绍 在当今数字化时代,网页已成为信息传递和交流的重要媒介。HTML和CSS作为网页制作的基石,对于构建美观、功能丰富的网站至关重要。本文将详细介绍如何使用HTML和CSS来制作一个网页成品。 2. 目的和重要性 …...

Ai+若依(集成easyexcel实现excel表格增强)
EasyExcel 介绍 官方地址:EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 官网 Java解析、生成Excel比较有名的框架有Apache poi、jxl。但他们都存在一个严重的问题就是非常的耗内存,poi有一套SAX模式的API可以一定程度的解决一些内存溢出的问题,但POI还是有一…...

钻机、塔吊等大型工程设备,如何远程维护、实时采集运行数据?
在建筑和工程领域,重型设备的应用不可或缺,无论是在道路与桥梁建设、高层建筑施工,还是在风电、石油等能源项目的开发中,都会用到塔吊、钻机等大型机械工程设备。 随着数字化升级、工业4.0成为行业发展趋势,为了进一步…...

【AutoX.js】选择器 UiSelector - 查找包名
文章目录 原文:https://blog.c12th.cn/archives/38.html选择器 UiSelector - 查找包名笔记直接查找包名双层判断(推荐)查找最外层控件的子控件 最后 原文:https://blog.c12th.cn/archives/38.html 选择器 UiSelector - 查找包名 笔记 AutoX.js UiSelec…...

ERP进销存多仓库管理系统源码 带完整的安装代码包以及搭建部署教程
系统概述 ERP进销存多仓库管理系统是一款专为中小企业量身定制的集成化管理软件,它集成了采购管理、销售管理、库存管理、财务管理以及多仓库协同作业等核心模块。通过统一的平台,企业可以实时掌握商品从入库到出库的全过程,实现库存的自动化…...
数据清洗-缺失值填充-对XGBoost参数优化填充
目录 一、安装所需的python包二、采用XGboost算法进行缺失值填充2.1可直接运行代码2.2以某个缺失值数据进行实战2.2.1 代码运行过程截屏:2.2.2 填充后的数据截屏:三、网格搜索(Grid Search)对 XGBoost 模型的超参数进行优化原理介绍3.1 说明3.2 参数优化的原理1. 网格搜索(…...

Qt_按钮类控件
目录 1、QAbstractButton 2、设置带图标的按钮 3、设置带有快捷键的按钮 4、QRadioButtion(单选按钮) 4.1 QButtonGroup 5、QCheckBox 结语 前言: 按钮类控件是Qt中最重要的控件类型之一,该类型的控件可以通过鼠标的点击…...
union 的定义和基本结构以及用途
在 C 语言中,union(联合体) 是一种数据结构,它允许多个成员共享相同的内存空间。换句话说,联合体中的所有成员都存储在同一块内存区域,不同的成员会占用相同的内存地址,但在同一时刻只能保存一个…...

混合整数规划及其MATLAB实现
目录 引言 混合整数规划的基本模型 混合整数规划的求解方法 MATLAB中的混合整数规划实现 示例:多变量系统的混合整数规划 表格总结:混合整数规划的求解方法与适用场景 结论 引言 混合整数规划(Mixed Integer Programming, MIP…...
【数据结构】6——图1,概念
数据结构6——图1,概念 文章目录 数据结构6——图1,概念基本概念图的分类图的表示方法 基本概念 由 顶点(Vertex) 和 边(Edge) 组成的集合。顶点表示图中的点,而边表示顶点之间的连接。记为 G …...
技术周总结 09.09~09.15周日(C# WinForm WPF)
文章目录 一、09.09 周一1.1) 问题01: Windows桌面开发中,WPF和WinForm的区别和联系?联系:区别: 二、09.12 周四2.1)问题01:visual studio的相关快捷键有哪些?通用快捷键编辑导航调试窗口管理 2…...

4K投影仪选购全攻略:全玻璃镜头的当贝F6,画面细节纤毫毕现
在当今的投影市场上,4K投影仪已经成了主流产品,越来越多家庭开始关注如何选择一款性价比高、口碑好的4K投影仪。4K投影仪其实指的是具备3840*2160像素分辨率投影仪,它能够提供更清晰、更细腻、更真实的画面效果。 那么4K投影仪该怎么选&…...

除了字符串前导的*号之外,将串中其它*号全部删除
要求 假定输入的字符串中只包含字母和*号。请编写函数fun,它的功能是:除了字符串前导的*号之外,将串中其它*号全部删除。在编写函数时,不得使用C语言提供的字符串函数。函数fun中给出的语句仅供参考。 例如,字符串中的内容为:-**…...
【杂谈】-递归进化:人工智能的自我改进与监管挑战
递归进化:人工智能的自我改进与监管挑战 文章目录 递归进化:人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管?3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

TDengine 快速体验(Docker 镜像方式)
简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...
Objective-C常用命名规范总结
【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名(Class Name)2.协议名(Protocol Name)3.方法名(Method Name)4.属性名(Property Name)5.局部变量/实例变量(Local / Instance Variables&…...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块,…...

人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

C/C++ 中附加包含目录、附加库目录与附加依赖项详解
在 C/C 编程的编译和链接过程中,附加包含目录、附加库目录和附加依赖项是三个至关重要的设置,它们相互配合,确保程序能够正确引用外部资源并顺利构建。虽然在学习过程中,这些概念容易让人混淆,但深入理解它们的作用和联…...

AI+无人机如何守护濒危物种?YOLOv8实现95%精准识别
【导读】 野生动物监测在理解和保护生态系统中发挥着至关重要的作用。然而,传统的野生动物观察方法往往耗时耗力、成本高昂且范围有限。无人机的出现为野生动物监测提供了有前景的替代方案,能够实现大范围覆盖并远程采集数据。尽管具备这些优势…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...