模型(卷积、fc、attention)计算量 MAC/FLOPs 的手动统计方法
文章目录
- 简介
- 背景
- 为什么理解神经网络中的MAC和FLOPs很重要?
- 资源效率
- 内存效率
- 能耗
- 功耗效率
- 模型优化
- 性能基准
- 研究与发展
- FLOPs 和 MACs 定义
- 1. 全连接层 FLOPs 计算
- 步骤 1:识别层参数
- 步骤 2:计算 FLOPs 和 MACs
- 步骤 3:总结结果
- 使用 torchprofile 库验证
- 2. 卷积神经网络(CNNs)
- 计算卷积操作时的重要考虑因素
- 第一步:确定层参数
- 第二步:计算FLOPs和MACs
- 第三步:汇总结果
- 使用torchprofile库验证操作
- 3. 自注意力模块 (self-attention) FLOPs 计算
- 第一步:确定层参数
- 第二步:汇总结果
- 使用torchprofile库验证操作
- 总结:按不同批次大小缩放MACs和FLOPs
简介
理解神经网络中的 MAC(乘累加操作)和 FLOPs(浮点运算)对于优化网络性能和效率至关重要。通过手动计算这些指标,可以更深入地了解网络结构的计算复杂性和资源需求。这不仅能帮助设计高效的模型,还能在训练和推理阶段节省时间和资源。本文将通过实例演示如何计算全连接层(fc)、卷积层(conv) 以及 自注意力模块(self-attention) 的 FLOPs 和 MACs,并探讨其对资源效率、内存效率、能耗和模型优化的影响。
背景
为什么理解神经网络中的MAC和FLOPs很重要?
在本节中,我们将深入探讨神经网络中 MAC(乘累加操作)和 FLOPs(浮点运算)的概念。通过学习如何使用笔和纸手动计算这些指标将获得对各种网络结构的计算复杂性和效率的基本理解。
理解 MAC 和 FLOPs 不仅仅是学术练习;它是优化神经网络性能和效率的关键组成部分。它有助于设计既计算高效又有效的模型,从而在训练和推理阶段节省时间和资源。
这是一个在 Colab 笔记本中完全运行的示例
资源效率
理解 FLOPs 有助于估算神经网络的计算成本。通过优化 FLOPs 的数量,可以潜在地减少训练或运行神经网络所需的时间。
内存效率
MAC 操作通常决定了网络的内存使用情况,因为它们直接与网络中的参数和激活数量相关。减少 MACs 有助于使网络的内存使用更高效。
能耗
功耗效率
FLOPs 和 MAC 操作都对运行神经网络的硬件的功耗有贡献。通过优化这些指标,可以潜在地减少运行网络所需的能量,这对于移动设备和嵌入式设备尤为重要。
模型优化
- 剪枝和量化
理解 FLOPs 和 MACs 可以帮助通过剪枝(去除不必要的连接)和量化(降低权重和激活的精度)等技术优化神经网络,这些技术旨在减少计算和内存成本。
性能基准
-
模型间比较
FLOPs 和 MACs 提供了一种比较不同模型计算复杂性的方法,这可以作为为特定应用选择模型的标准。 -
硬件基准
这些指标还可以用于对比不同硬件平台运行神经网络的性能。 -
边缘设备上的部署
- 实时应用
对于实时应用,特别是在计算资源有限的边缘设备上,理解和优化这些指标对于确保网络能够在应用的时间限制内运行至关重要。 - 电池寿命
在电池供电的设备中,减少神经网络的计算成本(从而减少能耗)可以帮助延长电池寿命。
- 实时应用
研究与发展
- 设计新算法
在开发新算法或神经网络结构时,研究人员可以使用这些指标作为指导,目的是在不牺牲精度的情况下提高计算效率。
FLOPs 和 MACs 定义
-
FLOP(浮点运算)被认为是加法、减法、乘法或除法运算。
-
MAC(乘加运算)基本上是一次乘法加上一次加法,即 MAC = a * b + c。它算作两个FLOP(一次乘法和一次加法)。
1. 全连接层 FLOPs 计算
现在,我们将创建一个包含三层的简单神经网络,并开始计算所涉及的操作。以下是计算第一层线性层(全连接层)操作数的公式:
- 对于具有 I 个输入和 O 个输出的全连接层,操作数如下:
- MACs: I × O
- FLOPs: 2 × (I × O)(因为每个 MAC 算作两个 FLOP)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchprofile import profile_macsclass SimpleLinearModel(nn.Module):def __init__(self):super(SimpleLinearModel,self).__init__()self.fc1 = nn.Linear(in_features=10, out_features=20, bias=False)self.fc2 = nn.Linear(in_features=20, out_features=15, bias=False)self.fc3 = nn.Linear(in_features=15, out_features=1, bias=False)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)F.relu(x)x = self.fc3(x)return xlinear_model = SimpleLinearModel().cuda()
sample_data = torch.randn(1, 10).cuda()
步骤 1:识别层参数
- 对于给定的模型,我们定义了三层线性层:
fc1:10 个输入特征,20 个输出特征
fc2:20 个输入特征,15 个输出特征
fc3:15 个输入特征,1 个输出特征
步骤 2:计算 FLOPs 和 MACs
现在,计算每层的 MACs 和 FLOPs:
-
层 fc1:
MACs = 10 × 20 = 200
FLOPs = 2 × MACs = 2 × 200 = 400 -
层 fc2:
MACs = 20 × 15 = 300
FLOPs = 2 × MACs = 2 × 300 = 600 -
层 fc3:
MACs = 15 × 1 = 15
FLOPs = 2 × MACs = 2 × 15 = 30
步骤 3:总结结果
- 最后,为了找到单个输入通过整个网络的总 MACs 和 FLOPs,我们将所有层的结果相加:
- 总 MACs = MACs(fc1) + MACs(fc2) + MACs(fc3) = 200 + 300 + 15 = 515
- 总 FLOPs = FLOPs(fc1) + FLOPs(fc2) + FLOPs(fc3) = 400 + 600 + 30 = 1030
使用 torchprofile 库验证
可以使用 torchprofile 库来验证给定神经网络模型的 FLOPs 和 MACs 计算。以下是具体操作步骤:
macs = profile_macs(linear_model, sample_data)
print(macs)
# -> 515
2. 卷积神经网络(CNNs)
现在,让我们确定一个简单卷积模型的 MACs(乘加运算)和 FLOPs(浮点运算)。由于诸如步幅、填充和核大小等因素,这种计算比我们之前用密集层的例子更复杂一些。然而,我将逐步讲解以便于学习。
class SimpleConv(nn.Module):def __init__(self):super(SimpleConv, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)self.fc = nn.Linear(in_features=32*28*28, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = x.view(x.shape[0], -1)x = self.fc(x)return xx = torch.rand(1, 1, 28, 28).cuda()
conv_model = SimpleConv().cuda()
计算卷积操作时的重要考虑因素
在计算卷积核的操作时,必须记住核的通道数量应与输入的通道数量相匹配。例如,如果我们的输入是一个有三个颜色通道的 RGB 图像,则核的维度将是 3x3x3 以匹配输入的三个通道。
为了演示的目的,我们将保持图像大小在整个卷积层中一致。为此,我们将填充和步幅值都设置为1。
第一步:确定层参数
对于给定的模型,我们定义了两个卷积层和一个线性层:
- conv1: 1 个输入通道,16 个输出通道,核大小为 3
- conv2: 16 个输入通道,32 个输出通道
- fc: 32x28x28 个输入特征,10 个输出特征。因为我们的图像在卷积层中没有改变
第二步:计算FLOPs和MACs
现在,计算每层的 MACs 和 FLOPs:
公式是:output_image_size * kernel_shape * output_channels
层conv1:
- MACs = 28 * 28 * 3 * 3 * 1 * 16 = 1,12,896
- FLOPs = 2 × MACs = 2 × 1,12,896 = 2,25,792
层conv2:
- MACs = 28 × 28 * 3 * 3 * 16 * 32 = 3,612,672
- FLOPs = 2 × MACs = 2 × 3,612,672 = 7,225,344
层fc:
- MACs = 32 * 28 * 28 * 10 = 250,880
- FLOPs = 2 × MACs = 2 × 250,880 = 501,760
第三步:汇总结果
最后,为了找到单个输入通过整个网络的总MACs和FLOPs,我们汇总所有层的结果:
- 总MACs = MACs(conv1) + MACs(conv2) + MACs(fc) = 1,12,896 + 3,612,672 + 250,880 = 3,976,448
- 总FLOPs = FLOPs(conv1) + FLOPs(conv2) + FLOPs(fc) = 2,25,792 + 7,225,344 + 501,760 = 7,952,896
使用torchprofile库验证操作
macs = profile_macs(conv_model, (x,))
print(macs)
# 输出: 3976448
3. 自注意力模块 (self-attention) FLOPs 计算
在涵盖了线性和卷积层的 MACs 之后,我们的下一步是确定自注意力模块的FLOPs(浮点运算),这是大型语言模型中的一个关键组件。这个计算对于理解这些模型的计算复杂度至关重要。让我们深入探讨。
class SimpleAttentionBlock(nn.Module):def __init__(self, embed_size, heads):super(SimpleAttentionBlock, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, queries, mask):N = queries.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]print(values.shape)values = self.values(values).reshape(N, self.heads, value_len, self.head_dim)keys = self.keys(keys).reshape(N, self.heads, key_len, self.head_dim)queries = self.queries(queries).reshape(N, self.heads, query_len, self.head_dim)energy = torch.matmul(queries, keys.transpose(-2, -1)) if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.nn.functional.softmax(energy, dim=3)out = torch.matmul(attention, values).reshape(N, query_len, self.heads * self.head_dim)return self.fc_out(out)
第一步:确定层参数
线性变换
让我们定义一些超参数:
batch_size = 1
seq_len = 10
embed_size = 256
在注意力块中,我们有三个线性变换(用于查询、键和值),以及一个在末尾的线性变换(fc_out)。
输入大小: [batch_size, seq_len, embed_size]
线性变换矩阵: [embed_size, embed_size]
MACs: batch_size × seq_len × embed_size × embed_size
查询、键、值线性变换:
- 查询变换的MACs =
1 × 10 × 256 × 256 = 655,360
- 键变换的MACs =
1 × 10 × 256 × 256 = 655,360
- 值变换的MACs =
1 × 10 × 256 × 256 = 655,360
能量计算: 查询(重塑后)和键(重塑后)点积——一个点积操作。
MACs: batch_size × seq_len × seq_len × heads × head_dim
查询和键的点积
MACs = 1 × 10 × 10 × 8 × 32
[32 因为256/8] = 25,600
从注意力权重和值的计算输出: 注意力权重和值(重塑后)点积——另一个点积操作。
MACs : batch_size × seq_len × seq_len × heads × head_dim
注意力和值的点积
MACs = 1 × 10 × 10 × 8 × 32
= 25,600
全连接输出(fc_out)
MACs: batch_size × seq_len × heads × head_dim × embed_size
MACs = 1 × 10 × 8 × 32 × 256
= 655,360
第二步:汇总结果
总 MACs = MACs(conv1) + MACs(conv2) + MACs(fc)= 655,360 + 655,360 + 655,360 + 25,600 + 25,600 + 655,360 = 2,672,640
总 FLOPs = 2 × 总MACs = 5,345,280
使用torchprofile库验证操作
# 创建模型实例
model = SimpleAttentionBlock(embed_size=256, heads=8).cuda()# 生成一些样本数据(5个序列的批次,每个长度为10,嵌入大小为256)
values = torch.randn(1, 10, 256).cuda()
keys = torch.randn(1, 10, 256).cuda()
queries = torch.randn(1, 10, 256).cuda()# 简化起见,没有掩码
mask = None# 使用样本数据进行前向传递
macs = profile_macs(model, (values, keys, queries, mask))
print(macs)
# -> 2672640
总结:按不同批次大小缩放MACs和FLOPs
在我们的计算中,我们主要考虑了批次大小为 1。然而,按更大的批次大小缩放 MACs 和 FLOPs 是很简单的。
要计算批次大小大于 1 的 MACs 或 FLOPs,您可以简单地将批次大小 1 得到的总 MACs 或 FLOPs 乘以所需的批次大小值。此缩放允许您估计神经网络模型的各种批次大小的计算需求。
请记住,结果将直接线性缩放批次大小。例如,如果您的批次大小为 32,您可以通过将批次大小为 1 的值乘以 32 来获得 MACs 或 FLOPs。
原文链接: https://medium.com/@pashashaik/a-guide-to-hand-calculating-flops-and-macs-fa5221ce5ccc
相关文章:
模型(卷积、fc、attention)计算量 MAC/FLOPs 的手动统计方法
文章目录 简介背景为什么理解神经网络中的MAC和FLOPs很重要?资源效率内存效率能耗功耗效率 模型优化性能基准研究与发展 FLOPs 和 MACs 定义1. 全连接层 FLOPs 计算步骤 1:识别层参数步骤 2:计算 FLOPs 和 MACs步骤 3:总结结果使用…...

Git 删除包含敏感数据的历史记录及敏感文件
环境 Windows 10 Git 2.41.0 首先备份你需要删除的文件(如果还需要的话),因为命令会将本地也删除将项目中修改的内容撤回或直接提交到仓库中(有修改内容无法提交) 会提示Cannot rewrite branches: You have unstaged …...
vue-tabs标签页引入其他页面
tabs页面 <template> <div class"app-container"> <el-tabs v-model"activeName" type"card" tab-click"handleClick"> <el-tab-pane label"套餐用户列表" name"first"> <user-list r…...

U-net和U²-Net网络详解
目录 U-Net: Convolutional Networks for Biomedical Image Segmentation摘要U-net网络结构pixel-wise loss weight U-Net: Going Deeper with Nested U-Structure for Salient Object Detection摘要网络结构详解整体结构RSU-n结构RSU-4F结构saliency map fusion module -- 显著…...

Vue3 引入腾讯地图 包含标注简易操作
1. 引入腾讯地图API JavaScript API | 腾讯位置服务 (qq.com) 首先在官网注册账号 并正确获取并配置key后 找到合适的引入方式 本文不涉及版本操作和附加库 据体引入参数参考如下图 具体以链接中官方参数为准标题 在项目根目录 index.html 中 写入如下代码 <!-- 引入腾…...

迅狐抖音机构号授权矩阵系统源码
在数字化营销的浪潮中,抖音以其独特的短视频形式迅速崛起,成为品牌传播和用户互动的重要平台。迅狐抖音机构号授权矩阵系统源码作为一项创新技术,为品牌在抖音上的深度运营提供了强大支持。 迅狐抖音机构号授权矩阵系统源码简介 迅狐抖音机…...

数据库系统原理练习 | 作业2-第2章关系数据库(附答案)
整理自博主本科《数据库系统原理》专业课完成的课后作业,以便各位学习数据库系统概论的小伙伴们参考、学习。 *文中若存在书写不合理的地方,欢迎各位斧正。 专业课本: 目录 一、选择题 二、填空题 三、简答题 四、关系代数 1.课本p70页&…...

有向图的强连通分量——AcWing 367. 学校网络
有向图的强连通分量 定义 强连通分量(Strongly Connected Components, SCC) 是图论中的一个概念,在一个有向图中,如果存在一个子图,使得该子图中的任意两个顶点都相互可达(即从任何一个顶点出发都可以到达该子图中的其他任何顶点…...
安全开发--多语言基础知识
注释:还是要特别说明一下,想成为专业开发者不要看本文,本文是自己从业安全以来的一些经验总结,所有知识点也只限于网络安全这点事儿,再多搞不明白了。 开发语言 笼统的按照是否编译成机器码分类开发语言,…...
如何使一个盒子水平垂直居中(常用的)
目录 1. 使用Flex布局 2. 使用Grid布局 3.绝对定位 负外边距 (必须知晓盒子的具体大小) 4.绝对定位外边距 auto 5.绝对定位 transform (无须知晓盒子的具体大小) 1. 使用Flex布局 如何实现: 在父元素上添加: display: flex; align-items: center…...

安全防御-用户认证综合实验
一、拓扑图 二、实验要求 1、DMZ区的服务器,办公区仅能在办公时间内(9:00-18:00)可以访问,生产区设备全天都是可以访问的 2、生产区不允许访问互联网,办公区和游客区允许访问互联网 3、办公区设备10.0.2.20不允许访…...
uniapp安卓离线打包配置scheme url
uniapp安卓离线打包配置scheme url 打开 AndroidManifest.xml 搜索 scheme 填入 即可 <?xml version"1.0" encoding"utf-8"?> <manifest xmlns:android"http://schemas.android.com/apk/res/android" package"uni.UNI979A394…...
C++ STL std::lexicographical_compare用法和实现
一:功能 按字典顺序比较两个序列,判断第一个序列是否小于(或大于)第二个序列 二:用法 #include <compare> #include <vector> #include <string> #include <algorithm> #include <iostream> #include <fo…...
ORM Bee,如何使用Oracle的TO_DATE函数?
ORM Bee,如何使用Oracle的TO_DATE函数? 在Bee V2.4.0,可以这样使用: LocaldatetimeTable selectBeannew LocaldatetimeTable();Condition conditionBF.getCondition();condition.op("localdatetime", Op.ge, new TO_DATE("2024-07-08", "YYYY-MM-DD&…...

HTML CSS 基础复习笔记 - 框架、装饰、弹性盒子
自己复习前端基础,仅用于记忆,初学者不太适合 示例代码 - HTML <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initi…...
C++:创建线程
在C中创建线程,最直接的方式是使用C11标准引入的<thread>库。这个库提供了std::thread类,使得线程的创建和管理变得简单直接。 以下是一个简单的示例,展示了如何在C中使用std::thread来创建和启动线程: 示例1:…...

python如何查看类的函数
Python非常方便,它不需要用户查询文档,只需掌握如下两个帮助函数,即可查看Python中的所有函数(方法)以及它们的用法和功能: dir():列出指定类或模块包含的全部内容(包括函数、方法、…...
P6. 对局列表和排行榜功能
P6. 对局列表和排行榜功能 0 概述1 对局列表功能1.1 分页配置1.2 后端按页获取对局列表接口1.3 前端展示传回来的对局列表1.4 录像回放功能1.4.1 录像回放的流程1.4.2 录像回放的实现 1.5 前端分页展示 2 排行榜功能2.1 排行榜的实现 0 概述 本节主要介绍了如何实现对局列表和…...

uniapp easycom组件冲突
提示信息 easycom组件冲突:[/components/uni-icons/uni-icons.vue,/uni_modules/uni-icons/components/uni-icons/uni-icons.vue] 问题描述 老项目,在uniapp插件商城导入了一个新的uniapp官方开发的组件》uni-data-picker 数据驱动的picker选择器 …...

总结24个Python接单赚钱平台与详细教程,兼职月入5000+
如果说当下什么编程语言最靠谱或者比较适合搞副业? 答案肯定100%是:Python。 python是所有语法中最简单易上手的语言,不需要特别的的英语词汇量,逻辑思维也不需要很差就能上手。而且学会了之后就能编写代码爬取各种数据…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

TDengine 快速体验(Docker 镜像方式)
简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...

剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...

网站指纹识别
网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...

基于Java+MySQL实现(GUI)客户管理系统
客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...

【笔记】WSL 中 Rust 安装与测试完整记录
#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统:Ubuntu 24.04 LTS (WSL2)架构:x86_64 (GNU/Linux)Rust 版本:rustc 1.87.0 (2025-05-09)Cargo 版本:cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...
Mysql8 忘记密码重置,以及问题解决
1.使用免密登录 找到配置MySQL文件,我的文件路径是/etc/mysql/my.cnf,有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...