模型压缩——基于粒度剪枝
1.引言
模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢?
理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的性能贡献较小。这也就意味着,识别并移除那些对模型性能影响较小的参数,可以减少模型的复杂度和计算成本,并且不会影响到模型的准确性。
根据粒度不同,剪枝可以有不同的方法,并且不同的模型由于内部结构不同剪枝方法也有所区别。就像卷积神经网络中可以对卷积核剪枝和通道进行剪枝,而transformer模型中则可以针对不活跃的自注意力头进行剪枝。
但是,不论哪种模型,以下三种粒度的剪枝是都适用的:
- 权重级剪枝
- 基于模式剪枝
- 向量级剪枝
本文将以一个二维矩阵为例,来分别介绍这三种基本粒度的剪枝,为了方便观察,我们先来讨论下矩阵的可视化。
2.矩阵可视化
首先,用随机数创建一个二维权重矩阵。
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Dweight = torch.rand(8, 8)
封装一个可视化函数来直观的显示二维权重矩阵。
注:为了后面剪枝的可视化需要,我们会在显示时将0值元素与其它元素区分开。
def plot_tensor(tensor, title):fig, ax = plt.subplots()# 将矩阵转换为0和非0两个类别,并设置两个类别的颜色映射为tab20cax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')ax.set_title(title)ax.set_yticklabels([])ax.set_xticklabels([])# 遍历矩阵为每个文本元素添加文本标签rows, cols = tensor.shapefor i in range(rows):for j in range(cols):ax.text(i, j, f"{tensor[j, i].item():.2f}", ha='center', va='center', color='k')plt.show()plot_tensor(weight, "weight")
3.权重级剪枝
权重级剪枝又称为细粒度剪枝,以单个权重为剪枝单位,在具体操作时,一般会定义一个规则来决定移除哪些值。
在下面这个方法中,会通过目标张量与掩码相乘的方式,将小于threshold阀值的权重都置为0。
def weight_level_pruning(tensor, threshold: float) -> torch.Tensor:mask = torch.ge(tensor, threshold)return tensor.mul(mask)
注1:torch.ge函数的作用是对张量中的每个权重值与给定阀值threshold进行逐元素比较,大于阀值的会置为True,反之则置为False,函数运算结果是一个0(False)和1(True)组成的掩码。
注2:mul 用于对张量进行按元素乘法,要求两个矩阵的形状完全相同(请与矩阵乘法运算符@区分开)。
以threshold=0.2为例进行剪枝,剪枝后的结果如下所示。
pruned_weight = weight_level_pruning(weight, 0.2)
plot_tensor(pruned_weight, "pruned_weight")
权重级剪枝不关心权重在网络中的位置,灵活度最高,可实现高压缩比。但是,它破坏了原有模型的结构,现有硬件架构的计算方式通常无法对它进行加速,需要特殊的硬件或软件才能利用剪枝后模型的稀疏性,所以在目前通用的硬件上运行时速度并不能得到提升。
4.基于模式剪枝
基于模式剪枝通常是基于非常规则的N:M稀疏性进行剪枝,它要求在M个连续权重中固定有N个非零值,而其余元素均置为0。
下面我们将以2:4稀疏性为例子,一步一步说明如何实现N:M稀疏性。
4.1 稀疏模式计算
首先,创建一个长度为 4 的一维张量 sequence 并初始化为 0,表示总共有4个元素。
sequence = torch.zeros(4)
sequence
tensor([0., 0., 0., 0.])
对张量 patterns 的前 2 个元素设置为 1,表示4个元素中固定有2个非零值。
sequence[:2] = 1
sequence
tensor([1., 1., 0., 0.])
用permutations
函数生成patterns列表的所有可能排列,并用set去重。
from itertools import permutations
patterns = set(permutations(sequence.tolist()))
list(patterns)
[(0.0, 1.0, 0.0, 1.0),(1.0, 1.0, 0.0, 0.0),(0.0, 1.0, 1.0, 0.0),(1.0, 0.0, 1.0, 0.0),(1.0, 0.0, 0.0, 1.0),(0.0, 0.0, 1.0, 1.0)]
这个patterns中包含了长度为4恰好有两个1的所有可能排列模式。
为了方便复数,将上面的计算过程封装成一个函数。
def compute_valid_1d_patterns(m, n):patterns = torch.zeros(m)patterns[:n] = 1# permutations: 用于生成给定序列的所有可能排列valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternspatterns = compute_valid_1d_patterns(4,2)
patterns
tensor([[0., 1., 0., 1.],[1., 1., 0., 0.],[0., 1., 1., 0.],[1., 0., 1., 0.],[1., 0., 0., 1.],[0., 0., 1., 1.]])
4.2 生成掩码
计算掩码的目的是为了找到每个权重分组(每M个连续权重为一组)的最佳稀疏模式。
首先,生成一个初始掩码,它与权重矩阵的形状相同,并用 1 填充。然后将其视图更改为形状 (-1, 4),以便我们可以处理每 4 个权重一组。
tensor = weight
mask = torch.IntTensor(tensor.shape).fill_(1).view(-1, 4)
mask
tensor([[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1]], dtype=torch.int32)
重塑权重矩阵,使其形状与掩码一致。
mat = tensor.view(-1, 4)
mat
tensor([[0.6689, 0.4118, 0.9726, 0.9845],[0.8126, 0.4900, 0.8162, 0.0835],[0.5984, 0.1732, 0.7412, 0.2995],[0.7361, 0.1535, 0.9121, 0.1895],[0.8570, 0.1778, 0.1318, 0.5525],[0.0492, 0.5464, 0.4381, 0.2630],[0.9935, 0.0955, 0.6935, 0.7049],[0.1594, 0.5785, 0.9095, 0.8378],[0.0899, 0.0569, 0.7214, 0.3372],[0.3512, 0.9062, 0.0120, 0.7077],[0.1819, 0.6778, 0.7691, 0.5124],[0.3399, 0.4008, 0.2745, 0.2768],[0.9185, 0.1250, 0.9466, 0.5318],[0.9118, 0.1470, 0.6657, 0.6492],[0.1116, 0.8223, 0.7062, 0.2872],[0.1826, 0.4946, 0.5415, 0.8882]])
对于每一行权重,我们计算其绝对值与所有稀疏模式的点积,得到每一行权重在每一种稀疏模式下的加权和。
mat_patterns = torch.matmul(mat.abs(), patterns.t())
mat_patterns
tensor([[1.3963, 1.0807, 1.3844, 1.6415, 1.6534, 1.9571],[0.5735, 1.3027, 1.3062, 1.6288, 0.8961, 0.8997],[0.4727, 0.7716, 0.9144, 1.3396, 0.8979, 1.0406],[0.3429, 0.8896, 1.0656, 1.6481, 0.9255, 1.1015],[0.7304, 1.0348, 0.3096, 0.9888, 1.4095, 0.6843],[0.8095, 0.5956, 0.9845, 0.4873, 0.3122, 0.7012],[0.8005, 1.0890, 0.7891, 1.6870, 1.6984, 1.3985],[1.4162, 0.7379, 1.4880, 1.0689, 0.9972, 1.7473],[0.3941, 0.1468, 0.7783, 0.8113, 0.4271, 1.0585],[1.6139, 1.2574, 0.9183, 0.3632, 1.0589, 0.7197],[1.1902, 0.8597, 1.4469, 0.9510, 0.6943, 1.2815],[0.6776, 0.7407, 0.6752, 0.6144, 0.6167, 0.5513],[0.6569, 1.0435, 1.0717, 1.8651, 1.4503, 1.4785],[0.7962, 1.0587, 0.8126, 1.5774, 1.5610, 1.3149],[1.1095, 0.9339, 1.5285, 0.8177, 0.3988, 0.9934],[1.3828, 0.6772, 1.0361, 0.7240, 1.0708, 1.4296]])
我们只需要保留加权和最大的模式即可,为每一行权重选择加权和最大的模式索引:
pmax = torch.argmax(mat_patterns, dim=1)
pmax
tensor([5, 3, 3, 3, 4, 2, 4, 5, 5, 0, 2, 1, 3, 3, 2, 5])
使用 pmax 索引从稀疏模式中为每一行权重选择相应的模式,然后将其赋值给掩码:
mask[:] = patterns[pmax[:]]
mask
tensor([[0, 0, 1, 1],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 0, 1],[0, 1, 1, 0],[1, 0, 0, 1],[0, 0, 1, 1],[0, 0, 1, 1],[0, 1, 0, 1],[0, 1, 1, 0],[1, 1, 0, 0],[1, 0, 1, 0],[1, 0, 1, 0],[0, 1, 1, 0],[0, 0, 1, 1]], dtype=torch.int32)
最后,将掩码视图重塑回原始权重的形状。
mask = mask.view(tensor.shape)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
同样,将上面计算掩码的过程封装为一个函数。
def compute_mask(tensor, m, n):# 计算所有可能的模式patterns = compute_valid_1d_patterns(m,n) # m中取n所有可能的模式,N行4列# 生成初始掩码mask = torch.IntTensor(tensor.shape).fill_(1).view(-1,m) # 将张量转换成列为m的格式,若不能整除m则填充0if tensor.shape[1] % m > 0:mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)mat[:, : tensor.shape[1]] = tensormat = mat.view(-1, m)else:mat = tensor.view(-1, m) pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) # 16行N列,每一行的点积操作都得到N种可能,取点积最大值的元素下标mask[:] = patterns[pmax[:]] # 找到最大下标对应的排列mask = mask.view(tensor.shape) # 再转换成tensor的形状return maskmask = compute_mask(weight, 4, 2)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
4.3 用掩码剪枝
pruned_pattern = tensor.mul(mask)
plot_tensor(pruned_pattern, "pruned_pattern")
可以看到,2:4稀疏性将一半权重都置为了0,每4个权重中保留了两个非0权重,我们成功地在权重矩阵上应用了 2:4 稀疏性。
虽然都属于非结构化索引,但与前面基于权重的索引不同的是,2:4
结构的稀疏性可以被英伟达的稀疏张量核心加速。在运算时,稀疏矩阵W首先会被压缩,压缩后的矩阵存储着非零的数据值,而metadata则存储着对应非零元素在原矩阵W中的索引信息。
具体来说,metadata会将W中非零元素的行号和列号压缩成两个独立的一维数组,这两个数组就是metadata中存储的索引信息。
注:N:M 稀疏性是一种有效的稀疏技术,这种稀疏模式不仅减少了模型大小,还提高了计算效率,特别适用于硬件加速器,如 GPU 和 TPU,从而加速深度学习模型的训练和推理过程。
5.向量级剪枝
向量级剪枝以行或列为单位对权重进行裁剪。
def vector_pruning(tensor, point):rows, cols = pointprune_weight = tensor.clone()prune_weight[rows, :] = 0prune_weight[:, cols] = 0return prune_weightpoint = (2,3)
pruned_vector = vector_pruning(weight, point)
plot_tensor(pruned_vector, "pruned_vector")
经过上面的向量级剪枝后,可以直接去掉一行一列,整个权重矩阵的形状可以直接由[8,8]变为[7,7], 因此向量级剪枝属于结构化剪枝。
通常在进行向量级别的剪枝时,需要对模型的所有层统一进行剪枝,其目的是在整个模型中保持一致的稀疏结构,以确保上下游各层中结构和计算的一致性。因此,这种方法被称为“全局剪枝”或“统一剪枝”。
小结:本文主要介绍了权重级剪枝、基于模式剪枝和向量级剪枝三种不同粒度的剪枝方法,并结合可视化的方式,一步一步详细演示了每种剪枝方法的运算过程,和剪枝前后权重矩阵的变化。在实际场景中,基于模式的剪枝越来越受到青睐,因为其规则的稀疏模式可以充分利用硬件加速器的计算能力,从而显著提高计算效率。
参考阅读
- 模型压缩概览
- Awesome Compression
相关文章:

模型压缩——基于粒度剪枝
1.引言 模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢? 理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的…...

IntelliJ IDEA 2023.2x——图文配置
IntelliJ IDEA 2023.2——配置说明 界面如下图所示 : 绿泡泡查找 “码猿趣事” 查找【idea99】 IntelliJ IDEA 的官方下载地址 IntelliJ IDEA 官网下载地址 一路上NEXT 到结尾: 继续NEXT 下一步:...

SpringBoot(5)-SpringSecurity
目录 一、是什么 二、实战测试 2.1 认识 2.2 认证和授权 2.3 权限控制和注销 2.4 记住我 一、是什么 Spring Security是一个框架,侧重于为java应用程序提供身份验证和授权。 Web应用的安全性主要分为两个部分: 认证(Authentication&…...

fast-api后端 + fetch 前端流式文字响应
fast-api后端 fetch 前端流式文字响应 fast-api后台接口流式响应 前端fetch 流式文本数据处理 fast-api后台接口 流式响应 from fastapi.responses import StreamingResponse from tqdm import tqdm from pydantic import BaseModelclass ItemDataSingle(BaseModel):data: …...

Qt 的 QThread:多线程编程的基础
Qt 的 QThread:多线程编程的基础 在现代应用程序中,尤其是需要处理大量数据、进行长时间计算或者进行 I/O 操作时,多线程编程变得至关重要。Qt 提供了一个功能强大且易于使用的线程类 QThread,可以帮助开发者在 Qt 应用程序中实现…...

周末总结(2024/11/16)
工作 人际关系核心实践: 要学会随时回应别人的善意,执行时间控制在5分钟以内 坚持每天早会打招呼 遇到接不住的话题时拉低自己,抬高别人(无阴阳气息) 朋友圈点赞控制在5min以内,职场社交不要放在5min以外 职场的人际关系在面对利…...

Chrome和Chromium的区别?浏览器引擎都用的哪些?浏览器引擎的作用?
Chrome和Chromium的区别? Chrome是Google专属的产品,它是基于后者Chromium开源引擎开发。第三方浏览器公司为了加快开发流程,会直接选择开源的浏览器引擎,例如Chromium. Google将Chromium开源,本意为了打破浏览器被其他公司控制的…...

流程图图解@RequestBody @RequestPart @RequestParam @ModelAttribute
RequestBody 只能用一次,因为只有一个请求体 #mermaid-svg-8WZfkzl0GPvOiNj3 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8WZfkzl0GPvOiNj3 .error-icon{fill:#552222;}#mermaid-svg-8WZfkzl0GPvOiNj…...

AutoUpdater.NET 实现 dotNET应用自动更新
AutoUpdater.NET 是一款用于WPF、Winform软件版本更新的框架,类似框架还有Squirrel、WinSparkle、NetSparkle、Google Omaha。 一、安装AutoUpdater.NET 首先,您需要在项目中安装AutoUpdater.NET库。您可以通过NuGet包管理器来安装它。在Visual Studio中…...

108. UE5 GAS RPG 实现地图名称更新和加载关卡
在这一篇里,我们将实现对存档的删除功能,在删除时会有弹框确认。接着实现获取玩家的等级和地图名称和存档位置,我们可以通过存档进入游戏,玩家在游戏中可以在存档点存储存档。 实现删除存档 删除存档需要一个弹框确认࿰…...

对称加密与非对称加密:密码学的基石及 RSA 算法详解
对称加密与非对称加密:密码学的基石及 RSA 算法详解 在当今数字化的时代,信息安全至关重要。对称加密和非对称加密作为密码学中的两种基本加密技术,为我们的数据安全提供了强大的保障。本文将深入探讨对称加密和非对称加密的特点、应用场景&…...

排列问题方法总结(递归+迭代)
递归 一、逐步生成结果法(无序) #include<iostream> #include<vector> #include<string> #include<algorithm>using namespace std;vector<string> GetChild(int n,int curIndex){vector<string> now;vector&…...

C#从入门到放弃
C#和.NET的区别 C# C#是一个编程语言 .NET .NET是一个在window下创建程序的框架 .NET框架不仅局限于C#,它还可以支持很多语言 .NET包括了2个组件,一个叫CLR(通用语言运行时),另一个是用来构建程序的类库 CLR 用C写一个程序,在一台8688的机器…...

视频质量评价学习笔记
目录 MD VQA:大淘宝团队: ReIQA KVQ 视频质量评价学习笔记 MD VQA:大淘宝团队: https://github.com/kunyou99/MD-VQA_cvpr2023?tabreadme-ov-file ReIQA GitHub - avinabsaha/ReIQA: Official implementation for CVPR2023 Paper "Re-IQA : U…...

OpenCV、YOLO、VOC、COCO之间的关系和区别
OpenCV、YOLO、COCO 和 VOC 是计算机视觉和深度学习领域常见的几个名词,它们分别代表不同的工具、算法和数据集,之间有一些联系和区别。下面分别说明它们的定义、用途以及相互关系。 1. OpenCV(Open Source Computer Vision Library…...

Pandas进行周期与时间戳转换
时间序列数据在数据分析和金融领域非常常见,处理这些数据时,通常会面临周期(Period)与时间戳(Timestamp)之间的转换需求。理解和掌握这种转换,对于时间序列数据的清洗、预处理以及进一步分析至关重要。Python 中的 pandas 库提供了一系列便捷的函数来帮助处理这些时间序…...

【GPTs】Get Simpsonized:一键变身趣味辛普森角色
博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Get Simpsonized主要功能适用场景优点缺点使用方式 💯小结 💯GPTs指令 中文翻译: 指令保护和安全规则&…...

概率论公式整理
1 概率 古典概型和几何概型 古典概型(有限等可能)几何概型(无限等可能) 条件概率 P ( A ∣ B ) P ( A B ) P ( B ) P(A|B) \frac{P(AB)}{P(B)} P(A∣B)P(B)P(AB) 全概率公式 P ( B ) ∑ i 1 n P ( A i ) P ( B ∣ A i ) P…...

【C++】—— stack和queue的模拟实现
前言 stack 和 queue使用起来都非常简单,现在来模拟实现一下,理解其底层的原理。 在实现之前,应该知道,stack 和 queue 都是容器适配器,通过看官网文件也可以看出来;其默认的容器都是dequeÿ…...

管家婆工贸ERP BR039.采购订单关联MRP明细表
最低适用版本: 工贸系列 23.8 插件简要功能说明: 采购订单明细表,支持显示采购订单明细上游请购单明细关联的MRP中对应销售订单明细产成品相关信息更多细节描述见下方详细文档 插件操作视频: 进销存类定制插件--采购订单关联M…...

SwanLab安装教程
SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。 其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能&…...

MySQL EXPLAIN,数据库调优的秘密通道
EXPLAIN 是 MySQL 中一个非常有用的工具,它用于分析 SQL 查询的执行计划。通过 EXPLAIN,你可以获取 MySQL 是如何准备执行你的 SQL 语句的,包括使用的索引、连接类型、扫描的行数等信息。这些信息对于优化查询性能、识别性能瓶颈至关重要。 使…...

利用redis的key失效监听器KeyExpirationEventMessageListener作任务定时提醒功能
某需求: 要求在任务截止日期的前3天时,系统自动给用户发一条消息提醒。 用定时任务的话感觉很不舒服。间隔时间不好弄。不能精准卡到那个点。 由于系统简单,没有使用消息列队,也不能使用延时队列来做。 用Timer的话开销还挺大的&a…...

如何基于Tesseract实现图片的文本识别
在前一篇文章基础上,如何将报告图片中的文本解析出来,最近研究了基于Tesseract的OCR方案,Tesseract OCR是一个开源的OCR引擎,主要结合开源的tesseract和pytesseract,实现了jpg/png等格式图片文本识别,供大家…...

JavaWeb之AJAX
前言 这一节讲JavaWeb之AJAX 1.概述 以前我们在servlet中得到数据,必须通过域给jsp,然后jsp在响应给浏览器 纯html不能获取servlet返回数据 所以我们用jsp 但是现在我们可以同AJAX给返回数据了 我们可以在sevlet中直接通过AJAX返回给浏览器 html中的J…...

算法---解决“汉诺塔”问题
# 初始化步骤计数器 i 1 # 定义移动盘子的函数 def move(n, mfrom, mto): global i # 使用全局变量i来跟踪步骤 print("第%d步:将%d号盘子从%s->%s" % (i, n, mfrom, mto)) # 打印移动步骤 i 1 # 步骤计数器加1 #第一种方法 # 定义汉诺塔问题的递归…...

1-Equity-Transformer:求解NP-Hard Min-Max路由问题的顺序生成算法(AAAI-24)(完)(code)
文章目录 AbstractIntroduction问题表述Methodology多智能体位置编码公平上下文编码训练方案ExperimentsmTSP的性能评估mPDP的性能评估Related WorkConclusionAbstract 最小最大路由问题旨在通过智能体合作完成任务来最小化多个智能体中最长行程的长度。这些问题包括对现实世界…...

linux001.在Oracle VM VirtualBox中ubuntu虚拟系统扩容
1.打开终端切换到virtualBox安装目录 2.输入命令扩容 如上终端中的代码解释: D:\Program Files\Oracle\VirtualBox>.\VBoxManage modifyhd D:\ubuntu18.04\Ubuntu18.04\Ubuntu18.04.vdi --resize 40960如上代码说明:D:\Program Files\Oracle\Virtual…...

RabbitMQ教程:路由(Routing)(四)
文章目录 RabbitMQ教程:路由(Routing)(四)一、引言二、基本概念2.1 路由与绑定2.2 Direct交换机2.3 多绑定2.4 发送日志2.5 订阅 三、整合代码3.1 EmitLogDirectApp.cs3.2 ReceiveLogsDirectApp.cs3.3 推送所有和接收e…...

华为Ensp模拟器配置RIP路由协议
目录 RIP路由详解:另一种视角解读 1. RIP简介:轻松理解基础概念 2. RIP的核心机制:距离向量的魅力 3. RIP的实用与局限 RIP配置实验 实验图 编辑 PC的ip配置 RIP配置步骤 测试 结语:RIP的今天与明天 RIP路由详解&…...