当前位置: 首页 > article >正文

PyTorch 混合精度训练:FP16 与 BF16 性能对比

PyTorch 混合精度训练FP16 与 BF16 性能对比1. 技术分析1.1 浮点精度对比精度位数范围精度内存占用FP32321.2e-38 ~ 3.4e387位有效数字4字节FP16166.1e-5 ~ 6.5e43位有效数字2字节BF16161.1e-38 ~ 3.4e383位有效数字2字节1.2 混合精度训练原理混合精度训练流程 1. 参数保持 FP32 2. 前向传播使用 FP16/BF16 3. 梯度计算使用 FP16/BF16 4. 梯度转换回 FP32 更新参数1.3 AMP (Automatic Mixed Precision)PyTorch 的 AMP 自动混合精度工具from torch.cuda.amp import autocast, GradScaler scaler GradScaler() with autocast(): output model(input) loss loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()2. 核心功能实现2.1 手动混合精度import torch import torch.nn as nn class MixedPrecisionModel(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size3) self.conv2 nn.Conv2d(64, 128, kernel_size3) self.fc nn.Linear(128 * 28 * 28, 10) def forward(self, x): x x.half() x self.conv1(x).half() x torch.nn.functional.relu(x) x self.conv2(x).half() x torch.nn.functional.relu(x) x x.float() x x.view(x.size(0), -1) x self.fc(x) return x def train_mixed_precision(): model MixedPrecisionModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() for epoch in range(10): inputs torch.randn(32, 3, 224, 224).cuda() targets torch.randint(0, 10, (32,)).cuda() optimizer.zero_grad() inputs_fp16 inputs.half() outputs model(inputs_fp16) loss loss_fn(outputs, targets) loss.backward() optimizer.step()2.2 使用 AMPfrom torch.cuda.amp import autocast, GradScaler class AMPModel(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( nn.Conv2d(3, 64, kernel_size3), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size3), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier nn.Linear(128 * 54 * 54, 10) def forward(self, x): x self.features(x) x x.view(x.size(0), -1) x self.classifier(x) return x def train_with_amp(): model AMPModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() scaler GradScaler() for epoch in range(100): inputs torch.randn(64, 3, 224, 224).cuda() targets torch.randint(0, 10, (64,)).cuda() optimizer.zero_grad() with autocast(): outputs model(inputs) loss loss_fn(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() class GradientScaling: def __init__(self, optimizer, initial_scale2**16): self.optimizer optimizer self.scale initial_scale self._growth_factor 2.0 self._backoff_factor 0.5 self._growth_interval 1000 def scale_loss(self, loss): return loss * self.scale def step(self): self.unscale_optimizer() self.optimizer.step() def unscale_optimizer(self): for param in self.optimizer.param_groups: if param[params][0].grad is not None: param[params][0].grad.data.div_(self.scale) def update(self, success): if success: self.scale min(self.scale * self._growth_factor, 2**24) else: self.scale * self._backoff_factor2.3 BF16 训练class BF16Model(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size3).bfloat16() self.conv2 nn.Conv2d(64, 128, kernel_size3).bfloat16() self.fc nn.Linear(128 * 54 * 54, 10).bfloat16() def forward(self, x): x x.bfloat16() x self.conv1(x) x torch.nn.functional.relu(x) x self.conv2(x) x torch.nn.functional.relu(x) x x.float() x x.view(x.size(0), -1) x self.fc(x) return x def train_bf16(): if not torch.cuda.is_bf16_supported(): print(BF16 not supported on this device) return model BF16Model().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() for epoch in range(10): inputs torch.randn(32, 3, 224, 224).cuda() targets torch.randint(0, 10, (32,)).cuda() optimizer.zero_grad() with torch.cuda.amp.autocast(dtypetorch.bfloat16): outputs model(inputs) loss loss_fn(outputs, targets) loss.backward() optimizer.step()2.4 精度混合策略class PrecisionMixer: def __init__(self, model, strategyauto): self.model model self.strategy strategy def apply_precision(self): if self.strategy fp16: return self._apply_fp16() elif self.strategy bf16: return self._apply_bf16() elif self.strategy auto: return self._apply_auto() def _apply_fp16(self): return self.model.half() def _apply_bf16(self): if not torch.cuda.is_bf16_supported(): raise RuntimeError(BF16 not supported) return self.model.bfloat16() def _apply_auto(self): for name, param in self.model.named_parameters(): if batch_norm in name or layer_norm in name: param.data param.data.float() else: param.data param.data.half() return self.model class MixedPrecisionLossScaler: def __init__(self, optimizer, dtypetorch.float16): self.optimizer optimizer self.dtype dtype self.scaler GradScaler(dtypedtype) def scale(self, loss): return self.scaler.scale(loss) def step(self): self.scaler.step(self.optimizer) self.scaler.update()3. 性能对比3.1 精度对比指标FP32FP16BF16训练速度1x1.5-2x1.3-1.8x内存占用1x0.5x0.5x数值稳定性高中高适用GPU所有VoltaAmpere3.2 训练时间对比模型FP32FP16BF16加速比ResNet-50100s55s60sFP16: 1.8xBERT-base200s110s120sFP16: 1.8xGPT-2500s280s300sFP16: 1.8x3.3 数值精度对比任务FP32准确率FP16准确率BF16准确率差异ImageNet分类76.1%75.9%76.0%-0.2%GLUE基准82.5%82.3%82.4%-0.2%语言建模45.245.045.1-0.24. 最佳实践4.1 梯度检查点与混合精度from torch.utils.checkpoint import checkpoint class CheckpointedModel(nn.Module): def __init__(self): super().__init__() self.block1 nn.Sequential( nn.Conv2d(3, 64, 3), nn.ReLU() ) self.block2 nn.Sequential( nn.Conv2d(64, 128, 3), nn.ReLU() ) self.block3 nn.Linear(128 * 54 * 54, 10) def forward(self, x): x checkpoint(self.block1, x) x checkpoint(self.block2, x) x x.view(x.size(0), -1) x self.block3(x) return x def train_checkpoint_amp(): model CheckpointedModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) scaler GradScaler() for epoch in range(10): inputs torch.randn(64, 3, 224, 224).cuda() targets torch.randint(0, 10, (64,)).cuda() optimizer.zero_grad() with autocast(): outputs model(inputs) loss nn.CrossEntropyLoss()(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4.2 精度选择策略def select_precision(): if torch.cuda.is_bf16_supported(): return torch.bfloat16 elif torch.cuda.is_available(): return torch.float16 else: return torch.float32 class PrecisionSelector: staticmethod def for_task(task_type): if task_type in [training, fine-tuning]: return select_precision() elif task_type inference: return torch.float16 else: return torch.float325. 总结混合精度训练是提升训练效率的关键技术FP16适合需要最大加速的场景BF16适合需要更好数值稳定性的场景AMP自动选择最佳精度策略梯度缩放防止梯度下溢对比数据如下FP16 可提升 1.5-2 倍训练速度BF16 数值稳定性更好适合大模型内存占用减少 50%精度损失通常在 0.2% 以内

相关文章:

PyTorch 混合精度训练:FP16 与 BF16 性能对比

PyTorch 混合精度训练:FP16 与 BF16 性能对比 1. 技术分析 1.1 浮点精度对比 精度位数范围精度内存占用FP32321.2e-38 ~ 3.4e387位有效数字4字节FP16166.1e-5 ~ 6.5e43位有效数字2字节BF16161.1e-38 ~ 3.4e383位有效数字2字节 1.2 混合精度训练原理 混合精度训练流程…...

AI意识评估:从理论到工程实践的科学探索

1. 项目概述:当AI开始“思考”,我们如何评估?“AI意识评估”这个标题,听起来像科幻小说里的概念,但事实上,它正迅速从一个哲学思辨议题,演变为一个迫在眉睫的工程与伦理挑战。作为一名长期关注前…...

医疗生成式AI的伦理挑战与GREAT PLEA治理框架实践指南

1. 项目概述:当AI开始“思考”医疗最近几年,生成式AI在医疗领域的应用,已经从实验室的“概念验证”阶段,快速渗透到临床辅助诊断、药物研发、患者教育乃至医院运营管理的方方面面。作为一名长期关注医疗科技交叉领域的从业者&…...

从信托义务到AI对齐:构建可信人工智能的技术与治理框架

1. 项目概述:当法律遇上代码最近和几位做AI产品落地的朋友聊天,大家不约而同地提到了同一个词:“对齐”。但聊着聊着,话题就从技术上的“奖励模型”和“人类反馈强化学习”,滑向了更让人头疼的领域——合规、责任和信任…...

基于Claude API的智能代码生成工具设计与实现

1. 项目概述:一个被“设计失败”命名的代码生成工具在开发者社区里,项目名称往往承载着创始人的某种情绪或愿景。当你第一次看到designfailure/claudecode这个仓库名时,可能会感到一丝困惑甚至好奇。designfailure(设计失败&#…...

自主智能体架构解析:从ReAct框架到实战应用开发指南

1. 项目概述与核心价值最近在GitHub上看到一个名为“Autonomous-Agents”的项目,作者是tmgthb。这个标题本身就充满了吸引力,它指向了当前人工智能领域一个极其热门且富有想象力的方向——自主智能体。简单来说,这个项目探讨和实现的&#xf…...

RAG-Fusion:用多查询与RRF融合提升复杂意图检索效果

1. 项目概述:RAG-Fusion,一次对搜索本质的深度探索如果你和我一样,在过去几年里一直在折腾RAG(检索增强生成)相关的项目,那你肯定经历过这种时刻:精心构建的向量数据库,配上强大的大…...

基于AI的GitHub仓库自动化管理:GHPT项目实战解析

1. 项目概述:当GitHub遇上AI,一个开源项目的新玩法最近在GitHub上闲逛,发现了一个挺有意思的项目,叫“GHPT”。光看名字,你可能会联想到GPT,没错,它确实和AI有关。但它的全称和定位,…...

Yocto与SystemReady IR构建嵌入式Linux统一镜像实践

1. 项目概述 在嵌入式Linux开发领域,Yocto Project已成为构建定制化Linux发行版的事实标准工具链。其核心价值在于模块化设计理念,通过OpenEmbedded构建系统和BitBake工具实现高效的跨平台编译。然而,传统嵌入式开发面临一个根本性挑战&#…...

AI友好型Excel知识库与自动化工具:提升数据分析与报表生成效率

1. 项目概述:一个为AI“投喂”的Excel生产力工具箱如果你和我一样,每天的工作都离不开Excel,但又不是那种能把VBA玩出花来的“表哥表姐”,那你一定经历过这种痛苦:面对一堆数据,你知道用某个公式或者透视表…...

ARM GIC IRS寄存器框架解析与性能优化

1. ARM GIC IRS寄存器框架概述中断控制器(GIC)是现代ARM处理器系统中的核心组件,负责高效管理和分发硬件中断。IRS(Interrupt Routing Service)作为GICv5架构引入的重要功能模块,通过精心设计的寄存器框架实现了对中断域(Interrupt Domain)的精确控制。与…...

ClawTeam-OpenClaw:基于文件系统的AI多智能体集群协调框架实战

1. 项目概述:从单兵作战到智能集群的进化如果你和我一样,长期在AI辅助编程和自动化领域摸爬滚打,那你一定经历过这样的场景:面对一个复杂的项目,你让一个AI代理去处理,它吭哧吭哧干半天,要么卡在…...

BrowserOS:基于现代Web技术构建的浏览器内桌面操作系统

1. 项目概述:一个运行在浏览器里的操作系统,它想做什么?最近在GitHub上看到一个挺有意思的项目,叫BrowserOS。光看名字,你可能会想,这又是个什么“玩具”或者概念验证?但当我真正花时间研究并尝…...

隐私优先的本地化个人基因组分析工具:从SNP解析到多基因风险评分

1. 项目概述:一个隐私至上的本地化个人基因组分析工具如果你和我一样,对消费级基因检测(比如23andMe、AncestryDNA)的结果感到好奇,但又对把最私密的遗传数据上传到云端服务器心存疑虑,那么你一定会对wkyle…...

基于AST的Markdown文档自动化发现工具discovery-md实战指南

1. 项目概述与核心价值 最近在整理个人知识库和项目文档时,我一直在寻找一种能兼顾简洁、强大和可移植性的文档格式。Markdown 无疑是首选,但如何高效地“发现”和组织散落在各个角落的 .md 文件,并快速理解其内容结构,却是个不…...

Haft:AI辅助开发中的工程治理与决策可追溯性实践

1. 项目概述:Haft——AI辅助软件交付的工程治理层在AI编码助手(如Claude Code、Cursor)日益普及的今天,我们正面临一个全新的工程挑战:代码生成的速度前所未有,但生成代码背后的决策质量、长期可维护性以及…...

ARM TrustZone MPC寄存器架构与安全机制解析

1. ARM TrustZone MPC寄存器架构解析在嵌入式安全领域,内存保护控制器(Memory Protection Controller, MPC)作为TrustZone技术体系的核心组件,承担着物理内存隔离的关键职责。以AHB5总线上的TrustZone MPC为例,其寄存器…...

基于MCP与ReceiptConverter的票据自动化解析与AI集成方案

1. 项目概述:让AI助手直接“看懂”你的票据 如果你和我一样,经常需要处理一堆杂乱的发票、收据,然后手动把它们录入到表格或者记账软件里,那你肯定知道这活儿有多烦人。一张张拍照、整理、对着模糊的小票辨认商品和金额&#xff…...

ARM Cortex-A9中断控制器架构与多核处理优化

1. ARM Cortex-A9中断控制器架构解析在嵌入式系统设计中,中断控制器作为处理器与外部设备通信的核心枢纽,其性能直接影响系统的实时响应能力。ARM Cortex-A9 MPCore采用的中断控制器架构,通过硬件级的中断管理和分发机制,为多核处…...

从零到一掌握提示工程:系统化方法与实战指南

1. 项目概述:从零到一掌握提示工程如果你正在使用ChatGPT、Claude或者任何基于大语言模型(LLM)的工具,并且感觉自己的提问方式总是“差那么一点意思”——要么得到的答案太笼统,要么需要反复追问才能触及核心&#xff…...

医疗AI协作实战:跨越数据科学与临床医学的沟通鸿沟

1. 项目概述:当数据科学家遇上临床医生“我们模型在测试集上的AUC达到了0.95!”数据科学家兴奋地向团队汇报。 “所以,它能告诉我明天早上查房时,3床的病人会不会发生术后感染吗?”临床主任医师平静地问道。 会议室里瞬…...

Craft Agents 爆火:Agent 工具正在从“命令行玩具”走向“工作流系统”

开源地址:GitHub 项目 lukilabs/craft-agents-oss当前 GitHub 页面显示,该项目已达到 5.8k Star、779 Fork,同时还有较活跃的 Issue 和 PR 讨论。https://github.com/lukilabs/craft-agents-oss最近,Agent 类开源项目又火了一个。…...

并行计算突破:RNN序列依赖的并行化重构与优化

1. 并行计算革命:打破RNN序列依赖的固有认知循环神经网络(RNN)长期被视为序列建模的黄金标准,但其序列依赖性导致的计算瓶颈一直困扰着研究者。传统观点认为,评估长度为T的序列必须严格遵循O(T)的时间复杂度——即使拥…...

ARM GIC中断域管理与系统指令详解

1. ARM GIC中断域管理概述在ARM架构中,通用中断控制器(GIC)是处理中断请求的核心组件。作为系统级外设,GIC负责接收来自各种硬件设备的中断信号,进行优先级仲裁后分发给处理器核心处理。现代ARM处理器通常集成GICv3或GICv4架构的中断控制器&a…...

创业团队如何利用统一API网关管理多个大模型调用与成本

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 创业团队如何利用统一API网关管理多个大模型调用与成本 对于资源有限的创业团队而言,在业务开发中引入大模型能力&…...

AI Agent自动化求职实战:基于Python与LLM的智能简历投递系统

1. 项目概述与核心价值最近在技术社区里,关于AI Agent如何自动化处理重复性工作的讨论越来越热。作为一个在招聘和自动化领域摸爬滚打了十来年的老手,我亲眼见证了求职者从海投简历到使用各种工具辅助的演变。今天想和大家深入聊聊一个让我印象深刻的开源…...

Python基础篇之初识Python必看攻略

Python简介python的创始人为吉多范罗苏姆(Guido van Rossum)。1989年的圣诞节期间,吉多范罗苏姆为了在阿姆斯特丹打发时间,决心开发一个新的脚本解释程序,作为ABC语言的一种继承。 Python和其他语言的对比:…...

CANN/HCOMM通信通道内存屏障API

HcommChannelFenceOnThread 【免费下载链接】hcomm HCOMM(Huawei Communication)是HCCL的通信基础库,提供通信域以及通信资源的管理能力。 项目地址: https://gitcode.com/cann/hcomm 产品支持情况 Ascend 950PR/Ascend 950DT&#x…...

CANN/SiP Cgemv复数矩阵向量乘法

Cgemv 【免费下载链接】sip 本项目是CANN提供的一款高效、可靠的高性能信号处理算子加速库,基于华为Ascend AI处理器,专门为信号处理领域而设计。 项目地址: https://gitcode.com/cann/sip 产品支持情况 产品是否支持Atlas 200I/500 A2 推理产品…...

集成电路设计中的关键特征分析(CFA)技术与应用

1. 关键特征分析(CFA)技术概述关键特征分析(Critical Feature Analysis, CFA)是现代集成电路设计制造(DFM)流程中的核心质量评估工具。这项技术最早由Mentor Graphics(现为Siemens EDA)在2000年代中期提出,旨在解决传统DRC(设计规则检查)仅做"通过/失败"二…...