加速PyTorch模型训练:自动混合精度(AMP)
在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。
AMP简单介绍
是一种训练技巧,允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。
官网地址:https://pytorch.org/docs/stable/amp.html

为什么使用AMP
在某些上下文中,torch.FloatTensor(FP32)有其优势,而在其他情况下,torch.HalfTensor(FP16)则更具优势。FP16的优势包括减少显存占用、加快训练和推断计算以及更好地利用CUDA设备的Tensor Core。然而,FP16也存在数值范围小和舍入误差等问题。通过混合精度训练,可以在享受FP16带来的好处的同时,避免其劣势。
两个核心组件
PyTorch 的 AMP 模块主要包含两个核心组件:autocast 和 GradScaler。
autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。
代码示例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time
torch.manual_seed(42)
# A simple Model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# init model
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# GradScaler
scaler = GradScaler(device='cuda')# random data
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# train
for epoch in range(1):start_time = time.time() print(f"inputs dtype:{inputs.dtype}")# autocastwith autocast('cuda'):outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")optimizer.zero_grad(set_to_none=True)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")end_time = time.time() elapsed_time = end_time - start_time allocated_memory = torch.cuda.memory_allocated() / 1024**2 reserved_memory = torch.cuda.memory_reserved() / 1024**2 print(f"Single Batch, Single Epoch with AMP, Loss: {loss.item():.4f}")print(f"Time taken: {elapsed_time:.4f} seconds")print(f"Allocated memory: {allocated_memory:.2f} MB")print(f"Reserved memory: {reserved_memory:.2f} MB")
输出如下:
Time taken for epoch 1: 0.0116 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB
不使用AMP(更快了):
Time taken for epoch 1: 0.0024 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB
由于上面的示例是一个很小的模型(只有几层的小型网络),其本身的计算量不大,因此即使采用了FP16精度,也难以观察到明显的加速效果。同时,如果模型中的某些层无法有效利用Tensor Cores(例如一些自定义操作,非标准层),那么整个流程可能会受到限制。所以感受不到有计算优化。

相关文章:
加速PyTorch模型训练:自动混合精度(AMP)
在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。 AMP简单介绍 是一种训练…...
【py】python安装教程(Windows系统,python3.13.2版本为例)
1.下载地址 官网:https://www.python.org/ 官网下载地址:https://www.python.org/downloads/ 2.64版本或者32位选择 【Stable Releases】:稳定发布版本,指的是已经测试过的版本,相对稳定。 【Pre-releases】&#…...
Django REST Framework:如何获取序列化后的ID
Django REST Framework:如何获取序列化后的ID 😄 嗨,小伙伴们!今天我们来聊一聊Django REST Framework(简称DRF)中一个非常常见的操作:如何获取序列化后的ID。对于那些刚入门的朋友们ÿ…...
QT修仙笔记 事件大圆满 闹钟大成
学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间,格式为 hh:mm:ss,实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间,可通过日历选择日期,设置范围为当前时间到未来 …...
Leetcode - 149双周赛
目录 一、3438. 找到字符串中合法的相邻数字二、3439. 重新安排会议得到最多空余时间 I三、3440. 重新安排会议得到最多空余时间 II四、3441. 变成好标题的最少代价 一、3438. 找到字符串中合法的相邻数字 题目链接 本题有两个条件: 相邻数字互不相同两个数字的的…...
解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题
解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题 1. 安装ComfyUI-Impact-Pack 首先确保ComfyUI-Impact-Pack 已经下载 地址: https://github.com/ltdrdata/ComfyUI-Impact-Pack 2. 安装ComfyUI-Impact-Subpack 由于新版本的Impact Pack 不再提供这…...
使用Kickstart配置文件封装操作系统实现Linux的自动化安装
使用Kickstart配置文件封装操作系统实现Linux的自动化安装 创建ks.cfg配置文件 可以使用已经安装完成的Linux操作系统中的/root目录下的anaconda.cfg配置文件 注意,配置文件会因为kickstart的版本兼容性的问题导致无法安装报错需要在实际使用过程中删除某些参数 …...
Android笔记【snippet】
一、 6、Card及ConstraintLayout线性布局 //定义单独的机器人单独一行的卡片 Composable fun RobotCard(robot: Robot,navController:NavController){Card(modifier Modifier.fillMaxWidth().wrapContentHeight().padding(5.dp),colors CardDefaults.elevatedCardColors(co…...
zsh: command not found: conda
场景描述 在 Linux 服务器上使用 zsh 时,如果出现 zsh: command not found: conda 错误,说明你的系统未正确配置 conda 命令,或者你尚未安装 Anaconda/Miniconda。 解决方案 确保已安装 Anaconda 或 Miniconda conda 是 Anaconda 或 Minico…...
【知识科普】CPU,GPN,NPU知识普及
CPU,GPU,NPU CPU、GPU、NPU 详解1. CPU(中央处理器)2. GPU(图形处理器)3. NPU(神经网络处理器) **三者的核心区别****协同工作示例****总结** CPU、GPU、NPU 详解 1. CPU(中央处理器࿰…...
【C++八股】struct和Class的区别
1. 默认访问控制 struct:结构体中的成员默认是 public,即外部代码可以直接访问结构体的成员。class:类中的成员默认是 private,即外部代码不能直接访问类的成员,必须通过公有接口(通常是成员函数ÿ…...
鹧鸪云光伏仓储、物料管理软件详细功能
采购中心 :作为核心枢纽,能集中管理多品牌设备,企业可灵活按需采购。采购与退货流程高效便捷,审核通过后物资快速补充、问题货物及时退回,保障资金与物资顺畅周转,避免积压浪费。付款与退款环节 ࿱…...
bazel 小白理解
Bazel命令是用于构建和测试软件项目的一个强大工具,尤其适用于大规模和多语言的软件项目。对于小白来说,可以这样理解Bazel及其命令: Bazel的基本概念 构建系统:Bazel是一个构建系统,它的主要任务是自动化地编译和链…...
MVC(Model-View-Controller)framework using Python ,Tkinter and SQLite
1.项目结构 sql: CREATE TABLE IF NOT EXISTS School (SchoolId TEXT not null, SchoolName TEXT NOT NULL,SchoolTelNo TEXT NOT NULL) 整体思路 Model:负责与 SQLite 数据库进行交互,包括创建表、插入、删除、更新和查询数据等操作。View࿱…...
WPF 设置宽度为 父容器 宽度的一半
方法1:使用 绑定和转换器 实现 创建类文件 HalfWidthConverter public class HalfWidthConverter : IValueConverter{public object Convert(object value, Type targetType, object parameter, CultureInfo culture){if (value is double width){return width / 4…...
java项目之在线心理评测与咨询管理系统(源码+文档)
项目简介 在线心理评测与咨询管理系统实现了以下功能: 在线心理评测与咨询管理系统的主要使用者分为: (1)在个人中心,管理员可以修改自己的用户名和登录密码。 (2)在系统前台可以查看首页&…...
【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)
ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…...
Springboot框架扩展功能的使用
Spring Boot 提供了许多扩展点,允许开发者在应用程序的生命周期中插入自定义逻辑。这些扩展点可以帮助你更好地控制应用程序的行为,例如在启动时初始化数据、在关闭时释放资源、或者自定义配置加载逻辑。以下是 Spring Boot 中常见的扩展点: …...
yum报错 Could not resolve host: mirrorlist.centos.org
检查dns 使用ping www.baidu.com ,如果ping不通,检查/etc/resolv.conf文件中是否有: nameserver 8.8.8.8 nameserver 8.8.4.4 替换yum源 1.备份原始的 YUM 源配置文件: sudo cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.r…...
docker使用dockerfile打包镜像(docker如何打包)
文章目录 1. 编写 Dockerfile2. 构建 Docker 镜像3. 运行 Docker 容器4. 导出与导入镜像(可选) 1. 编写 Dockerfile Dockerfile 是一个文本文件,其中包含了一系列指令,这些指令定义了如何构建你的 Docker 镜像。下面以一个简单的…...
QRCoder:开发者必备的二维码生成解决方案全攻略
QRCoder:开发者必备的二维码生成解决方案全攻略 【免费下载链接】QRCoder A pure C# Open Source QR Code implementation 项目地址: https://gitcode.com/gh_mirrors/qr/QRCoder 在数字化时代,二维码已成为信息传递的重要桥梁,但如何…...
保姆级教程:用ESP32-P4和ST7703屏打造24fps高清视频轮播器(附完整代码)
ESP32-P4与ST7703屏实战:24fps高清视频轮播系统全流程解析 当一块性能强劲的嵌入式开发板遇到高分辨率显示屏,会碰撞出怎样的火花?本文将带您从零构建一个基于ESP32-P4和ST7703屏幕的高清视频轮播系统,实现稳定的24fps播放效果。不…...
4个关键步骤:开源散热控制解决Dell G15温度难题
4个关键步骤:开源散热控制解决Dell G15温度难题 【免费下载链接】tcc-g15 Thermal Control Center for Dell G15 - open source alternative to AWCC 项目地址: https://gitcode.com/gh_mirrors/tc/tcc-g15 在游戏本使用过程中,散热控制往往是影响…...
Linux进程,存储,软件,日志004
目录一、进程管理二、磁盘与存储管理三、软件包管理四、系统日志管理一、进程管理1.1 进程概念与状态进程定义:进程是正在执行的程序实例,包含程序代码、数据和系统资源。进程状态转换:● 运行(RUNNING):进程正在CPU上执行● 就绪…...
仅剩最后3家银行未完成Java Istio全面替换——这份含12类Java Agent冲突检测脚本、4种Sidecar注入模式对比的适配手册即将下线
第一章:Java Istio适配现状与收官倒计时Istio 1.20 是最后一个官方支持 Java 客户端(istio-java-api)的版本,自 1.21 起,Istio 社区正式移除了对 Java SDK 的维护和 CI 验证。这一决策标志着 Java 生态在 Istio 原生控…...
Cyber Engine Tweaks:解锁《赛博朋克2077》终极模组开发能力的5大核心功能 [特殊字符]
Cyber Engine Tweaks:解锁《赛博朋克2077》终极模组开发能力的5大核心功能 🚀 【免费下载链接】CyberEngineTweaks Cyberpunk 2077 tweaks, hacks and scripting framework 项目地址: https://gitcode.com/gh_mirrors/cy/CyberEngineTweaks Cyber…...
7个效率倍增技巧:StarRailAssistant自动化工具解放崩坏星穹铁道玩家双手
7个效率倍增技巧:StarRailAssistant自动化工具解放崩坏星穹铁道玩家双手 【免费下载链接】StarRailAssistant 崩坏:星穹铁道自动化 | 崩坏:星穹铁道自动锄大地 | 崩坏:星穹铁道锄大地 | 自动锄大地 | 基于模拟按键 项目地址: ht…...
MATLAB MultiDIC/Ncorr实战:从图像采集到应力应变云图生成的全流程解析
1. 数字图像相关技术入门指南 第一次接触数字图像相关(DIC)技术时,我完全被那些专业术语搞晕了。后来在实际项目中摸爬滚打才发现,这套技术本质上就是用相机"看"材料变形的过程。想象一下橡皮筋被拉伸时表面的斑点移动—…...
HeliOS:面向嵌入式设备的零上下文切换RTOS
1. 项目概述HeliOS 是一款面向资源受限嵌入式设备的轻量级、开源、免费使用的实时内核(RTOS),其定位并非传统意义上的通用操作系统,而是一个高度可裁剪、零上下文切换开销的多任务调度内核。它专为 Arduino、ARM Cortex-M 等低功耗…...
知识蒸馏(Knowledge Distillation)完全指南:原理、实践与进阶
一句话概括:知识蒸馏是一种模型压缩技术,它让一个轻量级的“学生模型”模仿一个高性能的“教师模型”的输出行为,从而在保持小体积、低延迟的同时,获得接近大模型的能力。一、为什么需要知识蒸馏?—— 大模型的“奢侈”…...
