当前位置: 首页 > news >正文

模型优化【2】-剪枝[局部剪枝]

模型剪枝是一种常见的模型压缩技术,它可以通过去除模型中不必要的参数和结构来减小模型的大小和计算量,从而提高模型的效率和速度。在 PyTorch 中,我们可以使用一些库和工具来实现模型剪枝。

pytorch实现剪枝的思路是生成一个掩码,然后同时保存原参数、mask、新参数,如下图:
在这里插入图片描述

Pytorch实现模型剪枝的基本步骤

  1. 加载模型:我们首先需要加载一个已经训练好的模型,可以使用 PyTorch 提供的模型库或者自己训练的模型。

  2. 定义剪枝方法:我们需要定义一种剪枝方法,来决定哪些参数和结构需要被剪枝。

  3. 执行剪枝操作:我们需要执行剪枝操作,将不必要的参数和结构从模型中去除。

  4. 保存剪枝后的模型:我们需要将剪枝后的模型保存下来,以便后续使用。

pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

局部剪枝

局部剪枝是指在什么网络的单个层或局部范围内进行剪枝。
Pytorch中与剪枝有关的接口封装在torch.nn.utils.prune中。下面开始演示三种剪枝在LeNet网络中的应用效果,首先给出LeNet网络结构。

加载模型

import torch
from torch import nnclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸self.conv1 = nn.Conv2d(1, 6, 3)self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
model = LeNet().to(device=device)

局部剪枝实验,假定对模型的第一个卷积层中的权重进行剪枝

# 打印输出剪枝前的参数
module = model.conv1
print(list(module.named_parameters()))
print(list(module.buffers()))
print(module.weight)

运行结果

[('weight', Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],[-0.1132,  0.1059, -0.0381],[ 0.0430, -0.1634, -0.1345]]],...[[[-0.0226,  0.2091, -0.1479],[ 0.2302, -0.0988,  0.2117],[-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2658,  0.2096, -0.2639, -0.3063, -0.1453,  0.1201], device='cuda:0',requires_grad=True))]
[]
Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],[-0.1132,  0.1059, -0.0381],[ 0.0430, -0.1634, -0.1345]]],...[[[-0.0226,  0.2091, -0.1479],[ 0.2302, -0.0988,  0.2117],[-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)

定义剪枝+执行剪枝

# 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
# random_unstructured 是一种裁剪技术,随机非结构化裁剪
# 第一个参数:modeul,代表要进行剪枝的特定模型,之前我们已经制定了module=module.conv1,说明这里要对第一个卷积层执行剪枝
# 第二个参数:name,指定要对选中模块中的那些参数执行剪枝,这里设定为name='weight',意味着对连接网络的weight剪枝,而不死bias剪枝
# 第三个参数:amount,指定要对模型中的多大比例的参数执行剪枝,amount是一个介于0.0~1.0的float数值,或者一个正整数指定裁剪多少条连接边。
prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
print(list(module.named_parameters()))
# 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
print(list(module.named_buffers()))# 经过裁剪操作后的模型,原始的参数存放在了weight-orig中,
# 对应的剪枝矩阵存放在weight-mask中,而将weight-mask视作掩码张量
# 再和weight-orig相乘的结果就存放在了weight中
print(module.weight)
print(module.bias)

运行结果

[('bias', Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-1.1443e-01,  3.2276e-01, -2.4664e-02],[ 4.6659e-02,  1.8311e-01,  6.6681e-02],[-2.5493e-01, -1.1471e-01,  2.8336e-01]]],...[[[ 1.4041e-01,  2.0963e-02,  2.2884e-01],[ 3.5870e-02,  7.5861e-02,  8.4728e-02],[ 4.1965e-02, -1.2838e-01,  8.8462e-02]]]], device='cuda:0',requires_grad=True))]
[('weight_mask', tensor([[[[1., 0., 0.],[1., 0., 0.],[0., 0., 0.]]],...[[[1., 0., 0.],[1., 0., 1.],[0., 1., 0.]]]], device='cuda:0'))]
tensor([[[[-1.1443e-01,  0.0000e+00, -0.0000e+00],[ 4.6659e-02,  0.0000e+00,  0.0000e+00],[-0.0000e+00, -0.0000e+00,  0.0000e+00]]],...[[[ 1.4041e-01,  0.0000e+00,  0.0000e+00],[ 3.5870e-02,  0.0000e+00,  8.4728e-02],[ 0.0000e+00, -1.2838e-01,  0.0000e+00]]]], device='cuda:0',grad_fn=<MulBackward0>)
Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',requires_grad=True)

保存剪枝后的模型

# 保存剪枝后的模型
torch.save(model.state_dict(), 'pruned_model.pth')

模型经历剪枝以后,原始的权重矩阵weight参数不见了,变成了weight_orig。并且剪枝前打印为空的列表module.name_buffers(),此时拥有了一个weight_mask参数。经过剪枝操作后的模型,原始的参数存放在了weight_orig中,对应的剪枝矩阵存在weight_mask中,而将weight_mask视作掩码张量,再和weight_orig相乘的结果就存在了weight中。

Q1:打印经过剪枝处理的 weight 参数。这个 weight 实际上是原始的 weight_orig 和 weight_mask 的元素乘积,其中被剪枝的权重会被设置为0。这个weight不是剪枝了嘛?为什么还能打印出来?

答:在Pytorch的剪枝过程中,当我们说剪枝一个权重的参数时,并不是真的从网络中移除这些参数,而是通过一个掩码来“禁用”它们。这是通过将某些权重的值设为0来实现的,从而在网络的前向传播中这些权重不会有任何作用,这种方法允许我们在保留原始权重信息的同时,实现剪枝的效果。

在 PyTorch 的剪枝过程中,当我们说“剪枝”一个权重参数时,并不是真的从网络中移除这些权重,而是通过应用一个掩码来“禁用”它们,为什么禁用就可以达到模型压缩的目的?为什么剪枝完,执行print(list(module.named_parameters())),没有显示weight属性,但是执行print(module.weight)时,weight依然存在?

  • 为什么“禁用”权重可以达到模型压缩的目的?
    虽然剪枝后的权重仍然占据内存空间,但在实际计算中,值为0的权重不会对前向传播产生任何影响。这意味着在计算层面可以忽略这些权重,从而减少计算量。
  • 为什么 print(list(module.named_parameters())) 没有显示 weight 属性,但执行 print(module.weight) 时 weight 依然存在?
    • 修改参数列表:当执行剪枝操作时,PyTorch 会修改模块的参数列表。原始的 weight 参数被重命名为 weight_orig,并且创建了一个新的名为 weight_mask 的缓冲区。原始的 weight 参数(现在是 weight_orig)和 weight_mask 通过一个钩子(hook)相结合,生成了新的 weight 属性。

    • 动态权重生成:在调用 module.weight 时,由于剪枝过程中添加的前向钩子,weight 参数是动态生成的,它是 weight_orig 和 weight_mask 的元素乘积。因此,尽管 weight 在 named_parameters 列表中看起来已经不存在,但它实际上是在运行时动态生成的。

    • 参数与属性的区别:在 PyTorch 中,模块的参数(可通过 named_parameters 访问)和模块的属性(如直接通过 module.weight 访问)是不同的。module.weight 被视为一个可访问的属性,但由于剪枝过程的内部处理,它可能不再直接列在模块的参数列表中。

既然原始的 weight 参数被重命名为 weight_orig,那参数是不是并没有发生变化,又怎么能达到剪枝的效果呢?
原始的 weight 参数在剪枝过程中被重命名为 weight_orig,并且保持不变。剪枝的效果是通过以下几个关键步骤实现的:

  1. 掩码(Mask)创建:
  • 在剪枝过程中,PyTorch 创建了一个掩码(weight_mask),它是一个与 weight 形状相同的二进制张量(由0和1组成)。
  • 在这个掩码中,1表示相应的权重保持不变,而0表示相应的权重被“剪枝”(实际上是被禁用)。
  1. 动态权重更新:
  • 尽管 weight_orig 保持不变,但是模块的 weight 属性被动态更新为 weight_orig 和 weight_mask 的元素乘积。
  • 这意味着,在模型的前向传播过程中,实际使用的 weight 是被掩码修改过的。在这个新的 weight 中,被剪枝的权重(在 weight_mask 中对应0的位置)的值为0,而其他位置的权重保持原始值。
  1. 前向传播的影响:
  • 当模型进行前向传播时,使用的是被掩码修改过的 weight。因此,尽管原始的 weight 参数(现在是 weight_orig)没有变化,模型实际上使用的权重已经被剪枝修改了。
  • 剪枝过程实际上通过使某些权重值为0,从而在模型的计算过程中禁用了这些权重。
  1. 模型复杂度的降低:
  • 通过这种方式,模型的复杂度在实际运行时降低了,因为一部分权重不再对输出产生影响。
  • 这可以提高计算效率,并且在某些情况下,可以通过专门的硬件和软件优化来利用权重的这种稀疏性。
    综上所述,虽然原始的 weight 参数作为 weight_orig 保留下来,但是实际上模型使用的是被掩码修改过的权重,这就是剪枝效果的实现方式。这种方法的一个优势是可以在不永久性地移除权重的情况下,测试和评估剪枝的影响,甚至可以在必要时撤销剪枝操作。

相关文章:

模型优化【2】-剪枝[局部剪枝]

模型剪枝是一种常见的模型压缩技术&#xff0c;它可以通过去除模型中不必要的参数和结构来减小模型的大小和计算量&#xff0c;从而提高模型的效率和速度。在 PyTorch 中&#xff0c;我们可以使用一些库和工具来实现模型剪枝。 pytorch实现剪枝的思路是生成一个掩码&#xff0…...

VMware 系列:ESXI6.7升级7.0

ESXI6.7升级7.0 一、下载补丁二、上传文件三 启用Shell四、登录Shell后台五、删除不兼容驱动六、正常升级最近,将一台使用ESXI6.7的虚拟机升级到了7.0版本,下面记录一下自己的升级过程。 升级条件 首先确保硬件是否能升级到7.0版本,物理网卡驱动为e1000e不能升级,如果是ig…...

4-20mA高精度采集方案

下载链接&#xff01;https://mp.weixin.qq.com/s?__bizMzU2OTc4ODA4OA&mid2247557466&idx1&snb5a323285c2629a41d2a896764db27eb&chksmfcfaf28dcb8d7b9bb6211030d9bda53db63ab51f765b4165d9fa630e54301f0406efdabff0fb&token976581939&langzh_CN#rd …...

案例022:基于微信小程序的行政复议在线预约系统

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…...

Go 工具链详解(七):模块缓存清理工具

go mod 缓存 在 Golang 中&#xff0c;模块是对一组版本化的包的集合的描述。Go 1.11 版本引入了模块支持&#xff0c;通过 go mod 命令提供了对模块的管理。Go 模块的一个重要特性是依赖管理&#xff0c;可以清晰地定义项目所依赖的模块及对应的版本&#xff0c;并确保代码使…...

1.7 C语言之函数概述

1.7 C语言之函数概述 一、概述二、练习 一、概述 函数就是把一组计算操作封装起来&#xff0c;供程序员调用&#xff0c;我们只需知道其提供了什么功能&#xff0c;而无需关注具体实现细节(前提是其久经考验&#xff0c;设计没有问题&#xff0c;后续我们自己写的函数大概率还…...

CTA-GAN:基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影 CT到增强CT的合成技术

Generative Adversarial Network–based Noncontrast CT Angiography for Aorta and Carotid Arteries 基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影背景贡献实验方法损失函数Thinking 基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影 https://github.com/ying-f…...

电源控制系统架构(PCSA)之电源管理基础设施组件

目录 6.5 电源管理基础设施组件 6.5.1 电源策略单元 6.5.2 时钟控制器 6.5.3 低功耗Distributor 6.5.4 低功耗Combiner 6.5.5 P-Channel到Q-Channel转换器 6.5 电源管理基础设施组件 6.5.1 电源策略单元 本节介绍电源策略单元(Power Policy Unit, PPU)。PPU的完整细节见…...

影刀RPA_boss直聘翻页(避坑)

boss直聘翻页这里有个坑 问题&#xff1a; 无限循环中&#xff0c;点击下一页按钮&#xff0c;直到不可点击为止。 发现&#xff0c;在点到第5页的时候&#xff0c;再次点击下一页&#xff0c;直接就点击了页码10&#xff0c;导致流程直接就结束了。 在第5页进行校验&#xff0…...

第十四章 控制值的转换 - 在DISPLAYLIST中投影值

文章目录 第十四章 控制值的转换 - 在DISPLAYLIST中投影值在DISPLAYLIST中投影值 第十四章 控制值的转换 - 在DISPLAYLIST中投影值 在DISPLAYLIST中投影值 对于 %String 类型&#xff08;或任何子类&#xff09;的属性&#xff0c;XML 投影可以使用 DISPLAYLIST 参数。 简单…...

C++类与对象(5)—流运算符重载、const、取地址

目录 一、流输出 1、实现单个输出 2、实现连续输出 二、流输入 总结&#xff1a; 三、const修饰 四、取地址 .取地址及const取地址操作符重载 五、[ ]运算符重载 一、流输出 1、实现单个输出 创建一个日期类。 class Date { public:Date(int year 1, int month 1,…...

Vue框架学习笔记——事件修饰符

文章目录 前文提要事件修饰符prevent&#xff08;常用&#xff09;stop&#xff08;不常用&#xff09;事件冒泡stop使用方法三层嵌套下的stop三层嵌套看出的stop&#xff1a; once&#xff08;常用&#xff09;capture&#xff08;不常用&#xff09;self&#xff08;不常用&a…...

嵌入式虚拟机原理

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术。关…...

AMESim|Make failed:Unable to create an excutable for the system

最近在AMESIM与MATLAB进行联合仿真的时候遇到如下问题&#xff1a; Make failed:Unable to create an excutable for the system. 看了网上的解决办法如下 配置环境变量重装AMESIM&#xff0c;有顺序要求&#xff0c;首先是VS&#xff0c;然后是AMESIM与MATLAB。在AMESIM安装…...

OpenHarmony之NAPI框架介绍

张志成 诚迈科技高级技术专家 NAPI是什么 NAPI的概念源自Nodejs&#xff0c;为了实现javascript脚本与C库之间的相互调用&#xff0c;Nodejs对V8引擎的api做了一层封装&#xff0c;称为NAPI。可以在Nodejs官网&#xff08;https://nodejs.org/dist/latest-v20.x/docs/api/n-api…...

计算机网络之网络层

一、概述 主要任务是实现网络互连&#xff0c;进而实现数据包在各网络之间的传输 1.1网络引入的目的 从7层结构上看&#xff0c;网络层下是数据链路层 从4层结构上看&#xff0c;网络层下面是网络接口层 至少我们看到的网络层下面是以太网 以太网解决了什么问题&#xff1f; 答…...

【C指针(五)】6种转移表实现整合longjmp()/setjmp()函数和qsort函数详解分析模拟实现

&#x1f308;write in front :&#x1f50d;个人主页 &#xff1a; 啊森要自信的主页 ✏️真正相信奇迹的家伙&#xff0c;本身和奇迹一样了不起啊&#xff01; 欢迎大家关注&#x1f50d;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;>希望看完我的文章对你有小小的帮助&am…...

浅谈电力设备智能无线温度检测系统

安科瑞 华楠 摘要&#xff1a;在长期工作中&#xff0c;由于设备基础变化、温湿度变化、严重超负荷运行、触点氧化等原因造成的电力设备压接不紧&#xff0c;触头接触部分发生改变。终导致接触电阻增大&#xff0c;造成巨大的风险隐患。本系统将通过无线测温的方式&#xff0c…...

通过ros系统中websocket中发送sensor_msgs::Image数据给web端显示(二)

通过ros系统中websocket中发送sensor_msgs::Image数据给web端显示(二) mp4媒体流数据 #include <ros/ros.h> #include <signal.h> #include <sensor_msgs/Image.h> #include <message_filters/subscriber.h> #include <message_filters/synchroniz…...

LeetCode [简单] 283. 移动零

给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元素的相对顺序。 请注意 &#xff0c;必须在不复制数组的情况下原地对数组进行操作。 283. 移动零 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 快慢指针&…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

高危文件识别的常用算法:原理、应用与企业场景

高危文件识别的常用算法&#xff1a;原理、应用与企业场景 高危文件识别旨在检测可能导致安全威胁的文件&#xff0c;如包含恶意代码、敏感数据或欺诈内容的文档&#xff0c;在企业协同办公环境中&#xff08;如Teams、Google Workspace&#xff09;尤为重要。结合大模型技术&…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...

人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式

今天是关于AI如何在教学中增强学生的学习体验&#xff0c;我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育&#xff0c;这并非炒作&#xff0c;而是已经发生的巨大变革。教育机构和教育者不能忽视它&#xff0c;试图简单地禁止学生使…...

手机平板能效生态设计指令EU 2023/1670标准解读

手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读&#xff0c;综合法规核心要求、最新修正及企业合规要点&#xff1a; 一、法规背景与目标 生效与强制时间 发布于2023年8月31日&#xff08;OJ公报&…...

wpf在image控件上快速显示内存图像

wpf在image控件上快速显示内存图像https://www.cnblogs.com/haodafeng/p/10431387.html 如果你在寻找能够快速在image控件刷新大图像&#xff08;比如分辨率3000*3000的图像&#xff09;的办法&#xff0c;尤其是想把内存中的裸数据&#xff08;只有图像的数据&#xff0c;不包…...

LangFlow技术架构分析

&#x1f527; LangFlow 的可视化技术栈 前端节点编辑器 底层框架&#xff1a;基于 &#xff08;一个现代化的 React 节点绘图库&#xff09; 功能&#xff1a; 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...

2025年低延迟业务DDoS防护全攻略:高可用架构与实战方案

一、延迟敏感行业面临的DDoS攻击新挑战 2025年&#xff0c;金融交易、实时竞技游戏、工业物联网等低延迟业务成为DDoS攻击的首要目标。攻击呈现三大特征&#xff1a; AI驱动的自适应攻击&#xff1a;攻击流量模拟真实用户行为&#xff0c;差异率低至0.5%&#xff0c;传统规则引…...