模型压缩——基于粒度剪枝
1.引言
模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢?
理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的性能贡献较小。这也就意味着,识别并移除那些对模型性能影响较小的参数,可以减少模型的复杂度和计算成本,并且不会影响到模型的准确性。
根据粒度不同,剪枝可以有不同的方法,并且不同的模型由于内部结构不同剪枝方法也有所区别。就像卷积神经网络中可以对卷积核剪枝和通道进行剪枝,而transformer模型中则可以针对不活跃的自注意力头进行剪枝。
但是,不论哪种模型,以下三种粒度的剪枝是都适用的:
- 权重级剪枝
- 基于模式剪枝
- 向量级剪枝
本文将以一个二维矩阵为例,来分别介绍这三种基本粒度的剪枝,为了方便观察,我们先来讨论下矩阵的可视化。
2.矩阵可视化
首先,用随机数创建一个二维权重矩阵。
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Dweight = torch.rand(8, 8)
封装一个可视化函数来直观的显示二维权重矩阵。
注:为了后面剪枝的可视化需要,我们会在显示时将0值元素与其它元素区分开。
def plot_tensor(tensor, title):fig, ax = plt.subplots()# 将矩阵转换为0和非0两个类别,并设置两个类别的颜色映射为tab20cax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')ax.set_title(title)ax.set_yticklabels([])ax.set_xticklabels([])# 遍历矩阵为每个文本元素添加文本标签rows, cols = tensor.shapefor i in range(rows):for j in range(cols):ax.text(i, j, f"{tensor[j, i].item():.2f}", ha='center', va='center', color='k')plt.show()plot_tensor(weight, "weight")
3.权重级剪枝
权重级剪枝又称为细粒度剪枝,以单个权重为剪枝单位,在具体操作时,一般会定义一个规则来决定移除哪些值。
在下面这个方法中,会通过目标张量与掩码相乘的方式,将小于threshold阀值的权重都置为0。
def weight_level_pruning(tensor, threshold: float) -> torch.Tensor:mask = torch.ge(tensor, threshold)return tensor.mul(mask)
注1:torch.ge函数的作用是对张量中的每个权重值与给定阀值threshold进行逐元素比较,大于阀值的会置为True,反之则置为False,函数运算结果是一个0(False)和1(True)组成的掩码。
注2:mul 用于对张量进行按元素乘法,要求两个矩阵的形状完全相同(请与矩阵乘法运算符@区分开)。
以threshold=0.2为例进行剪枝,剪枝后的结果如下所示。
pruned_weight = weight_level_pruning(weight, 0.2)
plot_tensor(pruned_weight, "pruned_weight")
权重级剪枝不关心权重在网络中的位置,灵活度最高,可实现高压缩比。但是,它破坏了原有模型的结构,现有硬件架构的计算方式通常无法对它进行加速,需要特殊的硬件或软件才能利用剪枝后模型的稀疏性,所以在目前通用的硬件上运行时速度并不能得到提升。
4.基于模式剪枝
基于模式剪枝通常是基于非常规则的N:M稀疏性进行剪枝,它要求在M个连续权重中固定有N个非零值,而其余元素均置为0。
下面我们将以2:4稀疏性为例子,一步一步说明如何实现N:M稀疏性。
4.1 稀疏模式计算
首先,创建一个长度为 4 的一维张量 sequence 并初始化为 0,表示总共有4个元素。
sequence = torch.zeros(4)
sequence
tensor([0., 0., 0., 0.])
对张量 patterns 的前 2 个元素设置为 1,表示4个元素中固定有2个非零值。
sequence[:2] = 1
sequence
tensor([1., 1., 0., 0.])
用permutations
函数生成patterns列表的所有可能排列,并用set去重。
from itertools import permutations
patterns = set(permutations(sequence.tolist()))
list(patterns)
[(0.0, 1.0, 0.0, 1.0),(1.0, 1.0, 0.0, 0.0),(0.0, 1.0, 1.0, 0.0),(1.0, 0.0, 1.0, 0.0),(1.0, 0.0, 0.0, 1.0),(0.0, 0.0, 1.0, 1.0)]
这个patterns中包含了长度为4恰好有两个1的所有可能排列模式。
为了方便复数,将上面的计算过程封装成一个函数。
def compute_valid_1d_patterns(m, n):patterns = torch.zeros(m)patterns[:n] = 1# permutations: 用于生成给定序列的所有可能排列valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))return valid_patternspatterns = compute_valid_1d_patterns(4,2)
patterns
tensor([[0., 1., 0., 1.],[1., 1., 0., 0.],[0., 1., 1., 0.],[1., 0., 1., 0.],[1., 0., 0., 1.],[0., 0., 1., 1.]])
4.2 生成掩码
计算掩码的目的是为了找到每个权重分组(每M个连续权重为一组)的最佳稀疏模式。
首先,生成一个初始掩码,它与权重矩阵的形状相同,并用 1 填充。然后将其视图更改为形状 (-1, 4),以便我们可以处理每 4 个权重一组。
tensor = weight
mask = torch.IntTensor(tensor.shape).fill_(1).view(-1, 4)
mask
tensor([[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1]], dtype=torch.int32)
重塑权重矩阵,使其形状与掩码一致。
mat = tensor.view(-1, 4)
mat
tensor([[0.6689, 0.4118, 0.9726, 0.9845],[0.8126, 0.4900, 0.8162, 0.0835],[0.5984, 0.1732, 0.7412, 0.2995],[0.7361, 0.1535, 0.9121, 0.1895],[0.8570, 0.1778, 0.1318, 0.5525],[0.0492, 0.5464, 0.4381, 0.2630],[0.9935, 0.0955, 0.6935, 0.7049],[0.1594, 0.5785, 0.9095, 0.8378],[0.0899, 0.0569, 0.7214, 0.3372],[0.3512, 0.9062, 0.0120, 0.7077],[0.1819, 0.6778, 0.7691, 0.5124],[0.3399, 0.4008, 0.2745, 0.2768],[0.9185, 0.1250, 0.9466, 0.5318],[0.9118, 0.1470, 0.6657, 0.6492],[0.1116, 0.8223, 0.7062, 0.2872],[0.1826, 0.4946, 0.5415, 0.8882]])
对于每一行权重,我们计算其绝对值与所有稀疏模式的点积,得到每一行权重在每一种稀疏模式下的加权和。
mat_patterns = torch.matmul(mat.abs(), patterns.t())
mat_patterns
tensor([[1.3963, 1.0807, 1.3844, 1.6415, 1.6534, 1.9571],[0.5735, 1.3027, 1.3062, 1.6288, 0.8961, 0.8997],[0.4727, 0.7716, 0.9144, 1.3396, 0.8979, 1.0406],[0.3429, 0.8896, 1.0656, 1.6481, 0.9255, 1.1015],[0.7304, 1.0348, 0.3096, 0.9888, 1.4095, 0.6843],[0.8095, 0.5956, 0.9845, 0.4873, 0.3122, 0.7012],[0.8005, 1.0890, 0.7891, 1.6870, 1.6984, 1.3985],[1.4162, 0.7379, 1.4880, 1.0689, 0.9972, 1.7473],[0.3941, 0.1468, 0.7783, 0.8113, 0.4271, 1.0585],[1.6139, 1.2574, 0.9183, 0.3632, 1.0589, 0.7197],[1.1902, 0.8597, 1.4469, 0.9510, 0.6943, 1.2815],[0.6776, 0.7407, 0.6752, 0.6144, 0.6167, 0.5513],[0.6569, 1.0435, 1.0717, 1.8651, 1.4503, 1.4785],[0.7962, 1.0587, 0.8126, 1.5774, 1.5610, 1.3149],[1.1095, 0.9339, 1.5285, 0.8177, 0.3988, 0.9934],[1.3828, 0.6772, 1.0361, 0.7240, 1.0708, 1.4296]])
我们只需要保留加权和最大的模式即可,为每一行权重选择加权和最大的模式索引:
pmax = torch.argmax(mat_patterns, dim=1)
pmax
tensor([5, 3, 3, 3, 4, 2, 4, 5, 5, 0, 2, 1, 3, 3, 2, 5])
使用 pmax 索引从稀疏模式中为每一行权重选择相应的模式,然后将其赋值给掩码:
mask[:] = patterns[pmax[:]]
mask
tensor([[0, 0, 1, 1],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 1, 0],[1, 0, 0, 1],[0, 1, 1, 0],[1, 0, 0, 1],[0, 0, 1, 1],[0, 0, 1, 1],[0, 1, 0, 1],[0, 1, 1, 0],[1, 1, 0, 0],[1, 0, 1, 0],[1, 0, 1, 0],[0, 1, 1, 0],[0, 0, 1, 1]], dtype=torch.int32)
最后,将掩码视图重塑回原始权重的形状。
mask = mask.view(tensor.shape)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
同样,将上面计算掩码的过程封装为一个函数。
def compute_mask(tensor, m, n):# 计算所有可能的模式patterns = compute_valid_1d_patterns(m,n) # m中取n所有可能的模式,N行4列# 生成初始掩码mask = torch.IntTensor(tensor.shape).fill_(1).view(-1,m) # 将张量转换成列为m的格式,若不能整除m则填充0if tensor.shape[1] % m > 0:mat = torch.FloatTensor(tensor.shape[0], tensor.shape[1] + (m - tensor.shape[1] % m)).fill_(0)mat[:, : tensor.shape[1]] = tensormat = mat.view(-1, m)else:mat = tensor.view(-1, m) pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) # 16行N列,每一行的点积操作都得到N种可能,取点积最大值的元素下标mask[:] = patterns[pmax[:]] # 找到最大下标对应的排列mask = mask.view(tensor.shape) # 再转换成tensor的形状return maskmask = compute_mask(weight, 4, 2)
mask
tensor([[0, 0, 1, 1, 1, 0, 1, 0],[1, 0, 1, 0, 1, 0, 1, 0],[1, 0, 0, 1, 0, 1, 1, 0],[1, 0, 0, 1, 0, 0, 1, 1],[0, 0, 1, 1, 0, 1, 0, 1],[0, 1, 1, 0, 1, 1, 0, 0],[1, 0, 1, 0, 1, 0, 1, 0],[0, 1, 1, 0, 0, 0, 1, 1]], dtype=torch.int32)
4.3 用掩码剪枝
pruned_pattern = tensor.mul(mask)
plot_tensor(pruned_pattern, "pruned_pattern")
可以看到,2:4稀疏性将一半权重都置为了0,每4个权重中保留了两个非0权重,我们成功地在权重矩阵上应用了 2:4 稀疏性。
虽然都属于非结构化索引,但与前面基于权重的索引不同的是,2:4
结构的稀疏性可以被英伟达的稀疏张量核心加速。在运算时,稀疏矩阵W首先会被压缩,压缩后的矩阵存储着非零的数据值,而metadata则存储着对应非零元素在原矩阵W中的索引信息。
具体来说,metadata会将W中非零元素的行号和列号压缩成两个独立的一维数组,这两个数组就是metadata中存储的索引信息。
注:N:M 稀疏性是一种有效的稀疏技术,这种稀疏模式不仅减少了模型大小,还提高了计算效率,特别适用于硬件加速器,如 GPU 和 TPU,从而加速深度学习模型的训练和推理过程。
5.向量级剪枝
向量级剪枝以行或列为单位对权重进行裁剪。
def vector_pruning(tensor, point):rows, cols = pointprune_weight = tensor.clone()prune_weight[rows, :] = 0prune_weight[:, cols] = 0return prune_weightpoint = (2,3)
pruned_vector = vector_pruning(weight, point)
plot_tensor(pruned_vector, "pruned_vector")
经过上面的向量级剪枝后,可以直接去掉一行一列,整个权重矩阵的形状可以直接由[8,8]变为[7,7], 因此向量级剪枝属于结构化剪枝。
通常在进行向量级别的剪枝时,需要对模型的所有层统一进行剪枝,其目的是在整个模型中保持一致的稀疏结构,以确保上下游各层中结构和计算的一致性。因此,这种方法被称为“全局剪枝”或“统一剪枝”。
小结:本文主要介绍了权重级剪枝、基于模式剪枝和向量级剪枝三种不同粒度的剪枝方法,并结合可视化的方式,一步一步详细演示了每种剪枝方法的运算过程,和剪枝前后权重矩阵的变化。在实际场景中,基于模式的剪枝越来越受到青睐,因为其规则的稀疏模式可以充分利用硬件加速器的计算能力,从而显著提高计算效率。
参考阅读
- 模型压缩概览
- Awesome Compression
相关文章:

模型压缩——基于粒度剪枝
1.引言 模型剪枝本质上是一种利用稀疏性来减少模型大小和计算量,从而提高训练和推理效率的技术。它为何会有效呢? 理论依据:有研究发现,在许多深度神经网络中,大部分参数是接近于0的,这些参数对模型最终的…...

IntelliJ IDEA 2023.2x——图文配置
IntelliJ IDEA 2023.2——配置说明 界面如下图所示 : 绿泡泡查找 “码猿趣事” 查找【idea99】 IntelliJ IDEA 的官方下载地址 IntelliJ IDEA 官网下载地址 一路上NEXT 到结尾: 继续NEXT 下一步:...

SpringBoot(5)-SpringSecurity
目录 一、是什么 二、实战测试 2.1 认识 2.2 认证和授权 2.3 权限控制和注销 2.4 记住我 一、是什么 Spring Security是一个框架,侧重于为java应用程序提供身份验证和授权。 Web应用的安全性主要分为两个部分: 认证(Authentication&…...
fast-api后端 + fetch 前端流式文字响应
fast-api后端 fetch 前端流式文字响应 fast-api后台接口流式响应 前端fetch 流式文本数据处理 fast-api后台接口 流式响应 from fastapi.responses import StreamingResponse from tqdm import tqdm from pydantic import BaseModelclass ItemDataSingle(BaseModel):data: …...
Qt 的 QThread:多线程编程的基础
Qt 的 QThread:多线程编程的基础 在现代应用程序中,尤其是需要处理大量数据、进行长时间计算或者进行 I/O 操作时,多线程编程变得至关重要。Qt 提供了一个功能强大且易于使用的线程类 QThread,可以帮助开发者在 Qt 应用程序中实现…...
周末总结(2024/11/16)
工作 人际关系核心实践: 要学会随时回应别人的善意,执行时间控制在5分钟以内 坚持每天早会打招呼 遇到接不住的话题时拉低自己,抬高别人(无阴阳气息) 朋友圈点赞控制在5min以内,职场社交不要放在5min以外 职场的人际关系在面对利…...
Chrome和Chromium的区别?浏览器引擎都用的哪些?浏览器引擎的作用?
Chrome和Chromium的区别? Chrome是Google专属的产品,它是基于后者Chromium开源引擎开发。第三方浏览器公司为了加快开发流程,会直接选择开源的浏览器引擎,例如Chromium. Google将Chromium开源,本意为了打破浏览器被其他公司控制的…...

流程图图解@RequestBody @RequestPart @RequestParam @ModelAttribute
RequestBody 只能用一次,因为只有一个请求体 #mermaid-svg-8WZfkzl0GPvOiNj3 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8WZfkzl0GPvOiNj3 .error-icon{fill:#552222;}#mermaid-svg-8WZfkzl0GPvOiNj…...
AutoUpdater.NET 实现 dotNET应用自动更新
AutoUpdater.NET 是一款用于WPF、Winform软件版本更新的框架,类似框架还有Squirrel、WinSparkle、NetSparkle、Google Omaha。 一、安装AutoUpdater.NET 首先,您需要在项目中安装AutoUpdater.NET库。您可以通过NuGet包管理器来安装它。在Visual Studio中…...

108. UE5 GAS RPG 实现地图名称更新和加载关卡
在这一篇里,我们将实现对存档的删除功能,在删除时会有弹框确认。接着实现获取玩家的等级和地图名称和存档位置,我们可以通过存档进入游戏,玩家在游戏中可以在存档点存储存档。 实现删除存档 删除存档需要一个弹框确认࿰…...

对称加密与非对称加密:密码学的基石及 RSA 算法详解
对称加密与非对称加密:密码学的基石及 RSA 算法详解 在当今数字化的时代,信息安全至关重要。对称加密和非对称加密作为密码学中的两种基本加密技术,为我们的数据安全提供了强大的保障。本文将深入探讨对称加密和非对称加密的特点、应用场景&…...
排列问题方法总结(递归+迭代)
递归 一、逐步生成结果法(无序) #include<iostream> #include<vector> #include<string> #include<algorithm>using namespace std;vector<string> GetChild(int n,int curIndex){vector<string> now;vector&…...

C#从入门到放弃
C#和.NET的区别 C# C#是一个编程语言 .NET .NET是一个在window下创建程序的框架 .NET框架不仅局限于C#,它还可以支持很多语言 .NET包括了2个组件,一个叫CLR(通用语言运行时),另一个是用来构建程序的类库 CLR 用C写一个程序,在一台8688的机器…...
视频质量评价学习笔记
目录 MD VQA:大淘宝团队: ReIQA KVQ 视频质量评价学习笔记 MD VQA:大淘宝团队: https://github.com/kunyou99/MD-VQA_cvpr2023?tabreadme-ov-file ReIQA GitHub - avinabsaha/ReIQA: Official implementation for CVPR2023 Paper "Re-IQA : U…...
OpenCV、YOLO、VOC、COCO之间的关系和区别
OpenCV、YOLO、COCO 和 VOC 是计算机视觉和深度学习领域常见的几个名词,它们分别代表不同的工具、算法和数据集,之间有一些联系和区别。下面分别说明它们的定义、用途以及相互关系。 1. OpenCV(Open Source Computer Vision Library…...
Pandas进行周期与时间戳转换
时间序列数据在数据分析和金融领域非常常见,处理这些数据时,通常会面临周期(Period)与时间戳(Timestamp)之间的转换需求。理解和掌握这种转换,对于时间序列数据的清洗、预处理以及进一步分析至关重要。Python 中的 pandas 库提供了一系列便捷的函数来帮助处理这些时间序…...

【GPTs】Get Simpsonized:一键变身趣味辛普森角色
博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Get Simpsonized主要功能适用场景优点缺点使用方式 💯小结 💯GPTs指令 中文翻译: 指令保护和安全规则&…...
概率论公式整理
1 概率 古典概型和几何概型 古典概型(有限等可能)几何概型(无限等可能) 条件概率 P ( A ∣ B ) P ( A B ) P ( B ) P(A|B) \frac{P(AB)}{P(B)} P(A∣B)P(B)P(AB) 全概率公式 P ( B ) ∑ i 1 n P ( A i ) P ( B ∣ A i ) P…...

【C++】—— stack和queue的模拟实现
前言 stack 和 queue使用起来都非常简单,现在来模拟实现一下,理解其底层的原理。 在实现之前,应该知道,stack 和 queue 都是容器适配器,通过看官网文件也可以看出来;其默认的容器都是dequeÿ…...
管家婆工贸ERP BR039.采购订单关联MRP明细表
最低适用版本: 工贸系列 23.8 插件简要功能说明: 采购订单明细表,支持显示采购订单明细上游请购单明细关联的MRP中对应销售订单明细产成品相关信息更多细节描述见下方详细文档 插件操作视频: 进销存类定制插件--采购订单关联M…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...

【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...

【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
unix/linux,sudo,其发展历程详细时间线、由来、历史背景
sudo 的诞生和演化,本身就是一部 Unix/Linux 系统管理哲学变迁的微缩史。来,让我们拨开时间的迷雾,一同探寻 sudo 那波澜壮阔(也颇为实用主义)的发展历程。 历史背景:su的时代与困境 ( 20 世纪 70 年代 - 80 年代初) 在 sudo 出现之前,Unix 系统管理员和需要特权操作的…...

(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
scikit-learn机器学习
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...
Go语言多线程问题
打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...
深度剖析 DeepSeek 开源模型部署与应用:策略、权衡与未来走向
在人工智能技术呈指数级发展的当下,大模型已然成为推动各行业变革的核心驱动力。DeepSeek 开源模型以其卓越的性能和灵活的开源特性,吸引了众多企业与开发者的目光。如何高效且合理地部署与运用 DeepSeek 模型,成为释放其巨大潜力的关键所在&…...