当前位置: 首页 > news >正文

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。

AMP简单介绍

是一种训练技巧,允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。

官网地址:https://pytorch.org/docs/stable/amp.html
在这里插入图片描述

为什么使用AMP

在某些上下文中,torch.FloatTensor(FP32)有其优势,而在其他情况下,torch.HalfTensor(FP16)则更具优势。FP16的优势包括减少显存占用、加快训练和推断计算以及更好地利用CUDA设备的Tensor Core。然而,FP16也存在数值范围小和舍入误差等问题。通过混合精度训练,可以在享受FP16带来的好处的同时,避免其劣势。

两个核心组件

PyTorch 的 AMP 模块主要包含两个核心组件:autocastGradScaler

  • autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。
  • GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler 可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time
torch.manual_seed(42)
# A simple Model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# init model
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# GradScaler
scaler = GradScaler(device='cuda')# random data
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# train
for epoch in range(1):start_time = time.time() print(f"inputs dtype:{inputs.dtype}")# autocastwith autocast('cuda'):outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")optimizer.zero_grad(set_to_none=True)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")end_time = time.time() elapsed_time = end_time - start_time allocated_memory = torch.cuda.memory_allocated() / 1024**2  reserved_memory = torch.cuda.memory_reserved() / 1024**2  print(f"Single Batch, Single Epoch with AMP, Loss: {loss.item():.4f}")print(f"Time taken: {elapsed_time:.4f} seconds")print(f"Allocated memory: {allocated_memory:.2f} MB")print(f"Reserved memory: {reserved_memory:.2f} MB")

输出如下:
Time taken for epoch 1: 0.0116 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

不使用AMP(更快了):
Time taken for epoch 1: 0.0024 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

由于上面的示例是一个很小的模型(只有几层的小型网络),其本身的计算量不大,因此即使采用了FP16精度,也难以观察到明显的加速效果。同时,如果模型中的某些层无法有效利用Tensor Cores(例如一些自定义操作,非标准层),那么整个流程可能会受到限制。所以感受不到有计算优化。

在这里插入图片描述

相关文章:

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。 AMP简单介绍 是一种训练…...

【py】python安装教程(Windows系统,python3.13.2版本为例)

1.下载地址 官网:https://www.python.org/ 官网下载地址:https://www.python.org/downloads/ 2.64版本或者32位选择 【Stable Releases】:稳定发布版本,指的是已经测试过的版本,相对稳定。 【Pre-releases】&#…...

Django REST Framework:如何获取序列化后的ID

Django REST Framework:如何获取序列化后的ID 😄 嗨,小伙伴们!今天我们来聊一聊Django REST Framework(简称DRF)中一个非常常见的操作:如何获取序列化后的ID。对于那些刚入门的朋友们&#xff…...

QT修仙笔记 事件大圆满 闹钟大成

学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间,格式为 hh:mm:ss,实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间,可通过日历选择日期,设置范围为当前时间到未来 …...

Leetcode - 149双周赛

目录 一、3438. 找到字符串中合法的相邻数字二、3439. 重新安排会议得到最多空余时间 I三、3440. 重新安排会议得到最多空余时间 II四、3441. 变成好标题的最少代价 一、3438. 找到字符串中合法的相邻数字 题目链接 本题有两个条件: 相邻数字互不相同两个数字的的…...

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题 1. 安装ComfyUI-Impact-Pack 首先确保ComfyUI-Impact-Pack 已经下载 地址: https://github.com/ltdrdata/ComfyUI-Impact-Pack 2. 安装ComfyUI-Impact-Subpack 由于新版本的Impact Pack 不再提供这…...

使用Kickstart配置文件封装操作系统实现Linux的自动化安装

使用Kickstart配置文件封装操作系统实现Linux的自动化安装 创建ks.cfg配置文件 可以使用已经安装完成的Linux操作系统中的/root目录下的anaconda.cfg配置文件 注意,配置文件会因为kickstart的版本兼容性的问题导致无法安装报错需要在实际使用过程中删除某些参数 …...

Android笔记【snippet】

一、 6、Card及ConstraintLayout线性布局 //定义单独的机器人单独一行的卡片 Composable fun RobotCard(robot: Robot,navController:NavController){Card(modifier Modifier.fillMaxWidth().wrapContentHeight().padding(5.dp),colors CardDefaults.elevatedCardColors(co…...

zsh: command not found: conda

场景描述 在 Linux 服务器上使用 zsh 时,如果出现 zsh: command not found: conda 错误,说明你的系统未正确配置 conda 命令,或者你尚未安装 Anaconda/Miniconda。 解决方案 确保已安装 Anaconda 或 Miniconda conda 是 Anaconda 或 Minico…...

【知识科普】CPU,GPN,NPU知识普及

CPU,GPU,NPU CPU、GPU、NPU 详解1. CPU(中央处理器)2. GPU(图形处理器)3. NPU(神经网络处理器) **三者的核心区别****协同工作示例****总结** CPU、GPU、NPU 详解 1. CPU(中央处理器&#xff0…...

【C++八股】struct和Class的区别

1. 默认访问控制 struct:结构体中的成员默认是 public,即外部代码可以直接访问结构体的成员。class:类中的成员默认是 private,即外部代码不能直接访问类的成员,必须通过公有接口(通常是成员函数&#xff…...

鹧鸪云光伏仓储、物料管理软件详细功能

采购中心 :作为核心枢纽,能集中管理多品牌设备,企业可灵活按需采购。采购与退货流程高效便捷,审核通过后物资快速补充、问题货物及时退回,保障资金与物资顺畅周转,避免积压浪费。付款与退款环节 &#xff1…...

bazel 小白理解

Bazel命令是用于构建和测试软件项目的一个强大工具,尤其适用于大规模和多语言的软件项目。对于小白来说,可以这样理解Bazel及其命令: Bazel的基本概念 构建系统:Bazel是一个构建系统,它的主要任务是自动化地编译和链…...

MVC(Model-View-Controller)framework using Python ,Tkinter and SQLite

1.项目结构 sql: CREATE TABLE IF NOT EXISTS School (SchoolId TEXT not null, SchoolName TEXT NOT NULL,SchoolTelNo TEXT NOT NULL) 整体思路 Model:负责与 SQLite 数据库进行交互,包括创建表、插入、删除、更新和查询数据等操作。View&#xff1…...

WPF 设置宽度为 父容器 宽度的一半

方法1:使用 绑定和转换器 实现 创建类文件 HalfWidthConverter public class HalfWidthConverter : IValueConverter{public object Convert(object value, Type targetType, object parameter, CultureInfo culture){if (value is double width){return width / 4…...

java项目之在线心理评测与咨询管理系统(源码+文档)

项目简介 在线心理评测与咨询管理系统实现了以下功能: 在线心理评测与咨询管理系统的主要使用者分为: (1)在个人中心,管理员可以修改自己的用户名和登录密码。 (2)在系统前台可以查看首页&…...

【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)

ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…...

Springboot框架扩展功能的使用

Spring Boot 提供了许多扩展点,允许开发者在应用程序的生命周期中插入自定义逻辑。这些扩展点可以帮助你更好地控制应用程序的行为,例如在启动时初始化数据、在关闭时释放资源、或者自定义配置加载逻辑。以下是 Spring Boot 中常见的扩展点: …...

yum报错 Could not resolve host: mirrorlist.centos.org

检查dns 使用ping www.baidu.com ,如果ping不通,检查/etc/resolv.conf文件中是否有: nameserver 8.8.8.8 nameserver 8.8.4.4 替换yum源 1.备份原始的 YUM 源配置文件: sudo cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.r…...

docker使用dockerfile打包镜像(docker如何打包)

文章目录 1. 编写 Dockerfile2. 构建 Docker 镜像3. 运行 Docker 容器4. 导出与导入镜像(可选) 1. 编写 Dockerfile Dockerfile 是一个文本文件,其中包含了一系列指令,这些指令定义了如何构建你的 Docker 镜像。下面以一个简单的…...

黑马Mybatis

Mybatis 表现层&#xff1a;页面展示 业务层&#xff1a;逻辑处理 持久层&#xff1a;持久数据化保存 在这里插入图片描述 Mybatis快速入门 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6501c2109c4442118ceb6014725e48e4.png //logback.xml <?xml ver…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心&#xff0c;JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例&#xff1a; 1. onclick - 点击事件 当元素被单击时触发&#xff08;左键点击&#xff09; button.onclick function() {alert("按钮被点击了&#xff01;&…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案

目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后&#xff0c;迭代器会失效&#xff0c;因为顺序迭代器在内存中是连续存储的&#xff0c;元素删除后&#xff0c;后续元素会前移。 但一些场景中&#xff0c;我们又需要在执行删除操作…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...

聚六亚甲基单胍盐酸盐市场深度解析:现状、挑战与机遇

根据 QYResearch 发布的市场报告显示&#xff0c;全球市场规模预计在 2031 年达到 9848 万美元&#xff0c;2025 - 2031 年期间年复合增长率&#xff08;CAGR&#xff09;为 3.7%。在竞争格局上&#xff0c;市场集中度较高&#xff0c;2024 年全球前十强厂商占据约 74.0% 的市场…...