当前位置: 首页 > news >正文

「Pytorch」BF16 Mixed Precision Training

在深度学习领域,神经网络的训练性能瓶颈常常出现在 GPU显存的使用上。主要表现为两方面:

  1. 单卡上可容纳的模型和数据量有限;
  2. 显存与计算单元之间的带宽和延迟限制了运算速度;

为了解决显卡瓶颈的问题,涌现了不同的解决方法。


模型参数量估计

为了更好地估算模型所要占用的显存,首先需要分析模型训练过程中有哪些部分需要消耗存储空间。在 “ZeRO: Memory Optimizations Toward Training Trillion Parameter Model” 中提出,模型在训练时,主要有两大部分的空间占用。

  1. 对于大模型来说,主要的空间占用是模型状态,包括优化器状态(eg:Adam优化器的动量和方差)、模型参数和模型参数的梯度;
  2. 剩余的空间主要被模型训练中间激活值、临时缓冲区和不可用的内存碎片占用,统称为剩余状态。

以 3.5B 大模型为例。①35亿个参数,如果使用 FP16进行存储的话,即 70亿个字节,约7GB左右;②前向传播的激活值和反向传播的梯度大小跟模型参数保持一致,约 7GB;③以 Adam优化器为例,包括三部分,分别为 FP32格式模型参数的备份,FP32的动量和方差,加起来约28GB;因此,从理论上,要微调此模型的话,至少需要 49GB的空间。


1、BF16 半精度浮点数

双精度浮点 Float64;单浮点精度 Float32;半浮点精度 Float16,被广泛应用于model推理
为了进一步降低计算量和显存占用,可以考虑整数int4和int8量化推理

之前深度学习模型的训练通常都采用 Float32(FP32)的精度,而作者发现,使用较低的精度来进行模型的训练也是可行的,并且能够显著提升速度。通过采用混合精度训练,一般可以实现 2~3 倍的速度提升,极大的优化了模型的训练流程。

  • FP32 整体长度为4个字节,即32位,其中有8位的指数位宽,23位的尾数精度和1位的符号位,能够表示的数值范围是 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127
  • 在一些不太需要高精度计算的应用中,eg:图像处理和神经网络中,32位的空间其实有一些浪费,因此又出现了新的数据类型,半精度浮点数 FP16,使用16位(2个字节)来存储浮点值,有5位的指数位宽,10位的尾数精度和1位的符号位,能够表示的数值范围是 1 × 2 − 14 ∼ ( 2 − ϵ ) × 2 15 1\times 2^{-14} \sim (2-\epsilon)\times 2^{15} 1×214(2ϵ)×215
格式位数/位指数位宽/位尾数精度/位符号位/位数值范围
FP32328231 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127
FP16165101 1 × 2 − 14 ∼ ( 2 − ϵ ) × 2 15 1\times 2^{-14} \sim (2-\epsilon)\times 2^{15} 1×214(2ϵ)×215
BP3216871 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127

混合精度训练,即在模型训练时同时采用 FP32 和 FP16 两种精度。在实践过程中,研究人员发现在大语言模型的训练中直接使用 FP16会有一些问题,在训练过程中 loss 会非常不稳定,因此使用 FP16 训练大模型非常困难。问题在于 FP16的指数位宽只有 5位,能表示的最大整数为 65504,一旦权重超过这个值就会发生溢出,因此只能进行较小数的乘法,eg:可以计算 250 × 250 = 62500 250\times250=62500 250×250=62500,但如果计算 255 × 255 = 65025 255\times 255=65025 255×255=65025 就会溢出,这是导致训练出现问题的主要原因。这也意味着模型权重必须保持很小。一种成为损失缩放的技术可以缓解这个问题,但是当模型变得非常大时,FP16 较小的数值范围依旧是一个问题。

  • 为了更好地解决 FP16的问题,谷歌开发了一种新的浮点数格式 BF16(Brain Floating Point, 2个字节),用于降低存储需求,提高机器学习算法的计算速度。BF16 的指数位宽为8位,于 FP32相同,尾数精度采用7位,因此当使用 BF16时,精度非常差。然而,在训练模型时一般采用随机梯度下降法及其变体,其过程像蹒跚而行,即使某一步没有找到最优方向也没关系,模型会在后续调整纠正。

将模型参数类型从 FP16换为 BF16,训练的大模型 loss值的下降也会变得更加稳定。

这种低精度和 混合精度训练的方法逐渐被广泛接受和应用,深度学习框架、GPU以及 神经网络加速器的设计也因此受到了深渊的影响。可以说,混合精度训练的提出,对深度学习领域起到了关键的推动作用,有效地解决了 GPU显存的使用问题,提升了模型训练的效率。


2、混合精度训练

paper:Mixed Precision Training

  1. 维护一个权重的单精度副本,在每个优化器步骤后累计梯度(对于前向和反向传播,此副本四舍五入到半精度);
  2. 提出了损失缩放来保持小幅度的梯度值;
  3. 使用半精度算法,该算法累积为单精度输出,在存储到内存之前将其转化为半精度;

在这里插入图片描述

FP32 为主副本权重,

在混合精度训练时,权重、激活函数、梯度被保存为 FP16,为了与 FP32模型的精度相匹配,在optimizer step时,维持 FP32权重为主线,并使用权重梯度进行更新。在每次迭代时,主权重的 FP16副本用于前向和反向传播,将 FP32训练所需的存储和带宽减半,如上图所示。

虽然对 FP32住权重的需求并不普遍,但许多模型还需要 FP32的两个可能原因是:

  1. 权重更新变得太小,无法在 FP16中表示,任何大小小于 2 − 24 2^{-24} 224 的值在 FP16中都将变为 零,当与学习率相乘时,这些小值梯度在优化器中都会变为零,并对模型的准确性产生不利影响。使用单精度进行更新可以解决这一问题;
  2. 权重值 与 权重更新的比例非常大。在这种情况下,即使权重更新可以在 FP16中表示,当加法操作将其右移以使二进制点与权重对齐时,它仍然可能变为零。当归一化权重值的幅度比权重更新的新幅度大 至少2048倍,就会发生这种情况。由于 FP16有10位尾数,隐式位必须右移11位或更多位置,才能潜在地创建一个零。在比例大于2048时,隐式比特将右移12位或更多位。这将导致权重更新变得无法恢复的零。更大的比例将导致非标注化数字的效果。同样,可以通过计算 FP32中的更新来抵消这种影响。

图2-a所示,在 FP16前后传播更新 FP32权重主线时,匹配FP32训练结果,而更新FP16权重会导致 80%的相对精度损失。

由于更大的 batch-size 和每层的激活被保存以在反向传播过程中重复使用,因此训练内存消耗主要由激活决定。由于激活也以半精度格式存储,因此训练深度神经网络的整体内存消耗大约减半。

2.1 损失缩放

FP16指数偏差将归一化指数的范围集中到 [-14,15],而实践中的梯度值往往由小幅度(负指数)主导,如图3所示,显示了Multibox SSD模型的 FP32训练期间在所有层上的急活梯度值的直方图,FP16 可表示范围的大部分未必使用,而许多值低于最小可表示范围 变为指数为0。放大梯度将使它们占据更多的可表示范围,并保留否则会丢失为0的值。当梯度未被缩放时,这个特定的网络会发散,但将其压缩8倍(指数为3)就足以匹配 FP32训练所达到的精度。这说明激活

在这里插入图片描述

在这里插入图片描述

2.2运算精度

神经网络模型分为三类:vector dot-products, reductions, point-wise operations。当涉及到降低精度的算法时,这些类别受益于不同的处理,为了保持模型的准确性,发现一些模型要求 FP16矢量点积将部分累加成 FP32,在写入内存之前将其转换为 FP16。如果 FP32中没有这种积累,一些 FP16模型与基线模型的精度不匹配。

之前的GPU只支持 FP16 乘/加法运算,而 NVIDIA Volta GPU引入了 Tensor Core,可以将 FP16输入矩阵相乘,并将乘积累加到 FP16或 FP32输出中。

FP32中应进行大幅缩减(向量各元素之和)。在累积统计数据和 softmax层时,这种建好主要出现在 batch-normalization 层中。在两种层类型种,仍然从内存中读取和写入FP16张量,在 FP32中执行算术运算,这并没有减缓训练过程,因为这些层的内存带宽有限,对算术速度不敏感。

逐点操作,如非线性和逐像素矩阵运算,是内存带宽有限,由于算术精度不影响这些运算的速度,因此可以使用FP16 或 FP32.

相关文章:

「Pytorch」BF16 Mixed Precision Training

在深度学习领域,神经网络的训练性能瓶颈常常出现在 GPU显存的使用上。主要表现为两方面: 单卡上可容纳的模型和数据量有限;显存与计算单元之间的带宽和延迟限制了运算速度; 为了解决显卡瓶颈的问题,涌现了不同的解决…...

论文阅读:Efficient Core Maintenance in Large Bipartite Graphs | SIGMOD 2024

还记得我们昨天讨论的《Querying Historical Cohesive Subgraphs over Temporal Bipartite Graphs》这篇论文吗? https://blog.csdn.net/m0_62361730/article/details/141003301 这篇(还没看的快去看) 这篇论文主要研究如何在时间双向图上查询历史凝聚子图,而《E…...

LLMOps — 使用 BentoML 为 Llama-3 模型提供服务

使用 BentoML 和 Runpod 快速设置 LLM API 经常看到数据科学家对 LLM 的开发感兴趣,包括模型架构、训练技术或数据收集。然而,我注意到,很多时候,除了理论方面,许多人在以用户实际使用的方式提供这些模型时遇到了问题…...

微软蓝屏事件揭秘:有问题的数据引发内存读取越界

讲动人的故事,写懂人的代码 CrowdStrike前一阵在官网上发布了上周爆发的全球企业微软蓝屏事件的官方初步复盘结果。其中谈到了这次事件的根本原因: 2024年7月19日,我们部署了两个额外的IPC模板实例。由于内容验证器中的一个bug,使…...

NASA:北极ARCTAS差分吸收激光雷达(DIAL)遥感数据

ARCTAS Differential Absorption Lidar (DIAL) Remotely Sensed Data ARCTAS差分吸收激光雷达(DIAL)遥感数据 简介 ARCTAS差分吸收激光雷达(DIAL)遥感数据是一种远程感测技术,用于测量大气中不同波长的激光辐射被大…...

Android 文件上传与下载

在实际开发涉及文件上传不会自己写上传代码,一般 会集成第三网络库来做图片上传,比如android-async-http,okhttp等,另外还有七牛也提供 了下载和上传的API。 1.项目用到的图片上传的关键方法: 这里用到一个第三方的库…...

Java语言的充电桩系统Charging station system

介绍 SpringBoot 框架,充电桩平台充电桩系统充电平台充电桩互联互通协议云快充协议1.5-1.6协议新能源汽车二轮车公交车二轮车充电-四轮车充电充电源代码充电平台源码Java源码-共享充电桩-充电桩软件 软件介绍 小程序端:城市切换、附近电站、电桩详情页…...

RCE之无参数读取文件

什么是无参数&#xff1f; 顾名思义&#xff0c;就是只使用函数&#xff0c;且函数不能带有参数&#xff0c;这里有种种限制&#xff1a;比如我们选择的函数必须能接受其括号内函数的返回值&#xff1b;使用的函数规定必须参数为空或者为一个参数等 例题&#xff1a; <?…...

Python GUI开发必看:Tkinter Button控件使用详解

Button&#xff08;按钮&#xff09;组件用于实现各种各样的按钮。 Button组件可以包含文本或图像&#xff0c;你可以将一个Python的函数或方法与之相关联&#xff0c;当按钮被按下时&#xff0c;对应的函数或方法将被自动执行。 Button组件仅能显示单一字体的文本&#xff0c…...

上海市计算机学会竞赛平台2024年7月月赛丙组得分排名

题目描述 给定 nn 名学生的考试得分&#xff0c;这些学生的学号为 11 到 nn&#xff0c;其第 ii 号学生的得分为 aiai​&#xff0c;请将这些学生按照分数从大到小的顺序排列并输出学号序列。 若两个学生得分相同&#xff0c;则先输出较小的学号。 输入格式 第一行&#xf…...

Can GPT-3 Perform Statutory Reasoning?

文章目录 题目摘要相关工作SARAGPT-3 对美国法典的了解GPT-3 在对合成法规进行简单推理时遇到困难结论 题目 GPT-3 可以进行法定推理吗&#xff1f; 论文地址&#xff1a;https://arxiv.org/abs/2302.06100 摘要 法定推理是用事实和法规进行推理的任务&#xff0c;法规是立法机…...

redis面试(十一)锁超时

boolean res lock.tryLock(100, 10, TimeUnit.SECONDS); RedissonLock里面有这样一个方法tryLock()&#xff0c;意思是尝试获取锁的结果。 最大等待时间100s&#xff0c;并且获取到锁之后&#xff0c;10s之内没有释放的话&#xff0c;锁会自动失效。 尝试获取锁超时 time …...

C代码做底层及Matlab_SimuLink做应用层设计单片机程序

前言:SimuLink工具极其强大,但是能直接支持单片机自主开发的很少,造成这个问题的原因主要是我们使用的芯片底层多是C代码工程,芯片厂家也只提供C代码库,很少能提供SimuLink的支持库,即使提供也不是很不完善,如NXP的一些芯片提供的SimuLink库不含盖高级应用,再比如意法半…...

Cloud Kernel SIG 月度动态:ANCK OOT 驱动基线更新,发布 2 个 ANCK 版本

Cloud Kernel SIG&#xff08;Special Interest Group&#xff09;&#xff1a;支撑龙蜥内核版本的研发、发布和服务&#xff0c;提供生产可用的高性价比内核产品。 01 SIG 整体进展 1. 发布 ANCK 5.10-016.4 小版本。 2. 发布 ANCK 5.10-017.1 小版本。 3. ANCK 新增海光平…...

vue3仿飞书头像,根据不同名称生成不同的头像背景色

效果展示&#xff1a; 传递三个参数&#xff1a; name&#xff1a;要显示的名称&#xff1b;size&#xff1a;头像的大小&#xff1b;cutNum&#xff1a;分割当前名称的最后几位数&#xff1b; 代码如下&#xff1a; <template><div:style"{color: #fff,borde…...

SpringBoot整合三方

SpringBoot整合redis 引入redis依赖包 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency><dependency><groupId>redis.clients</groupId&g…...

React之组件的使用

Vue、React和Angular是三个流行的前端框架&#xff0c;采用组件化的开发方式。支持虚拟DOM&#xff08;Virtual DOM&#xff09;技术&#xff0c;有丰富的生态系统、大量的插件和工具可以使用。Vue的语法是传统的HTML和JavaScript&#xff0c;React使用JSX语法&#xff0c;Angu…...

深度学习--长短期记忆网络

1.引入 RNN 可以将以前的信息与当前的信息进行连接。例如&#xff0c;在视频中&#xff0c;可以用前面的帧来 帮助理解当前帧的内容&#xff1b;在文本中&#xff0c;可以用前面半句话的内容来预测后面的内容。但是&#xff0c; RNN 存在一个记忆消失的问题。例如&#xff0c;…...

研0 冲刺算法竞赛 day29 P2249 【深基13.例1】查找

P2249 【深基13.例1】查找 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路&#xff1a; ①二分查找 ②stl函数&#xff1a;lower_bound(a.begin(),a.end(),x) 返回第一个大于等于 x的数的地址 代码&#xff1a; #include<iostream> #include<algorithm> …...

基于vue框架的CKD电子病历系统nfa2e(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;患者,医生,药品信息,电子病历,临时医嘱,长期医嘱,健康科普 开题报告内容 基于Vue框架的CKD电子病历系统 开题报告 一、选题背景 随着信息技术的飞速发展和医疗信息化的深入推进&#xff0c;电子病历系统&#xff08;Electronic Medic…...

笔记:python 安装tar包报错

报错信息 ERROR: Could not find a version that satisfies the requirement setuptools>40.8.0 (from versions: none)ERROR: No matching distribution found for setuptools>40.8.0分析 1&#xff0c;当前已安装 setuptools 并且版本超过40.8.0 解决方案 缺包了&am…...

575. 分糖果

哈喽&#xff01;大家好&#xff0c;我是奇哥&#xff0c;一位专门给面试官添堵的职业面试员 文章持续更新&#xff0c;可以微信搜索【小奇JAVA面试】第一时间阅读&#xff0c;回复【资料】更有我为大家准备的福利哟&#xff01; 文章目录 一、题目二、答案三、总结 一、题目 …...

手机电量消耗分析工具 Battery Historian 指南

阅读五分钟&#xff0c;每日十点&#xff0c;和您一起终身学习&#xff0c;这里是程序员Android 本篇文章主要介绍 Android 开发中 电量 的部分知识点&#xff0c;通过阅读本篇文章&#xff0c;您将收获以下内容: 一、安装Battery Historian二、收集Batterystats 数据三、使用B…...

笔试练习day4

目录 WY22 Fibonacci数列题目解析解法暴力解法贪心代码 NC242 单词搜索题目解析例子1解析例子2解析例子3解析解法深度优先遍历dfs实现最终代码 BC140 杨辉三角解法线性dp问题代码 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 &#x1f412;&#x1f412;&am…...

公主少爷都爱看的haproxy七层代理详细介绍及常见实验详解

目录 一、负载均衡 1.1什么是负载均衡 1.2为什么要实验负载均衡 1.3四层负载均衡 1.4七层负载均衡 1.5四层负载均衡和七层负载均衡的对比 二、什么是haproxy 2.1定义 2. 2功能和特点 2.3应用场景 2.4haproxy的分类 三、安装及基本配置的信息 3.1软件的安装 3.2ha…...

Android笔试面试题AI答之Kotlin(6)

文章目录 24. 以下代码执行的结果是什么&#xff1f;25. 解释一下下述Kotlin 代码有什么问题&#xff1f;26. 如何在 Kotlin 中创建常量&#xff1f;示例注意事项总结 27. Koltin 可以互换使用 IntArray 和 Kotlin 中的 Array 吗&#xff1f;IntArrayArray<Int>互换使用从…...

“tcp控制协议”的理解

情景解释&#xff1a; 1.过程&#xff1a; 在用户进行网络间通信时&#xff0c;不管是客户端还是服务端&#xff0c;都会有两个缓冲区——发送缓冲区和接受缓冲区。 通过4个缓冲区进行数据交流。 用户通过write()将数据发送到他的发送缓冲区中&#xff0c;再传输到服务端的…...

考试成绩查询系统,发成绩不再愁

在考试季落幕后&#xff0c;家长们总是怀着紧张的心情&#xff0c;期待着孩子们的成绩单。传统上&#xff0c;老师们需要手动发送成绩单给每位家长&#xff0c;这个过程不仅耗时耗力&#xff0c;而且容易出错&#xff0c;给老师们带来了不小的压力。 随着科技的进步&#xff0c…...

【扒代码】图像数据 Transformer

def forward(self, x, bboxes):# 确定对象的数量&#xff0c;如果不是零样本学习场景&#xff0c;则根据bboxes的数量确定num_objects bboxes.size(1) if not self.zero_shot else self.num_objects# backbone# 通过主干网络提取特征backbone_features self.backbone(x)# pre…...

光伏仿真系统的使用流程

为了更有效地规划设计并评估光伏电站的性能与经济性&#xff0c;光伏仿真系统应运而生。该系统通过模拟实际运行场景&#xff0c;为项目开发者提供全面的数据支持和决策依据。 1、气象数据分析 气象数据是光伏系统性能评估的基础。此阶段&#xff0c;仿真系统需收集并处理项目…...