模型压缩——基于粒度剪枝
1.引言
模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢?
理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的性能贡献较小。这也就意味着,识别并移除那些对模型性能影响较小的参数,可以减少模型的复杂度和计算成本,并且不会影响到模型的准确性。
根据粒度不同,剪枝可以有不同的方法,并且不同的模型由于内部结构不同剪枝方法也有所区别。就像卷积神经网络中可以对卷积核剪枝和通道进行剪枝,而transformer模型中则可以针对不活跃的自注意力头进行剪枝。
但是,不论哪种模型,以下三种粒度的剪枝是都适用的:
- 权重级剪枝
- 基于模式剪枝
- 向量级剪枝
本文将以一个二维矩阵为例,来分别介绍这三种基本粒度的剪枝,为了方便观察,我们先来讨论下矩阵的可视化。
2.矩阵可视化
首先,用随机数创建一个二维权重矩阵。
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Dweight = torch.rand(8, 8)
封装一个可视化函数来直观的显示二维权重矩阵。
注:为了后面剪枝的可视化需要,我们会在显示时将0值元素与其它元素区分开。
def plot_tensor(tensor, title):fig, ax = plt.subplots()# 将矩阵转换为0和非0两个类别,并设置两个类别的颜色映射为tab20cax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')ax.set_title(title)ax.set_yticklabels([])ax.set_xticklabels([])# 遍历矩阵为每个文本元素添加文本标签rows, cols = tensor.shapefor i in range(rows):for j in range(cols):ax.text(i, j, f"{tensor[j, i].item():.2f}", ha='center', va='center', color='k')plt.show()plot_tensor(weight, "weight")
3.权重级剪枝
权重级剪枝又称为细粒度剪枝,以单个权重为剪枝单位,在具体操作时,一般会定义一个规则来决定移除哪些值。
在下面这个方法中,会通过目标张量与掩码相乘的方式,将小于threshold阀值的权重都置为0。
def weight_level_pruning(tensor, threshold: float) -> torch.Tensor:mask = torch.ge(tensor, threshold)return tensor.mul(mask)
注1:torch.ge函数的作用是对张量中的每个权重值与给定阀值threshold进行逐元素比较,大于阀值的会置为True,反之则置为False,函数运算结果是一个0(False)和1(True)组成的掩码。
注2:mul 用于对张量进行按元素乘法,要求两个矩阵的形状完全相同(请与矩阵乘法运算符@区分开)。
以threshold=0.2为例进行剪枝,剪枝后的结果如下所示。
pruned_weight = weight_level_pruning(weight, 0.2)
plot_tensor(pruned_weight, "pruned_weight")
权重级剪枝不关心权重在网络中的位置,灵活度最高,可实现高压缩比。但是,它破坏了原有模型的结构,现有硬件架构的计算方式通常无法对它进行加速,需要特殊的硬件或软件才能利用剪枝后模型的稀疏性,所以在目前通用的硬件上运行时速度并不能得到提升。
4.基于模式剪枝
基于模式剪枝通常是基于非常规则的N:M稀疏性进行剪枝,它要求在M个连续权重中固定有N个非零值,而其余元素均置为0。
下面我们将以2:4稀疏性为例子,一步一步说明如何实现N:M稀疏性。
4.1 稀疏模式计算
首先,创建一个长度为 4 的一维张量 sequence 并初始化为 0,表示总共有4个元素。
sequence = torch.zeros(4)
sequence
tensor([0., 0., 0., 0.])
对张量 patterns 的前 2 个元素设置为 1,表示4个元素中固定有2个非零值。
sequence[:2] = 1
sequence
tensor([1., 1., 0., 0.])
用permutations
函数生成patterns列表的所有可能排列,并用set去重。
from itertools import permutations
patterns = set(permutations(sequence.tolist()))
list(patterns)
[(0.0, 1.0, 0.0, 1.0),(1.0, 1.0, 0.0, 0.0),(0.0, 1.0, 1.0, 0.0),(1.0, 0.0, 1.0, 0.0),(1.0, 0.0, 0.0, 1.0),(0.0, 0.0, 1.0, 1.0)]
这个patterns中包含了长度为4恰好有两个1的所有可能排列模式。
为了方便复数,将上面的计算过程封装成一个函数。
def compute_valid_1d_patterns(m, n):patterns = torch.zeros(m)patterns[:n] = 1# permutations: 用于生成给定序列的所有可能排列valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternspatterns = compute_valid_1d_patterns(4,2)
patterns
tensor([[0., 1., 0., 1.],[1., 1., 0., 0.],[0., 1., 1., 0.],[1., 0., 1., 0.],[1., 0., 0., 1.],[0., 0., 1., 1.]])
4.2 生成掩码
计算掩码的目的是为了找到每个权重分组(每M个连续权重为一组)的最佳稀疏模式。
首先,生成一个初始掩码,它与权重矩阵的形状相同,并用 1 填充。然后将其视图更改为形状 (-1, 4),以便我们可以处理每 4 个权重一组。
tensor = weight
mask = torch.IntTensor(tensor.shape).fill_(1).view(-1, 4)
mask
tensor([[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1]], dtype=torch.int32)
重塑权重矩阵,使其形状与掩码一致。
mat = tensor.view(-1, 4)
mat
tensor([[0.6689, 0.4118, 0.9726, 0.9845],[0.8126, 0.4900, 0.8162, 0.0835],[0.5984, 0.1732, 0.7412, 0.2995],[0.7361, 0.1535, 0.9121, 0.1895],[0.8570, 0.1778, 0.1318, 0.5525],[0.0492, 0.5464, 0.4381, 0.2630],[0.9935, 0.0955, 0.6935, 0.7049],[0.1594, 0.5785, 0.9095, 0.8378],[0.0899, 0.0569, 0.7214, 0.3372],[0.3512, 0.9062, 0.0120, 0.7077],[0.1819, 0.6778, 0.7691, 0.5124],[0.3399, 0.4008, 0.2745, 0.2768],[0.9185, 0.1250, 0.9466, 0.5318],[0.9118, 0.1470, 0.6657, 0.6492],[0.1116, 0.8223, 0.7062, 0.2872],[0.1826, 0.4946, 0.5415, 0.8882]])
对于每一行权重,我们计算其绝对值与所有稀疏模式的点积,得到每一行权重在每一种稀疏模式下的加权和。
mat_patterns = torch.matmul(mat.abs(), patterns.t())
mat_patterns
tensor([[1.3963, 1.0807, 1.3844, 1.6415, 1.6534, 1.9571],[0.5735, 1.3027, 1.3062, 1.6288, 0.8961, 0.8997],[0.4727, 0.7716, 0.9144, 1.3396, 0.8979, 1.0406],[0.3429, 0.8896, 1.0656, 1.6481, 0.9255, 1.1015],[0.7304, 1.0348, 0.3096, 0.9888, 1.4095, 0.6843],[0.8095, 0.5956, 0.9845, 0.4873, 0.3122, 0.7012],[0.8005, 1.0890, 0.7891, 1.6870, 1.6984, 1.3985],[1.4162, 0.7379, 1.4880, 1.0689, 0.9972, 1.7473],[0.3941, 0.1468, 0.7783, 0.8113, 0.4271, 1.0585],[1.6139, 1.2574, 0.9183, 0.3632, 1.0589, 0.7197],[1.1902, 0.8597, 1.4469, 0.9510, 0.6943, 1.2815],[0.6776, 0.7407, 0.6752, 0.6144, 0.6167, 0.5513],[0.6569, 1.0435, 1.0717, 1.8651, 1.4503, 1.4785],[0.7962, 1.0587, 0.8126, 1.5774, 1.5610, 1.3149],[1.1095, 0.9339, 1.5285, 0.8177, 0.3988, 0.9934],[1.3828, 0.6772, 1.0361, 0.7240, 1.0708, 1.4296]])
我们只需要保留加权和最大的模式即可,为每一行权重选择加权和最大的模式索引:
pmax = torch.argmax(mat_patterns, dim=1)
pmax
tensor([5, 3, 3, 3, 4, 2, 4, 5, 5, 0, 2, 1, 3, 3, 2, 5])
使用 pmax 索引从稀疏模式中为每一行权重选择相应的模式,然后将其赋值给掩码:
mask[:] = patterns[pmax[:]]
mask
tensor([[0, 0, 1, 1],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 0, 1],[0, 1, 1, 0],[1, 0, 0, 1],[0, 0, 1, 1],[0, 0, 1, 1],[0, 1, 0, 1],[0, 1, 1, 0],[1, 1, 0, 0],[1, 0, 1, 0],[1, 0, 1, 0],[0, 1, 1, 0],[0, 0, 1, 1]], dtype=torch.int32)
最后,将掩码视图重塑回原始权重的形状。
mask = mask.view(tensor.shape)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
同样,将上面计算掩码的过程封装为一个函数。
def compute_mask(tensor, m, n):# 计算所有可能的模式patterns = compute_valid_1d_patterns(m,n) # m中取n所有可能的模式,N行4列# 生成初始掩码mask = torch.IntTensor(tensor.shape).fill_(1).view(-1,m) # 将张量转换成列为m的格式,若不能整除m则填充0if tensor.shape[1] % m > 0:mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)mat[:, : tensor.shape[1]] = tensormat = mat.view(-1, m)else:mat = tensor.view(-1, m) pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) # 16行N列,每一行的点积操作都得到N种可能,取点积最大值的元素下标mask[:] = patterns[pmax[:]] # 找到最大下标对应的排列mask = mask.view(tensor.shape) # 再转换成tensor的形状return maskmask = compute_mask(weight, 4, 2)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
4.3 用掩码剪枝
pruned_pattern = tensor.mul(mask)
plot_tensor(pruned_pattern, "pruned_pattern")
可以看到,2:4稀疏性将一半权重都置为了0,每4个权重中保留了两个非0权重,我们成功地在权重矩阵上应用了 2:4 稀疏性。
虽然都属于非结构化索引,但与前面基于权重的索引不同的是,2:4
结构的稀疏性可以被英伟达的稀疏张量核心加速。在运算时,稀疏矩阵W首先会被压缩,压缩后的矩阵存储着非零的数据值,而metadata则存储着对应非零元素在原矩阵W中的索引信息。
具体来说,metadata会将W中非零元素的行号和列号压缩成两个独立的一维数组,这两个数组就是metadata中存储的索引信息。
注:N:M 稀疏性是一种有效的稀疏技术,这种稀疏模式不仅减少了模型大小,还提高了计算效率,特别适用于硬件加速器,如 GPU 和 TPU,从而加速深度学习模型的训练和推理过程。
5.向量级剪枝
向量级剪枝以行或列为单位对权重进行裁剪。
def vector_pruning(tensor, point):rows, cols = pointprune_weight = tensor.clone()prune_weight[rows, :] = 0prune_weight[:, cols] = 0return prune_weightpoint = (2,3)
pruned_vector = vector_pruning(weight, point)
plot_tensor(pruned_vector, "pruned_vector")
经过上面的向量级剪枝后,可以直接去掉一行一列,整个权重矩阵的形状可以直接由[8,8]变为[7,7], 因此向量级剪枝属于结构化剪枝。
通常在进行向量级别的剪枝时,需要对模型的所有层统一进行剪枝,其目的是在整个模型中保持一致的稀疏结构,以确保上下游各层中结构和计算的一致性。因此,这种方法被称为“全局剪枝”或“统一剪枝”。
小结:本文主要介绍了权重级剪枝、基于模式剪枝和向量级剪枝三种不同粒度的剪枝方法,并结合可视化的方式,一步一步详细演示了每种剪枝方法的运算过程,和剪枝前后权重矩阵的变化。在实际场景中,基于模式的剪枝越来越受到青睐,因为其规则的稀疏模式可以充分利用硬件加速器的计算能力,从而显著提高计算效率。
参考阅读
- 模型压缩概览
- Awesome Compression
相关文章:

模型压缩——基于粒度剪枝
1.引言 模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢? 理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的…...

IntelliJ IDEA 2023.2x——图文配置
IntelliJ IDEA 2023.2——配置说明 界面如下图所示 : 绿泡泡查找 “码猿趣事” 查找【idea99】 IntelliJ IDEA 的官方下载地址 IntelliJ IDEA 官网下载地址 一路上NEXT 到结尾: 继续NEXT 下一步:...

SpringBoot(5)-SpringSecurity
目录 一、是什么 二、实战测试 2.1 认识 2.2 认证和授权 2.3 权限控制和注销 2.4 记住我 一、是什么 Spring Security是一个框架,侧重于为java应用程序提供身份验证和授权。 Web应用的安全性主要分为两个部分: 认证(Authentication&…...
fast-api后端 + fetch 前端流式文字响应
fast-api后端 fetch 前端流式文字响应 fast-api后台接口流式响应 前端fetch 流式文本数据处理 fast-api后台接口 流式响应 from fastapi.responses import StreamingResponse from tqdm import tqdm from pydantic import BaseModelclass ItemDataSingle(BaseModel):data: …...
Qt 的 QThread:多线程编程的基础
Qt 的 QThread:多线程编程的基础 在现代应用程序中,尤其是需要处理大量数据、进行长时间计算或者进行 I/O 操作时,多线程编程变得至关重要。Qt 提供了一个功能强大且易于使用的线程类 QThread,可以帮助开发者在 Qt 应用程序中实现…...
周末总结(2024/11/16)
工作 人际关系核心实践: 要学会随时回应别人的善意,执行时间控制在5分钟以内 坚持每天早会打招呼 遇到接不住的话题时拉低自己,抬高别人(无阴阳气息) 朋友圈点赞控制在5min以内,职场社交不要放在5min以外 职场的人际关系在面对利…...
Chrome和Chromium的区别?浏览器引擎都用的哪些?浏览器引擎的作用?
Chrome和Chromium的区别? Chrome是Google专属的产品,它是基于后者Chromium开源引擎开发。第三方浏览器公司为了加快开发流程,会直接选择开源的浏览器引擎,例如Chromium. Google将Chromium开源,本意为了打破浏览器被其他公司控制的…...

流程图图解@RequestBody @RequestPart @RequestParam @ModelAttribute
RequestBody 只能用一次,因为只有一个请求体 #mermaid-svg-8WZfkzl0GPvOiNj3 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8WZfkzl0GPvOiNj3 .error-icon{fill:#552222;}#mermaid-svg-8WZfkzl0GPvOiNj…...
AutoUpdater.NET 实现 dotNET应用自动更新
AutoUpdater.NET 是一款用于WPF、Winform软件版本更新的框架,类似框架还有Squirrel、WinSparkle、NetSparkle、Google Omaha。 一、安装AutoUpdater.NET 首先,您需要在项目中安装AutoUpdater.NET库。您可以通过NuGet包管理器来安装它。在Visual Studio中…...

108. UE5 GAS RPG 实现地图名称更新和加载关卡
在这一篇里,我们将实现对存档的删除功能,在删除时会有弹框确认。接着实现获取玩家的等级和地图名称和存档位置,我们可以通过存档进入游戏,玩家在游戏中可以在存档点存储存档。 实现删除存档 删除存档需要一个弹框确认࿰…...

对称加密与非对称加密:密码学的基石及 RSA 算法详解
对称加密与非对称加密:密码学的基石及 RSA 算法详解 在当今数字化的时代,信息安全至关重要。对称加密和非对称加密作为密码学中的两种基本加密技术,为我们的数据安全提供了强大的保障。本文将深入探讨对称加密和非对称加密的特点、应用场景&…...
排列问题方法总结(递归+迭代)
递归 一、逐步生成结果法(无序) #include<iostream> #include<vector> #include<string> #include<algorithm>using namespace std;vector<string> GetChild(int n,int curIndex){vector<string> now;vector&…...

C#从入门到放弃
C#和.NET的区别 C# C#是一个编程语言 .NET .NET是一个在window下创建程序的框架 .NET框架不仅局限于C#,它还可以支持很多语言 .NET包括了2个组件,一个叫CLR(通用语言运行时),另一个是用来构建程序的类库 CLR 用C写一个程序,在一台8688的机器…...
视频质量评价学习笔记
目录 MD VQA:大淘宝团队: ReIQA KVQ 视频质量评价学习笔记 MD VQA:大淘宝团队: https://github.com/kunyou99/MD-VQA_cvpr2023?tabreadme-ov-file ReIQA GitHub - avinabsaha/ReIQA: Official implementation for CVPR2023 Paper "Re-IQA : U…...
OpenCV、YOLO、VOC、COCO之间的关系和区别
OpenCV、YOLO、COCO 和 VOC 是计算机视觉和深度学习领域常见的几个名词,它们分别代表不同的工具、算法和数据集,之间有一些联系和区别。下面分别说明它们的定义、用途以及相互关系。 1. OpenCV(Open Source Computer Vision Library…...
Pandas进行周期与时间戳转换
时间序列数据在数据分析和金融领域非常常见,处理这些数据时,通常会面临周期(Period)与时间戳(Timestamp)之间的转换需求。理解和掌握这种转换,对于时间序列数据的清洗、预处理以及进一步分析至关重要。Python 中的 pandas 库提供了一系列便捷的函数来帮助处理这些时间序…...

【GPTs】Get Simpsonized:一键变身趣味辛普森角色
博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Get Simpsonized主要功能适用场景优点缺点使用方式 💯小结 💯GPTs指令 中文翻译: 指令保护和安全规则&…...
概率论公式整理
1 概率 古典概型和几何概型 古典概型(有限等可能)几何概型(无限等可能) 条件概率 P ( A ∣ B ) P ( A B ) P ( B ) P(A|B) \frac{P(AB)}{P(B)} P(A∣B)P(B)P(AB) 全概率公式 P ( B ) ∑ i 1 n P ( A i ) P ( B ∣ A i ) P…...

【C++】—— stack和queue的模拟实现
前言 stack 和 queue使用起来都非常简单,现在来模拟实现一下,理解其底层的原理。 在实现之前,应该知道,stack 和 queue 都是容器适配器,通过看官网文件也可以看出来;其默认的容器都是dequeÿ…...
管家婆工贸ERP BR039.采购订单关联MRP明细表
最低适用版本: 工贸系列 23.8 插件简要功能说明: 采购订单明细表,支持显示采购订单明细上游请购单明细关联的MRP中对应销售订单明细产成品相关信息更多细节描述见下方详细文档 插件操作视频: 进销存类定制插件--采购订单关联M…...
Cesium1.95中高性能加载1500个点
一、基本方式: 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...
Spring Boot面试题精选汇总
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...
三体问题详解
从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...

ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...
.Net Framework 4/C# 关键字(非常用,持续更新...)
一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…...
虚拟电厂发展三大趋势:市场化、技术主导、车网互联
市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦࿰…...
pycharm 设置环境出错
pycharm 设置环境出错 pycharm 新建项目,设置虚拟环境,出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...