Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口
下面是基于prune的接口进行剪枝的方法步骤
1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:
- RandomUnstructured
- L1Unstructured
- RandomStructured
- LnStructured
- CustomFromMask
ps:非结构性剪枝不会给剪枝后模型的速度带来提升。
2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数
from torch.nn.utils import prune
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0 , 1, 2, 3
'''
3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存
prune.remove(module, 'weight')
print(list(module.named_parameters()))
5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。
对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()
import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''
然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝
prune全局非结构性剪枝测试结果
# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg') # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential): for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model
base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323
case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364
case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859
case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’
case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578
case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322
prune结构性剪枝测试结果
通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)
结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。
结论:pytorch的官方接口并不能直接使用
相关文章:
Pytorch剪枝api测试和结果
Pytorch 官方给出的prune接口 下面是基于prune的接口进行剪枝的方法步骤 1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有: RandomUnstructuredL1UnstructuredRandomStructuredLnStructuredCustomFromMask ps:非结构性剪…...

微服务下网关聚合Swagger文档、starter统一配置Swagger
一、starter实现统一配置微服务文档 把Swagger配置中的公共部分抽取出来Swagger与SpringBoot整合中,可能会由于版本问题出现各种问题 1、制作starter 参考: 【SpringBoot】自定义启动器 Starter【保姆级教程】用starter实现Oauth2中资源服务的统一配置用…...
剑指 Offer第二版:机器人的运动范围、正则表达式匹配、表示数值的字符串
剑指 Offer第二版 13. 机器人的运动范围19. 正则表达式匹配20. 表示数值的字符串 13. 机器人的运动范围 题目:地上有一个m行n列的方格,从坐标 [0,0] 到坐标 [m-1,n-1] 。一个机器人从坐标 [0, 0] 的格子开始移动,它每次可以向左、右、上、下移…...

Delaunay三角网生成算法
目录 一、分而治之算法二、三角网生长算法三、逐点插入算法四、约束Delaunay三角网1、方法一1、原始点云2、构网结果 1、方法二1、原始点云2、普通Delaunay3、约束Delaunay Delaunay三角剖分分为直接三角剖分和间接三角剖分。间接三角剖分首先计算为Voronoi图,然后由Voronoi图产…...
hashcode是什么?有什么作用?
文章目录 (1)hashcode()方法的作用(2)equals和hashcode的关系(3)百度百科(4)小白解释 Java中Object有一个方法: public native int hashcode(); (1࿰…...

【人体姿态估计】(一)原理介绍
【人体姿态估计】(一)原理介绍 一、背景 人体姿态估计本质上是一个关键点检测的项目; 关键点检测在生活中的应用十分广泛,包括人脸识别、手势识别,而人体姿态估计则是对身体的关键点进行检测; 本文将介…...
一种新的流:为 Java 加入生成器(Generator)特性
作者:文镭(依来) 前言 这篇文章不是工具推荐,也不是应用案例分享。其主题思想,是介绍一种全新的设计模式。它既拥有抽象的数学美感,仅仅从一个简单接口出发,就能推演出庞大的特性集合,引出许多全新概念。…...
《数据结构C++版》实验一:线性表的顺序存储结构
实验目的 1、实现线性表的顺序存储结构 2、熟悉C++程序的基本结构,掌握程序中的头文件、实现文件和主文件之间的相互关系及各自的作用 3、熟悉顺序表的基本操作方式,掌握顺序表相关操作的具体实现 实验内容 对顺序存储的线性表进行一些基本操作。主要包括: (1)插入:操作…...

ChatGPT的开源平替,终于来了!
最近这段时间,一个号称全球最大ChatGPT开源平替项目Open Assistant引起了大家的注意。 这不最近还登上了GitHub的Trending热榜。 https://github.com/LAION-AI/Open-Assistant 根据官方的介绍,Open Assistant也是一个对话式的大型语言模型项目ÿ…...

Redis基础
Redis6 1. NoSQL数据库简介 1.1 技术发展 技术的分类 1、解决功能性的问题:Java、Jsp、RDBMS、Tomcat、HTML、Linux、JDBC、SVN。 2、解决扩展性的问题:Struts、Spring、SpringMVC、Hibernate、Mybatis。 3、解决性能的问题:NoSQL、Jav…...

为什么重视安全的公司都在用SSL安全证书?
我们今天来讲一讲为什么重视安全的公司都在用SSL证书 SSL证书是什么? SSL安全证书是由权威认证机构颁发的,是CA机构将公钥和相关信息写入一个文件,CA机构用他们的私钥对我们的公钥和相关信息进行签名后,将签名信息也写入这个文件…...

嵌入式QT (使用 Qt Designer 开发)
一、使用 UI 设计器开发程序 1.1、 在 UI 文件添加一个按钮 1.2、在 UI 文件里连接信号与槽 所谓信号即是一个对象发出的信号,槽即是当这个对象发出这个信号时,对应连接的槽就发被执行或者触发。 UI 设计器里信号与槽的连接方法一: 在主窗…...

每日一个小技巧:今天告诉你拍照识别文字的软件有哪些
在现代社会里,手机已经成为了人们生活中必不可少的工具。它的功能众多,比如通讯、上网、拍照以及导航等,为我们的生活带来了许多便利。除此之外,手机还能帮助我们解决一些实际的问题,例如,当你需要识别图片…...
多版本VersionARXDBG
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、一级标题二级标题三级标题四级标题五级标题六级标题总结前言 提示:这里可以添加本文要记录的大概内容: VersionARXDBG,多版本,2023.4.22-4.23两天时间,分别研究了在多版本编译ARXDB…...

# 生成器
生成器 生成器是什么? 生成器(generator)是一种用来生成数据的对象。它们是普通函数的一种特殊形式,可以用来控制数据的生成过程。 生成器有什么优势? 使用生成器的优势在于它们可以在生成数据的同时控制数据的生成过程…...

Netty 源码解析(上)
序 Netty的影响力以及使用场景就不用多说了, 去年10月份后,就着手研究Netty源码,之前研究过Spring源码,MyBatis源码,java.util.concurrent源码,tomcat源码,发现一个特点,之前的源码都…...

Vue 消息订阅与发布
消息订阅与发布,也可以实现任意组件之间的通信。 订阅者:就相当于是我们,用于接收数据。 发布者:就相当于是媒体,用于传递数据。 安装消息订阅与发布插件: 在原生 JS 中 不太容易实现消息订阅与发布&…...

如何在你的云服务器/云主机上更新并使用最新版本的python(python3.11)
更新并使用最新版本的python3.11 第一步,登录云服务器,并更新系统包 打开您的终端(Terminal)或使用任意SSH客户端,输入如下命令来登录云主机: ssh 用户名IP地址 在输入密码后,您将成功登录到云…...

python学习——【第八弹】
前言 上篇文章 python学习——【第七弹】学习了python中的可变序列集合,自此python中的序列的学习就完成啦,这篇文章开始学习python中的函数。 函数 在学习其他编程语言的时候我们就了解过函数:函数就是执行特定任何以完成特定功能的一段代…...

铁路应答器传输系统介绍
应答器传输系统 应答器传输系统是安全点式信息传输系统,通过应答器实现地面设备向车载设备传输信息。 应答器可根据应用需求向车载设备传输固定的(通过无源应答器)或可变的(通过有源应答器)上行链路数据。 当天线单…...

超短脉冲激光自聚焦效应
前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
【解密LSTM、GRU如何解决传统RNN梯度消失问题】
解密LSTM与GRU:如何让RNN变得更聪明? 在深度学习的世界里,循环神经网络(RNN)以其卓越的序列数据处理能力广泛应用于自然语言处理、时间序列预测等领域。然而,传统RNN存在的一个严重问题——梯度消失&#…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
linux 下常用变更-8
1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行,YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID: YW3…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)
目录 一、👋🏻前言 二、😈sinx波动的基本原理 三、😈波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、🌊波动优化…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
JS设计模式(4):观察者模式
JS设计模式(4):观察者模式 一、引入 在开发中,我们经常会遇到这样的场景:一个对象的状态变化需要自动通知其他对象,比如: 电商平台中,商品库存变化时需要通知所有订阅该商品的用户;新闻网站中࿰…...

面向无人机海岸带生态系统监测的语义分割基准数据集
描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...