当前位置: 首页 > news >正文

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 前言
  • 卷积层剪枝
  • 总结


前言

深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。


卷积层剪枝

可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:

  1. 初始化卷积层,并查看卷积层权重
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    print("原始卷积层权重:")
    print(conv.weight.data)
    print(conv.weight.size())
    print("原始卷积层偏置:")
    print(conv.bias.data)
    print(conv.bias.size())
    
  2. 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
    # remove_zero_kernels方法内的代码
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    
  3. 保存未被剪枝的卷积核的权重和偏置
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    
  4. 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias
    else:new_conv_layer = conv_layer
    

完整代码

import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())

总结

博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。

相关文章:

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经…...

机器人仿真-gazebo学习笔记(3)URDF和机器人模型

1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式,ROS提供了URDF文件的c解析器,可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...

lua-resty-request库写入爬虫ip实现数据抓取

根据提供的引用内容,正确的库名称应该是lua-resty-http,而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫,需要先安装OpenResty和lua-resty-http库,并将其引入到Lua脚本中。然后,可以使用lua-resty-h…...

gitlab Activating and deactivating users

原文:Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问,管理员可以选择停用…...

linux入门到精通-第五章-动态库和静态库

目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1)、生成目标文件,此时要加编译选项:-fPIC (f…...

markdown 如何更改字体以及颜色等功能

markdown 是IT人士写文档的常用方式&#xff0c;但是markdown默认又不支持颜色字体等特殊功能&#xff0c;所以呢想实现字体颜色高亮等特殊功能&#xff0c;实现的方法呢就是使用HTML&#xff0c;所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...

一次cs上线服务器的练习

环境&#xff1a;利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...

STM32-高级定时器

以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车&#xff08;断路&#xff09;功能&#xff0c;这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器&#xff0c…...

三季度业绩狂飙后,贝泰妮将开启集团化运营的“中场战事”?

双十一前夕&#xff0c;贝泰妮交出了一份亮眼的答卷。 得益于销售端和研发端的发展动能强劲&#xff0c;第三季度贝泰妮营收10.64亿元&#xff0c;同比增长25.77%&#xff1b;扣非净利润1.34亿元&#xff0c;同比增长39.88%。 如此亮眼的业绩&#xff0c;自然引得资本市场侧目…...

快速了解:什么是优化问题

1. 定义 数学优化问题是&#xff1a;在给定约束条件下&#xff0c;找到一个目标函数的最优解&#xff08;最大值或最小值&#xff09;。 2. 快速get理解 初学者对优化技术陌生的话&#xff0c;可以把 “求解优化问题” 理解为 “解一个不等式方程组”&#xff0c;解方程的。…...

Unity在Project右键点击物体之后获取到点击物体的名称

Unity在Project右键点击物体之后获取到点击物体的名称 描述&#xff1a; 在Unity的Project右键点击物体之后选择对应的菜单选项点击之后打印出物体的名称 注意事项 如果获取到文件或者预制体需要传递objcet类型&#xff0c;然后使用 GameObject.Instantiate((GameObject)se…...

【带头学C++】----- 三、指针章 ---- 3.7 数组指针

3.7 数组指针 1.数组指针的概述 数组指针是一个指向数组的指针变量&#xff0c;是用来保存数组元素的地址。在C/C中&#xff0c;数组名代表了数组的首地址&#xff0c;可以被解释为一个指向数组第一个元素的指针。因此&#xff0c;一个指向数组的指针可以通过数组名来获…...

Ubuntu20.04安装CUDA、cuDNN、tensorflow2可行流程(症状:tensorflow2在RTX3090上运行卡住)

最近发现我之前在2080ti上运行好好的代码&#xff0c;结果在3090上运行会卡住很久&#xff0c;而且模型预测结果完全乱掉&#xff0c;于是被迫研究了一天怎么在Ubuntu20.04安装CUDA、cuDNN、tensorflow2。 1.安装CUDA&#xff08;包括CUDA驱动和CUDA toolkit&#xff0c;注意此…...

untiy打开关闭浏览器

最简单的打开方法&#xff0c;只能打开不能关闭&#xff0c;自动打开默认浏览器 Application.OpenURL("https://www.bilibili.com/");打开关闭谷歌浏览器 using System.Diagnostics;private static Process web;if (web null)//打开 {web Process.Start("Chr…...

独立站优缺点解析,如何用黑科技进行缺点优化

随着跨境电商第三方平台平台红利越来越少&#xff0c;经营风险的不断增加&#xff0c;大部分人知道前年发生的亚马逊封店潮&#xff0c;涉及约1000家企业&#xff0c;5万多个账号&#xff0c;预估损失超过千亿元。 正因如此&#xff0c;更多的国内品牌和卖家不再仅依赖于第三方…...

道本科技||紧跟数字化转型趋势,企业如何提高合同管理能效?

随着数字化转型的快速发展&#xff0c;合同管理对于企业的运营效率和风险控制起着至关重要的作用。那么&#xff0c;如何紧跟数字化转型趋势&#xff0c;利用现代技术和工具提高合同管理的能效&#xff0c;以实现企业更高效、更安全的合同管理就成了企业管理中的核心问题。 在…...

框架安全-CVE 复现Apache ShiroApache Solr漏洞复现

文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞&#xff08;CVE-2020-1957&#xff09;CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…...

【OpenCV实现图像梯度,Canny边缘检测】

文章目录 概要图像梯度Canny边缘检测小结 概要 OpenCV中&#xff0c;可以使用各种函数实现图像梯度和Canny边缘检测&#xff0c;这些操作对于图像处理和分析非常重要。 图像梯度通常用于寻找图像中的边缘和轮廓。在OpenCV中&#xff0c;可以使用cv2.Sobel()函数计算图像的梯度…...

Spring Boot 解决跨域问题的 5种方案

跨域问题本质是浏览器的一种保护机制&#xff0c;它的初衷是为了保证用户的安全&#xff0c;防止恶意网站窃取数据。 一、跨域三种情况 在请求时&#xff0c;如果出现了以下情况中的任意一种&#xff0c;那么它就是跨域请求&#xff1a; 协议不同&#xff0c;如 http 和 https…...

linux 3.13版本nvme驱动阅读记录一

内核版本较低的nvme驱动代码不多&#xff0c;而且使用的是单队列的架构&#xff0c;阅读起来会轻松一点。 这个版本涉及到的nvme驱动源码文件一共就4个&#xff0c;两个nvme.h文件&#xff0c;分别在include/linux ,include/uapi/linux目录下&#xff0c;nvme-core.c是主要源码…...

树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频

使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

R语言AI模型部署方案:精准离线运行详解

R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)

宇树机器人多姿态起立控制强化学习框架论文解析 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架&#xff08;一&#xff09; 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...

Caliper 配置文件解析:config.yaml

Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词

Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵&#xff0c;其中每行&#xff0c;每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid&#xff0c;其中有多少个 3 3 的 “幻方” 子矩阵&am…...

虚拟电厂发展三大趋势:市场化、技术主导、车网互联

市场化&#xff1a;从政策驱动到多元盈利 政策全面赋能 2025年4月&#xff0c;国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》&#xff0c;首次明确虚拟电厂为“独立市场主体”&#xff0c;提出硬性目标&#xff1a;2027年全国调节能力≥2000万千瓦&#xff0…...