Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口
下面是基于prune的接口进行剪枝的方法步骤
1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:
- RandomUnstructured
- L1Unstructured
- RandomStructured
- LnStructured
- CustomFromMask
ps:非结构性剪枝不会给剪枝后模型的速度带来提升。
2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数
from torch.nn.utils import prune
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0 , 1, 2, 3
'''
3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存
prune.remove(module, 'weight')
print(list(module.named_parameters()))
5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。
对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()
import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''
然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝
prune全局非结构性剪枝测试结果
# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg') # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model
base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323
case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364
case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859
case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’
case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578
case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322
prune结构性剪枝测试结果
通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)
结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。
结论:pytorch的官方接口并不能直接使用
相关文章:
Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口 下面是基于prune的接口进行剪枝的方法步骤 1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有: RandomUnstructuredL1UnstructuredRandomStructuredLnStructuredCustomFromMask ps:非结构性剪…...
微服务下网关聚合Swagger文档、starter统一配置Swagger
一、starter实现统一配置微服务文档 把Swagger配置中的公共部分抽取出来Swagger与SpringBoot整合中,可能会由于版本问题出现各种问题 1、制作starter 参考: 【SpringBoot】自定义启动器 Starter【保姆级教程】用starter实现Oauth2中资源服务的统一配置用…...
剑指 Offer第二版:机器人的运动范围、正则表达式匹配、表示数值的字符串
剑指 Offer第二版 13. 机器人的运动范围19. 正则表达式匹配20. 表示数值的字符串 13. 机器人的运动范围 题目:地上有一个m行n列的方格,从坐标 [0,0] 到坐标 [m-1,n-1] 。一个机器人从坐标 [0, 0] 的格子开始移动,它每次可以向左、右、上、下移…...
Delaunay三角网生成算法
目录 一、分而治之算法二、三角网生长算法三、逐点插入算法四、约束Delaunay三角网1、方法一1、原始点云2、构网结果 1、方法二1、原始点云2、普通Delaunay3、约束Delaunay Delaunay三角剖分分为直接三角剖分和间接三角剖分。间接三角剖分首先计算为Voronoi图,然后由Voronoi图产…...
hashcode是什么?有什么作用?
文章目录 (1)hashcode()方法的作用(2)equals和hashcode的关系(3)百度百科(4)小白解释 Java中Object有一个方法: public native int hashcode(); (1࿰…...
【人体姿态估计】(一)原理介绍
【人体姿态估计】(一)原理介绍 一、背景 人体姿态估计本质上是一个关键点检测的项目; 关键点检测在生活中的应用十分广泛,包括人脸识别、手势识别,而人体姿态估计则是对身体的关键点进行检测; 本文将介…...
一种新的流:为 Java 加入生成器(Generator)特性
作者:文镭(依来) 前言 这篇文章不是工具推荐,也不是应用案例分享。其主题思想,是介绍一种全新的设计模式。它既拥有抽象的数学美感,仅仅从一个简单接口出发,就能推演出庞大的特性集合,引出许多全新概念。…...
《数据结构C++版》实验一:线性表的顺序存储结构
实验目的 1、实现线性表的顺序存储结构 2、熟悉C++程序的基本结构,掌握程序中的头文件、实现文件和主文件之间的相互关系及各自的作用 3、熟悉顺序表的基本操作方式,掌握顺序表相关操作的具体实现 实验内容 对顺序存储的线性表进行一些基本操作。主要包括: (1)插入:操作…...
ChatGPT的开源平替,终于来了!
最近这段时间,一个号称全球最大ChatGPT开源平替项目Open Assistant引起了大家的注意。 这不最近还登上了GitHub的Trending热榜。 https://github.com/LAION-AI/Open-Assistant 根据官方的介绍,Open Assistant也是一个对话式的大型语言模型项目ÿ…...
Redis基础
Redis6 1. NoSQL数据库简介 1.1 技术发展 技术的分类 1、解决功能性的问题:Java、Jsp、RDBMS、Tomcat、HTML、Linux、JDBC、SVN。 2、解决扩展性的问题:Struts、Spring、SpringMVC、Hibernate、Mybatis。 3、解决性能的问题:NoSQL、Jav…...
为什么重视安全的公司都在用SSL安全证书?
我们今天来讲一讲为什么重视安全的公司都在用SSL证书 SSL证书是什么? SSL安全证书是由权威认证机构颁发的,是CA机构将公钥和相关信息写入一个文件,CA机构用他们的私钥对我们的公钥和相关信息进行签名后,将签名信息也写入这个文件…...
嵌入式QT (使用 Qt Designer 开发)
一、使用 UI 设计器开发程序 1.1、 在 UI 文件添加一个按钮 1.2、在 UI 文件里连接信号与槽 所谓信号即是一个对象发出的信号,槽即是当这个对象发出这个信号时,对应连接的槽就发被执行或者触发。 UI 设计器里信号与槽的连接方法一: 在主窗…...
每日一个小技巧:今天告诉你拍照识别文字的软件有哪些
在现代社会里,手机已经成为了人们生活中必不可少的工具。它的功能众多,比如通讯、上网、拍照以及导航等,为我们的生活带来了许多便利。除此之外,手机还能帮助我们解决一些实际的问题,例如,当你需要识别图片…...
多版本VersionARXDBG
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、一级标题二级标题三级标题四级标题五级标题六级标题总结前言 提示:这里可以添加本文要记录的大概内容: VersionARXDBG,多版本,2023.4.22-4.23两天时间,分别研究了在多版本编译ARXDB…...
# 生成器
生成器 生成器是什么? 生成器(generator)是一种用来生成数据的对象。它们是普通函数的一种特殊形式,可以用来控制数据的生成过程。 生成器有什么优势? 使用生成器的优势在于它们可以在生成数据的同时控制数据的生成过程…...
Netty 源码解析(上)
序 Netty的影响力以及使用场景就不用多说了, 去年10月份后,就着手研究Netty源码,之前研究过Spring源码,MyBatis源码,java.util.concurrent源码,tomcat源码,发现一个特点,之前的源码都…...
Vue 消息订阅与发布
消息订阅与发布,也可以实现任意组件之间的通信。 订阅者:就相当于是我们,用于接收数据。 发布者:就相当于是媒体,用于传递数据。 安装消息订阅与发布插件: 在原生 JS 中 不太容易实现消息订阅与发布&…...
如何在你的云服务器/云主机上更新并使用最新版本的python(python3.11)
更新并使用最新版本的python3.11 第一步,登录云服务器,并更新系统包 打开您的终端(Terminal)或使用任意SSH客户端,输入如下命令来登录云主机: ssh 用户名IP地址 在输入密码后,您将成功登录到云…...
python学习——【第八弹】
前言 上篇文章 python学习——【第七弹】学习了python中的可变序列集合,自此python中的序列的学习就完成啦,这篇文章开始学习python中的函数。 函数 在学习其他编程语言的时候我们就了解过函数:函数就是执行特定任何以完成特定功能的一段代…...
铁路应答器传输系统介绍
应答器传输系统 应答器传输系统是安全点式信息传输系统,通过应答器实现地面设备向车载设备传输信息。 应答器可根据应用需求向车载设备传输固定的(通过无源应答器)或可变的(通过有源应答器)上行链路数据。 当天线单…...
linux arm系统烧录
1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...
【分享】推荐一些办公小工具
1、PDF 在线转换 https://smallpdf.com/cn/pdf-tools 推荐理由:大部分的转换软件需要收费,要么功能不齐全,而开会员又用不了几次浪费钱,借用别人的又不安全。 这个网站它不需要登录或下载安装。而且提供的免费功能就能满足日常…...
GruntJS-前端自动化任务运行器从入门到实战
Grunt 完全指南:从入门到实战 一、Grunt 是什么? Grunt是一个基于 Node.js 的前端自动化任务运行器,主要用于自动化执行项目开发中重复性高的任务,例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...
STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...
省略号和可变参数模板
本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...
Vue 模板语句的数据来源
🧩 Vue 模板语句的数据来源:全方位解析 Vue 模板(<template> 部分)中的表达式、指令绑定(如 v-bind, v-on)和插值({{ }})都在一个特定的作用域内求值。这个作用域由当前 组件…...
Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...
规则与人性的天平——由高考迟到事件引发的思考
当那位身着校服的考生在考场关闭1分钟后狂奔而至,他涨红的脸上写满绝望。铁门内秒针划过的弧度,成为改变人生的残酷抛物线。家长声嘶力竭的哀求与考务人员机械的"这是规定",构成当代中国教育最尖锐的隐喻。 一、刚性规则的必要性 …...
鸿蒙HarmonyOS 5军旗小游戏实现指南
1. 项目概述 本军旗小游戏基于鸿蒙HarmonyOS 5开发,采用DevEco Studio实现,包含完整的游戏逻辑和UI界面。 2. 项目结构 /src/main/java/com/example/militarychess/├── MainAbilitySlice.java // 主界面├── GameView.java // 游戏核…...
