模型(卷积、fc、attention)计算量 MAC/FLOPs 的手动统计方法
文章目录
- 简介
- 背景
- 为什么理解神经网络中的MAC和FLOPs很重要?
- 资源效率
- 内存效率
- 能耗
- 功耗效率
- 模型优化
- 性能基准
- 研究与发展
- FLOPs 和 MACs 定义
- 1. 全连接层 FLOPs 计算
- 步骤 1:识别层参数
- 步骤 2:计算 FLOPs 和 MACs
- 步骤 3:总结结果
- 使用 torchprofile 库验证
- 2. 卷积神经网络(CNNs)
- 计算卷积操作时的重要考虑因素
- 第一步:确定层参数
- 第二步:计算FLOPs和MACs
- 第三步:汇总结果
- 使用torchprofile库验证操作
- 3. 自注意力模块 (self-attention) FLOPs 计算
- 第一步:确定层参数
- 第二步:汇总结果
- 使用torchprofile库验证操作
- 总结:按不同批次大小缩放MACs和FLOPs
简介
理解神经网络中的 MAC(乘累加操作)和 FLOPs(浮点运算)对于优化网络性能和效率至关重要。通过手动计算这些指标,可以更深入地了解网络结构的计算复杂性和资源需求。这不仅能帮助设计高效的模型,还能在训练和推理阶段节省时间和资源。本文将通过实例演示如何计算全连接层(fc)、卷积层(conv) 以及 自注意力模块(self-attention) 的 FLOPs 和 MACs,并探讨其对资源效率、内存效率、能耗和模型优化的影响。
背景
为什么理解神经网络中的MAC和FLOPs很重要?
在本节中,我们将深入探讨神经网络中 MAC(乘累加操作)和 FLOPs(浮点运算)的概念。通过学习如何使用笔和纸手动计算这些指标将获得对各种网络结构的计算复杂性和效率的基本理解。
理解 MAC 和 FLOPs 不仅仅是学术练习;它是优化神经网络性能和效率的关键组成部分。它有助于设计既计算高效又有效的模型,从而在训练和推理阶段节省时间和资源。
这是一个在 Colab 笔记本中完全运行的示例
资源效率
理解 FLOPs 有助于估算神经网络的计算成本。通过优化 FLOPs 的数量,可以潜在地减少训练或运行神经网络所需的时间。
内存效率
MAC 操作通常决定了网络的内存使用情况,因为它们直接与网络中的参数和激活数量相关。减少 MACs 有助于使网络的内存使用更高效。
能耗
功耗效率
FLOPs 和 MAC 操作都对运行神经网络的硬件的功耗有贡献。通过优化这些指标,可以潜在地减少运行网络所需的能量,这对于移动设备和嵌入式设备尤为重要。
模型优化
- 剪枝和量化
理解 FLOPs 和 MACs 可以帮助通过剪枝(去除不必要的连接)和量化(降低权重和激活的精度)等技术优化神经网络,这些技术旨在减少计算和内存成本。
性能基准
-
模型间比较
FLOPs 和 MACs 提供了一种比较不同模型计算复杂性的方法,这可以作为为特定应用选择模型的标准。 -
硬件基准
这些指标还可以用于对比不同硬件平台运行神经网络的性能。 -
边缘设备上的部署
- 实时应用
对于实时应用,特别是在计算资源有限的边缘设备上,理解和优化这些指标对于确保网络能够在应用的时间限制内运行至关重要。 - 电池寿命
在电池供电的设备中,减少神经网络的计算成本(从而减少能耗)可以帮助延长电池寿命。
- 实时应用
研究与发展
- 设计新算法
在开发新算法或神经网络结构时,研究人员可以使用这些指标作为指导,目的是在不牺牲精度的情况下提高计算效率。
FLOPs 和 MACs 定义
-
FLOP(浮点运算)被认为是加法、减法、乘法或除法运算。
-
MAC(乘加运算)基本上是一次乘法加上一次加法,即 MAC = a * b + c。它算作两个FLOP(一次乘法和一次加法)。
1. 全连接层 FLOPs 计算
现在,我们将创建一个包含三层的简单神经网络,并开始计算所涉及的操作。以下是计算第一层线性层(全连接层)操作数的公式:
- 对于具有 I 个输入和 O 个输出的全连接层,操作数如下:
- MACs: I × O
- FLOPs: 2 × (I × O)(因为每个 MAC 算作两个 FLOP)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchprofile import profile_macsclass SimpleLinearModel(nn.Module):def __init__(self):super(SimpleLinearModel,self).__init__()self.fc1 = nn.Linear(in_features=10, out_features=20, bias=False)self.fc2 = nn.Linear(in_features=20, out_features=15, bias=False)self.fc3 = nn.Linear(in_features=15, out_features=1, bias=False)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)F.relu(x)x = self.fc3(x)return xlinear_model = SimpleLinearModel().cuda()
sample_data = torch.randn(1, 10).cuda()
步骤 1:识别层参数
- 对于给定的模型,我们定义了三层线性层:
fc1:10 个输入特征,20 个输出特征
fc2:20 个输入特征,15 个输出特征
fc3:15 个输入特征,1 个输出特征
步骤 2:计算 FLOPs 和 MACs
现在,计算每层的 MACs 和 FLOPs:
-
层 fc1:
MACs = 10 × 20 = 200
FLOPs = 2 × MACs = 2 × 200 = 400 -
层 fc2:
MACs = 20 × 15 = 300
FLOPs = 2 × MACs = 2 × 300 = 600 -
层 fc3:
MACs = 15 × 1 = 15
FLOPs = 2 × MACs = 2 × 15 = 30
步骤 3:总结结果
- 最后,为了找到单个输入通过整个网络的总 MACs 和 FLOPs,我们将所有层的结果相加:
- 总 MACs = MACs(fc1) + MACs(fc2) + MACs(fc3) = 200 + 300 + 15 = 515
- 总 FLOPs = FLOPs(fc1) + FLOPs(fc2) + FLOPs(fc3) = 400 + 600 + 30 = 1030
使用 torchprofile 库验证
可以使用 torchprofile 库来验证给定神经网络模型的 FLOPs 和 MACs 计算。以下是具体操作步骤:
macs = profile_macs(linear_model, sample_data)
print(macs)
# -> 515
2. 卷积神经网络(CNNs)
现在,让我们确定一个简单卷积模型的 MACs(乘加运算)和 FLOPs(浮点运算)。由于诸如步幅、填充和核大小等因素,这种计算比我们之前用密集层的例子更复杂一些。然而,我将逐步讲解以便于学习。
class SimpleConv(nn.Module):def __init__(self):super(SimpleConv, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)self.fc = nn.Linear(in_features=32*28*28, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = x.view(x.shape[0], -1)x = self.fc(x)return xx = torch.rand(1, 1, 28, 28).cuda()
conv_model = SimpleConv().cuda()
计算卷积操作时的重要考虑因素
在计算卷积核的操作时,必须记住核的通道数量应与输入的通道数量相匹配。例如,如果我们的输入是一个有三个颜色通道的 RGB 图像,则核的维度将是 3x3x3 以匹配输入的三个通道。
为了演示的目的,我们将保持图像大小在整个卷积层中一致。为此,我们将填充和步幅值都设置为1。
第一步:确定层参数
对于给定的模型,我们定义了两个卷积层和一个线性层:
- conv1: 1 个输入通道,16 个输出通道,核大小为 3
- conv2: 16 个输入通道,32 个输出通道
- fc: 32x28x28 个输入特征,10 个输出特征。因为我们的图像在卷积层中没有改变
第二步:计算FLOPs和MACs
现在,计算每层的 MACs 和 FLOPs:
公式是:output_image_size * kernel_shape * output_channels
层conv1:
- MACs = 28 * 28 * 3 * 3 * 1 * 16 = 1,12,896
- FLOPs = 2 × MACs = 2 × 1,12,896 = 2,25,792
层conv2:
- MACs = 28 × 28 * 3 * 3 * 16 * 32 = 3,612,672
- FLOPs = 2 × MACs = 2 × 3,612,672 = 7,225,344
层fc:
- MACs = 32 * 28 * 28 * 10 = 250,880
- FLOPs = 2 × MACs = 2 × 250,880 = 501,760
第三步:汇总结果
最后,为了找到单个输入通过整个网络的总MACs和FLOPs,我们汇总所有层的结果:
- 总MACs = MACs(conv1) + MACs(conv2) + MACs(fc) = 1,12,896 + 3,612,672 + 250,880 = 3,976,448
- 总FLOPs = FLOPs(conv1) + FLOPs(conv2) + FLOPs(fc) = 2,25,792 + 7,225,344 + 501,760 = 7,952,896
使用torchprofile库验证操作
macs = profile_macs(conv_model, (x,))
print(macs)
# 输出: 3976448
3. 自注意力模块 (self-attention) FLOPs 计算
在涵盖了线性和卷积层的 MACs 之后,我们的下一步是确定自注意力模块的FLOPs(浮点运算),这是大型语言模型中的一个关键组件。这个计算对于理解这些模型的计算复杂度至关重要。让我们深入探讨。
class SimpleAttentionBlock(nn.Module):def __init__(self, embed_size, heads):super(SimpleAttentionBlock, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, queries, mask):N = queries.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]print(values.shape)values = self.values(values).reshape(N, self.heads, value_len, self.head_dim)keys = self.keys(keys).reshape(N, self.heads, key_len, self.head_dim)queries = self.queries(queries).reshape(N, self.heads, query_len, self.head_dim)energy = torch.matmul(queries, keys.transpose(-2, -1)) if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.nn.functional.softmax(energy, dim=3)out = torch.matmul(attention, values).reshape(N, query_len, self.heads * self.head_dim)return self.fc_out(out)
第一步:确定层参数
线性变换
让我们定义一些超参数:
batch_size = 1
seq_len = 10
embed_size = 256
在注意力块中,我们有三个线性变换(用于查询、键和值),以及一个在末尾的线性变换(fc_out)。
输入大小: [batch_size, seq_len, embed_size]
线性变换矩阵: [embed_size, embed_size]
MACs: batch_size × seq_len × embed_size × embed_size
查询、键、值线性变换:
- 查询变换的MACs =
1 × 10 × 256 × 256 = 655,360
- 键变换的MACs =
1 × 10 × 256 × 256 = 655,360
- 值变换的MACs =
1 × 10 × 256 × 256 = 655,360
能量计算: 查询(重塑后)和键(重塑后)点积——一个点积操作。
MACs: batch_size × seq_len × seq_len × heads × head_dim
查询和键的点积
MACs = 1 × 10 × 10 × 8 × 32
[32 因为256/8] = 25,600
从注意力权重和值的计算输出: 注意力权重和值(重塑后)点积——另一个点积操作。
MACs : batch_size × seq_len × seq_len × heads × head_dim
注意力和值的点积
MACs = 1 × 10 × 10 × 8 × 32
= 25,600
全连接输出(fc_out)
MACs: batch_size × seq_len × heads × head_dim × embed_size
MACs = 1 × 10 × 8 × 32 × 256
= 655,360
第二步:汇总结果
总 MACs = MACs(conv1) + MACs(conv2) + MACs(fc)= 655,360 + 655,360 + 655,360 + 25,600 + 25,600 + 655,360 = 2,672,640
总 FLOPs = 2 × 总MACs = 5,345,280
使用torchprofile库验证操作
# 创建模型实例
model = SimpleAttentionBlock(embed_size=256, heads=8).cuda()# 生成一些样本数据(5个序列的批次,每个长度为10,嵌入大小为256)
values = torch.randn(1, 10, 256).cuda()
keys = torch.randn(1, 10, 256).cuda()
queries = torch.randn(1, 10, 256).cuda()# 简化起见,没有掩码
mask = None# 使用样本数据进行前向传递
macs = profile_macs(model, (values, keys, queries, mask))
print(macs)
# -> 2672640
总结:按不同批次大小缩放MACs和FLOPs
在我们的计算中,我们主要考虑了批次大小为 1。然而,按更大的批次大小缩放 MACs 和 FLOPs 是很简单的。
要计算批次大小大于 1 的 MACs 或 FLOPs,您可以简单地将批次大小 1 得到的总 MACs 或 FLOPs 乘以所需的批次大小值。此缩放允许您估计神经网络模型的各种批次大小的计算需求。
请记住,结果将直接线性缩放批次大小。例如,如果您的批次大小为 32,您可以通过将批次大小为 1 的值乘以 32 来获得 MACs 或 FLOPs。
原文链接: https://medium.com/@pashashaik/a-guide-to-hand-calculating-flops-and-macs-fa5221ce5ccc
相关文章:

模型(卷积、fc、attention)计算量 MAC/FLOPs 的手动统计方法
文章目录 简介背景为什么理解神经网络中的MAC和FLOPs很重要?资源效率内存效率能耗功耗效率 模型优化性能基准研究与发展 FLOPs 和 MACs 定义1. 全连接层 FLOPs 计算步骤 1:识别层参数步骤 2:计算 FLOPs 和 MACs步骤 3:总结结果使用…...

Git 删除包含敏感数据的历史记录及敏感文件
环境 Windows 10 Git 2.41.0 首先备份你需要删除的文件(如果还需要的话),因为命令会将本地也删除将项目中修改的内容撤回或直接提交到仓库中(有修改内容无法提交) 会提示Cannot rewrite branches: You have unstaged …...

vue-tabs标签页引入其他页面
tabs页面 <template> <div class"app-container"> <el-tabs v-model"activeName" type"card" tab-click"handleClick"> <el-tab-pane label"套餐用户列表" name"first"> <user-list r…...

U-net和U²-Net网络详解
目录 U-Net: Convolutional Networks for Biomedical Image Segmentation摘要U-net网络结构pixel-wise loss weight U-Net: Going Deeper with Nested U-Structure for Salient Object Detection摘要网络结构详解整体结构RSU-n结构RSU-4F结构saliency map fusion module -- 显著…...

Vue3 引入腾讯地图 包含标注简易操作
1. 引入腾讯地图API JavaScript API | 腾讯位置服务 (qq.com) 首先在官网注册账号 并正确获取并配置key后 找到合适的引入方式 本文不涉及版本操作和附加库 据体引入参数参考如下图 具体以链接中官方参数为准标题 在项目根目录 index.html 中 写入如下代码 <!-- 引入腾…...

迅狐抖音机构号授权矩阵系统源码
在数字化营销的浪潮中,抖音以其独特的短视频形式迅速崛起,成为品牌传播和用户互动的重要平台。迅狐抖音机构号授权矩阵系统源码作为一项创新技术,为品牌在抖音上的深度运营提供了强大支持。 迅狐抖音机构号授权矩阵系统源码简介 迅狐抖音机…...

数据库系统原理练习 | 作业2-第2章关系数据库(附答案)
整理自博主本科《数据库系统原理》专业课完成的课后作业,以便各位学习数据库系统概论的小伙伴们参考、学习。 *文中若存在书写不合理的地方,欢迎各位斧正。 专业课本: 目录 一、选择题 二、填空题 三、简答题 四、关系代数 1.课本p70页&…...

有向图的强连通分量——AcWing 367. 学校网络
有向图的强连通分量 定义 强连通分量(Strongly Connected Components, SCC) 是图论中的一个概念,在一个有向图中,如果存在一个子图,使得该子图中的任意两个顶点都相互可达(即从任何一个顶点出发都可以到达该子图中的其他任何顶点…...

安全开发--多语言基础知识
注释:还是要特别说明一下,想成为专业开发者不要看本文,本文是自己从业安全以来的一些经验总结,所有知识点也只限于网络安全这点事儿,再多搞不明白了。 开发语言 笼统的按照是否编译成机器码分类开发语言,…...

如何使一个盒子水平垂直居中(常用的)
目录 1. 使用Flex布局 2. 使用Grid布局 3.绝对定位 负外边距 (必须知晓盒子的具体大小) 4.绝对定位外边距 auto 5.绝对定位 transform (无须知晓盒子的具体大小) 1. 使用Flex布局 如何实现: 在父元素上添加: display: flex; align-items: center…...

安全防御-用户认证综合实验
一、拓扑图 二、实验要求 1、DMZ区的服务器,办公区仅能在办公时间内(9:00-18:00)可以访问,生产区设备全天都是可以访问的 2、生产区不允许访问互联网,办公区和游客区允许访问互联网 3、办公区设备10.0.2.20不允许访…...

uniapp安卓离线打包配置scheme url
uniapp安卓离线打包配置scheme url 打开 AndroidManifest.xml 搜索 scheme 填入 即可 <?xml version"1.0" encoding"utf-8"?> <manifest xmlns:android"http://schemas.android.com/apk/res/android" package"uni.UNI979A394…...

C++ STL std::lexicographical_compare用法和实现
一:功能 按字典顺序比较两个序列,判断第一个序列是否小于(或大于)第二个序列 二:用法 #include <compare> #include <vector> #include <string> #include <algorithm> #include <iostream> #include <fo…...

ORM Bee,如何使用Oracle的TO_DATE函数?
ORM Bee,如何使用Oracle的TO_DATE函数? 在Bee V2.4.0,可以这样使用: LocaldatetimeTable selectBeannew LocaldatetimeTable();Condition conditionBF.getCondition();condition.op("localdatetime", Op.ge, new TO_DATE("2024-07-08", "YYYY-MM-DD&…...

HTML CSS 基础复习笔记 - 框架、装饰、弹性盒子
自己复习前端基础,仅用于记忆,初学者不太适合 示例代码 - HTML <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initi…...

C++:创建线程
在C中创建线程,最直接的方式是使用C11标准引入的<thread>库。这个库提供了std::thread类,使得线程的创建和管理变得简单直接。 以下是一个简单的示例,展示了如何在C中使用std::thread来创建和启动线程: 示例1:…...

python如何查看类的函数
Python非常方便,它不需要用户查询文档,只需掌握如下两个帮助函数,即可查看Python中的所有函数(方法)以及它们的用法和功能: dir():列出指定类或模块包含的全部内容(包括函数、方法、…...

P6. 对局列表和排行榜功能
P6. 对局列表和排行榜功能 0 概述1 对局列表功能1.1 分页配置1.2 后端按页获取对局列表接口1.3 前端展示传回来的对局列表1.4 录像回放功能1.4.1 录像回放的流程1.4.2 录像回放的实现 1.5 前端分页展示 2 排行榜功能2.1 排行榜的实现 0 概述 本节主要介绍了如何实现对局列表和…...

uniapp easycom组件冲突
提示信息 easycom组件冲突:[/components/uni-icons/uni-icons.vue,/uni_modules/uni-icons/components/uni-icons/uni-icons.vue] 问题描述 老项目,在uniapp插件商城导入了一个新的uniapp官方开发的组件》uni-data-picker 数据驱动的picker选择器 …...

总结24个Python接单赚钱平台与详细教程,兼职月入5000+
如果说当下什么编程语言最靠谱或者比较适合搞副业? 答案肯定100%是:Python。 python是所有语法中最简单易上手的语言,不需要特别的的英语词汇量,逻辑思维也不需要很差就能上手。而且学会了之后就能编写代码爬取各种数据…...

macOS 的电源适配器设置
在 macOS 的电源适配器设置中,有四个选项,每个选项都有特定的功能: Prevent your Mac from automatically sleeping when the display is off(当显示屏关闭时,防止你的 Mac 自动进入睡眠状态):…...

视觉SLAM与定位之一前端特征点及匹配
视觉SLAM中的特征点及匹配 参考文章或链接特征点性能的评估传统特征点和描述子(仅特征点或者特征点描述子)传统描述子 基于深度学习的特征点基于深度学习的描述子基于深度学习的特征点描述子特征匹配 参考文章或链接 Image Matching from Handcrafted t…...

开源项目的认识理解
目录 开源项目有哪些机遇与挑战? 1.开源项目的发展趋势 2.开源的经验分享(向大佬请教与上网查询) 3.开源项目的挑战 开源项目有哪些机遇与挑战? 1.开源项目的发展趋势 1. 持续增长与普及 - 开源项目将继续增长,…...

37.哀家要长脑子了!--层序遍历
gongmi层序遍历模板 vector<vector<int>> levelOrder(TreeNode *root){queue<TreeNode*> que;vector<vector<int>> res;if(root ! nullptr)que.push(root);while(!que.empty()){int size que.size();vector<int> storey;for(int i 0; i …...

【从零开始AI绘画6】StableDiffusionWebUI拓展的安装方法以及推荐的几个拓展
这里写自定义目录标题 拓展Extention安装方法(以双语对照插件为例)1、WebUI内置的下载方式(推荐)2、git clone安装(更推荐)3、github下载安装包后解压(不推荐) 强力推荐安装的几个插…...

HTML5表单的自动验证、取消验证、自定义错误信息
1、自动验证 通过在元素中使用属性的方法,该属性可以实现在表单提交时执行自动验证的功能。下面是关于对元素内输入内容进行限制的属性的指定。 属性说明required输入内容是否不为空pattern输入的内容是否符合指定格式min、max输入的数值是否在min~max范围step判断…...

SpringMVC系列九: 数据格式化与验证及国际化
SpringMVC 数据格式化基本介绍基本数据类型和字符串自动转换应用实例-页面演示方式Postman完成测试 特殊数据类型和字符串自动转换应用实例-页面演示方式Postman完成测试 验证及国际化概述应用实例代码实现注意事项和使用细节 注解的结合使用先看一个问题解决问题 数据类型转换…...

判断链表中是否有环(力扣141.环形链表)
这道题要用到快慢指针。 先解释一下什么是快慢指针。 快慢指针有两个指针,走得慢的是慢指针,走得快的是快指针。 在这道题,我们规定慢指针一次走一步,快指针一次走2步。 如果该链表有环,快慢指针最终会在环中相遇&a…...

Kubernetes基于helm部署jenkins
Kubernetes基于helm安装jenkins jenkins支持war包、docker镜像、系统安装包、helm安装等。在Kubernetes上使用Helm安装Jenkins可以简化安装和管理Jenkins的过程。同时借助Kubernetes,jenkins可以实现工作节点的动态调用伸缩,更好的提高资源利用率。通过…...

【Linux】vim详解
1.什么是vi/vim? 简单来说,vi是老式的文本编辑器,不过功能已经很齐全了,但是还是有可以进步的地方。vim则可以说是程序开发者的一项很好用的工具,就连 vim的官方网站( http://www.vim.org)自己也说vim是一…...