当前位置: 首页 > news >正文

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。

AMP简单介绍

是一种训练技巧,允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。

官网地址:https://pytorch.org/docs/stable/amp.html
在这里插入图片描述

为什么使用AMP

在某些上下文中,torch.FloatTensor(FP32)有其优势,而在其他情况下,torch.HalfTensor(FP16)则更具优势。FP16的优势包括减少显存占用、加快训练和推断计算以及更好地利用CUDA设备的Tensor Core。然而,FP16也存在数值范围小和舍入误差等问题。通过混合精度训练,可以在享受FP16带来的好处的同时,避免其劣势。

两个核心组件

PyTorch 的 AMP 模块主要包含两个核心组件:autocastGradScaler

  • autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。
  • GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler 可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time
torch.manual_seed(42)
# A simple Model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# init model
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# GradScaler
scaler = GradScaler(device='cuda')# random data
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# train
for epoch in range(1):start_time = time.time() print(f"inputs dtype:{inputs.dtype}")# autocastwith autocast('cuda'):outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")optimizer.zero_grad(set_to_none=True)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")end_time = time.time() elapsed_time = end_time - start_time allocated_memory = torch.cuda.memory_allocated() / 1024**2  reserved_memory = torch.cuda.memory_reserved() / 1024**2  print(f"Single Batch, Single Epoch with AMP, Loss: {loss.item():.4f}")print(f"Time taken: {elapsed_time:.4f} seconds")print(f"Allocated memory: {allocated_memory:.2f} MB")print(f"Reserved memory: {reserved_memory:.2f} MB")

输出如下:
Time taken for epoch 1: 0.0116 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

不使用AMP(更快了):
Time taken for epoch 1: 0.0024 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

由于上面的示例是一个很小的模型(只有几层的小型网络),其本身的计算量不大,因此即使采用了FP16精度,也难以观察到明显的加速效果。同时,如果模型中的某些层无法有效利用Tensor Cores(例如一些自定义操作,非标准层),那么整个流程可能会受到限制。所以感受不到有计算优化。

在这里插入图片描述

相关文章:

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。 AMP简单介绍 是一种训练…...

【py】python安装教程(Windows系统,python3.13.2版本为例)

1.下载地址 官网:https://www.python.org/ 官网下载地址:https://www.python.org/downloads/ 2.64版本或者32位选择 【Stable Releases】:稳定发布版本,指的是已经测试过的版本,相对稳定。 【Pre-releases】&#…...

Django REST Framework:如何获取序列化后的ID

Django REST Framework:如何获取序列化后的ID 😄 嗨,小伙伴们!今天我们来聊一聊Django REST Framework(简称DRF)中一个非常常见的操作:如何获取序列化后的ID。对于那些刚入门的朋友们&#xff…...

QT修仙笔记 事件大圆满 闹钟大成

学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间,格式为 hh:mm:ss,实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间,可通过日历选择日期,设置范围为当前时间到未来 …...

Leetcode - 149双周赛

目录 一、3438. 找到字符串中合法的相邻数字二、3439. 重新安排会议得到最多空余时间 I三、3440. 重新安排会议得到最多空余时间 II四、3441. 变成好标题的最少代价 一、3438. 找到字符串中合法的相邻数字 题目链接 本题有两个条件: 相邻数字互不相同两个数字的的…...

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题 1. 安装ComfyUI-Impact-Pack 首先确保ComfyUI-Impact-Pack 已经下载 地址: https://github.com/ltdrdata/ComfyUI-Impact-Pack 2. 安装ComfyUI-Impact-Subpack 由于新版本的Impact Pack 不再提供这…...

使用Kickstart配置文件封装操作系统实现Linux的自动化安装

使用Kickstart配置文件封装操作系统实现Linux的自动化安装 创建ks.cfg配置文件 可以使用已经安装完成的Linux操作系统中的/root目录下的anaconda.cfg配置文件 注意,配置文件会因为kickstart的版本兼容性的问题导致无法安装报错需要在实际使用过程中删除某些参数 …...

Android笔记【snippet】

一、 6、Card及ConstraintLayout线性布局 //定义单独的机器人单独一行的卡片 Composable fun RobotCard(robot: Robot,navController:NavController){Card(modifier Modifier.fillMaxWidth().wrapContentHeight().padding(5.dp),colors CardDefaults.elevatedCardColors(co…...

zsh: command not found: conda

场景描述 在 Linux 服务器上使用 zsh 时,如果出现 zsh: command not found: conda 错误,说明你的系统未正确配置 conda 命令,或者你尚未安装 Anaconda/Miniconda。 解决方案 确保已安装 Anaconda 或 Miniconda conda 是 Anaconda 或 Minico…...

【知识科普】CPU,GPN,NPU知识普及

CPU,GPU,NPU CPU、GPU、NPU 详解1. CPU(中央处理器)2. GPU(图形处理器)3. NPU(神经网络处理器) **三者的核心区别****协同工作示例****总结** CPU、GPU、NPU 详解 1. CPU(中央处理器&#xff0…...

【C++八股】struct和Class的区别

1. 默认访问控制 struct:结构体中的成员默认是 public,即外部代码可以直接访问结构体的成员。class:类中的成员默认是 private,即外部代码不能直接访问类的成员,必须通过公有接口(通常是成员函数&#xff…...

鹧鸪云光伏仓储、物料管理软件详细功能

采购中心 :作为核心枢纽,能集中管理多品牌设备,企业可灵活按需采购。采购与退货流程高效便捷,审核通过后物资快速补充、问题货物及时退回,保障资金与物资顺畅周转,避免积压浪费。付款与退款环节 &#xff1…...

bazel 小白理解

Bazel命令是用于构建和测试软件项目的一个强大工具,尤其适用于大规模和多语言的软件项目。对于小白来说,可以这样理解Bazel及其命令: Bazel的基本概念 构建系统:Bazel是一个构建系统,它的主要任务是自动化地编译和链…...

MVC(Model-View-Controller)framework using Python ,Tkinter and SQLite

1.项目结构 sql: CREATE TABLE IF NOT EXISTS School (SchoolId TEXT not null, SchoolName TEXT NOT NULL,SchoolTelNo TEXT NOT NULL) 整体思路 Model:负责与 SQLite 数据库进行交互,包括创建表、插入、删除、更新和查询数据等操作。View&#xff1…...

WPF 设置宽度为 父容器 宽度的一半

方法1:使用 绑定和转换器 实现 创建类文件 HalfWidthConverter public class HalfWidthConverter : IValueConverter{public object Convert(object value, Type targetType, object parameter, CultureInfo culture){if (value is double width){return width / 4…...

java项目之在线心理评测与咨询管理系统(源码+文档)

项目简介 在线心理评测与咨询管理系统实现了以下功能: 在线心理评测与咨询管理系统的主要使用者分为: (1)在个人中心,管理员可以修改自己的用户名和登录密码。 (2)在系统前台可以查看首页&…...

【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)

ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…...

Springboot框架扩展功能的使用

Spring Boot 提供了许多扩展点,允许开发者在应用程序的生命周期中插入自定义逻辑。这些扩展点可以帮助你更好地控制应用程序的行为,例如在启动时初始化数据、在关闭时释放资源、或者自定义配置加载逻辑。以下是 Spring Boot 中常见的扩展点: …...

yum报错 Could not resolve host: mirrorlist.centos.org

检查dns 使用ping www.baidu.com ,如果ping不通,检查/etc/resolv.conf文件中是否有: nameserver 8.8.8.8 nameserver 8.8.4.4 替换yum源 1.备份原始的 YUM 源配置文件: sudo cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.r…...

docker使用dockerfile打包镜像(docker如何打包)

文章目录 1. 编写 Dockerfile2. 构建 Docker 镜像3. 运行 Docker 容器4. 导出与导入镜像(可选) 1. 编写 Dockerfile Dockerfile 是一个文本文件,其中包含了一系列指令,这些指令定义了如何构建你的 Docker 镜像。下面以一个简单的…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

JDK 17 新特性

#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的&#xff…...

学习STC51单片机32(芯片为STC89C52RCRC)OLED显示屏2

每日一言 今天的每一份坚持,都是在为未来积攒底气。 案例:OLED显示一个A 这边观察到一个点,怎么雪花了就是都是乱七八糟的占满了屏幕。。 解释 : 如果代码里信号切换太快(比如 SDA 刚变,SCL 立刻变&#…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

Ubuntu系统复制(U盘-电脑硬盘)

所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...

深度解析:etcd 在 Milvus 向量数据库中的关键作用

目录 🚀 深度解析:etcd 在 Milvus 向量数据库中的关键作用 💡 什么是 etcd? 🧠 Milvus 架构简介 📦 etcd 在 Milvus 中的核心作用 🔧 实际工作流程示意 ⚠️ 如果 etcd 出现问题会怎样&am…...

JUC并发编程(二)Monitor/自旋/轻量级/锁膨胀/wait/notify/锁消除

目录 一 基础 1 概念 2 卖票问题 3 转账问题 二 锁机制与优化策略 0 Monitor 1 轻量级锁 2 锁膨胀 3 自旋 4 偏向锁 5 锁消除 6 wait /notify 7 sleep与wait的对比 8 join原理 一 基础 1 概念 临界区 一段代码块内如果存在对共享资源的多线程读写操作&#xf…...

HTMLCSS 学习总结

目录 ​​​一、HTML核心概念​​ ​​三大前端技术作用​​ ​​HTML基础结构​​ 开发工具:VS Code 专业配置​​​​安装步骤​​: ​​二、HTML标签大全(含表格)​​ ​​三、CSS核心技术​​ 1. 三种引入方式对比 2.…...