【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录
- 0. 前言
- 1. 什么是自动混合精度?
- 2. PyTorch AMP 模块
- 3. 如何使用 PyTorch AMP
- 3.1 环境准备
- 3.2 代码实例
- 3.3 代码解析
- 4. 结论
0. 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
在深度学习领域,训练大型神经网络往往需要大量的计算资源。为了提高训练效率和减少内存占用,研究人员和工程师们不断探索新的技术手段。其中,自动混合精度(Automatic Mixed Precision, AMP)是一种非常有效的技术,它能够在保证模型准确性的同时显著提高训练速度和降低内存使用。
PyTorch 1.6 版本引入了对自动混合精度的支持,通过 torch.cuda.amp
模块来实现。本文将详细介绍 PyTorch 中的 AMP 模块,并提供一个示例来演示如何使用它。
1. 什么是自动混合精度?
自动混合精度是一种训练技巧,它允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。这种方法的主要好处包括:
- 加速训练:在现代GPU上,对于16位浮点数的算术运算比32位浮点数更快。因此,使用混合精度训练可以显著提高训练速度;
- 减少内存使用:16位浮点数占用的空间是32位浮点数的一半,这意味着模型可以在有限的GPU内存中处理更大的批次大小,或者可以将更多的数据缓存到内存中,从而进一步加速训练。
- 提高计算效率:通过减少数据类型转换的需求,可以减少计算开销。在某些情况下,使用16位浮点数的运算可以利用特定硬件(如NVIDIA Tensor Cores)的优势,这些硬件专门为低精度运算进行了优化。
- 数值稳定性:虽然16位浮点数的动态范围较小,但通过适当的缩放策略(例如使用GradScaler)可以维持数值稳定性,从而避免梯度消失或爆炸的问题。
- 易于集成:PyTorch等框架提供的自动混合精度(Automatic Mixed Precision, AMP)工具使得混合精度训练变得非常简单,通常只需要添加几行代码即可实现。
2. PyTorch AMP 模块
PyTorch 的 AMP 模块主要包含两个核心组件:autocast
和 GradScaler
。
-
autocast
:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。 -
GradScaler
:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler
可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。
3. 如何使用 PyTorch AMP
接下来,将通过一个简单的示例来演示如何使用 PyTorch 的 AMP 模块来训练一个神经网络。
3.1 环境准备
确保安装了 PyTorch 1.6 或更高版本。可以使用以下命令安装:
pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio===0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3.2 代码实例
下面的示例代码演示了如何使用 PyTorch 的 AMP 模块来训练一个简单的多层感知器(MLP)。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast# 设置随机种子以保证结果的一致性
torch.manual_seed(0)# 创建一个简单的多层感知器模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# 初始化模型、损失函数和优化器
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建 GradScaler
scaler = GradScaler()# 生成一些随机数据
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# 训练循环
for epoch in range(1):print(f"inputs dtype:{inputs.dtype}")# 使用 autocast 上下文管理器with autocast(): #尝试去掉这行再看下# 前向传播outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")# 清除梯度optimizer.zero_grad(set_to_none=True)# 使用 GradScaler 缩放损失scaler.scale(loss).backward()# 更新权重scaler.step(optimizer)# 更新 GradScalerscaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
3.3 代码解析
上面实例输出为:
inputs dtype:torch.float32
outputs dtype:torch.float16
loss dtype:torch.float32
Epoch 1, Loss: 2.2972
这里可以注意到outputs
的类型自动变成了float16
。
- 模型定义:我们定义了一个简单的多层感知器模型,包含两个线性层。
- 初始化:初始化模型、损失函数和优化器,并创建
GradScaler
对象。 - 数据准备:生成一些随机输入数据和目标标签。
- 训练循环:
- 使用
with autocast()
上下文管理器自动转换张量精度。 - 前向传播计算输出和损失。
- 使用
scaler.scale(loss)
放大损失以确保数值稳定性。 - 反向传播和梯度更新。
- 更新
GradScaler
状态。
- 使用
4. 结论
通过使用 PyTorch 的自动混合精度模块,我们可以显著提高模型的训练速度并减少内存使用,尤其是在 GPU 上训练大型神经网络时。上述示例展示了如何轻松地将 AMP 集成到现有训练流程中,只需几行代码即可启用这一功能。
相关文章:
【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录 0. 前言1. 什么是自动混合精度?2. PyTorch AMP 模块3. 如何使用 PyTorch AMP3.1 环境准备3.2 代码实例3.3 代码解析 4. 结论 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果&a…...
数据结构 --- 哈希表
哈希表(Hash Table),也叫散列表,是一种根据关键码值(Key value)而直接进行访问的数据结构。 一、基本原理 哈希函数 哈希表通过一个特定的哈希函数,将关键码映射到表中的一个位置。这个位置通常…...

Linux相关:在阿里云下载centos系统镜像
文章目录 1、镜像站2、下载方式一2.1、第一步打开镜像站地址2.2 下载地址: https://mirrors.aliyun.com/centos/2.3、选择7版本2.4、镜像文件在isos文件夹中2.5、选择合适的版本 3、下载镜像快捷方式 1、镜像站 阿里云镜像站地址 2、下载方式一 2.1、第一步打开镜像站地址 2…...
24. 线模型对象
线模型Line渲染顶点数据 下面代码是把几何体作为线模型Line (opens new window)的参数,你会发现渲染效果是从第一个点开始到最后一个点,依次连成线。 // 线材质对象 const material new THREE.LineBasicMaterial({color: 0xff0000 //线条颜色 }); //…...

EasyExcel 快速入门
目录 一、 EasyExcel简介 官网链接: 代码链接: 二、 EasyExcel快速上手 引入依赖: 设置Excel相关注解 编写对应的监听类: 简单写入数据: 简单读取数据: 不需要使用监听器: 需要使…...

Sparse4D v1
Sparse4D: Multi-view 3D Object Detection with Sparse Spatial-Temporal Fusion Abstract 基于鸟瞰图 (BEV) 的方法最近在多视图 3D 检测任务方面取得了重大进展。与基于 BEV 的方法相比,基于稀疏的方法在性能上落后,但仍然有很多不可忽略的优点。为了…...
速盾:你知道高防 IP 和高防 CDN 的区别吗?
在当今网络安全形势日益严峻的情况下,网站的安全防护成为了企业和个人关注的焦点。高防 IP 和高防 CDN 作为两种常见的网络安全防护手段,被广泛应用于网站的安全防护中。那么,高防 IP 和高防 CDN 有什么区别呢?防护网站哪个更好呢…...
HTML和CSS网页制作成品
HTML和CSS网页制作成品 一、引言 1. 背景介绍 在当今数字化时代,网页已成为信息传递和交流的重要媒介。HTML和CSS作为网页制作的基石,对于构建美观、功能丰富的网站至关重要。本文将详细介绍如何使用HTML和CSS来制作一个网页成品。 2. 目的和重要性 …...

Ai+若依(集成easyexcel实现excel表格增强)
EasyExcel 介绍 官方地址:EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 官网 Java解析、生成Excel比较有名的框架有Apache poi、jxl。但他们都存在一个严重的问题就是非常的耗内存,poi有一套SAX模式的API可以一定程度的解决一些内存溢出的问题,但POI还是有一…...

钻机、塔吊等大型工程设备,如何远程维护、实时采集运行数据?
在建筑和工程领域,重型设备的应用不可或缺,无论是在道路与桥梁建设、高层建筑施工,还是在风电、石油等能源项目的开发中,都会用到塔吊、钻机等大型机械工程设备。 随着数字化升级、工业4.0成为行业发展趋势,为了进一步…...

【AutoX.js】选择器 UiSelector - 查找包名
文章目录 原文:https://blog.c12th.cn/archives/38.html选择器 UiSelector - 查找包名笔记直接查找包名双层判断(推荐)查找最外层控件的子控件 最后 原文:https://blog.c12th.cn/archives/38.html 选择器 UiSelector - 查找包名 笔记 AutoX.js UiSelec…...

ERP进销存多仓库管理系统源码 带完整的安装代码包以及搭建部署教程
系统概述 ERP进销存多仓库管理系统是一款专为中小企业量身定制的集成化管理软件,它集成了采购管理、销售管理、库存管理、财务管理以及多仓库协同作业等核心模块。通过统一的平台,企业可以实时掌握商品从入库到出库的全过程,实现库存的自动化…...
数据清洗-缺失值填充-对XGBoost参数优化填充
目录 一、安装所需的python包二、采用XGboost算法进行缺失值填充2.1可直接运行代码2.2以某个缺失值数据进行实战2.2.1 代码运行过程截屏:2.2.2 填充后的数据截屏:三、网格搜索(Grid Search)对 XGBoost 模型的超参数进行优化原理介绍3.1 说明3.2 参数优化的原理1. 网格搜索(…...

Qt_按钮类控件
目录 1、QAbstractButton 2、设置带图标的按钮 3、设置带有快捷键的按钮 4、QRadioButtion(单选按钮) 4.1 QButtonGroup 5、QCheckBox 结语 前言: 按钮类控件是Qt中最重要的控件类型之一,该类型的控件可以通过鼠标的点击…...
union 的定义和基本结构以及用途
在 C 语言中,union(联合体) 是一种数据结构,它允许多个成员共享相同的内存空间。换句话说,联合体中的所有成员都存储在同一块内存区域,不同的成员会占用相同的内存地址,但在同一时刻只能保存一个…...

混合整数规划及其MATLAB实现
目录 引言 混合整数规划的基本模型 混合整数规划的求解方法 MATLAB中的混合整数规划实现 示例:多变量系统的混合整数规划 表格总结:混合整数规划的求解方法与适用场景 结论 引言 混合整数规划(Mixed Integer Programming, MIP…...
【数据结构】6——图1,概念
数据结构6——图1,概念 文章目录 数据结构6——图1,概念基本概念图的分类图的表示方法 基本概念 由 顶点(Vertex) 和 边(Edge) 组成的集合。顶点表示图中的点,而边表示顶点之间的连接。记为 G …...
技术周总结 09.09~09.15周日(C# WinForm WPF)
文章目录 一、09.09 周一1.1) 问题01: Windows桌面开发中,WPF和WinForm的区别和联系?联系:区别: 二、09.12 周四2.1)问题01:visual studio的相关快捷键有哪些?通用快捷键编辑导航调试窗口管理 2…...

4K投影仪选购全攻略:全玻璃镜头的当贝F6,画面细节纤毫毕现
在当今的投影市场上,4K投影仪已经成了主流产品,越来越多家庭开始关注如何选择一款性价比高、口碑好的4K投影仪。4K投影仪其实指的是具备3840*2160像素分辨率投影仪,它能够提供更清晰、更细腻、更真实的画面效果。 那么4K投影仪该怎么选&…...

除了字符串前导的*号之外,将串中其它*号全部删除
要求 假定输入的字符串中只包含字母和*号。请编写函数fun,它的功能是:除了字符串前导的*号之外,将串中其它*号全部删除。在编写函数时,不得使用C语言提供的字符串函数。函数fun中给出的语句仅供参考。 例如,字符串中的内容为:-**…...

Appium+python自动化(十六)- ADB命令
简介 Android 调试桥(adb)是多种用途的工具,该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具,其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利,如安装和调试…...
Oracle查询表空间大小
1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql
智慧工地管理云平台系统,智慧工地全套源码,java版智慧工地源码,支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求,提供“平台网络终端”的整体解决方案,提供劳务管理、视频管理、智能监测、绿色施工、安全管…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...
大模型多显卡多服务器并行计算方法与实践指南
一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...
Caliper 配置文件解析:config.yaml
Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...

Ubuntu系统复制(U盘-电脑硬盘)
所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...
用递归算法解锁「子集」问题 —— LeetCode 78题解析
文章目录 一、题目介绍二、递归思路详解:从决策树开始理解三、解法一:二叉决策树 DFS四、解法二:组合式回溯写法(推荐)五、解法对比 递归算法是编程中一种非常强大且常见的思想,它能够优雅地解决很多复杂的…...

【iOS】 Block再学习
iOS Block再学习 文章目录 iOS Block再学习前言Block的三种类型__ NSGlobalBlock____ NSMallocBlock____ NSStackBlock__小结 Block底层分析Block的结构捕获自由变量捕获全局(静态)变量捕获静态变量__block修饰符forwarding指针 Block的copy时机block作为函数返回值将block赋给…...