【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝
最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论
文章目录
- 前言
- 卷积层剪枝
- 总结
前言
深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。
卷积层剪枝
可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:
- 初始化卷积层,并查看卷积层权重
# 示例使用一个具有3个输入通道和5个输出通道的卷积层 conv = nn.Conv2d(3, 5, 3) print("原始卷积层权重:") print(conv.weight.data) print(conv.weight.size()) print("原始卷积层偏置:") print(conv.bias.data) print(conv.bias.size()) - 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
# remove_zero_kernels方法内的代码 weight = conv_layer.weight.data # 卷积核个数 num_kernels = weight.size(0) # 随机对部分卷积置0 pruned = torch.ones(num_kernels, 1, 1, 1) # 选择随着置0的卷积序号 random_int = random.randint(1, num_kernels-1) for i in range(random_int):pruned[i, 0, 0, 0] = 0 conv_layer.weight.data = weight * pruned weight = conv_layer.weight.data bias = conv_layer.bias.data - 保存未被剪枝的卷积核的权重和偏置
# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了 norms = torch.norm(weight.view(num_kernels, -1), dim=1) zero_kernel_indices = torch.nonzero(norms==0).squeeze() print(zero_kernel_indices) # 移除L2范数为零的卷积核 new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices]) new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices]) - 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
# 构建新的卷积层 if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias else:new_conv_layer = conv_layer
完整代码
import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())
总结
博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。
相关文章:
【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝
最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经…...
机器人仿真-gazebo学习笔记(3)URDF和机器人模型
1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式,ROS提供了URDF文件的c解析器,可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...
lua-resty-request库写入爬虫ip实现数据抓取
根据提供的引用内容,正确的库名称应该是lua-resty-http,而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫,需要先安装OpenResty和lua-resty-http库,并将其引入到Lua脚本中。然后,可以使用lua-resty-h…...
gitlab Activating and deactivating users
原文:Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问,管理员可以选择停用…...
linux入门到精通-第五章-动态库和静态库
目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1)、生成目标文件,此时要加编译选项:-fPIC (f…...
markdown 如何更改字体以及颜色等功能
markdown 是IT人士写文档的常用方式,但是markdown默认又不支持颜色字体等特殊功能,所以呢想实现字体颜色高亮等特殊功能,实现的方法呢就是使用HTML,所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...
一次cs上线服务器的练习
环境:利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...
STM32-高级定时器
以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车(断路)功能,这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器,…...
三季度业绩狂飙后,贝泰妮将开启集团化运营的“中场战事”?
双十一前夕,贝泰妮交出了一份亮眼的答卷。 得益于销售端和研发端的发展动能强劲,第三季度贝泰妮营收10.64亿元,同比增长25.77%;扣非净利润1.34亿元,同比增长39.88%。 如此亮眼的业绩,自然引得资本市场侧目…...
快速了解:什么是优化问题
1. 定义 数学优化问题是:在给定约束条件下,找到一个目标函数的最优解(最大值或最小值)。 2. 快速get理解 初学者对优化技术陌生的话,可以把 “求解优化问题” 理解为 “解一个不等式方程组”,解方程的。…...
Unity在Project右键点击物体之后获取到点击物体的名称
Unity在Project右键点击物体之后获取到点击物体的名称 描述: 在Unity的Project右键点击物体之后选择对应的菜单选项点击之后打印出物体的名称 注意事项 如果获取到文件或者预制体需要传递objcet类型,然后使用 GameObject.Instantiate((GameObject)se…...
【带头学C++】----- 三、指针章 ---- 3.7 数组指针
3.7 数组指针 1.数组指针的概述 数组指针是一个指向数组的指针变量,是用来保存数组元素的地址。在C/C中,数组名代表了数组的首地址,可以被解释为一个指向数组第一个元素的指针。因此,一个指向数组的指针可以通过数组名来获…...
Ubuntu20.04安装CUDA、cuDNN、tensorflow2可行流程(症状:tensorflow2在RTX3090上运行卡住)
最近发现我之前在2080ti上运行好好的代码,结果在3090上运行会卡住很久,而且模型预测结果完全乱掉,于是被迫研究了一天怎么在Ubuntu20.04安装CUDA、cuDNN、tensorflow2。 1.安装CUDA(包括CUDA驱动和CUDA toolkit,注意此…...
untiy打开关闭浏览器
最简单的打开方法,只能打开不能关闭,自动打开默认浏览器 Application.OpenURL("https://www.bilibili.com/");打开关闭谷歌浏览器 using System.Diagnostics;private static Process web;if (web null)//打开 {web Process.Start("Chr…...
独立站优缺点解析,如何用黑科技进行缺点优化
随着跨境电商第三方平台平台红利越来越少,经营风险的不断增加,大部分人知道前年发生的亚马逊封店潮,涉及约1000家企业,5万多个账号,预估损失超过千亿元。 正因如此,更多的国内品牌和卖家不再仅依赖于第三方…...
道本科技||紧跟数字化转型趋势,企业如何提高合同管理能效?
随着数字化转型的快速发展,合同管理对于企业的运营效率和风险控制起着至关重要的作用。那么,如何紧跟数字化转型趋势,利用现代技术和工具提高合同管理的能效,以实现企业更高效、更安全的合同管理就成了企业管理中的核心问题。 在…...
框架安全-CVE 复现Apache ShiroApache Solr漏洞复现
文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞(CVE-2020-1957)CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…...
【OpenCV实现图像梯度,Canny边缘检测】
文章目录 概要图像梯度Canny边缘检测小结 概要 OpenCV中,可以使用各种函数实现图像梯度和Canny边缘检测,这些操作对于图像处理和分析非常重要。 图像梯度通常用于寻找图像中的边缘和轮廓。在OpenCV中,可以使用cv2.Sobel()函数计算图像的梯度…...
Spring Boot 解决跨域问题的 5种方案
跨域问题本质是浏览器的一种保护机制,它的初衷是为了保证用户的安全,防止恶意网站窃取数据。 一、跨域三种情况 在请求时,如果出现了以下情况中的任意一种,那么它就是跨域请求: 协议不同,如 http 和 https…...
linux 3.13版本nvme驱动阅读记录一
内核版本较低的nvme驱动代码不多,而且使用的是单队列的架构,阅读起来会轻松一点。 这个版本涉及到的nvme驱动源码文件一共就4个,两个nvme.h文件,分别在include/linux ,include/uapi/linux目录下,nvme-core.c是主要源码…...
Android Studio 高版本兼容低版本项目配置
AndroidStudio开发工具高版本兼容低版本项目配置:1、 JDK 配置:gradle.properties 文件中指定jdk 版本:org.gradle.java.homeD\:\\ProgramFiles\\JDK\\jdk-11.0.262 配置Gradle 编译版本:3. 显示所有Gradle task 列表设置完成后&a…...
Pencil原型工具全攻略:从环境搭建到高级配置
Pencil原型工具全攻略:从环境搭建到高级配置 【免费下载链接】pencil DEPRECATED: Multiplatform GUI Prototyping/Wireframing 项目地址: https://gitcode.com/gh_mirrors/pen/pencil Pencil原型工具:开源价值定位与核心特性解析 核心价值&…...
利润中心(Profit Center)和段(Segment)在 SAP 中关系非常紧密,但它们的设计目的和应用场景有本质区别
利润中心(Profit Center)和段(Segment)在 SAP 中关系非常紧密,但它们的设计目的和应用场景有本质区别。简单来说,段(Segment)是利润中心的一个上级归类。它们之间通常是“一对多”的…...
一张照片秒变3D模型!用Splatter Image和3D高斯溅射快速上手单视图重建
从单张照片到3D模型:Splatter Image技术实战指南 想象一下,你刚在二手市场淘到一个绝版手办,想为它创建数字档案;或是设计师客户临时需要将一张产品照片转为3D模型。传统流程需要专业设备扫描或手工建模,耗时数小时甚…...
Phi-4-reasoning-vision-15B多场景落地:OCR/图表分析/GUI理解三类任务统一部署
Phi-4-reasoning-vision-15B多场景落地:OCR/图表分析/GUI理解三类任务统一部署 1. 模型介绍 Phi-4-reasoning-vision-15B是微软推出的视觉多模态推理模型,能够处理多种视觉理解任务。这个模型特别擅长从图像中提取和理解信息,无论是文档文字…...
WebGL开发者必备:用RenderDoc旧版本抓帧调试的完整避坑指南(附DEBUG_CHROME.bat脚本)
WebGL开发者必备:用RenderDoc旧版本抓帧调试的完整避坑指南(附DEBUG_CHROME.bat脚本) 最近在WebGL开发中遇到一个棘手问题:最新版RenderDoc已经禁止了对Chrome等浏览器的抓帧功能。这对于正在学习图形学课程(比如GAMES…...
OSI七层模型的意义:网络世界的工程思维密码
理解七层网络模型(OSI模型)的意义,不在于死记硬背哪一层叫什么名字,而在于它能帮你建立一套拆解复杂系统的思维框架。具体来说,学习它主要有以下几层价值:1. 建立“分而治之”的工程思维网络通信是一个极其…...
基于H5的初学开发
目标: 1.能搭出页面 2.能看懂基本标签 3.能做表单 4.能放图片、音频、视频 5.能做简单画布效果 6.能做一个 AI Photo Booth 静态演示页 7.每个实验做完都能看到结果,不容易卡死 开发工具:VS Cod 本实验覆盖哪些 H5 内容 1.h…...
如何降低ai率?盘点3个降ai率神器与5个手改技巧,降aigc全流程解析!
最近我发现很多同学都在苦恼ai率这件事,后台发来的截图里,那报告,简直红得触目惊心。 现在的系统早已是next level,不是看你用了什么词,而是在分析你的文本生成逻辑。今天这篇文章,我不讲虚的,…...
Wan2.2-I2V-A14B镜像免配置:SSH直连后cd /workspace即可执行全部命令
Wan2.2-I2V-A14B镜像免配置:SSH直连后cd /workspace即可执行全部命令 1. 镜像概述与核心优势 Wan2.2-I2V-A14B私有部署镜像是一款专为文生视频模型定制的开箱即用解决方案。这个镜像最大的特点就是"免配置"——通过SSH连接后,只需进入/works…...
