Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口
下面是基于prune的接口进行剪枝的方法步骤
1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:
- RandomUnstructured
- L1Unstructured
- RandomStructured
- LnStructured
- CustomFromMask
ps:非结构性剪枝不会给剪枝后模型的速度带来提升。
2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数
from torch.nn.utils import prune
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0 , 1, 2, 3
'''
3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存
prune.remove(module, 'weight')
print(list(module.named_parameters()))
5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。
对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()
import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''
然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝
prune全局非结构性剪枝测试结果
# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg') # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model
base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323
case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364
case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859
case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’
case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578
case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322
prune结构性剪枝测试结果
通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)
结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。
结论:pytorch的官方接口并不能直接使用
相关文章:
Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口 下面是基于prune的接口进行剪枝的方法步骤 1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有: RandomUnstructuredL1UnstructuredRandomStructuredLnStructuredCustomFromMask ps:非结构性剪…...
微服务下网关聚合Swagger文档、starter统一配置Swagger
一、starter实现统一配置微服务文档 把Swagger配置中的公共部分抽取出来Swagger与SpringBoot整合中,可能会由于版本问题出现各种问题 1、制作starter 参考: 【SpringBoot】自定义启动器 Starter【保姆级教程】用starter实现Oauth2中资源服务的统一配置用…...
剑指 Offer第二版:机器人的运动范围、正则表达式匹配、表示数值的字符串
剑指 Offer第二版 13. 机器人的运动范围19. 正则表达式匹配20. 表示数值的字符串 13. 机器人的运动范围 题目:地上有一个m行n列的方格,从坐标 [0,0] 到坐标 [m-1,n-1] 。一个机器人从坐标 [0, 0] 的格子开始移动,它每次可以向左、右、上、下移…...
Delaunay三角网生成算法
目录 一、分而治之算法二、三角网生长算法三、逐点插入算法四、约束Delaunay三角网1、方法一1、原始点云2、构网结果 1、方法二1、原始点云2、普通Delaunay3、约束Delaunay Delaunay三角剖分分为直接三角剖分和间接三角剖分。间接三角剖分首先计算为Voronoi图,然后由Voronoi图产…...
hashcode是什么?有什么作用?
文章目录 (1)hashcode()方法的作用(2)equals和hashcode的关系(3)百度百科(4)小白解释 Java中Object有一个方法: public native int hashcode(); (1࿰…...
【人体姿态估计】(一)原理介绍
【人体姿态估计】(一)原理介绍 一、背景 人体姿态估计本质上是一个关键点检测的项目; 关键点检测在生活中的应用十分广泛,包括人脸识别、手势识别,而人体姿态估计则是对身体的关键点进行检测; 本文将介…...
一种新的流:为 Java 加入生成器(Generator)特性
作者:文镭(依来) 前言 这篇文章不是工具推荐,也不是应用案例分享。其主题思想,是介绍一种全新的设计模式。它既拥有抽象的数学美感,仅仅从一个简单接口出发,就能推演出庞大的特性集合,引出许多全新概念。…...
《数据结构C++版》实验一:线性表的顺序存储结构
实验目的 1、实现线性表的顺序存储结构 2、熟悉C++程序的基本结构,掌握程序中的头文件、实现文件和主文件之间的相互关系及各自的作用 3、熟悉顺序表的基本操作方式,掌握顺序表相关操作的具体实现 实验内容 对顺序存储的线性表进行一些基本操作。主要包括: (1)插入:操作…...
ChatGPT的开源平替,终于来了!
最近这段时间,一个号称全球最大ChatGPT开源平替项目Open Assistant引起了大家的注意。 这不最近还登上了GitHub的Trending热榜。 https://github.com/LAION-AI/Open-Assistant 根据官方的介绍,Open Assistant也是一个对话式的大型语言模型项目ÿ…...
Redis基础
Redis6 1. NoSQL数据库简介 1.1 技术发展 技术的分类 1、解决功能性的问题:Java、Jsp、RDBMS、Tomcat、HTML、Linux、JDBC、SVN。 2、解决扩展性的问题:Struts、Spring、SpringMVC、Hibernate、Mybatis。 3、解决性能的问题:NoSQL、Jav…...
为什么重视安全的公司都在用SSL安全证书?
我们今天来讲一讲为什么重视安全的公司都在用SSL证书 SSL证书是什么? SSL安全证书是由权威认证机构颁发的,是CA机构将公钥和相关信息写入一个文件,CA机构用他们的私钥对我们的公钥和相关信息进行签名后,将签名信息也写入这个文件…...
嵌入式QT (使用 Qt Designer 开发)
一、使用 UI 设计器开发程序 1.1、 在 UI 文件添加一个按钮 1.2、在 UI 文件里连接信号与槽 所谓信号即是一个对象发出的信号,槽即是当这个对象发出这个信号时,对应连接的槽就发被执行或者触发。 UI 设计器里信号与槽的连接方法一: 在主窗…...
每日一个小技巧:今天告诉你拍照识别文字的软件有哪些
在现代社会里,手机已经成为了人们生活中必不可少的工具。它的功能众多,比如通讯、上网、拍照以及导航等,为我们的生活带来了许多便利。除此之外,手机还能帮助我们解决一些实际的问题,例如,当你需要识别图片…...
多版本VersionARXDBG
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、一级标题二级标题三级标题四级标题五级标题六级标题总结前言 提示:这里可以添加本文要记录的大概内容: VersionARXDBG,多版本,2023.4.22-4.23两天时间,分别研究了在多版本编译ARXDB…...
# 生成器
生成器 生成器是什么? 生成器(generator)是一种用来生成数据的对象。它们是普通函数的一种特殊形式,可以用来控制数据的生成过程。 生成器有什么优势? 使用生成器的优势在于它们可以在生成数据的同时控制数据的生成过程…...
Netty 源码解析(上)
序 Netty的影响力以及使用场景就不用多说了, 去年10月份后,就着手研究Netty源码,之前研究过Spring源码,MyBatis源码,java.util.concurrent源码,tomcat源码,发现一个特点,之前的源码都…...
Vue 消息订阅与发布
消息订阅与发布,也可以实现任意组件之间的通信。 订阅者:就相当于是我们,用于接收数据。 发布者:就相当于是媒体,用于传递数据。 安装消息订阅与发布插件: 在原生 JS 中 不太容易实现消息订阅与发布&…...
如何在你的云服务器/云主机上更新并使用最新版本的python(python3.11)
更新并使用最新版本的python3.11 第一步,登录云服务器,并更新系统包 打开您的终端(Terminal)或使用任意SSH客户端,输入如下命令来登录云主机: ssh 用户名IP地址 在输入密码后,您将成功登录到云…...
python学习——【第八弹】
前言 上篇文章 python学习——【第七弹】学习了python中的可变序列集合,自此python中的序列的学习就完成啦,这篇文章开始学习python中的函数。 函数 在学习其他编程语言的时候我们就了解过函数:函数就是执行特定任何以完成特定功能的一段代…...
铁路应答器传输系统介绍
应答器传输系统 应答器传输系统是安全点式信息传输系统,通过应答器实现地面设备向车载设备传输信息。 应答器可根据应用需求向车载设备传输固定的(通过无源应答器)或可变的(通过有源应答器)上行链路数据。 当天线单…...
CnOpenData 沪市IPO发行文件-B来源
IPO(Initial Public Offing),即首次公开募股,是指一家企业(发行人)第一次将它的股份向公众出售。资本市场是现代金融体系的核心,是企业最高效的融资渠道和最强大的资本运作平台,IPO作为公司登陆资本市场的唯一路径,将使…...
保姆级教程:用PyTorch 1.13.1在GPU上跑通PointNet分类与分割(附自写推理脚本)
从零实现PointNet分类与分割:PyTorch 1.13.1 GPU实战指南 当你第一次接触3D点云处理时,可能会被各种复杂的数学公式和算法吓退。但PointNet的出现改变了这一局面——这个开创性的网络架构直接处理原始点云数据,无需复杂的体素化或网格化预处理…...
终极指南:OPAL外部数据源配置与API策略源实战
终极指南:OPAL外部数据源配置与API策略源实战 【免费下载链接】opal Policy and data administration, distribution, and real-time updates on top of Policy Agents (OPA, Cedar, ...) 项目地址: https://gitcode.com/gh_mirrors/opal1/opal OPAL…...
DownKyi:三分钟学会B站视频下载的终极解决方案
DownKyi:三分钟学会B站视频下载的终极解决方案 【免费下载链接】downkyi 哔哩下载姬downkyi,哔哩哔哩网站视频下载工具,支持批量下载,支持8K、HDR、杜比视界,提供工具箱(音视频提取、去水印等)。…...
AI技术原理--AI Token是什么:10分钟搞懂大模型基础单位
当你在ChatGPT里输入"你好,今天天气怎么样"的时候,你以为它真的读懂你的话吗? 并不是。 在你看不到的地方,有一个叫"分词器"的程序,正在把你的文字拆解成一个一个叫"Token"的单元。 …...
JTAG引脚定义:从接口信号到调试实践的深度解析
1. JTAG接口的核心引脚功能解析 第一次接触JTAG接口时,看到那一排密密麻麻的引脚确实有点发怵。但实际用起来你会发现,真正关键的信号线就那么几根。我调试过的板子少说也有上百块,总结下来最核心的就是TCK、TMS、TDO、TDI这四根线࿰…...
告别官方解锁后,我们还能怎么玩?深挖华为荣耀解BL的三种野路子与工具链
华为荣耀设备解锁技术深度解析:原理、工具与风险规避 在移动设备定制化需求日益增长的今天,安卓系统的Bootloader解锁成为技术爱好者绕不开的话题。特别是对于华为和荣耀设备用户而言,自2018年官方关闭解锁服务后,这一过程变得更…...
解密Abaqus许可证“心跳”机制与合理超时时间设置
解密Abaqus许可证“心跳”机制跟合理超时时间设置你是不单是也碰到过这种情况:Abaqus许可证明明用不了,可系统还在继续计费?我在一家制造企业做许可证优化,就碰到了此老问题。为何许可证会“死掉”?这跟许可证的心跳&a…...
Windows和Office激活终极解决方案:KMS_VL_ALL_AIO完全指南
Windows和Office激活终极解决方案:KMS_VL_ALL_AIO完全指南 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 你是否厌倦了Windows系统频繁弹出激活提醒?是否因为Office软件…...
告别慢查询:用快马ai智能生成postgresql性能优化与索引方案
告别慢查询:用快马AI智能生成PostgreSQL性能优化与索引方案 在电商系统中,订单查询是最常见的操作之一。随着业务量的增长,数据库查询性能往往会成为瓶颈。最近我在优化一个电商平台的订单查询模块时,发现几个典型的性能问题&…...
