实践教程|基于 pytorch 实现模型剪枝
PyTorch剪枝方法详解,附详细代码。
-
一,剪枝分类
-
1.1,非结构化剪枝
-
1.2,结构化剪枝
-
1.3,本地与全局修剪
-
二,PyTorch 的剪枝
-
2.1,pytorch 剪枝工作原理
-
2.2,局部剪枝
-
2.3,全局非结构化剪枝
-
三,总结
-
参考资料
一,剪枝分类
所谓模型剪枝,其实是一种从神经网络中移除"不必要"权重或偏差(weigths/bias)的模型压缩技术。关于什么参数才是“不必要的”,这是一个目前依然在研究的领域。
1.1,非结构化剪枝
非结构化剪枝(Unstructured Puning)是指修剪参数的单个元素,比如全连接层中的单个权重、卷积层中的单个卷积核参数元素或者自定义层中的浮点数(scaling floats)。其重点在于,剪枝权重对象是随机的,没有特定结构,因此被称为非结构化剪枝。
1.2,结构化剪枝
与非结构化剪枝相反,结构化剪枝会剪枝整个参数结构。比如,丢弃整行或整列的权重,或者在卷积层中丢弃整个过滤器(Filter
)。
1.3,本地与全局修剪
剪枝可以在每层(局部)或多层/所有层(全局)上进行。
二,PyTorch 的剪枝
目前 PyTorch 框架支持的权重剪枝方法有:
-
Random: 简单地修剪随机参数。
-
Magnitude: 修剪权重最小的参数(例如它们的 L2 范数)
以上两种方法实现简单、计算容易,且可以在没有任何数据的情况下应用。
2.1,pytorch 剪枝工作原理
剪枝功能在 torch.nn.utils.prune
类中实现,代码在文件 torch/nn/utils/prune.py 中,主要剪枝类如下图所示。
pytorch_pruning_api_file.png
剪枝原理是基于张量(Tensor)的掩码(Mask)实现。掩码是一个与张量形状相同的布尔类型的张量,掩码的值为 True 表示相应位置的权重需要保留,掩码的值为 False 表示相应位置的权重可以被删除。
Pytorch 将原始参数 <param>
复制到名为 <param>_original
的参数中,并创建一个缓冲区来存储剪枝掩码 <param>_mask
。同时,其也会创建一个模块级的 forward_pre_hook 回调函数(在模型前向传播之前会被调用的回调函数),将剪枝掩码应用于原始权重。
pytorch 剪枝的 api
和教程比较混乱,我个人将做了如下表格,希望能将 api 和剪枝方法及分类总结好。
pytorch_pruning_api
pytorch 中进行模型剪枝的工作流程如下:
-
选择剪枝方法(或者子类化 BasePruningMethod 实现自己的剪枝方法)。
-
指定剪枝模块和参数名称。
-
设置剪枝方法的参数,比如剪枝比例等。
2.2,局部剪枝
Pytorch 框架中的局部剪枝有非结构化和结构化剪枝两种类型,值得注意的是结构化剪枝只支持局部不支持全局。
2.2.1,局部非结构化剪枝
1,局部非结构化剪枝(Locall Unstructured Pruning)对应函数原型如下:
def random_unstructured(module, name, amount)
1,函数功能:用于对权重参数张量进行非结构化剪枝。该方法会在张量中随机选择一些权重或连接进行剪枝,剪枝率由用户指定。2,函数参数定义:
-
module
(nn.Module): 需要剪枝的网络层/模块,例如 nn.Conv2d() 和 nn.Linear()。 -
name
(str): 要剪枝的参数名称,比如 “weight” 或 “bias”。 -
amount
(int or float): 指定要剪枝的数量,如果是 0~1 之间的小数,则表示剪枝比例;如果是证书,则直接剪去参数的绝对数量。比如amount=0.2
,表示将随机选择 20% 的元素进行剪枝。
3,下面是 random_unstructured
函数的使用示例。
import torch
import torch.nn.utils.prune as prune
conv = torch.nn.Conv2d(1, 1, 4)
prune.random_unstructured(conv, name="weight", amount=0.5)
conv.weight
"""
tensor([[[[-0.1703, 0.0000, -0.0000, 0.0690], [ 0.1411, 0.0000, -0.0000, -0.1031], [-0.0527, 0.0000, 0.0640, 0.1666], [ 0.0000, -0.0000, -0.0000, 0.2281]]]], grad_fn=<MulBackward0>)
"""
可以看出输出的 conv 层中权重值有一半比例为 0
。
2.2.2,局部结构化剪枝
局部结构化剪枝(Locall Structured Pruning)有两种函数,对应函数原型如下:
def random_structured(module, name, amount, dim)
def ln_structured(module, name, amount, n, dim, importance_scores=None)
1,函数功能
与非结构化移除的是连接权重不同,结构化剪枝移除的是整个通道权重。
2,参数定义
与局部非结构化函数非常相似,唯一的区别是您必须定义 dim 参数(ln_structured 函数多了 n
参数)。
n
表示剪枝的范数,dim
表示剪枝的维度。
对于 torch.nn.Linear:
-
dim = 0
:移除一个神经元。 -
dim = 1
:移除与一个输入的所有连接。
对于 torch.nn.Conv2d:
-
dim = 0
(Channels) : 通道 channels 剪枝/过滤器 filters 剪枝 -
dim = 1
(Neurons): 二维卷积核 kernel 剪枝,即与输入通道相连接的 kernel
2.2.3,局部结构化剪枝示例代码
在写示例代码之前,我们先需要理解 Conv2d
函数参数、卷积核 shape、轴以及张量的关系。首先,Conv2d 函数原型如下;
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
而 pytorch 中常规卷积的卷积核权重 shape
都为(C_out, C_in, kernel_height, kernel_width
),所以在代码中卷积层权重 shape
为 [3, 2, 3, 3]
,dim = 0 对应的是 shape [3, 2, 3, 3] 中的 3
。这里我们 dim 设定了哪个轴,那自然剪枝之后权重张量对应的轴机会发生变换。
dim
理解了前面的关键概念,下面就可以实际使用了,dim=0
的示例如下所示。
conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])
print(norm1)
"""
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
print(conv.weight)
"""
tensor([[[[-0.0005, 0.1039, 0.0306], [ 0.1233, 0.1517, 0.0628], [ 0.1075, -0.0606, 0.1140]], [[ 0.2263, -0.0199, 0.1275], [-0.0455, -0.0639, -0.2153], [ 0.1587, -0.1928, 0.1338]]], [[[-0.2023, 0.0012, 0.1617], [-0.1089, 0.2102, -0.2222], [ 0.0645, -0.2333, -0.1211]], [[ 0.2138, -0.0325, 0.0246], [-0.0507, 0.1812, -0.2268], [-0.1902, 0.0798, 0.0531]]], [[[ 0.0000, -0.0000, -0.0000], [ 0.0000, -0.0000, -0.0000], [ 0.0000, -0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000], [-0.0000, 0.0000, 0.0000], [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)
"""
从运行结果可以明显看出,卷积层参数的最后一个通道参数张量被移除了(为 0
张量),其解释参见下图。
dim_understand
dim = 1
的情况:
conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])
print(norm1)
"""
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
print(conv.weight)
"""
tensor([[[[ 0.0000, -0.0000, -0.0000], [-0.0000, 0.0000, 0.0000], [-0.0000, 0.0000, -0.0000]], [[-0.2140, 0.1038, 0.1660], [ 0.1265, -0.1650, -0.2183], [-0.0680, 0.2280, 0.2128]]], [[[-0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.0000], [-0.0000, -0.0000, -0.0000]], [[-0.2087, 0.1275, 0.0228], [-0.1888, -0.1345, 0.1826], [-0.2312, -0.1456, -0.1085]]], [[[-0.0000, 0.0000, 0.0000], [ 0.0000, -0.0000, 0.0000], [ 0.0000, -0.0000, 0.0000]], [[-0.0891, 0.0946, -0.1724], [-0.2068, 0.0823, 0.0272], [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)
"""
很明显,对于 dim=1
的维度,其第一个张量的 L2 范数更小,所以shape 为 [2, 3, 3] 的张量中,第一个 [3, 3] 张量参数会被移除(即张量为 0 矩阵) 。
2.3,全局非结构化剪枝
前文的 local 剪枝的对象是特定网络层,而 global 剪枝是将模型看作一个整体去移除指定比例(数量)的参数,同时 global 剪枝结果会导致模型中每层的稀疏比例是不一样的。
全局非结构化剪枝函数原型如下:
# v1.4.0 版本
def global_unstructured(parameters, pruning_method, **kwargs)
# v2.0.0-rc2版本
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
1,函数功能:
随机选择全局所有参数(包括权重和偏置)的一部分进行剪枝,而不管它们属于哪个层。
2,参数定义:
-
parameters
((Iterable of (module, name) tuples)): 修剪模型的参数列表,列表中的元素是 (module, name)。 -
pruning_method
(function): 目前好像官方只支持 pruning_method=prune.L1Unstuctured,另外也可以是自己实现的非结构化剪枝方法函数。 -
importance_scores
: 表示每个参数的重要性得分,如果为 None,则使用默认得分。 -
**kwargs
: 表示传递给特定剪枝方法的额外参数。比如amount
指定要剪枝的数量。
3,global_unstructured
函数的示例代码如下所示。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel, 6 output channels, 3x3 square conv kernel self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet().to(device=device) model = LeNet() parameters_to_prune = ( (model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'),
) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2,
)
# 计算卷积层和整个模型的稀疏度
# 其实调用的是 Tensor.numel 内内函数,返回输入张量中元素的总数
print( "Sparsity in conv1.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv1.weight == 0)) / float(model.conv1.weight.nelement()) )
)
print( "Global sparsity: {:.2f}%".format( 100. * float( torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) ) / float( model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement() ) )
)
# 程序运行结果
"""
Sparsity in conv1.weight: 3.70%
Global sparsity: 20.00%
"""
运行结果表明,虽然模型整体(全局)的稀疏度是 20%
,但每个网络层的稀疏度不一定是 20%。
三,总结
另外,pytorch 框架还提供了一些帮助函数:
-
torch.nn.utils.prune.is_pruned(module): 判断模块 是否被剪枝。
-
torch.nn.utils.prune.remove(module, name):用于将指定模块中指定参数上的剪枝操作移除,从而恢复该参数的原始形状和数值。
虽然 PyTorch 提供了内置剪枝 API
,也支持了一些非结构化和结构化剪枝方法,但是 API
比较混乱,对应文档描述也不清晰,所以后面我还会结合微软的开源 nni
工具来实现模型剪枝功能。
更多剪枝方法实践,可以参考这个 github
仓库:Model-Compression。
参考资料
-
How to Prune Neural Networks with PyTorch
-
PRUNING TUTORIAL
-
PyTorch Pruning
相关文章:

实践教程|基于 pytorch 实现模型剪枝
PyTorch剪枝方法详解,附详细代码。 一,剪枝分类 1.1,非结构化剪枝 1.2,结构化剪枝 1.3,本地与全局修剪 二,PyTorch 的剪枝 2.1,pytorch 剪枝工作原理 2.2,局部剪枝 2.3&#…...

[Docker精进篇] Docker镜像构建和实践 (三)
前言: Docker镜像构建的作用是将应用程序及其依赖打包到一个可移植、自包含的镜像中,以便在不同环境中快速、可靠地部署和运行应用程序。 文章目录 Docker镜像构建1️⃣是什么?2️⃣为什么?3️⃣镜像构建一、用现有容器构建新镜像…...

【Unity细节】Unity中的层级LayerMask
👨💻个人主页:元宇宙-秩沅 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 秩沅 原创 😶🌫️收录于专栏:unity细节和bug 😶🌫️优质专栏 ⭐【…...

修改el-table行悬停状态的背景颜色
.content:deep().el-table tr:hover>td {background-color: #f5f5f5 !important; /* 设置悬停时的背景颜色 */ }/*这一点很重要,否则可能会导致hover行时操作列还是原来的背景色*/ .content:deep().el-table__body tr.hover-row>td{background-color: #f5f5f5…...

记一次mysql not in的使用问题
现象:使用not in 某个id集合,出现脏数据,存在null数据。例如:not in(1,2,null),结果会一条数据都没有,为空 原因: 当使用NOT IN操作符时,传递给它的值列表中不能包含NULL值…...

JavaFx基础学习【四】:UI控件的通用属性
目录 前言 一、介绍 二、继承关系 三、常用通用属性 四、属性Properties 五、属性绑定 六、属性监听 七、事件驱动 八、其他章节 前言 如果你还没有看过前面的文章,可以通过以下链接快速前往学习: JavaFx基础学习【一】:基本认识_明…...

【Leetcode】101.对称二叉树
一、题目 1、题目描述 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例1: 输入:root = [1,2,2,3,4,4,3] 输出:true示例2: 输入:root = [1,2,2,null,3,null,3] 输出:false提示: 树中节点数目在范围 [1, 1000] 内-100 <= Node.val <= 100进阶:你可以…...

用Java实现原神抽卡算法
哈喽~大家好,好久没有更新了,也确实遇到了很多事,这篇开始恢复更新,喜欢的话,可以给个的三连,什么?你要白嫖?那可以给个免费的赞麻。 🥇个人主页:个人主页…...

微服务—Eureka注册中心
eureka相当于是一个公司的管理人事HR,各部门之间如果有合作时,由HR进行人员的分配以及调度,具体选哪个人,全凭HR的心情,如果你这个部门存在没有意义,直接把你这个部门撤销,全体人员裁掉,所以不想…...

AI问答:JSBridge / WebView 与 Native 通信
一、理解JSBridge JSBridge是一种连接JavaScript和Native代码的桥梁,它提供了一种方法,使得JavaScript可以直接调用Native的代码,同时使得Native的代码也能直接调用JavaScript的方法,从而实现了JavaScript和Native之间的相互调用和…...

Mybatis动态SQL,标签大全
动态SQL常用场景 批量删除delete from t_car where id in(1,2,3,4,5,6,......这里的值是动态的,根据用户选择的 id不同,值是不同的);多条件查询哪些字段会作为查询条件是不确定的,根据用户而定 select * from 1 t_car where brand like 丰田…...

zotero在不同系统的安装(win/linux)
1 window系统安装 zotero 官网: https://www.zotero.org/ 官方文档 :https://www.zotero.org/support/ (官方)推荐常用的插件: https://www.zotero.org/support/plugins 入门视频推荐: Zotero 文献管理与知识整理最佳实践 点击 exe文件自…...

web会话跟踪以及JWT响应拦截机制
目录 JWT 会话跟踪 token 响应拦截器 http是无状态的,登录成功后,客户端就与服务器断开连接,之后再向后端发送请求时,后端需要知道前端是哪个用户在进行操作。 JWT Json web token (JWT), 是为了在网络应用环境间传递声明而…...

Web菜鸟入门教程 - Swagger实现自动生成文档
如果是一个人把啥都开发了,那用不到Swagger-UI,但一般情况是前后端分离的,所以就需要告诉前端开发人员都有哪些接口,传入什么参数,怎么调用,返回什么。有了Swagger-UI就能把这部分文档编写的业务给省去了。…...

2023国赛数学建模思路 - 复盘:校园消费行为分析
文章目录 0 赛题思路1 赛题背景2 分析目标3 数据说明4 数据预处理5 数据分析5.1 食堂就餐行为分析5.2 学生消费行为分析 建模资料 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 赛题背景 校园一卡通是集…...

第7章:贝叶斯分类器
贝叶斯决策论 贝叶斯分类器:使用贝叶斯公式 贝叶斯学习:使用分布估计(不同于频率主义的点估计) 极大似然估计 朴素贝叶斯分类 半朴素贝叶斯 条件独立性假设,在现实生活中往往很难成立。 半朴素贝叶 斯的一个常用策略…...

【LeetCode】88.合并两个有序数组
题目 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中,使合并后的数组同样按 非递减顺序 排列。 注意:最终,合并…...

05 - 研究 .git 目录
查看所有文章链接:(更新中)GIT常用场景- 目录 文章目录 1. HEAD2. config3. refs4. objects 1. HEAD 2. config 3. refs 4. objects Git对象一共有三种:数据对象 blob、树对象 tree以及提交对象 commit,这些对象都被保…...

MySQL之索引和事务
索引什么是索引索引怎么用索引的原理 事务使用事务事务特性MySQL隔离级别 索引 什么是索引 索引包含数据表所有记录的引用指针;你可以对某一列或者多列创建索引和指定不同的类型(唯一索引、主键索引、普通索引等不同类型;他们底层实现也是不…...

⛳ 将本地已有的项目上传到 git 仓库
目录 ⛳ 将本地已有的项目上传到 git 仓库🏭 一、克隆 拷贝🎨 二、强行合并两个仓库 ⛳ 将本地已有的项目上传到 git 仓库 有两种方法: 一、克隆 拷贝 二、强行合并两个仓库 🏭 一、克隆 拷贝 直接用把远程仓库拉到本…...

ADB常用命令整理(全网最全)
调试Android程序时,我们经常需要使用adb shell命令。adb是Android Debug Bridge的缩写,它充当调试桥梁的作用,就像一条连接开发机和设备之间的桥梁。 通过adb,我们可以在Eclipse中使用DDMS来调试Android程序,简单来说…...

BBS项目day02、注册、登录(登录之随机验证码)、退出登录、密码加密加盐、首页(导航条、模态框,修改密码)
一、注册 1.注册之前端页面 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>注册页面</title><!--动态引入文件-->{% load static %}<script src"{% static js/jquery.min.js %…...

HTML5+CSS3自用笔记
助解:解析编译,加载运行 浏览器的渲染过程 JS加载执行 普通js/sync:阻塞 DOM加载解析 async:下载完就执行,无依赖 <script type"text/javascript" src"x.min.js" async"async"&g…...

无则插入有则更新(PostgreSQL,MySQL,Oracle、SqlServer)
无则插入有则更新 PostgreSQL 无则插入有则更新 conflict(带有唯一性约束的字段),根据此字段判断是更新还是插入 INSERT INTO student(id,name,sex) VALUES(1, 小明, 男) ON conflict (id) DO UPDATE SET id 1,name 小明,sex 男;无则插入有则不做操作 INSERT I…...

常见的 JavaScript 框架比较
以下是10种常见的JavaScript框架的比较: React:是由Facebook开发和维护的开源JavaScript库,用于构建用户界面。它允许你使用组件来构建复杂的UI,并专注于每个组件的内部逻辑,而不必担心管理整个应用程序的状态。WebBu…...

基于R语言APSIM模型进阶应用与参数优化、批量模拟
随着数字农业和智慧农业的发展,基于过程的农业生产系统模型在模拟作物对气候变化的响应与适应、农田管理优化、作物品种和株型筛选、农田固碳和温室气体排放等领域扮演着越来越重要的作用。APSIM (Agricultural Production Systems sIMulator)模型是世界知名的作物生…...

AMD卡启动Stable Diffusion AI绘画的方法
WindowsAMD安装法 1.安装python 3.10.6,在python官网上下载安装程序,***重要*** 在安装的第一个窗口下方勾选“将python添加到path”。 2.安装git 3.WindowsAMD使用AUTOMATIC1111的directml这一个fork,在这个页面的第一段:https:/…...

Ubuntu系统kubeadm安装K8S_v1.25.x容器使用docker(K8S_v1.24版本以后依然使用docker容器管理)
安装所需要的全部文档请点击这里下载 系统是: root@k8s-master:~# cat /etc/lsb-release DISTRIB_ID=Ubuntu DISTRIB_RELEASE=22.04 DISTRIB_CODENAME=jammy DISTRIB_DESCRIPTION=“Ubuntu 22.04.3 LTS” root@k8s-master:~# uname -a Linux k8s-master 5.15.0-76-generic #8…...

【MaxKey对接一】对接gitlab的oauth登录
MaxKey的Oauth过程 引导进入 GET http://{{maxKey_host}}/sign/authz/oauth/v20/authorize?client_idYOUR_CLIENT_ID&response_typecode&redirect_uriYOUR_REGISTERED_REDIRECT_URI 登录后回调地址 YOUR_REGISTERED_REDIRECT_URI/?code{{code}} 换取Access Token GET…...

【Buildroot】构建根文件系统等
文章目录 0. 前言10. 环境软件硬件 20. Buildroot 环境搭建简述下载环境搭建toolchain下载、安装构建镜像(仅供参考) 80. 问题点1. 编译、清除时提示权限不足 0. 前言 对嵌入式linux开发和linux开发环境不熟悉的同志们就不要往下看了 对嵌入式linux开发和…...