PyTorch AMP 混合精度中grad_scaler.py的scale函数解析
PyTorch AMP 混合精度中的 scale
函数解析
混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler
类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale
函数是 GradScaler
的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。
本文将详细解析 scale
函数的作用、代码逻辑,以及 apply_scale
子函数的递归作用。
函数代码回顾
以下是 scale
函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py
torch 2.4.0+cu121版本
def scale(self,outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:"""Multiplies ('scales') a tensor or list of tensors by the scale factor.Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returnedunmodified.Args:outputs (Tensor or iterable of Tensors): Outputs to scale."""if not self._enabled:return outputs# Short-circuit for the common case.if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scaledef apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
1. 函数作用
scale
函数的主要作用是将输出张量(outputs
)按当前的缩放因子(self._scale
)进行缩放。它支持以下两种输入:
- 单个张量:直接将缩放因子乘以张量。
- 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。
当 AMP 功能未启用时(即 self._enabled
为 False
),scale
函数会直接返回原始的 outputs
,不执行任何缩放操作。
使用场景
- 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
- 支持多设备:通过
_MultiDeviceReplicator
支持张量分布在多个设备(如多 GPU)的场景。
2. 核心代码解析
(1) 短路处理单个张量
当输入 outputs
是单个张量(torch.Tensor
)时,函数直接对其进行缩放:
if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
- 如果缩放因子
self._scale
尚未初始化,则调用_lazy_init_scale_growth_tracker
方法在指定设备上初始化缩放因子。 - 使用
outputs * self._scale
对张量进行缩放。这里使用了to(device=outputs.device)
确保缩放因子与张量在同一设备上。
这是单个张量输入的快速路径处理。
(2) 多张量递归处理逻辑
当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale
进行递归缩放:
stash: List[_MultiDeviceReplicator] = [] # 用于存储缩放因子对象def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
apply_scale
子函数的作用
-
张量处理:
- 如果
val
是单个张量,检查stash
是否为空。 - 如果为空,初始化缩放因子对象
_MultiDeviceReplicator
,并存储在stash
中。 - 使用
stash[0].get(val.device)
获取对应设备上的缩放因子,并对张量进行缩放。
- 如果
-
递归处理可迭代对象:
- 如果
val
是一个可迭代对象,调用map(apply_scale, val)
,对其中的每个元素递归地调用apply_scale
。 - 如果输入是
list
或tuple
,则保持其原始类型。
- 如果
-
类型检查:
- 如果
val
既不是张量也不是可迭代对象,抛出错误。
- 如果
3. apply_scale
是递归函数吗?
是的,apply_scale
是一个递归函数。
递归逻辑
- 当输入为嵌套结构(如张量的列表或列表中的列表)时,
apply_scale
会递归调用自身,将缩放因子应用到最底层的张量。 - 递归的终止条件是
val
为单个张量(torch.Tensor
)。
示例:
假设输入为嵌套张量列表:
outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)
递归处理过程如下:
-
对
outputs
调用apply_scale
:- 第一个元素是张量
torch.tensor([1.0, 2.0])
,直接缩放。 - 第二个元素是列表,递归调用
apply_scale
。
- 第一个元素是张量
-
进入嵌套列表
[torch.tensor([3.0]), torch.tensor([4.0, 5.0])]
:- 第一个元素是张量
torch.tensor([3.0])
,缩放。 - 第二个元素是张量
torch.tensor([4.0, 5.0])
,缩放。
- 第一个元素是张量
4. _MultiDeviceReplicator
的作用
_MultiDeviceReplicator
是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。
- 当张量分布在多个设备(如 GPU)时,
_MultiDeviceReplicator
可以高效地为每个设备提供所需的缩放因子,避免重复初始化。
总结
scale
函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale
,该函数能够处理嵌套的张量结构,同时支持多设备场景。
关键点总结:
- 快速路径:单张量输入的情况下,直接进行缩放。
- 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
- 设备管理:通过
_MultiDeviceReplicator
支持多设备场景。
通过 scale
函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。
后记
2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。
相关文章:

PyTorch AMP 混合精度中grad_scaler.py的scale函数解析
PyTorch AMP 混合精度中的 scale 函数解析 混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子&…...

【Ubuntu20.04】Apollo10.0 Docker容器部署+常见错误解决
官方参考文档【点击我】 Apollo 10.0 版本开始,支持本机和Docker容器两种部署方式。 如果您使用本机部署方式,建议使用x86_64架构的Ubuntu 22.04操作系统或者aarch64架构的Ubuntu 20.04操作系统。 如果您使用Docker容器部署方式,可以使用x…...

【文献精读笔记】Explainability for Large Language Models: A Survey (大语言模型的可解释性综述)(二)
****非斜体正文为原文献内容(也包含笔者的补充),灰色块中是对文章细节的进一步详细解释! 3.1.2 基于注意力的解释(Attention-Based Explanation) 注意力机制可以揭示输入数据中各个部分之间的关系&#…...

朱姆沃尔特隐身战舰:从失败到威慑
前言 "朱姆沃尔特"号驱逐舰是美国海军雄心勃勃的项目,旨在重塑未来海战。它融合了隐身、自动化和强大火力,然而由于技术问题和预算超支,原计划建造32艘的目标被大幅缩减,最终只建造了三艘。该舰的设计特点包括“穿浪逆船…...

免费分享 | 基于极光优化算法PLO优化宽度学习BLS实现光伏数据预测算法研究附Matlab代码
研究内容 宽度学习系统(BLS)简介: BLS是一种新型的神经网络结构,由增强节点(Enhancement Nodes, ENs)和特征节点(Feature Nodes, FNs)组成,具有结构简单、训练速度快、泛…...

logback日志文件多环境配置路径
项目中遇到问题,springboot项目 本地jar包部署到现场后,经常遇到现场的日志存放的路径会更改,经过查阅,有两种方式,下面简单说明一下。 一、第一种 启动jar包时 添加参数 --logging.configF:\hgtest\config\logback.x…...

面试高频:一致性hash算法
这两天看到技术群里,有小伙伴在讨论一致性hash算法的问题,正愁没啥写的题目就来了,那就简单介绍下它的原理。下边我们以分布式缓存中经典场景举例,面试中也是经常提及的一些话题,看看什么是一致性hash算法以及它有那些…...

docker部署项目
docker部署项目 (加载tar包:docker image load -i mysql.tar) 一、jdk环境配置 1.jdk下载地址 --Java Archive | Oracle 中国 --选择好版本进入 --下载Linux x64 Compressed Archive的链接 2.解压 --创建文件夹:mkdir /ro…...

每天40分玩转Django:Django Celery
Django Celery 一、知识要点概览表 模块知识点掌握程度要求Celery基础配置、任务定义、任务执行深入理解异步任务任务状态、结果存储、错误处理熟练应用周期任务定时任务、Crontab、任务调度熟练应用监控管理Flower、任务监控、性能优化理解应用 二、基础配置实现 1. 安装和…...

df.groupby(pd.Grouper(level=1)).sum()
df.groupby(pd.Grouper(level1)).sum() 在 Python 中的作用是根据 DataFrame 的某一索引级别进行分组,并计算每个分组的总和。具体来说: df.groupby(...):这是 pandas 的分组操作,按照指定的规则将 DataFrame 分组。 pd.Grouper(…...
运动控制探针功能详细介绍(CODESYS+SV63N伺服)
汇川AM400PLC和禾川X3E伺服EtherCAT通信 汇川AM400PLC和禾川X3E伺服EtherCAT通信_汇川ethercat通信-CSDN博客文章浏览阅读1.2k次。本文详细介绍了如何使用汇川AM400PLC通过EtherCAT总线与禾川X3E伺服进行通信。包括XML硬件描述文件的下载与安装,EtherCAT总线的启用,从站添加…...

C语言基础18(GDB调试)
文章目录 GDBGDB概述什么是GDB**GDB**的主要功能 GDB的启动GDB常见的启动方式 GDB的退出GDB的常用命令GDB查看源代码指令———list(1)**GDB** 查看设置**------info****GDB** 查看内存**GDB** 设置断点**---break (b)****GDB** 设置观察点**---watch****GDB** 程序调试 GDB完整…...

《向量数据库指南》——应对ElasticSearch挑战,拥抱Mlivus Cloud的新时代
在当今数据驱动的商业环境中,向量数据库的应用正变得愈加重要。随着人工智能和机器学习的快速发展,尤其是在自然语言处理、图像识别及推荐系统等领域,向量数据库以其强大的存储和检索能力,迎来了广泛的应用机会。然而,在实际应用中,企业在选择和实施向量数据库方案时,常…...

c++的stl库中stack的解析和模拟实现
目录 1.stack的介绍和使用 1.1stack的介绍 1.2stack的使用 2.stack的模拟实现 1.stack的介绍和使用 1.1stack的介绍 1. stack 是一种容器适配器,专门用在具有后进先出操作的上下文环境中,其删除只能从容器的一端进行元素的插入与提取操作。 2. stac…...

C语言——字符函数和内存函数
目录 前言 字符函数 1strlen 模拟实现 2strcpy 模拟实现 3strcat 模拟实现 4strcmp 模拟实现 5strncpy 模拟实现 6strncat 模拟实现 7strncmp 模拟实现 8strstr 模拟实现 9strtok 10strerror 11大小写字符转换函数 内存函数 1memcpy 模拟实现 2…...

查询docker overlay2文件夹下的 c7ffc13c49xxx是哪一个容器使用的
问题背景 查询docker overlay2文件夹下的 c7ffc13c49xxx是哪一个容器使用的 [root@lnops overlay2]# du -sh * | grep G 1.7G 30046eca3e838e43d16d9febc63cc8f8bb3d327b4c9839ca791b3ddfa845e12e 435G c7ffc13c49a43f08ef9e234c6ef9fc5a3692deda3c5d42149d0070e9d8124f71 1.…...

Golang的容器编排实践
Golang的容器编排实践 一、Golang中的容器编排概述 作为一种高效的编程语言,其在容器编排领域也有着广泛的运用。容器编排是指利用自动化工具对容器化的应用进行部署、管理和扩展的过程,典型的容器编排工具包括Docker Swarm、Kubernetes等。在Golang中&a…...

【51项目】51单片机自制小霸王游戏机
视频演示效果: 纳新作品——小霸王游戏机 目录: 目录 视频演示效果: 目录: 前言:...

ArkTs之NAPI学习
1.Node-api组成架构 为了应对日常开发经的网络通信、串口访问、多媒体解码、传感器数据收集等模块,这些模块大多数是使用c接口实现的,arkts侧如果想使用这些能力,就需要使用node-api这样一套接口去桥接c代码。Node-api整体的架构图如下&…...

【数据库初阶】MySQL中表的约束(上)
🎉博主首页: 有趣的中国人 🎉专栏首页: 数据库初阶 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 亲爱的小伙伴们,大家好!在这篇文章中,我们将深入浅出地为大家讲解 MySQL…...

173. 矩阵距离 acwing -多路BFS
原题链接:173. 矩阵距离 - AcWing题库 给定一个 N行 M 列的 01矩阵 A,A[i][j] 与 A[k][l]]之间的曼哈顿距离定义为: dist(i,j,k,l)|i−k||j−l|| 输出一个 N 行 M 列的整数矩阵 B,其中: B[i][j]min1≤x≤N,1≤y≤M,A…...

Linux下部署Redis集群 - 一主二从三哨兵模式
三台服务器redis一主二从三哨兵模式搭建 最近使用到了redis集群部署,使用一主二从三哨兵集群部署redis,将自己部署的过程中的使用心得分享给大家,希望大家以后部署的过程减少一些坑。 服务器准备 3台服务器 ,确定主redis和从red…...

实战设计模式之建造者模式
概述 在实际项目中,我们有时会遇到需要创建复杂对象的情况。这些对象可能包含多个组件或属性,而且每个组件都有自己的配置选项。如果直接使用构造函数或前面介绍的工厂方法来创建这样的对象,可能会导致以下两个严重问题。 1、参数过多。当一个…...

活动预告 | Microsoft Azure 在线技术公开课:使用 Azure OpenAI 服务构建生成式应用
课程介绍 通过 Microsoft Learn 免费参加 Microsoft Azure 在线技术公开课,掌握创造新机遇所需的技能,加快对 Microsoft Cloud 技术的了解。参加我们举办的“使用 Azure OpenAI 服务构建生成式应用”活动,了解如何使用包括 GPT 在内的强大的…...

ubuntu安装firefox
firefox下载地址:https://ftp.mozilla.org/pub/firefox/releases/ 卸载 sudo apt-get update dpkg --get-selections |grep firefox apt-get purge firefox 解压 tar -xjf firefox*.tar.bz2复制文件 sudo mv firefox/ /opt/firefox30sudo mv /usr/bin/firefox /…...

计算机网络原理(谢希仁第八版)第4章课后习题答案
第四章 网络层 详细计算机网络(谢希仁-第八版)第四章习题全解_计算机网络第八版谢希仁课后答案-CSDN博客 1.网络层向上提供的服务有哪两种?是比较其优缺点。网络层向运输层提供 “面向连接”虚电路(Virtual Circuit)服…...

RabbitMQ-基本使用
RabbitMQ: One broker to queue them all | RabbitMQ 官方 安装到Docker中 docker run \-e RABBITMQ_DEFAULT_USERrabbit \-e RABBITMQ_DEFAULT_PASSrabbit \-v mq-plugins:/plugins \--name mq \--hostname mq \-p 15672:15672 \-p 5672:5672 \--network mynet\-d \rabbitmq:3…...

从零开始学架构——互联网架构的演进
1 技术演进 1.1 技术演进的动力 对于新技术,我们应该站在行业的角度上思考,哪些技术我们要采取,哪些技术我们不能用,投入成本过大会不会导致满盘皆输?市场、技术、管理三者组成的业务发展铁三角,任何一个…...

python +tkinter绘制彩虹和云朵
python tkinter绘制彩虹和云朵 彩虹,简称虹,是气象中的一种光学现象,当太阳光照射到半空中的水滴,光线被折射及反射,在天空上形成拱形的七彩光谱,由外圈至内圈呈红、橙、黄、绿、蓝、靛、紫七种颜色。事实…...

重新整理机器学习和神经网络框架
本篇重新梳理了人工智能(AI)、机器学习(ML)、神经网络(NN)和深度学习(DL)之间存在一定的包含关系,以下是它们的关系及各自内容,以及人工智能领域中深度学习分支对比整理。…...