当前位置: 首页 > news >正文

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 前言
  • 卷积层剪枝
  • 总结


前言

深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。


卷积层剪枝

可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:

  1. 初始化卷积层,并查看卷积层权重
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    print("原始卷积层权重:")
    print(conv.weight.data)
    print(conv.weight.size())
    print("原始卷积层偏置:")
    print(conv.bias.data)
    print(conv.bias.size())
    
  2. 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
    # remove_zero_kernels方法内的代码
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    
  3. 保存未被剪枝的卷积核的权重和偏置
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    
  4. 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias
    else:new_conv_layer = conv_layer
    

完整代码

import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())

总结

博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。

相关文章:

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经…...

机器人仿真-gazebo学习笔记(3)URDF和机器人模型

1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式,ROS提供了URDF文件的c解析器,可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...

lua-resty-request库写入爬虫ip实现数据抓取

根据提供的引用内容,正确的库名称应该是lua-resty-http,而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫,需要先安装OpenResty和lua-resty-http库,并将其引入到Lua脚本中。然后,可以使用lua-resty-h…...

gitlab Activating and deactivating users

原文:Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问,管理员可以选择停用…...

linux入门到精通-第五章-动态库和静态库

目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1)、生成目标文件,此时要加编译选项:-fPIC (f…...

markdown 如何更改字体以及颜色等功能

markdown 是IT人士写文档的常用方式&#xff0c;但是markdown默认又不支持颜色字体等特殊功能&#xff0c;所以呢想实现字体颜色高亮等特殊功能&#xff0c;实现的方法呢就是使用HTML&#xff0c;所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...

一次cs上线服务器的练习

环境&#xff1a;利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...

STM32-高级定时器

以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车&#xff08;断路&#xff09;功能&#xff0c;这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器&#xff0c…...

三季度业绩狂飙后,贝泰妮将开启集团化运营的“中场战事”?

双十一前夕&#xff0c;贝泰妮交出了一份亮眼的答卷。 得益于销售端和研发端的发展动能强劲&#xff0c;第三季度贝泰妮营收10.64亿元&#xff0c;同比增长25.77%&#xff1b;扣非净利润1.34亿元&#xff0c;同比增长39.88%。 如此亮眼的业绩&#xff0c;自然引得资本市场侧目…...

快速了解:什么是优化问题

1. 定义 数学优化问题是&#xff1a;在给定约束条件下&#xff0c;找到一个目标函数的最优解&#xff08;最大值或最小值&#xff09;。 2. 快速get理解 初学者对优化技术陌生的话&#xff0c;可以把 “求解优化问题” 理解为 “解一个不等式方程组”&#xff0c;解方程的。…...

Unity在Project右键点击物体之后获取到点击物体的名称

Unity在Project右键点击物体之后获取到点击物体的名称 描述&#xff1a; 在Unity的Project右键点击物体之后选择对应的菜单选项点击之后打印出物体的名称 注意事项 如果获取到文件或者预制体需要传递objcet类型&#xff0c;然后使用 GameObject.Instantiate((GameObject)se…...

【带头学C++】----- 三、指针章 ---- 3.7 数组指针

3.7 数组指针 1.数组指针的概述 数组指针是一个指向数组的指针变量&#xff0c;是用来保存数组元素的地址。在C/C中&#xff0c;数组名代表了数组的首地址&#xff0c;可以被解释为一个指向数组第一个元素的指针。因此&#xff0c;一个指向数组的指针可以通过数组名来获…...

Ubuntu20.04安装CUDA、cuDNN、tensorflow2可行流程(症状:tensorflow2在RTX3090上运行卡住)

最近发现我之前在2080ti上运行好好的代码&#xff0c;结果在3090上运行会卡住很久&#xff0c;而且模型预测结果完全乱掉&#xff0c;于是被迫研究了一天怎么在Ubuntu20.04安装CUDA、cuDNN、tensorflow2。 1.安装CUDA&#xff08;包括CUDA驱动和CUDA toolkit&#xff0c;注意此…...

untiy打开关闭浏览器

最简单的打开方法&#xff0c;只能打开不能关闭&#xff0c;自动打开默认浏览器 Application.OpenURL("https://www.bilibili.com/");打开关闭谷歌浏览器 using System.Diagnostics;private static Process web;if (web null)//打开 {web Process.Start("Chr…...

独立站优缺点解析,如何用黑科技进行缺点优化

随着跨境电商第三方平台平台红利越来越少&#xff0c;经营风险的不断增加&#xff0c;大部分人知道前年发生的亚马逊封店潮&#xff0c;涉及约1000家企业&#xff0c;5万多个账号&#xff0c;预估损失超过千亿元。 正因如此&#xff0c;更多的国内品牌和卖家不再仅依赖于第三方…...

道本科技||紧跟数字化转型趋势,企业如何提高合同管理能效?

随着数字化转型的快速发展&#xff0c;合同管理对于企业的运营效率和风险控制起着至关重要的作用。那么&#xff0c;如何紧跟数字化转型趋势&#xff0c;利用现代技术和工具提高合同管理的能效&#xff0c;以实现企业更高效、更安全的合同管理就成了企业管理中的核心问题。 在…...

框架安全-CVE 复现Apache ShiroApache Solr漏洞复现

文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞&#xff08;CVE-2020-1957&#xff09;CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…...

【OpenCV实现图像梯度,Canny边缘检测】

文章目录 概要图像梯度Canny边缘检测小结 概要 OpenCV中&#xff0c;可以使用各种函数实现图像梯度和Canny边缘检测&#xff0c;这些操作对于图像处理和分析非常重要。 图像梯度通常用于寻找图像中的边缘和轮廓。在OpenCV中&#xff0c;可以使用cv2.Sobel()函数计算图像的梯度…...

Spring Boot 解决跨域问题的 5种方案

跨域问题本质是浏览器的一种保护机制&#xff0c;它的初衷是为了保证用户的安全&#xff0c;防止恶意网站窃取数据。 一、跨域三种情况 在请求时&#xff0c;如果出现了以下情况中的任意一种&#xff0c;那么它就是跨域请求&#xff1a; 协议不同&#xff0c;如 http 和 https…...

linux 3.13版本nvme驱动阅读记录一

内核版本较低的nvme驱动代码不多&#xff0c;而且使用的是单队列的架构&#xff0c;阅读起来会轻松一点。 这个版本涉及到的nvme驱动源码文件一共就4个&#xff0c;两个nvme.h文件&#xff0c;分别在include/linux ,include/uapi/linux目录下&#xff0c;nvme-core.c是主要源码…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案

一、TRS收益互换的本质与业务逻辑 &#xff08;一&#xff09;概念解析 TRS&#xff08;Total Return Swap&#xff09;收益互换是一种金融衍生工具&#xff0c;指交易双方约定在未来一定期限内&#xff0c;基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南

文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/55aefaea8a9f477e86d065227851fe3d.pn…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)

考察一般的三次多项式&#xff0c;以r为参数&#xff1a; p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]&#xff1b; 此多项式的根为&#xff1a; 尽管看起来这个多项式是特殊的&#xff0c;其实一般的三次多项式都是可以通过线性变换化为这个形式…...