当前位置: 首页 > news >正文

【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解

文章目录

      • 0. 前言
      • 1. 什么是自动混合精度?
      • 2. PyTorch AMP 模块
      • 3. 如何使用 PyTorch AMP
        • 3.1 环境准备
        • 3.2 代码实例
        • 3.3 代码解析
      • 4. 结论

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习领域,训练大型神经网络往往需要大量的计算资源。为了提高训练效率和减少内存占用,研究人员和工程师们不断探索新的技术手段。其中,自动混合精度(Automatic Mixed Precision, AMP)是一种非常有效的技术,它能够在保证模型准确性的同时显著提高训练速度和降低内存使用。

PyTorch 1.6 版本引入了对自动混合精度的支持,通过 torch.cuda.amp 模块来实现。本文将详细介绍 PyTorch 中的 AMP 模块,并提供一个示例来演示如何使用它。

1. 什么是自动混合精度?

自动混合精度是一种训练技巧,它允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。这种方法的主要好处包括:

  1. 加速训练:在现代GPU上,对于16位浮点数的算术运算比32位浮点数更快。因此,使用混合精度训练可以显著提高训练速度;
  2. 减少内存使用:16位浮点数占用的空间是32位浮点数的一半,这意味着模型可以在有限的GPU内存中处理更大的批次大小,或者可以将更多的数据缓存到内存中,从而进一步加速训练。
  3. 提高计算效率:通过减少数据类型转换的需求,可以减少计算开销。在某些情况下,使用16位浮点数的运算可以利用特定硬件(如NVIDIA Tensor Cores)的优势,这些硬件专门为低精度运算进行了优化。
  4. 数值稳定性:虽然16位浮点数的动态范围较小,但通过适当的缩放策略(例如使用GradScaler)可以维持数值稳定性,从而避免梯度消失或爆炸的问题。
  5. 易于集成:PyTorch等框架提供的自动混合精度(Automatic Mixed Precision, AMP)工具使得混合精度训练变得非常简单,通常只需要添加几行代码即可实现。

2. PyTorch AMP 模块

PyTorch 的 AMP 模块主要包含两个核心组件:autocastGradScaler

  • autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。

  • GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler 可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。

3. 如何使用 PyTorch AMP

接下来,将通过一个简单的示例来演示如何使用 PyTorch 的 AMP 模块来训练一个神经网络。

3.1 环境准备

确保安装了 PyTorch 1.6 或更高版本。可以使用以下命令安装:

pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio===0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3.2 代码实例

下面的示例代码演示了如何使用 PyTorch 的 AMP 模块来训练一个简单的多层感知器(MLP)。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast# 设置随机种子以保证结果的一致性
torch.manual_seed(0)# 创建一个简单的多层感知器模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# 初始化模型、损失函数和优化器
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建 GradScaler
scaler = GradScaler()# 生成一些随机数据
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# 训练循环
for epoch in range(1):print(f"inputs dtype:{inputs.dtype}")# 使用 autocast 上下文管理器with autocast():  #尝试去掉这行再看下# 前向传播outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")# 清除梯度optimizer.zero_grad(set_to_none=True)# 使用 GradScaler 缩放损失scaler.scale(loss).backward()# 更新权重scaler.step(optimizer)# 更新 GradScalerscaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
3.3 代码解析

上面实例输出为:

inputs dtype:torch.float32
outputs dtype:torch.float16   
loss dtype:torch.float32
Epoch 1, Loss: 2.2972

这里可以注意到outputs的类型自动变成了float16

  1. 模型定义:我们定义了一个简单的多层感知器模型,包含两个线性层。
  2. 初始化:初始化模型、损失函数和优化器,并创建 GradScaler 对象。
  3. 数据准备:生成一些随机输入数据和目标标签。
  4. 训练循环
    • 使用 with autocast() 上下文管理器自动转换张量精度。
    • 前向传播计算输出和损失。
    • 使用 scaler.scale(loss) 放大损失以确保数值稳定性。
    • 反向传播和梯度更新。
    • 更新 GradScaler 状态。

4. 结论

通过使用 PyTorch 的自动混合精度模块,我们可以显著提高模型的训练速度并减少内存使用,尤其是在 GPU 上训练大型神经网络时。上述示例展示了如何轻松地将 AMP 集成到现有训练流程中,只需几行代码即可启用这一功能。

相关文章:

【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解

文章目录 0. 前言1. 什么是自动混合精度?2. PyTorch AMP 模块3. 如何使用 PyTorch AMP3.1 环境准备3.2 代码实例3.3 代码解析 4. 结论 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果&a…...

数据结构 --- 哈希表

哈希表(Hash Table),也叫散列表,是一种根据关键码值(Key value)而直接进行访问的数据结构。 一、基本原理 哈希函数 哈希表通过一个特定的哈希函数,将关键码映射到表中的一个位置。这个位置通常…...

Linux相关:在阿里云下载centos系统镜像

文章目录 1、镜像站2、下载方式一2.1、第一步打开镜像站地址2.2 下载地址: https://mirrors.aliyun.com/centos/2.3、选择7版本2.4、镜像文件在isos文件夹中2.5、选择合适的版本 3、下载镜像快捷方式 1、镜像站 阿里云镜像站地址 2、下载方式一 2.1、第一步打开镜像站地址 2…...

24. 线模型对象

线模型Line渲染顶点数据 下面代码是把几何体作为线模型Line (opens new window)的参数,你会发现渲染效果是从第一个点开始到最后一个点,依次连成线。 // 线材质对象 const material new THREE.LineBasicMaterial({color: 0xff0000 //线条颜色 }); //…...

EasyExcel 快速入门

目录 一、 EasyExcel简介 官网链接: 代码链接: 二、 EasyExcel快速上手 引入依赖: 设置Excel相关注解 编写对应的监听类: 简单写入数据: 简单读取数据: 不需要使用监听器: 需要使…...

Sparse4D v1

Sparse4D: Multi-view 3D Object Detection with Sparse Spatial-Temporal Fusion Abstract 基于鸟瞰图 (BEV) 的方法最近在多视图 3D 检测任务方面取得了重大进展。与基于 BEV 的方法相比,基于稀疏的方法在性能上落后,但仍然有很多不可忽略的优点。为了…...

速盾:你知道高防 IP 和高防 CDN 的区别吗?

在当今网络安全形势日益严峻的情况下,网站的安全防护成为了企业和个人关注的焦点。高防 IP 和高防 CDN 作为两种常见的网络安全防护手段,被广泛应用于网站的安全防护中。那么,高防 IP 和高防 CDN 有什么区别呢?防护网站哪个更好呢…...

HTML和CSS网页制作成品

HTML和CSS网页制作成品 一、引言 1. 背景介绍 在当今数字化时代,网页已成为信息传递和交流的重要媒介。HTML和CSS作为网页制作的基石,对于构建美观、功能丰富的网站至关重要。本文将详细介绍如何使用HTML和CSS来制作一个网页成品。 2. 目的和重要性 …...

Ai+若依(集成easyexcel实现excel表格增强)

EasyExcel 介绍 官方地址:EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 官网 Java解析、生成Excel比较有名的框架有Apache poi、jxl。但他们都存在一个严重的问题就是非常的耗内存,poi有一套SAX模式的API可以一定程度的解决一些内存溢出的问题,但POI还是有一…...

钻机、塔吊等大型工程设备,如何远程维护、实时采集运行数据?

在建筑和工程领域,重型设备的应用不可或缺,无论是在道路与桥梁建设、高层建筑施工,还是在风电、石油等能源项目的开发中,都会用到塔吊、钻机等大型机械工程设备。 随着数字化升级、工业4.0成为行业发展趋势,为了进一步…...

【AutoX.js】选择器 UiSelector - 查找包名

文章目录 原文:https://blog.c12th.cn/archives/38.html选择器 UiSelector - 查找包名笔记直接查找包名双层判断(推荐)查找最外层控件的子控件 最后 原文:https://blog.c12th.cn/archives/38.html 选择器 UiSelector - 查找包名 笔记 AutoX.js UiSelec…...

ERP进销存多仓库管理系统源码 带完整的安装代码包以及搭建部署教程

系统概述 ERP进销存多仓库管理系统是一款专为中小企业量身定制的集成化管理软件,它集成了采购管理、销售管理、库存管理、财务管理以及多仓库协同作业等核心模块。通过统一的平台,企业可以实时掌握商品从入库到出库的全过程,实现库存的自动化…...

数据清洗-缺失值填充-对XGBoost参数优化填充

目录 一、安装所需的python包二、采用XGboost算法进行缺失值填充2.1可直接运行代码2.2以某个缺失值数据进行实战2.2.1 代码运行过程截屏:2.2.2 填充后的数据截屏:三、网格搜索(Grid Search)对 XGBoost 模型的超参数进行优化原理介绍3.1 说明3.2 参数优化的原理1. 网格搜索(…...

Qt_按钮类控件

目录 1、QAbstractButton 2、设置带图标的按钮 3、设置带有快捷键的按钮 4、QRadioButtion(单选按钮) 4.1 QButtonGroup 5、QCheckBox 结语 前言: 按钮类控件是Qt中最重要的控件类型之一,该类型的控件可以通过鼠标的点击…...

union 的定义和基本结构以及用途

在 C 语言中,union(联合体) 是一种数据结构,它允许多个成员共享相同的内存空间。换句话说,联合体中的所有成员都存储在同一块内存区域,不同的成员会占用相同的内存地址,但在同一时刻只能保存一个…...

混合整数规划及其MATLAB实现

目录 引言 混合整数规划的基本模型 混合整数规划的求解方法 MATLAB中的混合整数规划实现 示例:多变量系统的混合整数规划 表格总结:混合整数规划的求解方法与适用场景 结论 引言 混合整数规划(Mixed Integer Programming, MIP&#xf…...

【数据结构】6——图1,概念

数据结构6——图1,概念 文章目录 数据结构6——图1,概念基本概念图的分类图的表示方法 基本概念 由 顶点(Vertex) 和 边(Edge) 组成的集合。顶点表示图中的点,而边表示顶点之间的连接。记为 G …...

技术周总结 09.09~09.15周日(C# WinForm WPF)

文章目录 一、09.09 周一1.1) 问题01: Windows桌面开发中,WPF和WinForm的区别和联系?联系:区别: 二、09.12 周四2.1)问题01:visual studio的相关快捷键有哪些?通用快捷键编辑导航调试窗口管理 2…...

4K投影仪选购全攻略:全玻璃镜头的当贝F6,画面细节纤毫毕现

在当今的投影市场上,4K投影仪已经成了主流产品,越来越多家庭开始关注如何选择一款性价比高、口碑好的4K投影仪。4K投影仪其实指的是具备3840*2160像素分辨率投影仪,它能够提供更清晰、更细腻、更真实的画面效果。 那么4K投影仪该怎么选&…...

除了字符串前导的*号之外,将串中其它*号全部删除

要求 假定输入的字符串中只包含字母和*号。请编写函数fun,它的功能是:除了字符串前导的*号之外,将串中其它*号全部删除。在编写函数时,不得使用C语言提供的字符串函数。函数fun中给出的语句仅供参考。 例如,字符串中的内容为:-**…...

从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路

进入2025年以来,尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断,但全球市场热度依然高涨,入局者持续增加。 以国内市场为例,天眼查专业版数据显示,截至5月底,我国现存在业、存续状态的机器人相关企…...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要:设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP(Work-in-Progress)弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中,设立专门的紧急任务通道尤为重要,这能…...

EtherNet/IP转DeviceNet协议网关详解

一,设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络,本网关连接到EtherNet/IP总线中做为从站使用,连接到DeviceNet总线中做为从站使用。 在自动…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,社区养老保险系统小程序被用户普遍使用,为方…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...

虚拟电厂发展三大趋势:市场化、技术主导、车网互联

市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦&#xff0…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...

字符串哈希+KMP

P10468 兔子与兔子 #include<bits/stdc.h> using namespace std; typedef unsigned long long ull; const int N 1000010; ull a[N], pw[N]; int n; ull gethash(int l, int r){return a[r] - a[l - 1] * pw[r - l 1]; } signed main(){ios::sync_with_stdio(false), …...