【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录
- 0. 前言
- 1. 什么是自动混合精度?
- 2. PyTorch AMP 模块
- 3. 如何使用 PyTorch AMP
- 3.1 环境准备
- 3.2 代码实例
- 3.3 代码解析
- 4. 结论
0. 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
在深度学习领域,训练大型神经网络往往需要大量的计算资源。为了提高训练效率和减少内存占用,研究人员和工程师们不断探索新的技术手段。其中,自动混合精度(Automatic Mixed Precision, AMP)是一种非常有效的技术,它能够在保证模型准确性的同时显著提高训练速度和降低内存使用。
PyTorch 1.6 版本引入了对自动混合精度的支持,通过 torch.cuda.amp 模块来实现。本文将详细介绍 PyTorch 中的 AMP 模块,并提供一个示例来演示如何使用它。
1. 什么是自动混合精度?
自动混合精度是一种训练技巧,它允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。这种方法的主要好处包括:
- 加速训练:在现代GPU上,对于16位浮点数的算术运算比32位浮点数更快。因此,使用混合精度训练可以显著提高训练速度;
- 减少内存使用:16位浮点数占用的空间是32位浮点数的一半,这意味着模型可以在有限的GPU内存中处理更大的批次大小,或者可以将更多的数据缓存到内存中,从而进一步加速训练。
- 提高计算效率:通过减少数据类型转换的需求,可以减少计算开销。在某些情况下,使用16位浮点数的运算可以利用特定硬件(如NVIDIA Tensor Cores)的优势,这些硬件专门为低精度运算进行了优化。
- 数值稳定性:虽然16位浮点数的动态范围较小,但通过适当的缩放策略(例如使用GradScaler)可以维持数值稳定性,从而避免梯度消失或爆炸的问题。
- 易于集成:PyTorch等框架提供的自动混合精度(Automatic Mixed Precision, AMP)工具使得混合精度训练变得非常简单,通常只需要添加几行代码即可实现。
2. PyTorch AMP 模块
PyTorch 的 AMP 模块主要包含两个核心组件:autocast 和 GradScaler。
-
autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。 -
GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。
3. 如何使用 PyTorch AMP
接下来,将通过一个简单的示例来演示如何使用 PyTorch 的 AMP 模块来训练一个神经网络。
3.1 环境准备
确保安装了 PyTorch 1.6 或更高版本。可以使用以下命令安装:
pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio===0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3.2 代码实例
下面的示例代码演示了如何使用 PyTorch 的 AMP 模块来训练一个简单的多层感知器(MLP)。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast# 设置随机种子以保证结果的一致性
torch.manual_seed(0)# 创建一个简单的多层感知器模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# 初始化模型、损失函数和优化器
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建 GradScaler
scaler = GradScaler()# 生成一些随机数据
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# 训练循环
for epoch in range(1):print(f"inputs dtype:{inputs.dtype}")# 使用 autocast 上下文管理器with autocast(): #尝试去掉这行再看下# 前向传播outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")# 清除梯度optimizer.zero_grad(set_to_none=True)# 使用 GradScaler 缩放损失scaler.scale(loss).backward()# 更新权重scaler.step(optimizer)# 更新 GradScalerscaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
3.3 代码解析
上面实例输出为:
inputs dtype:torch.float32
outputs dtype:torch.float16
loss dtype:torch.float32
Epoch 1, Loss: 2.2972
这里可以注意到outputs的类型自动变成了float16。
- 模型定义:我们定义了一个简单的多层感知器模型,包含两个线性层。
- 初始化:初始化模型、损失函数和优化器,并创建
GradScaler对象。 - 数据准备:生成一些随机输入数据和目标标签。
- 训练循环:
- 使用
with autocast()上下文管理器自动转换张量精度。 - 前向传播计算输出和损失。
- 使用
scaler.scale(loss)放大损失以确保数值稳定性。 - 反向传播和梯度更新。
- 更新
GradScaler状态。
- 使用
4. 结论
通过使用 PyTorch 的自动混合精度模块,我们可以显著提高模型的训练速度并减少内存使用,尤其是在 GPU 上训练大型神经网络时。上述示例展示了如何轻松地将 AMP 集成到现有训练流程中,只需几行代码即可启用这一功能。
相关文章:
【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解
文章目录 0. 前言1. 什么是自动混合精度?2. PyTorch AMP 模块3. 如何使用 PyTorch AMP3.1 环境准备3.2 代码实例3.3 代码解析 4. 结论 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果&a…...
数据结构 --- 哈希表
哈希表(Hash Table),也叫散列表,是一种根据关键码值(Key value)而直接进行访问的数据结构。 一、基本原理 哈希函数 哈希表通过一个特定的哈希函数,将关键码映射到表中的一个位置。这个位置通常…...
Linux相关:在阿里云下载centos系统镜像
文章目录 1、镜像站2、下载方式一2.1、第一步打开镜像站地址2.2 下载地址: https://mirrors.aliyun.com/centos/2.3、选择7版本2.4、镜像文件在isos文件夹中2.5、选择合适的版本 3、下载镜像快捷方式 1、镜像站 阿里云镜像站地址 2、下载方式一 2.1、第一步打开镜像站地址 2…...
24. 线模型对象
线模型Line渲染顶点数据 下面代码是把几何体作为线模型Line (opens new window)的参数,你会发现渲染效果是从第一个点开始到最后一个点,依次连成线。 // 线材质对象 const material new THREE.LineBasicMaterial({color: 0xff0000 //线条颜色 }); //…...
EasyExcel 快速入门
目录 一、 EasyExcel简介 官网链接: 代码链接: 二、 EasyExcel快速上手 引入依赖: 设置Excel相关注解 编写对应的监听类: 简单写入数据: 简单读取数据: 不需要使用监听器: 需要使…...
Sparse4D v1
Sparse4D: Multi-view 3D Object Detection with Sparse Spatial-Temporal Fusion Abstract 基于鸟瞰图 (BEV) 的方法最近在多视图 3D 检测任务方面取得了重大进展。与基于 BEV 的方法相比,基于稀疏的方法在性能上落后,但仍然有很多不可忽略的优点。为了…...
速盾:你知道高防 IP 和高防 CDN 的区别吗?
在当今网络安全形势日益严峻的情况下,网站的安全防护成为了企业和个人关注的焦点。高防 IP 和高防 CDN 作为两种常见的网络安全防护手段,被广泛应用于网站的安全防护中。那么,高防 IP 和高防 CDN 有什么区别呢?防护网站哪个更好呢…...
HTML和CSS网页制作成品
HTML和CSS网页制作成品 一、引言 1. 背景介绍 在当今数字化时代,网页已成为信息传递和交流的重要媒介。HTML和CSS作为网页制作的基石,对于构建美观、功能丰富的网站至关重要。本文将详细介绍如何使用HTML和CSS来制作一个网页成品。 2. 目的和重要性 …...
Ai+若依(集成easyexcel实现excel表格增强)
EasyExcel 介绍 官方地址:EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 官网 Java解析、生成Excel比较有名的框架有Apache poi、jxl。但他们都存在一个严重的问题就是非常的耗内存,poi有一套SAX模式的API可以一定程度的解决一些内存溢出的问题,但POI还是有一…...
钻机、塔吊等大型工程设备,如何远程维护、实时采集运行数据?
在建筑和工程领域,重型设备的应用不可或缺,无论是在道路与桥梁建设、高层建筑施工,还是在风电、石油等能源项目的开发中,都会用到塔吊、钻机等大型机械工程设备。 随着数字化升级、工业4.0成为行业发展趋势,为了进一步…...
【AutoX.js】选择器 UiSelector - 查找包名
文章目录 原文:https://blog.c12th.cn/archives/38.html选择器 UiSelector - 查找包名笔记直接查找包名双层判断(推荐)查找最外层控件的子控件 最后 原文:https://blog.c12th.cn/archives/38.html 选择器 UiSelector - 查找包名 笔记 AutoX.js UiSelec…...
ERP进销存多仓库管理系统源码 带完整的安装代码包以及搭建部署教程
系统概述 ERP进销存多仓库管理系统是一款专为中小企业量身定制的集成化管理软件,它集成了采购管理、销售管理、库存管理、财务管理以及多仓库协同作业等核心模块。通过统一的平台,企业可以实时掌握商品从入库到出库的全过程,实现库存的自动化…...
数据清洗-缺失值填充-对XGBoost参数优化填充
目录 一、安装所需的python包二、采用XGboost算法进行缺失值填充2.1可直接运行代码2.2以某个缺失值数据进行实战2.2.1 代码运行过程截屏:2.2.2 填充后的数据截屏:三、网格搜索(Grid Search)对 XGBoost 模型的超参数进行优化原理介绍3.1 说明3.2 参数优化的原理1. 网格搜索(…...
Qt_按钮类控件
目录 1、QAbstractButton 2、设置带图标的按钮 3、设置带有快捷键的按钮 4、QRadioButtion(单选按钮) 4.1 QButtonGroup 5、QCheckBox 结语 前言: 按钮类控件是Qt中最重要的控件类型之一,该类型的控件可以通过鼠标的点击…...
union 的定义和基本结构以及用途
在 C 语言中,union(联合体) 是一种数据结构,它允许多个成员共享相同的内存空间。换句话说,联合体中的所有成员都存储在同一块内存区域,不同的成员会占用相同的内存地址,但在同一时刻只能保存一个…...
混合整数规划及其MATLAB实现
目录 引言 混合整数规划的基本模型 混合整数规划的求解方法 MATLAB中的混合整数规划实现 示例:多变量系统的混合整数规划 表格总结:混合整数规划的求解方法与适用场景 结论 引言 混合整数规划(Mixed Integer Programming, MIP…...
【数据结构】6——图1,概念
数据结构6——图1,概念 文章目录 数据结构6——图1,概念基本概念图的分类图的表示方法 基本概念 由 顶点(Vertex) 和 边(Edge) 组成的集合。顶点表示图中的点,而边表示顶点之间的连接。记为 G …...
技术周总结 09.09~09.15周日(C# WinForm WPF)
文章目录 一、09.09 周一1.1) 问题01: Windows桌面开发中,WPF和WinForm的区别和联系?联系:区别: 二、09.12 周四2.1)问题01:visual studio的相关快捷键有哪些?通用快捷键编辑导航调试窗口管理 2…...
4K投影仪选购全攻略:全玻璃镜头的当贝F6,画面细节纤毫毕现
在当今的投影市场上,4K投影仪已经成了主流产品,越来越多家庭开始关注如何选择一款性价比高、口碑好的4K投影仪。4K投影仪其实指的是具备3840*2160像素分辨率投影仪,它能够提供更清晰、更细腻、更真实的画面效果。 那么4K投影仪该怎么选&…...
除了字符串前导的*号之外,将串中其它*号全部删除
要求 假定输入的字符串中只包含字母和*号。请编写函数fun,它的功能是:除了字符串前导的*号之外,将串中其它*号全部删除。在编写函数时,不得使用C语言提供的字符串函数。函数fun中给出的语句仅供参考。 例如,字符串中的内容为:-**…...
Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
简易版抽奖活动的设计技术方案
1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...
Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
工程地质软件市场:发展现状、趋势与策略建议
一、引言 在工程建设领域,准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具,正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...
Spring AI与Spring Modulith核心技术解析
Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...
GC1808高性能24位立体声音频ADC芯片解析
1. 芯片概述 GC1808是一款24位立体声音频模数转换器(ADC),支持8kHz~96kHz采样率,集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器,适用于高保真音频采集场景。 2. 核心特性 高精度:24位分辨率,…...
10-Oracle 23 ai Vector Search 概述和参数
一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI,使用客户端或是内部自己搭建集成大模型的终端,加速与大型语言模型(LLM)的结合,同时使用检索增强生成(Retrieval Augmented Generation &#…...
Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?
在大数据处理领域,Hive 作为 Hadoop 生态中重要的数据仓库工具,其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式,很多开发者常常陷入选择困境。本文将从底…...
