当前位置: 首页 > news >正文

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。

AMP简单介绍

是一种训练技巧,允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。

官网地址:https://pytorch.org/docs/stable/amp.html
在这里插入图片描述

为什么使用AMP

在某些上下文中,torch.FloatTensor(FP32)有其优势,而在其他情况下,torch.HalfTensor(FP16)则更具优势。FP16的优势包括减少显存占用、加快训练和推断计算以及更好地利用CUDA设备的Tensor Core。然而,FP16也存在数值范围小和舍入误差等问题。通过混合精度训练,可以在享受FP16带来的好处的同时,避免其劣势。

两个核心组件

PyTorch 的 AMP 模块主要包含两个核心组件:autocastGradScaler

  • autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。
  • GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler 可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time
torch.manual_seed(42)
# A simple Model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# init model
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# GradScaler
scaler = GradScaler(device='cuda')# random data
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# train
for epoch in range(1):start_time = time.time() print(f"inputs dtype:{inputs.dtype}")# autocastwith autocast('cuda'):outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")optimizer.zero_grad(set_to_none=True)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")end_time = time.time() elapsed_time = end_time - start_time allocated_memory = torch.cuda.memory_allocated() / 1024**2  reserved_memory = torch.cuda.memory_reserved() / 1024**2  print(f"Single Batch, Single Epoch with AMP, Loss: {loss.item():.4f}")print(f"Time taken: {elapsed_time:.4f} seconds")print(f"Allocated memory: {allocated_memory:.2f} MB")print(f"Reserved memory: {reserved_memory:.2f} MB")

输出如下:
Time taken for epoch 1: 0.0116 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

不使用AMP(更快了):
Time taken for epoch 1: 0.0024 seconds
Allocated memory: 20.64 MB
Reserved memory: 44.00 MB

由于上面的示例是一个很小的模型(只有几层的小型网络),其本身的计算量不大,因此即使采用了FP16精度,也难以观察到明显的加速效果。同时,如果模型中的某些层无法有效利用Tensor Cores(例如一些自定义操作,非标准层),那么整个流程可能会受到限制。所以感受不到有计算优化。

在这里插入图片描述

相关文章:

加速PyTorch模型训练:自动混合精度(AMP)

在深度学习领域,模型训练的速度和效率尤为重要。为了提升训练速度并减少显存占用(较复杂的模型中),PyTorch自1.6版本起引入了自动混合精度(Automatic Mixed Precision, AMP)功能。 AMP简单介绍 是一种训练…...

【py】python安装教程(Windows系统,python3.13.2版本为例)

1.下载地址 官网:https://www.python.org/ 官网下载地址:https://www.python.org/downloads/ 2.64版本或者32位选择 【Stable Releases】:稳定发布版本,指的是已经测试过的版本,相对稳定。 【Pre-releases】&#…...

Django REST Framework:如何获取序列化后的ID

Django REST Framework:如何获取序列化后的ID 😄 嗨,小伙伴们!今天我们来聊一聊Django REST Framework(简称DRF)中一个非常常见的操作:如何获取序列化后的ID。对于那些刚入门的朋友们&#xff…...

QT修仙笔记 事件大圆满 闹钟大成

学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间,格式为 hh:mm:ss,实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间,可通过日历选择日期,设置范围为当前时间到未来 …...

Leetcode - 149双周赛

目录 一、3438. 找到字符串中合法的相邻数字二、3439. 重新安排会议得到最多空余时间 I三、3440. 重新安排会议得到最多空余时间 II四、3441. 变成好标题的最少代价 一、3438. 找到字符串中合法的相邻数字 题目链接 本题有两个条件: 相邻数字互不相同两个数字的的…...

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题

解决 ComfyUI-Impact-Pack 中缺少 UltralyticsDetectorProvider 节点的问题 1. 安装ComfyUI-Impact-Pack 首先确保ComfyUI-Impact-Pack 已经下载 地址: https://github.com/ltdrdata/ComfyUI-Impact-Pack 2. 安装ComfyUI-Impact-Subpack 由于新版本的Impact Pack 不再提供这…...

使用Kickstart配置文件封装操作系统实现Linux的自动化安装

使用Kickstart配置文件封装操作系统实现Linux的自动化安装 创建ks.cfg配置文件 可以使用已经安装完成的Linux操作系统中的/root目录下的anaconda.cfg配置文件 注意,配置文件会因为kickstart的版本兼容性的问题导致无法安装报错需要在实际使用过程中删除某些参数 …...

Android笔记【snippet】

一、 6、Card及ConstraintLayout线性布局 //定义单独的机器人单独一行的卡片 Composable fun RobotCard(robot: Robot,navController:NavController){Card(modifier Modifier.fillMaxWidth().wrapContentHeight().padding(5.dp),colors CardDefaults.elevatedCardColors(co…...

zsh: command not found: conda

场景描述 在 Linux 服务器上使用 zsh 时,如果出现 zsh: command not found: conda 错误,说明你的系统未正确配置 conda 命令,或者你尚未安装 Anaconda/Miniconda。 解决方案 确保已安装 Anaconda 或 Miniconda conda 是 Anaconda 或 Minico…...

【知识科普】CPU,GPN,NPU知识普及

CPU,GPU,NPU CPU、GPU、NPU 详解1. CPU(中央处理器)2. GPU(图形处理器)3. NPU(神经网络处理器) **三者的核心区别****协同工作示例****总结** CPU、GPU、NPU 详解 1. CPU(中央处理器&#xff0…...

【C++八股】struct和Class的区别

1. 默认访问控制 struct:结构体中的成员默认是 public,即外部代码可以直接访问结构体的成员。class:类中的成员默认是 private,即外部代码不能直接访问类的成员,必须通过公有接口(通常是成员函数&#xff…...

鹧鸪云光伏仓储、物料管理软件详细功能

采购中心 :作为核心枢纽,能集中管理多品牌设备,企业可灵活按需采购。采购与退货流程高效便捷,审核通过后物资快速补充、问题货物及时退回,保障资金与物资顺畅周转,避免积压浪费。付款与退款环节 &#xff1…...

bazel 小白理解

Bazel命令是用于构建和测试软件项目的一个强大工具,尤其适用于大规模和多语言的软件项目。对于小白来说,可以这样理解Bazel及其命令: Bazel的基本概念 构建系统:Bazel是一个构建系统,它的主要任务是自动化地编译和链…...

MVC(Model-View-Controller)framework using Python ,Tkinter and SQLite

1.项目结构 sql: CREATE TABLE IF NOT EXISTS School (SchoolId TEXT not null, SchoolName TEXT NOT NULL,SchoolTelNo TEXT NOT NULL) 整体思路 Model:负责与 SQLite 数据库进行交互,包括创建表、插入、删除、更新和查询数据等操作。View&#xff1…...

WPF 设置宽度为 父容器 宽度的一半

方法1:使用 绑定和转换器 实现 创建类文件 HalfWidthConverter public class HalfWidthConverter : IValueConverter{public object Convert(object value, Type targetType, object parameter, CultureInfo culture){if (value is double width){return width / 4…...

java项目之在线心理评测与咨询管理系统(源码+文档)

项目简介 在线心理评测与咨询管理系统实现了以下功能: 在线心理评测与咨询管理系统的主要使用者分为: (1)在个人中心,管理员可以修改自己的用户名和登录密码。 (2)在系统前台可以查看首页&…...

【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)

ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…...

Springboot框架扩展功能的使用

Spring Boot 提供了许多扩展点,允许开发者在应用程序的生命周期中插入自定义逻辑。这些扩展点可以帮助你更好地控制应用程序的行为,例如在启动时初始化数据、在关闭时释放资源、或者自定义配置加载逻辑。以下是 Spring Boot 中常见的扩展点: …...

yum报错 Could not resolve host: mirrorlist.centos.org

检查dns 使用ping www.baidu.com ,如果ping不通,检查/etc/resolv.conf文件中是否有: nameserver 8.8.8.8 nameserver 8.8.4.4 替换yum源 1.备份原始的 YUM 源配置文件: sudo cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.r…...

docker使用dockerfile打包镜像(docker如何打包)

文章目录 1. 编写 Dockerfile2. 构建 Docker 镜像3. 运行 Docker 容器4. 导出与导入镜像(可选) 1. 编写 Dockerfile Dockerfile 是一个文本文件,其中包含了一系列指令,这些指令定义了如何构建你的 Docker 镜像。下面以一个简单的…...

在软件开发中正确使用MySQL日期时间类型的深度解析

在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...

ESP32读取DHT11温湿度数据

芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...

大数据学习(132)-HIve数据分析

​​​​🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言&#x1f4…...

10-Oracle 23 ai Vector Search 概述和参数

一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI,使用客户端或是内部自己搭建集成大模型的终端,加速与大型语言模型(LLM)的结合,同时使用检索增强生成(Retrieval Augmented Generation &#…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中,我们经常会遇到这样的场景:一个对象的状态变化需要自动通知其他对象,比如: 电商平台中,商品库存变化时需要通知所有订阅该商品的用户;新闻网站中&#xff0…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的,可以通过集中管理和高效资源的分配,来支持多个独立的网站同时运行,让每一个网站都可以分配到独立的IP地址,避免出现IP关联的风险,用户还可以通过控制面板进行管理功…...

GitHub 趋势日报 (2025年06月06日)

📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践

作者:吴岐诗,杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言:融合数据湖与数仓的创新之路 在数字金融时代,数据已成为金融机构的核心竞争力。杭银消费金…...