当前位置: 首页 > news >正文

PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

相关文章:

PyTorch 基础学习(13)- 混合精度训练

系列文章: 《PyTorch 基础学习》文章索引 基本概念 混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16 或 torch.bfloat16)数据类型的优势,…...

Mycat分片-垂直拆分

目录 场景 配置 测试 全局表配置 续接上篇:MySQ分库分表与MyCat安装配置-CSDN博客 续接下篇:Mycat分片-水平拆分-CSDN博客 场景 在业务系统中, 涉及以下表结构 ,但是由于用户与订单每天都会产生大量的数据, 单台服务器的数据 存储及处理能力是有限…...

一元四次方程求解-【附MATLAB代码】

目录 前言 求解方法 ​编辑 MATLAB验证 附:一元四次方程的故事 前言 最近在研究机器人的干涉(碰撞)检测,遇到了一个问题,就是在求椭圆到原点的最短距离时,构建的方程是一个一元四次方程。无论是高中的…...

【极限性能,尽在掌控】ROG NUC:游戏与创作的微型巨擘

初见ROG NUC,你或许会为它的小巧体型惊讶。然而,这看似不起眼的机身内,蕴藏着游戏、创意的强大能量。 掌中风暴,性能无界 ROG NUC搭载英特尔高性能处理器,配合高速NVMe SSD固态硬盘以及可选的高端独立显卡&#xff08…...

Ecosmos开启公测,将深度赋能CIOE中国光博会元宇宙参会新体验

如今,生成式AI技术的发展,极大地降低了3D数字资产的制作成本,元宇宙作为一种可以无缝将物理和数字资产进行融合的技术,在推动电子产业数字化进程、助力产业高质量发展的方面展现出了巨大的潜力。 当前,发展新质生产力是…...

【Kubernetes】k8s集群之包管理器Helm

目录 一.Helm概述 1.Helm的简介 2.Helm的三个重要概念 3.Helm2与Helm3的的区别 二.Helm 部署 1.安装 helm 2.使用 helm 安装 Chart 3.Helm 自定义模板 4.Helm 仓库 每个成功的软件平台都有一个优秀的打包系统,比如Debian、Ubuntu 的 apt,RedH…...

嵌入式linux系统镜像制作day3(构建镜像)

点击上方"蓝字"关注我们 01、上节回顾 嵌入式linux系统镜像制作day1嵌入式linux系统镜像制作day2提前下载好准备工具,不然失败了大眼瞪小眼。 02、构建 Poky 的 Sato 镜像1 环境: ubuntu18.04poky版本:Dizzy 工具git 在开始之前,针对不同的发行版,需要先执行…...

【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】

教师节中秋节国庆节车模特美女举牌生日视频制作教程AE模板改文字软件生成器素材 怎么如何做的【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件下载AE模板把AE模板导入AE软件修改图…...

RongCallKit iOS 端本地私有 pod 方案

RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…...

C++11:可变参数模板

目录 一、概述 二、场景 1.深拷贝的类 2.浅拷贝的类 C使用指南 一、概述 // Args是一个模板参数包&#xff0c;args是一个函数形参参数包 // 声明一个参数包Args...args&#xff0c;这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(…...

C++ 与 QML 之间进行数据交互的几种方法

https://www.cnblogs.com/jzcn/p/17774676.html 一、属性绑定 这是最简单的方式&#xff0c;可以在QML中直接绑定C 对象的属性。通过在C 对象中使用Q_PROPERTY宏定义属性&#xff0c;然后在QML中使用绑定语法将属性与QML元素关联起来。 1. person.h #include <QObject&g…...

Javaweb学习之Vue项目的创建(二)

学习资料 Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org) 准备工作都做完了&#xff0c;接下来开始Vue的正式学习。 第一步&#xff0c;打开VS Code 在VS Code里&#xff0c;我们也需要使用到终端&#xff0c;如果不是以管理员身份打开&#xff0c;在新建Vue项目的时候…...

『深度长文』4种有效提高LLM输出质量的方法!

LLM&#xff0c;全称Large Language Model&#xff0c;意为大型语言模型&#xff0c;是一种基于深度学习的AI技术&#xff0c;能够生成、理解和处理自然语言文本&#xff0c;也因此成为当前大多数AI工具的核心引擎。LLM通过学习海量的文本数据&#xff0c;掌握了词汇、语法、语…...

【工业机器人】工业异常检测大模型AnomalyGPT

AnomalyGPT 工业异常检测视觉大模型AnomalyGPT AnomalyGPT: Detecting Industrial Anomalies using Large Vision-Language Models AnomalyGPT是一种基于大视觉语言模型&#xff08;LVLM&#xff09;的新型工业异常检测&#xff08;IAD&#xff09;方法。它利用LVLM的能力来理…...

【PGCCC】PostgreSQL案例:planning time超长问题分析#PG初级

在使用 PostgreSQL 时&#xff0c;查询的执行计划&#xff08;planning time&#xff09;有时会出现异常长的情况&#xff0c;这可能会影响数据库的整体性能。分析和解决这种问题可以从多个角度入手&#xff0c;以下是常见原因和相应的解决思路&#xff1a; 1. 统计信息不准确…...

【图文并茂】ant design pro 如何给后端发送 json web token - 请求拦截器的使用

上一节有讲过 【图文并茂】ant design pro 如何对接后端个人信息接口 还差一个东西&#xff0c;去获取个人信息的时候&#xff0c;是要发送 token 的&#xff0c;不然会报 403. 就是说在你登录之后才去获得个人信息。这样后端才能知道是谁的信息。 token 就代码了某个人。 …...

【微信小程序】自定义组件 - behaviors

1. 什么是 behaviors 2. behaviors 的工作方式 3. 创建 behavior 调用 Behavior(Object object) 方法即可创建一个共享的 behavior 实例对象&#xff0c;供所有的组件使用&#xff1a; 4. 导入并使用 behavior 5. behavior 中所有可用的节点 6. 同名字段的覆盖和组合规则* 关…...

Linux ubuntu 24.04 安装运行《帝国时代3》免安装绿色版游戏,解决 “Could not load DATAP.BAR”等问题

Linux ubuntu 24.04 安装运行《帝国时代3》游戏&#xff0c;解决 “Could not load DATAP.BAR" 等问题 《帝国时代 3》是一款比较经典的即时战斗游戏&#xff0c;伴随了我半个高中时代&#xff0c;周末有时间就去泡网吧&#xff0c;可惜玩的都是简单人机&#xff0c;高难…...

Springboot 图片

Springboot 图片 因为 server.servlet.context-path: /api 所以 url是这个的时候 http://127.0.0.1:9100/api/staticfiles/image/dd56a59d-da84-441a-8dac-1d97f9e42090.jpeg 配置代码的前面的 /api 是不要写的 package com.gk.study.config;import org.springframework.conte…...

LIMS实验室管理系统如何实现数据自动采集

随着科研技术的不断发展&#xff0c;LIMS实验室管理系统的应用也愈来愈广&#xff0c;已经成为现代化实验室管理不可或缺的工具。LIMS实验室管理系统未与仪器设备对接前&#xff0c;仪器设备产生的数据都是通过人工录入到系统中&#xff0c;再经过人工审核形成最终的数据报告。…...

内存分配函数malloc kmalloc vmalloc

内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

vulnyx Blogger writeup

信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面&#xff0c;gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress&#xff0c;说明目标所使用的cms是wordpress&#xff0c;访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...

解读《网络安全法》最新修订,把握网络安全新趋势

《网络安全法》自2017年施行以来&#xff0c;在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂&#xff0c;网络攻击、数据泄露等事件频发&#xff0c;现行法律已难以完全适应新的风险挑战。 2025年3月28日&#xff0c;国家网信办会同相关部门起草了《网络安全…...

spring Security对RBAC及其ABAC的支持使用

RBAC (基于角色的访问控制) RBAC (Role-Based Access Control) 是 Spring Security 中最常用的权限模型&#xff0c;它将权限分配给角色&#xff0c;再将角色分配给用户。 RBAC 核心实现 1. 数据库设计 users roles permissions ------- ------…...

一些实用的chrome扩展0x01

简介 浏览器扩展程序有助于自动化任务、查找隐藏的漏洞、隐藏自身痕迹。以下列出了一些必备扩展程序&#xff0c;无论是测试应用程序、搜寻漏洞还是收集情报&#xff0c;它们都能提升工作流程。 FoxyProxy 代理管理工具&#xff0c;此扩展简化了使用代理&#xff08;如 Burp…...

MyBatis-Plus 常用条件构造方法

1.常用条件方法 方法 说明eq等于 ne不等于 <>gt大于 >ge大于等于 >lt小于 <le小于等于 <betweenBETWEEN 值1 AND 值2notBetweenNOT BETWEEN 值1 AND 值2likeLIKE %值%notLikeNOT LIKE %值%likeLeftLIKE %值likeRightLIKE 值%isNull字段 IS NULLisNotNull字段…...

简单聊下阿里云DNS劫持事件

阿里云域名被DNS劫持事件 事件总结 根据ICANN规则&#xff0c;域名注册商&#xff08;Verisign&#xff09;认定aliyuncs.com域名下的部分网站被用于非法活动&#xff08;如传播恶意软件&#xff09;&#xff1b;顶级域名DNS服务器将aliyuncs.com域名的DNS记录统一解析到shado…...