当前位置: 首页 > news >正文

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 前言
  • 卷积层剪枝
  • 总结


前言

深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。


卷积层剪枝

可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:

  1. 初始化卷积层,并查看卷积层权重
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    print("原始卷积层权重:")
    print(conv.weight.data)
    print(conv.weight.size())
    print("原始卷积层偏置:")
    print(conv.bias.data)
    print(conv.bias.size())
    
  2. 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
    # remove_zero_kernels方法内的代码
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    
  3. 保存未被剪枝的卷积核的权重和偏置
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    
  4. 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias
    else:new_conv_layer = conv_layer
    

完整代码

import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())

总结

博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。

相关文章:

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经…...

机器人仿真-gazebo学习笔记(3)URDF和机器人模型

1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式,ROS提供了URDF文件的c解析器,可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...

lua-resty-request库写入爬虫ip实现数据抓取

根据提供的引用内容,正确的库名称应该是lua-resty-http,而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫,需要先安装OpenResty和lua-resty-http库,并将其引入到Lua脚本中。然后,可以使用lua-resty-h…...

gitlab Activating and deactivating users

原文:Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问,管理员可以选择停用…...

linux入门到精通-第五章-动态库和静态库

目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1)、生成目标文件,此时要加编译选项:-fPIC (f…...

markdown 如何更改字体以及颜色等功能

markdown 是IT人士写文档的常用方式&#xff0c;但是markdown默认又不支持颜色字体等特殊功能&#xff0c;所以呢想实现字体颜色高亮等特殊功能&#xff0c;实现的方法呢就是使用HTML&#xff0c;所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...

一次cs上线服务器的练习

环境&#xff1a;利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...

STM32-高级定时器

以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车&#xff08;断路&#xff09;功能&#xff0c;这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器&#xff0c…...

三季度业绩狂飙后,贝泰妮将开启集团化运营的“中场战事”?

双十一前夕&#xff0c;贝泰妮交出了一份亮眼的答卷。 得益于销售端和研发端的发展动能强劲&#xff0c;第三季度贝泰妮营收10.64亿元&#xff0c;同比增长25.77%&#xff1b;扣非净利润1.34亿元&#xff0c;同比增长39.88%。 如此亮眼的业绩&#xff0c;自然引得资本市场侧目…...

快速了解:什么是优化问题

1. 定义 数学优化问题是&#xff1a;在给定约束条件下&#xff0c;找到一个目标函数的最优解&#xff08;最大值或最小值&#xff09;。 2. 快速get理解 初学者对优化技术陌生的话&#xff0c;可以把 “求解优化问题” 理解为 “解一个不等式方程组”&#xff0c;解方程的。…...

Unity在Project右键点击物体之后获取到点击物体的名称

Unity在Project右键点击物体之后获取到点击物体的名称 描述&#xff1a; 在Unity的Project右键点击物体之后选择对应的菜单选项点击之后打印出物体的名称 注意事项 如果获取到文件或者预制体需要传递objcet类型&#xff0c;然后使用 GameObject.Instantiate((GameObject)se…...

【带头学C++】----- 三、指针章 ---- 3.7 数组指针

3.7 数组指针 1.数组指针的概述 数组指针是一个指向数组的指针变量&#xff0c;是用来保存数组元素的地址。在C/C中&#xff0c;数组名代表了数组的首地址&#xff0c;可以被解释为一个指向数组第一个元素的指针。因此&#xff0c;一个指向数组的指针可以通过数组名来获…...

Ubuntu20.04安装CUDA、cuDNN、tensorflow2可行流程(症状:tensorflow2在RTX3090上运行卡住)

最近发现我之前在2080ti上运行好好的代码&#xff0c;结果在3090上运行会卡住很久&#xff0c;而且模型预测结果完全乱掉&#xff0c;于是被迫研究了一天怎么在Ubuntu20.04安装CUDA、cuDNN、tensorflow2。 1.安装CUDA&#xff08;包括CUDA驱动和CUDA toolkit&#xff0c;注意此…...

untiy打开关闭浏览器

最简单的打开方法&#xff0c;只能打开不能关闭&#xff0c;自动打开默认浏览器 Application.OpenURL("https://www.bilibili.com/");打开关闭谷歌浏览器 using System.Diagnostics;private static Process web;if (web null)//打开 {web Process.Start("Chr…...

独立站优缺点解析,如何用黑科技进行缺点优化

随着跨境电商第三方平台平台红利越来越少&#xff0c;经营风险的不断增加&#xff0c;大部分人知道前年发生的亚马逊封店潮&#xff0c;涉及约1000家企业&#xff0c;5万多个账号&#xff0c;预估损失超过千亿元。 正因如此&#xff0c;更多的国内品牌和卖家不再仅依赖于第三方…...

道本科技||紧跟数字化转型趋势,企业如何提高合同管理能效?

随着数字化转型的快速发展&#xff0c;合同管理对于企业的运营效率和风险控制起着至关重要的作用。那么&#xff0c;如何紧跟数字化转型趋势&#xff0c;利用现代技术和工具提高合同管理的能效&#xff0c;以实现企业更高效、更安全的合同管理就成了企业管理中的核心问题。 在…...

框架安全-CVE 复现Apache ShiroApache Solr漏洞复现

文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞&#xff08;CVE-2020-1957&#xff09;CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…...

【OpenCV实现图像梯度,Canny边缘检测】

文章目录 概要图像梯度Canny边缘检测小结 概要 OpenCV中&#xff0c;可以使用各种函数实现图像梯度和Canny边缘检测&#xff0c;这些操作对于图像处理和分析非常重要。 图像梯度通常用于寻找图像中的边缘和轮廓。在OpenCV中&#xff0c;可以使用cv2.Sobel()函数计算图像的梯度…...

Spring Boot 解决跨域问题的 5种方案

跨域问题本质是浏览器的一种保护机制&#xff0c;它的初衷是为了保证用户的安全&#xff0c;防止恶意网站窃取数据。 一、跨域三种情况 在请求时&#xff0c;如果出现了以下情况中的任意一种&#xff0c;那么它就是跨域请求&#xff1a; 协议不同&#xff0c;如 http 和 https…...

linux 3.13版本nvme驱动阅读记录一

内核版本较低的nvme驱动代码不多&#xff0c;而且使用的是单队列的架构&#xff0c;阅读起来会轻松一点。 这个版本涉及到的nvme驱动源码文件一共就4个&#xff0c;两个nvme.h文件&#xff0c;分别在include/linux ,include/uapi/linux目录下&#xff0c;nvme-core.c是主要源码…...

Perplexity数据验证功能全链路解析(98.7%准确率背后的4层校验架构)

更多请点击&#xff1a; https://kaifayun.com 第一章&#xff1a;Perplexity数据验证功能全链路解析&#xff08;98.7%准确率背后的4层校验架构&#xff09; Perplexity 的数据验证并非单一规则匹配&#xff0c;而是融合语义一致性、来源可信度、时效性约束与逻辑闭环性的四维…...

初创公司如何借助Taotoken降低大模型API的试用与集成门槛

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 初创公司如何借助Taotoken降低大模型API的试用与集成门槛 对于初创公司而言&#xff0c;技术选型阶段的效率与成本控制至关重要。在…...

CrapFixer深度解析:为什么这个7年老工具依然是Windows优化的首选

CrapFixer深度解析&#xff1a;为什么这个7年老工具依然是Windows优化的首选 【免费下载链接】Crapfixer Cr*ap Fixer 项目地址: https://gitcode.com/gh_mirrors/cr/Crapfixer 在Windows 11和Windows 10系统中&#xff0c;你是否厌倦了无处不在的广告、烦人的数据收集和…...

手把手教你用wget和迅雷搞定nuScenes数据集下载(附完整性校验命令)

高效获取nuScenes数据集的两种技术方案与完整性验证指南 在自动驾驶与计算机视觉研究领域&#xff0c;nuScenes数据集因其丰富的传感器数据和精细的标注体系已成为行业基准测试的重要资源。但对于大多数研究者而言&#xff0c;获取这个总容量超过550GB的数据集却面临着网络不稳…...

Escrcpy终极指南:5分钟掌握Android设备图形化控制与屏幕镜像

Escrcpy终极指南&#xff1a;5分钟掌握Android设备图形化控制与屏幕镜像 【免费下载链接】escrcpy &#x1f4f1; Display and control your Android device graphically with scrcpy. 项目地址: https://gitcode.com/GitHub_Trending/es/escrcpy 你是否曾经为在电脑上控…...

美业门店商业模式开发(系统介绍)

美业门店商业模式开发美业门店的商业模式开发需要考虑多个方面&#xff0c;包括目标客户群体、服务类型、定价策略、营销渠道和盈利模式。常见的商业模式包括单店经营、连锁加盟、线上预约结合线下服务、会员制等。单店经营适合初创品牌&#xff0c;成本较低&#xff0c;管理简…...

告别Modelsim命令行!用Notepad++插件NppExec一键检查Verilog语法(附详细配置命令)

硬件工程师的效率革命&#xff1a;Notepad与Verilog语法检查的终极整合方案 在数字电路设计领域&#xff0c;Verilog作为主流硬件描述语言&#xff0c;其语法检查是每位工程师日常工作中不可或缺的环节。传统工作流程中&#xff0c;工程师们不得不在文本编辑器与EDA工具之间频繁…...

如何在Windows上快速挂载ISO镜像?WinCDEmu虚拟光驱终极指南

如何在Windows上快速挂载ISO镜像&#xff1f;WinCDEmu虚拟光驱终极指南 【免费下载链接】WinCDEmu 项目地址: https://gitcode.com/gh_mirrors/wi/WinCDEmu 还在为ISO、IMG等光盘镜像文件无法直接使用而烦恼吗&#xff1f;还在为没有物理光驱而无法读取光盘内容而困扰吗…...

终极指南:3步快速掌握日语漫画OCR识别神器MangaOCR

终极指南&#xff1a;3步快速掌握日语漫画OCR识别神器MangaOCR 【免费下载链接】manga-ocr Optical character recognition for Japanese text, with the main focus being Japanese manga 项目地址: https://gitcode.com/gh_mirrors/ma/manga-ocr 你是否曾经面对日文漫…...

51单片机电子秤的语音播报怎么选?JQ8400模块 vs OTP芯片,实测成本与易用性对比

51单片机电子秤语音方案实战选型&#xff1a;JQ8400模块与OTP芯片的深度拆解 在智能硬件开发中&#xff0c;语音交互功能正从锦上添花的附加项逐渐变为核心用户体验的关键组成部分。以51单片机电子秤为例&#xff0c;语音播报功能不仅能提升产品的无障碍使用体验&#xff0c;还…...