【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝
最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论
文章目录
- 前言
- 卷积层剪枝
- 总结
前言
深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。
卷积层剪枝
可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:
- 初始化卷积层,并查看卷积层权重
# 示例使用一个具有3个输入通道和5个输出通道的卷积层 conv = nn.Conv2d(3, 5, 3) print("原始卷积层权重:") print(conv.weight.data) print(conv.weight.size()) print("原始卷积层偏置:") print(conv.bias.data) print(conv.bias.size())
- 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
# remove_zero_kernels方法内的代码 weight = conv_layer.weight.data # 卷积核个数 num_kernels = weight.size(0) # 随机对部分卷积置0 pruned = torch.ones(num_kernels, 1, 1, 1) # 选择随着置0的卷积序号 random_int = random.randint(1, num_kernels-1) for i in range(random_int):pruned[i, 0, 0, 0] = 0 conv_layer.weight.data = weight * pruned weight = conv_layer.weight.data bias = conv_layer.bias.data
- 保存未被剪枝的卷积核的权重和偏置
# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了 norms = torch.norm(weight.view(num_kernels, -1), dim=1) zero_kernel_indices = torch.nonzero(norms==0).squeeze() print(zero_kernel_indices) # 移除L2范数为零的卷积核 new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices]) new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
- 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
# 构建新的卷积层 if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias else:new_conv_layer = conv_layer
完整代码
import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())
总结
博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。
相关文章:
【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝
最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经…...

机器人仿真-gazebo学习笔记(3)URDF和机器人模型
1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式,ROS提供了URDF文件的c解析器,可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...

lua-resty-request库写入爬虫ip实现数据抓取
根据提供的引用内容,正确的库名称应该是lua-resty-http,而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫,需要先安装OpenResty和lua-resty-http库,并将其引入到Lua脚本中。然后,可以使用lua-resty-h…...
gitlab Activating and deactivating users
原文:Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问,管理员可以选择停用…...

linux入门到精通-第五章-动态库和静态库
目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1)、生成目标文件,此时要加编译选项:-fPIC (f…...
markdown 如何更改字体以及颜色等功能
markdown 是IT人士写文档的常用方式,但是markdown默认又不支持颜色字体等特殊功能,所以呢想实现字体颜色高亮等特殊功能,实现的方法呢就是使用HTML,所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...

一次cs上线服务器的练习
环境:利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...

STM32-高级定时器
以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车(断路)功能,这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器,…...

三季度业绩狂飙后,贝泰妮将开启集团化运营的“中场战事”?
双十一前夕,贝泰妮交出了一份亮眼的答卷。 得益于销售端和研发端的发展动能强劲,第三季度贝泰妮营收10.64亿元,同比增长25.77%;扣非净利润1.34亿元,同比增长39.88%。 如此亮眼的业绩,自然引得资本市场侧目…...

快速了解:什么是优化问题
1. 定义 数学优化问题是:在给定约束条件下,找到一个目标函数的最优解(最大值或最小值)。 2. 快速get理解 初学者对优化技术陌生的话,可以把 “求解优化问题” 理解为 “解一个不等式方程组”,解方程的。…...

Unity在Project右键点击物体之后获取到点击物体的名称
Unity在Project右键点击物体之后获取到点击物体的名称 描述: 在Unity的Project右键点击物体之后选择对应的菜单选项点击之后打印出物体的名称 注意事项 如果获取到文件或者预制体需要传递objcet类型,然后使用 GameObject.Instantiate((GameObject)se…...

【带头学C++】----- 三、指针章 ---- 3.7 数组指针
3.7 数组指针 1.数组指针的概述 数组指针是一个指向数组的指针变量,是用来保存数组元素的地址。在C/C中,数组名代表了数组的首地址,可以被解释为一个指向数组第一个元素的指针。因此,一个指向数组的指针可以通过数组名来获…...

Ubuntu20.04安装CUDA、cuDNN、tensorflow2可行流程(症状:tensorflow2在RTX3090上运行卡住)
最近发现我之前在2080ti上运行好好的代码,结果在3090上运行会卡住很久,而且模型预测结果完全乱掉,于是被迫研究了一天怎么在Ubuntu20.04安装CUDA、cuDNN、tensorflow2。 1.安装CUDA(包括CUDA驱动和CUDA toolkit,注意此…...
untiy打开关闭浏览器
最简单的打开方法,只能打开不能关闭,自动打开默认浏览器 Application.OpenURL("https://www.bilibili.com/");打开关闭谷歌浏览器 using System.Diagnostics;private static Process web;if (web null)//打开 {web Process.Start("Chr…...
独立站优缺点解析,如何用黑科技进行缺点优化
随着跨境电商第三方平台平台红利越来越少,经营风险的不断增加,大部分人知道前年发生的亚马逊封店潮,涉及约1000家企业,5万多个账号,预估损失超过千亿元。 正因如此,更多的国内品牌和卖家不再仅依赖于第三方…...

道本科技||紧跟数字化转型趋势,企业如何提高合同管理能效?
随着数字化转型的快速发展,合同管理对于企业的运营效率和风险控制起着至关重要的作用。那么,如何紧跟数字化转型趋势,利用现代技术和工具提高合同管理的能效,以实现企业更高效、更安全的合同管理就成了企业管理中的核心问题。 在…...

框架安全-CVE 复现Apache ShiroApache Solr漏洞复现
文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞(CVE-2020-1957)CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…...

【OpenCV实现图像梯度,Canny边缘检测】
文章目录 概要图像梯度Canny边缘检测小结 概要 OpenCV中,可以使用各种函数实现图像梯度和Canny边缘检测,这些操作对于图像处理和分析非常重要。 图像梯度通常用于寻找图像中的边缘和轮廓。在OpenCV中,可以使用cv2.Sobel()函数计算图像的梯度…...
Spring Boot 解决跨域问题的 5种方案
跨域问题本质是浏览器的一种保护机制,它的初衷是为了保证用户的安全,防止恶意网站窃取数据。 一、跨域三种情况 在请求时,如果出现了以下情况中的任意一种,那么它就是跨域请求: 协议不同,如 http 和 https…...
linux 3.13版本nvme驱动阅读记录一
内核版本较低的nvme驱动代码不多,而且使用的是单队列的架构,阅读起来会轻松一点。 这个版本涉及到的nvme驱动源码文件一共就4个,两个nvme.h文件,分别在include/linux ,include/uapi/linux目录下,nvme-core.c是主要源码…...
脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)
一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...
React Native 开发环境搭建(全平台详解)
React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果
让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
在机器学习的回归分析中,损失函数的选择对模型性能具有决定性影响。均方误差(MSE)作为经典的损失函数,在处理干净数据时表现优异,但在面对包含异常值的噪声数据时,其对大误差的二次惩罚机制往往导致模型参数…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...

技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...