加速PyTorch模型训练:自动混合精度(AMP)
在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。
AMP简单介绍
是一种训练技巧,允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。
官网地址:https://pytorch.org/docs/stable/amp.html
为什么使用AMP
在某些上下文中,torch.FloatTensor
(FP32)有其优势,而在其他情况下,torch.HalfTensor
(FP16)则更具优势。FP16的优势包括减少显存占用、加快训练和推断计算以及更好地利用CUDA设备的Tensor Core。然而,FP16也存在数值范围小和舍入误差等问题。通过混合精度训练,可以在享受FP16带来的好处的同时,避免其劣势。
两个核心组件
PyTorch 的 AMP 模块主要包含两个核心组件:autocast
和 GradScaler
。
autocast
:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。GradScaler
:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler
可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。
代码示例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time
torch.manual_seed(42)
# A simple Model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# init model
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# GradScaler
scaler = GradScaler(device='cuda')# random data
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# train
for epoch in range(1):start_time = time.time() print(f"inputs dtype:{inputs.dtype}")# autocastwith autocast('cuda'):outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")optimizer.zero_grad(set_to_none=True)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")end_time = time.time() elapsed_time = end_time - start_time allocated_memory = torch.cuda.memory_allocated() / 1024**2 reserved_memory = torch.cuda.memory_reserved() / 1024**2 print(f"Single Batch, Single Epoch with AMP, Loss: {loss.item():.4f}")print(f"Time taken: {elapsed_time:.4f} seconds")print(f"Allocated memory: {allocated_memory:.2f} MB")print(f"Reserved memory: {reserved_memory:.2f} MB")
输出如下:
Time taken for epoch 1: 0.0116 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB
不使用AMP(更快了):
Time taken for epoch 1: 0.0024 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB
由于上面的示例是一个很小的模型(只有几层的小型网络),其本身的计算量不大,因此即使采用了FP16精度,也难以观察到明显的加速效果。同时,如果模型中的某些层无法有效利用Tensor Cores(例如一些自定义操作,非标准层),那么整个流程可能会受到限制。所以感受不到有计算优化。
相关文章:

加速PyTorch模型训练:自动混合精度(AMP)
在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。 AMP简单介绍 是一种训练…...

【py】python安装教程(Windows系统,python3.13.2版本为例)
1.下载地址 官网:https://www.python.org/ 官网下载地址:https://www.python.org/downloads/ 2.64版本或者32位选择 【Stable Releases】:稳定发布版本,指的是已经测试过的版本,相对稳定。 【Pre-releases】&#…...

Django REST Framework:如何获取序列化后的ID
Django REST Framework:如何获取序列化后的ID 😄 嗨,小伙伴们!今天我们来聊一聊Django REST Framework(简称DRF)中一个非常常见的操作:如何获取序列化后的ID。对于那些刚入门的朋友们ÿ…...

QT修仙笔记 事件大圆满 闹钟大成
学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间,格式为 hh:mm:ss,实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间,可通过日历选择日期,设置范围为当前时间到未来 …...

Leetcode - 149双周赛
目录 一、3438. 找到字符串中合法的相邻数字二、3439. 重新安排会议得到最多空余时间 I三、3440. 重新安排会议得到最多空余时间 II四、3441. 变成好标题的最少代价 一、3438. 找到字符串中合法的相邻数字 题目链接 本题有两个条件: 相邻数字互不相同两个数字的的…...

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题
解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题 1. 安装ComfyUI-Impact-Pack 首先确保ComfyUI-Impact-Pack 已经下载 地址: https://github.com/ltdrdata/ComfyUI-Impact-Pack 2. 安装ComfyUI-Impact-Subpack 由于新版本的Impact Pack 不再提供这…...

使用Kickstart配置文件封装操作系统实现Linux的自动化安装
使用Kickstart配置文件封装操作系统实现Linux的自动化安装 创建ks.cfg配置文件 可以使用已经安装完成的Linux操作系统中的/root目录下的anaconda.cfg配置文件 注意,配置文件会因为kickstart的版本兼容性的问题导致无法安装报错需要在实际使用过程中删除某些参数 …...

Android笔记【snippet】
一、 6、Card及ConstraintLayout线性布局 //定义单独的机器人单独一行的卡片 Composable fun RobotCard(robot: Robot,navController:NavController){Card(modifier Modifier.fillMaxWidth().wrapContentHeight().padding(5.dp),colors CardDefaults.elevatedCardColors(co…...

zsh: command not found: conda
场景描述 在 Linux 服务器上使用 zsh 时,如果出现 zsh: command not found: conda 错误,说明你的系统未正确配置 conda 命令,或者你尚未安装 Anaconda/Miniconda。 解决方案 确保已安装 Anaconda 或 Miniconda conda 是 Anaconda 或 Minico…...

【知识科普】CPU,GPN,NPU知识普及
CPU,GPU,NPU CPU、GPU、NPU 详解1. CPU(中央处理器)2. GPU(图形处理器)3. NPU(神经网络处理器) **三者的核心区别****协同工作示例****总结** CPU、GPU、NPU 详解 1. CPU(中央处理器࿰…...

【C++八股】struct和Class的区别
1. 默认访问控制 struct:结构体中的成员默认是 public,即外部代码可以直接访问结构体的成员。class:类中的成员默认是 private,即外部代码不能直接访问类的成员,必须通过公有接口(通常是成员函数ÿ…...

鹧鸪云光伏仓储、物料管理软件详细功能
采购中心 :作为核心枢纽,能集中管理多品牌设备,企业可灵活按需采购。采购与退货流程高效便捷,审核通过后物资快速补充、问题货物及时退回,保障资金与物资顺畅周转,避免积压浪费。付款与退款环节 ࿱…...

bazel 小白理解
Bazel命令是用于构建和测试软件项目的一个强大工具,尤其适用于大规模和多语言的软件项目。对于小白来说,可以这样理解Bazel及其命令: Bazel的基本概念 构建系统:Bazel是一个构建系统,它的主要任务是自动化地编译和链…...

MVC(Model-View-Controller)framework using Python ,Tkinter and SQLite
1.项目结构 sql: CREATE TABLE IF NOT EXISTS School (SchoolId TEXT not null, SchoolName TEXT NOT NULL,SchoolTelNo TEXT NOT NULL) 整体思路 Model:负责与 SQLite 数据库进行交互,包括创建表、插入、删除、更新和查询数据等操作。View࿱…...

WPF 设置宽度为 父容器 宽度的一半
方法1:使用 绑定和转换器 实现 创建类文件 HalfWidthConverter public class HalfWidthConverter : IValueConverter{public object Convert(object value, Type targetType, object parameter, CultureInfo culture){if (value is double width){return width / 4…...

java项目之在线心理评测与咨询管理系统(源码+文档)
项目简介 在线心理评测与咨询管理系统实现了以下功能: 在线心理评测与咨询管理系统的主要使用者分为: (1)在个人中心,管理员可以修改自己的用户名和登录密码。 (2)在系统前台可以查看首页&…...

【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)
ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…...

Springboot框架扩展功能的使用
Spring Boot 提供了许多扩展点,允许开发者在应用程序的生命周期中插入自定义逻辑。这些扩展点可以帮助你更好地控制应用程序的行为,例如在启动时初始化数据、在关闭时释放资源、或者自定义配置加载逻辑。以下是 Spring Boot 中常见的扩展点: …...

yum报错 Could not resolve host: mirrorlist.centos.org
检查dns 使用ping www.baidu.com ,如果ping不通,检查/etc/resolv.conf文件中是否有: nameserver 8.8.8.8 nameserver 8.8.4.4 替换yum源 1.备份原始的 YUM 源配置文件: sudo cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.r…...

docker使用dockerfile打包镜像(docker如何打包)
文章目录 1. 编写 Dockerfile2. 构建 Docker 镜像3. 运行 Docker 容器4. 导出与导入镜像(可选) 1. 编写 Dockerfile Dockerfile 是一个文本文件,其中包含了一系列指令,这些指令定义了如何构建你的 Docker 镜像。下面以一个简单的…...

去中心化AGI网络架构:下一代人工智能的范式革命
文章目录 引言:当AGI遇到去中心化一、中心化AI架构的四大困境1.1 算力垄断与资源错配1.2 数据孤岛与隐私悖论1.3 模型暴政与单点故障1.4 创新抑制与价值捕获二、去中心化AGI网络的架构设计2.1 分层架构总览2.2 网络层:混合拓扑结构2.3 计算层:动态算力编排2.4 数据层:零知识…...

gitlab无法登录问题
在我第一次安装gitlab的时候发现登录页面是 正常的页面应该是 这种情况的主要原因是不是第一次登录,所以我们要找到原先的密码 解决方式: [rootgitlab ~]# vim /etc/gitlab/initial_root_password# WARNING: This value is valid only in the followin…...

单向链表在实际项目中的应用
前言 在实际项目中,单向链表经常被用来解决排队问题,因为链表允许动态地添加和移除元素,非常适合模拟队列(FIFO,先进先出)的行为。 这里的链表包含头节点,头结点的数据用来记录链表长度&#x…...

【系统架构设计师】操作系统 ③ ( 存储管理 | 页式存储弊端 - 段式存储引入 | 段式存储 | 段表 | 段表结构 | 逻辑地址 的 合法段地址判断 )
文章目录 一、页式存储弊端 - 段式存储引入1、页式存储弊端 - 内存碎片2、页式存储弊端 - 逻辑结构不匹配3、段式存储引入 二、段式存储 简介1、段式存储2、段表3、段表 结构4、段内地址 / 段内偏移5、段式存储 优缺点6、段式存储 与 页式存储 对比 三、逻辑地址 的 合法段地址…...

PDF另存为图片的一个方法
说明 有时需要把PDF的每一页另存为图片。用Devexpress可以很方便的完成这个功能。 窗体上放置一个PdfViewer。 然后循环每一页 for (int i 1; i < pdfViewer1.PageCount; i) 调用 chg_pdf_to_bmp函数获得图片并保存 chg_pdf_to_bmp中调用了PdfViewer的CreateBitmap函数…...

HTML之JavaScript运算符
HTML之JavaScript运算符 1.算术运算符 - * / %除以0,结果为Infinity取余数,如果除数为0,结果为NaN NAN:Not A Number2.复合赋值运算符 - * / %/ 除以0,结果为Infinity% 如果除数为0,结果为NaN NaN:No…...

借助 ListWise 提升推荐系统精排效能:技术、案例与优化策略
目录 一、引言二、ListWise 方法概述三、ListWise 用于精排的优势四、ListWise 样本具体的构建过程4.1 确定样本的上下文4.2 收集候选物品及相关特征4.3 确定物品的真实排序标签4.4 构建样本列表4.5 划分训练集、验证集和测试集 五、ListWise 方法案例分析六、ListWise 方法在精…...

C++中什么时候用. 什么时候用->
学了一年C今天出了一个大岔子,因为太久没有做链表类型题目了,并且STL用惯了今天遇到一题,写的时候发现完全不对劲,搞慌了,首先我们看题目 2. 两数相加 再看我第一次的解答,先不论结果对不对 错的行为有很多…...

从云原生到 AI 原生,谈谈我经历的网关发展历程和趋势
作者:谢吉宝(唐三) 编者按: 云原生 API 网关系列教程即将推出,欢迎文末查看教程内容。本文整理自阿里云智能集团资深技术专家,云原生产品线中间件负责人谢吉宝(唐三) 在云栖大会的精…...

【Python深入浅出】Python3正则表达式:开启高效字符串处理大门
目录 一、正则表达式基础入门1.1 什么是正则表达式1.2 正则表达式的语法规则1.3 特殊字符与转义 二、Python 中的 re 模块2.1 re 模块概述2.2 常用函数与方法2.2.1 re.match()2.2.2 re.search()2.2.3 re.findall()2.2.4 re.sub() 2.3 修饰符(Flags)的使用…...