当前位置: 首页 > news >正文

PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

PyTorch AMP 混合精度中的 scale 函数解析

混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale 函数是 GradScaler 的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。

本文将详细解析 scale 函数的作用、代码逻辑,以及 apply_scale 子函数的递归作用。


函数代码回顾

以下是 scale 函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

torch 2.4.0+cu121版本

def scale(self,outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:"""Multiplies ('scales') a tensor or list of tensors by the scale factor.Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returnedunmodified.Args:outputs (Tensor or iterable of Tensors):  Outputs to scale."""if not self._enabled:return outputs# Short-circuit for the common case.if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scaledef apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)

1. 函数作用

scale 函数的主要作用是将输出张量(outputs)按当前的缩放因子(self._scale)进行缩放。它支持以下两种输入:

  1. 单个张量:直接将缩放因子乘以张量。
  2. 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。

当 AMP 功能未启用时(即 self._enabledFalse),scale 函数会直接返回原始的 outputs,不执行任何缩放操作。

使用场景

  • 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
  • 支持多设备:通过 _MultiDeviceReplicator 支持张量分布在多个设备(如多 GPU)的场景。

2. 核心代码解析

(1) 短路处理单个张量

当输入 outputs 是单个张量(torch.Tensor)时,函数直接对其进行缩放:

if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
  1. 如果缩放因子 self._scale 尚未初始化,则调用 _lazy_init_scale_growth_tracker 方法在指定设备上初始化缩放因子。
  2. 使用 outputs * self._scale 对张量进行缩放。这里使用了 to(device=outputs.device) 确保缩放因子与张量在同一设备上。

这是单个张量输入的快速路径处理。


(2) 多张量递归处理逻辑

当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale 进行递归缩放:

stash: List[_MultiDeviceReplicator] = []  # 用于存储缩放因子对象def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
apply_scale 子函数的作用
  1. 张量处理

    • 如果 val 是单个张量,检查 stash 是否为空。
    • 如果为空,初始化缩放因子对象 _MultiDeviceReplicator,并存储在 stash 中。
    • 使用 stash[0].get(val.device) 获取对应设备上的缩放因子,并对张量进行缩放。
  2. 递归处理可迭代对象

    • 如果 val 是一个可迭代对象,调用 map(apply_scale, val),对其中的每个元素递归地调用 apply_scale
    • 如果输入是 listtuple,则保持其原始类型。
  3. 类型检查

    • 如果 val 既不是张量也不是可迭代对象,抛出错误。

3. apply_scale 是递归函数吗?

是的,apply_scale 是一个递归函数。

递归逻辑

  • 当输入为嵌套结构(如张量的列表或列表中的列表)时,apply_scale 会递归调用自身,将缩放因子应用到最底层的张量。
  • 递归的终止条件是 val 为单个张量(torch.Tensor)。
示例:

假设输入为嵌套张量列表:

outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)

递归处理过程如下:

  1. outputs 调用 apply_scale

    • 第一个元素是张量 torch.tensor([1.0, 2.0]),直接缩放。
    • 第二个元素是列表,递归调用 apply_scale
  2. 进入嵌套列表 [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]

    • 第一个元素是张量 torch.tensor([3.0]),缩放。
    • 第二个元素是张量 torch.tensor([4.0, 5.0]),缩放。

4. _MultiDeviceReplicator 的作用

_MultiDeviceReplicator 是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。

  • 当张量分布在多个设备(如 GPU)时,_MultiDeviceReplicator 可以高效地为每个设备提供所需的缩放因子,避免重复初始化。

总结

scale 函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale,该函数能够处理嵌套的张量结构,同时支持多设备场景。

关键点总结:

  1. 快速路径:单张量输入的情况下,直接进行缩放。
  2. 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
  3. 设备管理:通过 _MultiDeviceReplicator 支持多设备场景。

通过 scale 函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。

后记

2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。

相关文章:

PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

PyTorch AMP 混合精度中的 scale 函数解析 混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子&…...

【Ubuntu20.04】Apollo10.0 Docker容器部署+常见错误解决

官方参考文档【点击我】 Apollo 10.0 版本开始,支持本机和Docker容器两种部署方式。 如果您使用本机部署方式,建议使用x86_64架构的Ubuntu 22.04操作系统或者aarch64架构的Ubuntu 20.04操作系统。 如果您使用Docker容器部署方式,可以使用x…...

【文献精读笔记】Explainability for Large Language Models: A Survey (大语言模型的可解释性综述)(二)

****非斜体正文为原文献内容(也包含笔者的补充),灰色块中是对文章细节的进一步详细解释! 3.1.2 基于注意力的解释(Attention-Based Explanation) 注意力机制可以揭示输入数据中各个部分之间的关系&#…...

朱姆沃尔特隐身战舰:从失败到威慑

前言 "朱姆沃尔特"号驱逐舰是美国海军雄心勃勃的项目,旨在重塑未来海战。它融合了隐身、自动化和强大火力,然而由于技术问题和预算超支,原计划建造32艘的目标被大幅缩减,最终只建造了三艘。该舰的设计特点包括“穿浪逆船…...

免费分享 | 基于极光优化算法PLO优化宽度学习BLS实现光伏数据预测算法研究附Matlab代码

研究内容 宽度学习系统(BLS)简介: BLS是一种新型的神经网络结构,由增强节点(Enhancement Nodes, ENs)和特征节点(Feature Nodes, FNs)组成,具有结构简单、训练速度快、泛…...

logback日志文件多环境配置路径

项目中遇到问题,springboot项目 本地jar包部署到现场后,经常遇到现场的日志存放的路径会更改,经过查阅,有两种方式,下面简单说明一下。 一、第一种 启动jar包时 添加参数 --logging.configF:\hgtest\config\logback.x…...

面试高频:一致性hash算法

这两天看到技术群里,有小伙伴在讨论一致性hash算法的问题,正愁没啥写的题目就来了,那就简单介绍下它的原理。下边我们以分布式缓存中经典场景举例,面试中也是经常提及的一些话题,看看什么是一致性hash算法以及它有那些…...

docker部署项目

docker部署项目 (加载tar包:docker image load -i mysql.tar) 一、jdk环境配置 1.jdk下载地址 --Java Archive | Oracle 中国 --选择好版本进入 --下载Linux x64 Compressed Archive的链接 2.解压 --创建文件夹:mkdir /ro…...

每天40分玩转Django:Django Celery

Django Celery 一、知识要点概览表 模块知识点掌握程度要求Celery基础配置、任务定义、任务执行深入理解异步任务任务状态、结果存储、错误处理熟练应用周期任务定时任务、Crontab、任务调度熟练应用监控管理Flower、任务监控、性能优化理解应用 二、基础配置实现 1. 安装和…...

df.groupby(pd.Grouper(level=1)).sum()

df.groupby(pd.Grouper(level1)).sum() 在 Python 中的作用是根据 DataFrame 的某一索引级别进行分组,并计算每个分组的总和。具体来说: df.groupby(...):这是 pandas 的分组操作,按照指定的规则将 DataFrame 分组。 pd.Grouper(…...

运动控制探针功能详细介绍(CODESYS+SV63N伺服)

汇川AM400PLC和禾川X3E伺服EtherCAT通信 汇川AM400PLC和禾川X3E伺服EtherCAT通信_汇川ethercat通信-CSDN博客文章浏览阅读1.2k次。本文详细介绍了如何使用汇川AM400PLC通过EtherCAT总线与禾川X3E伺服进行通信。包括XML硬件描述文件的下载与安装,EtherCAT总线的启用,从站添加…...

C语言基础18(GDB调试)

文章目录 GDBGDB概述什么是GDB**GDB**的主要功能 GDB的启动GDB常见的启动方式 GDB的退出GDB的常用命令GDB查看源代码指令———list(1)**GDB** 查看设置**------info****GDB** 查看内存**GDB** 设置断点**---break (b)****GDB** 设置观察点**---watch****GDB** 程序调试 GDB完整…...

《向量数据库指南》——应对ElasticSearch挑战,拥抱Mlivus Cloud的新时代

在当今数据驱动的商业环境中,向量数据库的应用正变得愈加重要。随着人工智能和机器学习的快速发展,尤其是在自然语言处理、图像识别及推荐系统等领域,向量数据库以其强大的存储和检索能力,迎来了广泛的应用机会。然而,在实际应用中,企业在选择和实施向量数据库方案时,常…...

c++的stl库中stack的解析和模拟实现

目录 1.stack的介绍和使用 1.1stack的介绍 1.2stack的使用 2.stack的模拟实现 1.stack的介绍和使用 1.1stack的介绍 1. stack 是一种容器适配器,专门用在具有后进先出操作的上下文环境中,其删除只能从容器的一端进行元素的插入与提取操作。 2. stac…...

C语言——字符函数和内存函数

目录 前言 字符函数 1strlen 模拟实现 2strcpy 模拟实现 3strcat 模拟实现 4strcmp 模拟实现 5strncpy 模拟实现 6strncat 模拟实现 7strncmp 模拟实现 8strstr 模拟实现 9strtok 10strerror 11大小写字符转换函数 内存函数 1memcpy 模拟实现 2…...

查询docker overlay2文件夹下的 c7ffc13c49xxx是哪一个容器使用的

问题背景 查询docker overlay2文件夹下的 c7ffc13c49xxx是哪一个容器使用的 [root@lnops overlay2]# du -sh * | grep G 1.7G 30046eca3e838e43d16d9febc63cc8f8bb3d327b4c9839ca791b3ddfa845e12e 435G c7ffc13c49a43f08ef9e234c6ef9fc5a3692deda3c5d42149d0070e9d8124f71 1.…...

Golang的容器编排实践

Golang的容器编排实践 一、Golang中的容器编排概述 作为一种高效的编程语言,其在容器编排领域也有着广泛的运用。容器编排是指利用自动化工具对容器化的应用进行部署、管理和扩展的过程,典型的容器编排工具包括Docker Swarm、Kubernetes等。在Golang中&a…...

【51项目】51单片机自制小霸王游戏机

视频演示效果: 纳新作品——小霸王游戏机 目录: 目录 视频演示效果: 目录: 前言:...

ArkTs之NAPI学习

1.Node-api组成架构 为了应对日常开发经的网络通信、串口访问、多媒体解码、传感器数据收集等模块,这些模块大多数是使用c接口实现的,arkts侧如果想使用这些能力,就需要使用node-api这样一套接口去桥接c代码。Node-api整体的架构图如下&…...

【数据库初阶】MySQL中表的约束(上)

🎉博主首页: 有趣的中国人 🎉专栏首页: 数据库初阶 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 亲爱的小伙伴们,大家好!在这篇文章中,我们将深入浅出地为大家讲解 MySQL…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

循环冗余码校验CRC码 算法步骤+详细实例计算

通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)&#xff0…...

Java 加密常用的各种算法及其选择

在数字化时代,数据安全至关重要,Java 作为广泛应用的编程语言,提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景,有助于开发者在不同的业务需求中做出正确的选择。​ 一、对称加密算法…...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

人机融合智能 | “人智交互”跨学科新领域

本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...

鸿蒙(HarmonyOS5)实现跳一跳小游戏

下面我将介绍如何使用鸿蒙的ArkUI框架,实现一个简单的跳一跳小游戏。 1. 项目结构 src/main/ets/ ├── MainAbility │ ├── pages │ │ ├── Index.ets // 主页面 │ │ └── GamePage.ets // 游戏页面 │ └── model │ …...