当前位置: 首页 > news >正文

transformer系列5---transformer显存占用分析

Transformer显存占用分析

  • 1 影响因素概述
  • 2 前向计算临时Tensor显存占用
    • 2.1 self-attention显存占用
    • 2.2 MLP显存占用
  • 3 梯度和优化器显存占用
    • 3.1 模型训练过程两者显存占用
    • 3.2 模型推理过程两者显存占用

1 影响因素概述

  1. 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
  2. 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
  3. 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
  4. 反向传播计算得到的梯度:
  5. 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器

2 前向计算临时Tensor显存占用

2.1 self-attention显存占用

这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。

  1. 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=IWq K = I ∗ W k K = I * W^{k} K=IWk V = I ∗ W v V = I * W^{v} V=IWv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
  2. Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
  3. softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
  4. dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
  5. score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
  6. W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
  7. dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
    综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2

2.2 MLP显存占用

  1. 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
  2. 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  3. 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  4. 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。

综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.

3 梯度和优化器显存占用

3.1 模型训练过程两者显存占用

参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:

  1. 混合精度训练:
    1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
    2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
    3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。

  1. 普通训练:
    上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。

3.2 模型推理过程两者显存占用

推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。

参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0

相关文章:

transformer系列5---transformer显存占用分析

Transformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述 模型训练框架:例如pytorch框架的cuda context…...

Docker项目部署

目录 一、前端项目部署 1、上传文件 2、开启容器 3、测试 二、后端项目部署 1、打包java项目 2、将jar包和Dockerfile文件长传到Linux系统 3、构建镜像 4、开启容器 5、测试 三、DockerCompose快速部署 基本语法 一、前端项目部署 1、上传文件 里面包括页面和配置文…...

vue3实现文本超出鼠标移入的时候文本滚动

判断文本长度是否大于容器长度 鼠标移入的时候判断&#xff0c;此处使用了tailwindcss&#xff0c;注意一下要设置文本不换行。 <divref"functionsItems"mouseenter"enterFunctionsItem($event, index)"><img class"w-5 h-5" :src&quo…...

光伏系统MPPT、恒功率控制切换Simulink仿真

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

mysql双主互从通过KeepAlived虚拟IP实现高可用

mysql双主互从通过KeepAlived虚拟IP实现高可用 在mysql 双主互从的基础上&#xff0c; 架构图&#xff1a; Keepalived有两个主要的功能&#xff1a; 提供虚拟IP&#xff0c;实现双机热备通过LVS&#xff0c;实现负载均衡 安装 # 安装 yum -y install keepalived # 卸载 …...

​苹果应用高版本出现:“无法安装此app,因为无法验证其完整性”是怎么回事?竟然是错误的?

最近经常有同学私聊我问苹果应用签名后用落地页下载出现高版本是什么意思&#xff1f;我一脸懵&#xff01;还有这个操作&#xff1f;高版本是个啥玩意&#xff01;所以我就上了一下科技去搜索引擎搜索了下&#xff0c;哈哈哈&#xff0c;然后了解下来发现是这样的首先我们确定…...

AF_UNIX和127.0.0.1(AF_INET)回环地址写数据速度对比

在linux下&#xff0c;存在着这样的情况&#xff0c;本地的进程间通信&#xff0c;并且其中一个是服务端&#xff0c;另外的都是客户端。 服务端通过绑定端口&#xff0c;客户端往127.0.0.1的对应端口发送&#xff0c;即可办到&#xff0c;不过这样会浪费一个端口&#xff0c;同…...

我在 NPM 发布了新包: con-colors

链接地址&#xff1a;npmjs.com con-colors 安装依赖 yarn add con-colors使用 导入&#xff1a; import { print } from "con-colors";使用&#xff1a; print.succ("成功的消息"); print.err("失败的消息")例子&#xff1a; import { p…...

【python数据建模】Scipy库

常用模块列表 模块名功能scipy.constants数学常量scipy.fft离散傅里叶变换scipy.integrate积分scipy.interpolate插值scipy.interpolate线性代数scipy.cluster聚类分析、向量量化scipy.io数据输入输出scipy.misc图像处理scipy.ndimagen维图像scipy.odr正交距离回归scipy.optim…...

C# App.xaml.cs的一些操作

一、保证只有一个进程 1.1 关闭旧的&#xff0c;打开新的 protected override void OnStartup(StartupEventArgs e) {base.OnStartup(e);var process Process.GetProcessesByName("Dog");if (process.Count() > 1) {var list process.ToList();list.Sort((p1,p2…...

【ORACLE】ORA-00972:标识符过长

问题 执行创建表结构sql&#xff0c;提示 ORA-00972&#xff1a;标识符过长&#xff1b; 如图所示&#xff0c;约束名称超过30个字符了 原因 一、11G and before 在使用11G数据库时&#xff0c;经常会遇到报错ORA-00972&#xff0c;原因是因为对象名称定义太长&#xff0c…...

【Vue】Vue快速入门、Vue常用指令、Vue的生命周期

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 Vue 一、 Vue快速入门二、Vue常用指令2.1 v…...

Pandas 数据处理 类别数据和数值数据

要是作深度学习的话&#xff0c;可以直接用tensoflow框架的预处理层&#xff0c;我试过&#xff0c;比PyTorch自己写出来的会好一点&#xff0c;主要是简单好用。处理CSV文件 它类别的处理逻辑是onehot&#xff0c;比较标准稀疏&#xff0c;数值的话就是归一化了。 有时候不需…...

Android攻城狮学鸿蒙 -- 点击事件

具体参考&#xff1a;华为官网学习地址 1、点击事件&#xff0c;界面跳转 对于一个按钮设置点击事件&#xff0c;跳转页面。但是onclick中&#xff0c;如果pages前边加上“/”&#xff0c;就没法跳转。但是开发工具加上“/”才会给出提示。不知道是不是开发工具的bug。&#…...

jmeter性能测试常见的一些问题

一、request 请求超时设置 timeout 超时时间是可以手动设置的&#xff0c;新建一个 http 请求&#xff0c;在“高级”设置中找到“超时”设置&#xff0c;设置连接、响应时间为2000ms。 1. 请求连接超时&#xff0c;连不上服务器。 现象&#xff1a; Jmeter表现形式为&#xff…...

利用国外 vps 为 switch 设置代理服务器加速游戏下载

switch 在国内通过 wifi 连网后如果直接下载游戏的话速度特别慢&#xff0c;据说要挂一个晚上才能下载成功一个游戏。当我尝试下载时发现进度条基本不动&#xff0c;怀疑软件源是在国外的原因&#xff0c;于是想到可以通过国外 vps 代理中转的方式。具体步骤如下&#xff08;以…...

云计算安全的新挑战:零信任架构的应用

文章目录 云计算的安全挑战什么是零信任架构&#xff1f;零信任架构的应用1. 多因素身份验证&#xff08;MFA&#xff09;2. 访问控制和策略3. 安全信息和事件管理&#xff08;SIEM&#xff09;4. 安全的应用程序开发 零信任架构的未来 &#x1f389;欢迎来到云计算技术应用专栏…...

基于SSM的药房药品采购集中管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…...

【GIT版本控制】--远程仓库

一、连接远程仓库 连接到远程仓库是在GIT中进行协作和备份的关键步骤。以下是连接到远程仓库的基本步骤&#xff1a; 获取远程仓库的URL&#xff1a;首先&#xff0c;你需要获得远程仓库的URL。通常&#xff0c;这是远程仓库提供给你的&#xff0c;可以是HTTPS或SSH URL。例如…...

1:Allotment,2:FeeSell,3:混合Allotment+FreeSell

根据您的描述&#xff0c;这似乎是与酒店预订相关的三种不同的方式。下面是对这三种方式的解释&#xff1a; Allotment&#xff08;配额&#xff09;&#xff1a;这是一种酒店预订方式&#xff0c;其中您可以与酒店签订协议&#xff0c;并购买其一定数量的房间或床位。在此之后…...

2026最权威的十大降AI率神器实际效果

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 随着人工智能生成内容也就是 AIGC 被广泛应用&#xff0c;文本的机器化特征越发明显地呈现出…...

2026届必备的五大AI辅助论文助手实际效果

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 基于大语言模型与自然语言处理技术的 AI 写作软件&#xff0c;是内容生产领域新兴工具&…...

【米家IoT开发】巧用Charles抓包,高效定位与调试网络接口

1. 为什么Charles是米家IoT开发的调试神器 当你开发米家扩展程序时&#xff0c;最头疼的莫过于接口返回异常数据&#xff0c;或者请求莫名其妙失败。这时候如果只能靠猜问题出在哪里&#xff0c;那简直就是在黑暗中摸索。我刚开始做米家IoT开发时&#xff0c;就经常被这种问题困…...

告别效率黑洞:AOSP构建降本增效实战!更有最新技术报告免费领!

近年来&#xff0c;AI模型训练与大型软件构建的复杂度持续攀升&#xff0c;企业级操作系统的多分支、多产品构建正成为工程团队的“效率黑洞”。在 Android 平台&#xff0c;AOSP 构建尤为突出&#xff1a;全量构建耗时长、增量改动触发大规模重建、CI 队列冗长、资源消耗高等问…...

Ostrakon-VL零售AI降本方案:替代人工巡检,单店年省8万元

Ostrakon-VL零售AI降本方案&#xff1a;替代人工巡检&#xff0c;单店年省8万元 1. 零售巡检的痛点与AI解决方案 在传统零售运营中&#xff0c;门店巡检是一项耗时耗力的日常工作。店长或督导人员需要每天检查&#xff1a; 商品陈列是否整齐货架缺货情况价签是否正确店铺环境…...

管理员命令提示符 命令提示符 cmd

命令提示符区别...

手把手教你用VSCode给Ai-WB2-12F烧录固件(含串口调试技巧)

手把手教你用VSCode给Ai-WB2-12F烧录固件&#xff08;含串口调试技巧&#xff09; 在物联网开发中&#xff0c;固件烧录是最基础也是最重要的环节之一。对于Ai-WB2-12F这款热门Wi-Fi/BLE双模模组&#xff0c;掌握高效的烧录方法能显著提升开发效率。本文将详细介绍如何利用VSC…...

数据仓库核心建模:星型模型与雪花模型全面对比与实战选择

数据仓库核心建模&#xff1a;星型模型与雪花模型全面对比与实战选择一、引言二、定义&#xff1a;什么是星型模型&#xff1f;什么是雪花模型&#xff1f;2.1 星型模型&#xff1a;定义2.2 雪花模型&#xff1a;定义三、结构流程图&#xff1a;直观对比两种模型3.1 星型模型流…...

[STM32问题解决(2)]编译错误:Error: L6218E的深度解析与实战排查指南

1. 认识Error: L6218E编译错误 当你正在Keil MDK环境下开发STM32项目时&#xff0c;突然弹出一个红色错误提示&#xff1a;"Error: L6218E: Undefined symbol xxx (referred from xxx.o)"&#xff0c;这可能是每个STM32开发者都会遇到的经典问题。我第一次遇到这个错…...

【测试之道】第四篇:分层测试论 —— 金字塔、奖杯与蜂巢:构建你的质量防御阵型

专栏进度&#xff1a;04 / 10 (测试理论专题) 在不同的架构&#xff08;单体、微服务、前端驱动&#xff09;下&#xff0c;测试资源的分配比例是完全不同的。盲目套用模板是测试经理最容易犯的错误。 一、 经典模型&#xff1a;测试金字塔 (Testing Pyramid) 由 Mike Cohn 提出…...