当前位置: 首页 > news >正文

PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

相关文章:

PyTorch 基础学习(13)- 混合精度训练

系列文章: 《PyTorch 基础学习》文章索引 基本概念 混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16 或 torch.bfloat16)数据类型的优势,…...

Mycat分片-垂直拆分

目录 场景 配置 测试 全局表配置 续接上篇:MySQ分库分表与MyCat安装配置-CSDN博客 续接下篇:Mycat分片-水平拆分-CSDN博客 场景 在业务系统中, 涉及以下表结构 ,但是由于用户与订单每天都会产生大量的数据, 单台服务器的数据 存储及处理能力是有限…...

一元四次方程求解-【附MATLAB代码】

目录 前言 求解方法 ​编辑 MATLAB验证 附:一元四次方程的故事 前言 最近在研究机器人的干涉(碰撞)检测,遇到了一个问题,就是在求椭圆到原点的最短距离时,构建的方程是一个一元四次方程。无论是高中的…...

【极限性能,尽在掌控】ROG NUC:游戏与创作的微型巨擘

初见ROG NUC,你或许会为它的小巧体型惊讶。然而,这看似不起眼的机身内,蕴藏着游戏、创意的强大能量。 掌中风暴,性能无界 ROG NUC搭载英特尔高性能处理器,配合高速NVMe SSD固态硬盘以及可选的高端独立显卡&#xff08…...

Ecosmos开启公测,将深度赋能CIOE中国光博会元宇宙参会新体验

如今,生成式AI技术的发展,极大地降低了3D数字资产的制作成本,元宇宙作为一种可以无缝将物理和数字资产进行融合的技术,在推动电子产业数字化进程、助力产业高质量发展的方面展现出了巨大的潜力。 当前,发展新质生产力是…...

【Kubernetes】k8s集群之包管理器Helm

目录 一.Helm概述 1.Helm的简介 2.Helm的三个重要概念 3.Helm2与Helm3的的区别 二.Helm 部署 1.安装 helm 2.使用 helm 安装 Chart 3.Helm 自定义模板 4.Helm 仓库 每个成功的软件平台都有一个优秀的打包系统,比如Debian、Ubuntu 的 apt,RedH…...

嵌入式linux系统镜像制作day3(构建镜像)

点击上方"蓝字"关注我们 01、上节回顾 嵌入式linux系统镜像制作day1嵌入式linux系统镜像制作day2提前下载好准备工具,不然失败了大眼瞪小眼。 02、构建 Poky 的 Sato 镜像1 环境: ubuntu18.04poky版本:Dizzy 工具git 在开始之前,针对不同的发行版,需要先执行…...

【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】

教师节中秋节国庆节车模特美女举牌生日视频制作教程AE模板改文字软件生成器素材 怎么如何做的【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件下载AE模板把AE模板导入AE软件修改图…...

RongCallKit iOS 端本地私有 pod 方案

RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…...

C++11:可变参数模板

目录 一、概述 二、场景 1.深拷贝的类 2.浅拷贝的类 C使用指南 一、概述 // Args是一个模板参数包&#xff0c;args是一个函数形参参数包 // 声明一个参数包Args...args&#xff0c;这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(…...

C++ 与 QML 之间进行数据交互的几种方法

https://www.cnblogs.com/jzcn/p/17774676.html 一、属性绑定 这是最简单的方式&#xff0c;可以在QML中直接绑定C 对象的属性。通过在C 对象中使用Q_PROPERTY宏定义属性&#xff0c;然后在QML中使用绑定语法将属性与QML元素关联起来。 1. person.h #include <QObject&g…...

Javaweb学习之Vue项目的创建(二)

学习资料 Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org) 准备工作都做完了&#xff0c;接下来开始Vue的正式学习。 第一步&#xff0c;打开VS Code 在VS Code里&#xff0c;我们也需要使用到终端&#xff0c;如果不是以管理员身份打开&#xff0c;在新建Vue项目的时候…...

『深度长文』4种有效提高LLM输出质量的方法!

LLM&#xff0c;全称Large Language Model&#xff0c;意为大型语言模型&#xff0c;是一种基于深度学习的AI技术&#xff0c;能够生成、理解和处理自然语言文本&#xff0c;也因此成为当前大多数AI工具的核心引擎。LLM通过学习海量的文本数据&#xff0c;掌握了词汇、语法、语…...

【工业机器人】工业异常检测大模型AnomalyGPT

AnomalyGPT 工业异常检测视觉大模型AnomalyGPT AnomalyGPT: Detecting Industrial Anomalies using Large Vision-Language Models AnomalyGPT是一种基于大视觉语言模型&#xff08;LVLM&#xff09;的新型工业异常检测&#xff08;IAD&#xff09;方法。它利用LVLM的能力来理…...

【PGCCC】PostgreSQL案例:planning time超长问题分析#PG初级

在使用 PostgreSQL 时&#xff0c;查询的执行计划&#xff08;planning time&#xff09;有时会出现异常长的情况&#xff0c;这可能会影响数据库的整体性能。分析和解决这种问题可以从多个角度入手&#xff0c;以下是常见原因和相应的解决思路&#xff1a; 1. 统计信息不准确…...

【图文并茂】ant design pro 如何给后端发送 json web token - 请求拦截器的使用

上一节有讲过 【图文并茂】ant design pro 如何对接后端个人信息接口 还差一个东西&#xff0c;去获取个人信息的时候&#xff0c;是要发送 token 的&#xff0c;不然会报 403. 就是说在你登录之后才去获得个人信息。这样后端才能知道是谁的信息。 token 就代码了某个人。 …...

【微信小程序】自定义组件 - behaviors

1. 什么是 behaviors 2. behaviors 的工作方式 3. 创建 behavior 调用 Behavior(Object object) 方法即可创建一个共享的 behavior 实例对象&#xff0c;供所有的组件使用&#xff1a; 4. 导入并使用 behavior 5. behavior 中所有可用的节点 6. 同名字段的覆盖和组合规则* 关…...

Linux ubuntu 24.04 安装运行《帝国时代3》免安装绿色版游戏,解决 “Could not load DATAP.BAR”等问题

Linux ubuntu 24.04 安装运行《帝国时代3》游戏&#xff0c;解决 “Could not load DATAP.BAR" 等问题 《帝国时代 3》是一款比较经典的即时战斗游戏&#xff0c;伴随了我半个高中时代&#xff0c;周末有时间就去泡网吧&#xff0c;可惜玩的都是简单人机&#xff0c;高难…...

Springboot 图片

Springboot 图片 因为 server.servlet.context-path: /api 所以 url是这个的时候 http://127.0.0.1:9100/api/staticfiles/image/dd56a59d-da84-441a-8dac-1d97f9e42090.jpeg 配置代码的前面的 /api 是不要写的 package com.gk.study.config;import org.springframework.conte…...

LIMS实验室管理系统如何实现数据自动采集

随着科研技术的不断发展&#xff0c;LIMS实验室管理系统的应用也愈来愈广&#xff0c;已经成为现代化实验室管理不可或缺的工具。LIMS实验室管理系统未与仪器设备对接前&#xff0c;仪器设备产生的数据都是通过人工录入到系统中&#xff0c;再经过人工审核形成最终的数据报告。…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周&#xff0c;有很多同学在写期末Java web作业时&#xff0c;运行tomcat出现乱码问题&#xff0c;经过多次解决与研究&#xff0c;我做了如下整理&#xff1a; 原因&#xff1a; IDEA本身编码与tomcat的编码与Windows编码不同导致&#xff0c;Windows 系统控制台…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下&#xff0c;无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作&#xff0c;还是游戏直播的画面实时传输&#xff0c;低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架&#xff0c;凭借其灵活的编解码、数据…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

【磁盘】每天掌握一个Linux命令 - iostat

目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat&#xff08;I/O Statistics&#xff09;是Linux系统下用于监视系统输入输出设备和CPU使…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码&#xff0c;专为学校招生场景量身打造&#xff0c;功能实用且操作便捷。 从技术架构来看&#xff0c;ThinkPHP提供稳定可靠的后台服务&#xff0c;FastAdmin加速开发流程&#xff0c;UniApp则保障小程序在多端有良好的兼…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

浅谈不同二分算法的查找情况

二分算法原理比较简单&#xff0c;但是实际的算法模板却有很多&#xff0c;这一切都源于二分查找问题中的复杂情况和二分算法的边界处理&#xff0c;以下是博主对一些二分算法查找的情况分析。 需要说明的是&#xff0c;以下二分算法都是基于有序序列为升序有序的情况&#xf…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...