当前位置: 首页 > news >正文

PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

相关文章:

PyTorch 基础学习(13)- 混合精度训练

系列文章: 《PyTorch 基础学习》文章索引 基本概念 混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16 或 torch.bfloat16)数据类型的优势,…...

Mycat分片-垂直拆分

目录 场景 配置 测试 全局表配置 续接上篇:MySQ分库分表与MyCat安装配置-CSDN博客 续接下篇:Mycat分片-水平拆分-CSDN博客 场景 在业务系统中, 涉及以下表结构 ,但是由于用户与订单每天都会产生大量的数据, 单台服务器的数据 存储及处理能力是有限…...

一元四次方程求解-【附MATLAB代码】

目录 前言 求解方法 ​编辑 MATLAB验证 附:一元四次方程的故事 前言 最近在研究机器人的干涉(碰撞)检测,遇到了一个问题,就是在求椭圆到原点的最短距离时,构建的方程是一个一元四次方程。无论是高中的…...

【极限性能,尽在掌控】ROG NUC:游戏与创作的微型巨擘

初见ROG NUC,你或许会为它的小巧体型惊讶。然而,这看似不起眼的机身内,蕴藏着游戏、创意的强大能量。 掌中风暴,性能无界 ROG NUC搭载英特尔高性能处理器,配合高速NVMe SSD固态硬盘以及可选的高端独立显卡&#xff08…...

Ecosmos开启公测,将深度赋能CIOE中国光博会元宇宙参会新体验

如今,生成式AI技术的发展,极大地降低了3D数字资产的制作成本,元宇宙作为一种可以无缝将物理和数字资产进行融合的技术,在推动电子产业数字化进程、助力产业高质量发展的方面展现出了巨大的潜力。 当前,发展新质生产力是…...

【Kubernetes】k8s集群之包管理器Helm

目录 一.Helm概述 1.Helm的简介 2.Helm的三个重要概念 3.Helm2与Helm3的的区别 二.Helm 部署 1.安装 helm 2.使用 helm 安装 Chart 3.Helm 自定义模板 4.Helm 仓库 每个成功的软件平台都有一个优秀的打包系统,比如Debian、Ubuntu 的 apt,RedH…...

嵌入式linux系统镜像制作day3(构建镜像)

点击上方"蓝字"关注我们 01、上节回顾 嵌入式linux系统镜像制作day1嵌入式linux系统镜像制作day2提前下载好准备工具,不然失败了大眼瞪小眼。 02、构建 Poky 的 Sato 镜像1 环境: ubuntu18.04poky版本:Dizzy 工具git 在开始之前,针对不同的发行版,需要先执行…...

【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】

教师节中秋节国庆节车模特美女举牌生日视频制作教程AE模板改文字软件生成器素材 怎么如何做的【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件下载AE模板把AE模板导入AE软件修改图…...

RongCallKit iOS 端本地私有 pod 方案

RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…...

C++11:可变参数模板

目录 一、概述 二、场景 1.深拷贝的类 2.浅拷贝的类 C使用指南 一、概述 // Args是一个模板参数包&#xff0c;args是一个函数形参参数包 // 声明一个参数包Args...args&#xff0c;这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(…...

C++ 与 QML 之间进行数据交互的几种方法

https://www.cnblogs.com/jzcn/p/17774676.html 一、属性绑定 这是最简单的方式&#xff0c;可以在QML中直接绑定C 对象的属性。通过在C 对象中使用Q_PROPERTY宏定义属性&#xff0c;然后在QML中使用绑定语法将属性与QML元素关联起来。 1. person.h #include <QObject&g…...

Javaweb学习之Vue项目的创建(二)

学习资料 Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org) 准备工作都做完了&#xff0c;接下来开始Vue的正式学习。 第一步&#xff0c;打开VS Code 在VS Code里&#xff0c;我们也需要使用到终端&#xff0c;如果不是以管理员身份打开&#xff0c;在新建Vue项目的时候…...

『深度长文』4种有效提高LLM输出质量的方法!

LLM&#xff0c;全称Large Language Model&#xff0c;意为大型语言模型&#xff0c;是一种基于深度学习的AI技术&#xff0c;能够生成、理解和处理自然语言文本&#xff0c;也因此成为当前大多数AI工具的核心引擎。LLM通过学习海量的文本数据&#xff0c;掌握了词汇、语法、语…...

【工业机器人】工业异常检测大模型AnomalyGPT

AnomalyGPT 工业异常检测视觉大模型AnomalyGPT AnomalyGPT: Detecting Industrial Anomalies using Large Vision-Language Models AnomalyGPT是一种基于大视觉语言模型&#xff08;LVLM&#xff09;的新型工业异常检测&#xff08;IAD&#xff09;方法。它利用LVLM的能力来理…...

【PGCCC】PostgreSQL案例:planning time超长问题分析#PG初级

在使用 PostgreSQL 时&#xff0c;查询的执行计划&#xff08;planning time&#xff09;有时会出现异常长的情况&#xff0c;这可能会影响数据库的整体性能。分析和解决这种问题可以从多个角度入手&#xff0c;以下是常见原因和相应的解决思路&#xff1a; 1. 统计信息不准确…...

【图文并茂】ant design pro 如何给后端发送 json web token - 请求拦截器的使用

上一节有讲过 【图文并茂】ant design pro 如何对接后端个人信息接口 还差一个东西&#xff0c;去获取个人信息的时候&#xff0c;是要发送 token 的&#xff0c;不然会报 403. 就是说在你登录之后才去获得个人信息。这样后端才能知道是谁的信息。 token 就代码了某个人。 …...

【微信小程序】自定义组件 - behaviors

1. 什么是 behaviors 2. behaviors 的工作方式 3. 创建 behavior 调用 Behavior(Object object) 方法即可创建一个共享的 behavior 实例对象&#xff0c;供所有的组件使用&#xff1a; 4. 导入并使用 behavior 5. behavior 中所有可用的节点 6. 同名字段的覆盖和组合规则* 关…...

Linux ubuntu 24.04 安装运行《帝国时代3》免安装绿色版游戏,解决 “Could not load DATAP.BAR”等问题

Linux ubuntu 24.04 安装运行《帝国时代3》游戏&#xff0c;解决 “Could not load DATAP.BAR" 等问题 《帝国时代 3》是一款比较经典的即时战斗游戏&#xff0c;伴随了我半个高中时代&#xff0c;周末有时间就去泡网吧&#xff0c;可惜玩的都是简单人机&#xff0c;高难…...

Springboot 图片

Springboot 图片 因为 server.servlet.context-path: /api 所以 url是这个的时候 http://127.0.0.1:9100/api/staticfiles/image/dd56a59d-da84-441a-8dac-1d97f9e42090.jpeg 配置代码的前面的 /api 是不要写的 package com.gk.study.config;import org.springframework.conte…...

LIMS实验室管理系统如何实现数据自动采集

随着科研技术的不断发展&#xff0c;LIMS实验室管理系统的应用也愈来愈广&#xff0c;已经成为现代化实验室管理不可或缺的工具。LIMS实验室管理系统未与仪器设备对接前&#xff0c;仪器设备产生的数据都是通过人工录入到系统中&#xff0c;再经过人工审核形成最终的数据报告。…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

【Go语言基础【12】】指针:声明、取地址、解引用

文章目录 零、概述&#xff1a;指针 vs. 引用&#xff08;类比其他语言&#xff09;一、指针基础概念二、指针声明与初始化三、指针操作符1. &&#xff1a;取地址&#xff08;拿到内存地址&#xff09;2. *&#xff1a;解引用&#xff08;拿到值&#xff09; 四、空指针&am…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

Docker拉取MySQL后数据库连接失败的解决方案

在使用Docker部署MySQL时&#xff0c;拉取并启动容器后&#xff0c;有时可能会遇到数据库连接失败的问题。这种问题可能由多种原因导致&#xff0c;包括配置错误、网络设置问题、权限问题等。本文将分析可能的原因&#xff0c;并提供解决方案。 一、确认MySQL容器的运行状态 …...

深入解析光敏传感技术:嵌入式仿真平台如何重塑电子工程教学

一、光敏传感技术的物理本质与系统级实现挑战 光敏电阻作为经典的光电传感器件&#xff0c;其工作原理根植于半导体材料的光电导效应。当入射光子能量超过材料带隙宽度时&#xff0c;价带电子受激发跃迁至导带&#xff0c;形成电子-空穴对&#xff0c;导致材料电导率显著提升。…...

python读取SQLite表个并生成pdf文件

代码用于创建含50列的SQLite数据库并插入500行随机浮点数据&#xff0c;随后读取数据&#xff0c;通过ReportLab生成横向PDF表格&#xff0c;包含格式化&#xff08;两位小数&#xff09;及表头、网格线等美观样式。 # 导入所需库 import sqlite3 # 用于操作…...

CppCon 2015 学习:Simple, Extensible Pattern Matching in C++14

什么是 Pattern Matching&#xff08;模式匹配&#xff09; ❝ 模式匹配就是一种“描述式”的写法&#xff0c;不需要你手动判断、提取数据&#xff0c;而是直接描述你希望的数据结构是什么样子&#xff0c;系统自动判断并提取。❞ 你给的定义拆解&#xff1a; ✴ Instead of …...