当前位置: 首页 > news >正文

Pytorch剪枝api测试和结果

Pytorch 官方给出的prune接口

下面是基于prune的接口进行剪枝的方法步骤

1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:

  • RandomUnstructured
  • L1Unstructured
  • RandomStructured
  • LnStructured
  • CustomFromMask
    ps:非结构性剪枝不会给剪枝后模型的速度带来提升。

2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数

from torch.nn.utils import prune 
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0123
'''

3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存

prune.remove(module, 'weight')
print(list(module.named_parameters()))

5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。

对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()

import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''

然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝

prune全局非结构性剪枝测试结果

# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg')  # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model

base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323

case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364

case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859

case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’

case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578

case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322

prune结构性剪枝测试结果

通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)

结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。

结论:pytorch的官方接口并不能直接使用

相关文章:

Pytorch剪枝api测试和结果

Pytorch 官方给出的prune接口 下面是基于prune的接口进行剪枝的方法步骤 1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有: RandomUnstructuredL1UnstructuredRandomStructuredLnStructuredCustomFromMask ps:非结构性剪…...

微服务下网关聚合Swagger文档、starter统一配置Swagger

一、starter实现统一配置微服务文档 把Swagger配置中的公共部分抽取出来Swagger与SpringBoot整合中,可能会由于版本问题出现各种问题 1、制作starter 参考: 【SpringBoot】自定义启动器 Starter【保姆级教程】用starter实现Oauth2中资源服务的统一配置用…...

剑指 Offer第二版:机器人的运动范围、正则表达式匹配、表示数值的字符串

剑指 Offer第二版 13. 机器人的运动范围19. 正则表达式匹配20. 表示数值的字符串 13. 机器人的运动范围 题目:地上有一个m行n列的方格,从坐标 [0,0] 到坐标 [m-1,n-1] 。一个机器人从坐标 [0, 0] 的格子开始移动,它每次可以向左、右、上、下移…...

Delaunay三角网生成算法

目录 一、分而治之算法二、三角网生长算法三、逐点插入算法四、约束Delaunay三角网1、方法一1、原始点云2、构网结果 1、方法二1、原始点云2、普通Delaunay3、约束Delaunay Delaunay三角剖分分为直接三角剖分和间接三角剖分。间接三角剖分首先计算为Voronoi图,然后由Voronoi图产…...

hashcode是什么?有什么作用?

文章目录 (1)hashcode()方法的作用(2)equals和hashcode的关系(3)百度百科(4)小白解释 Java中Object有一个方法: public native int hashcode(); (1&#xff0…...

【人体姿态估计】(一)原理介绍

【人体姿态估计】(一)原理介绍 一、背景 人体姿态估计本质上是一个关键点检测的项目; 关键点检测在生活中的应用十分广泛,包括人脸识别、手势识别,而人体姿态估计则是对身体的关键点进行检测; 本文将介…...

一种新的流:为 Java 加入生成器(Generator)特性

作者:文镭(依来) 前言 这篇文章不是工具推荐,也不是应用案例分享。其主题思想,是介绍一种全新的设计模式。它既拥有抽象的数学美感,仅仅从一个简单接口出发,就能推演出庞大的特性集合,引出许多全新概念。…...

《数据结构C++版》实验一:线性表的顺序存储结构

实验目的 1、实现线性表的顺序存储结构 2、熟悉C++程序的基本结构,掌握程序中的头文件、实现文件和主文件之间的相互关系及各自的作用 3、熟悉顺序表的基本操作方式,掌握顺序表相关操作的具体实现 实验内容 对顺序存储的线性表进行一些基本操作。主要包括: (1)插入:操作…...

ChatGPT的开源平替,终于来了!

最近这段时间,一个号称全球最大ChatGPT开源平替项目Open Assistant引起了大家的注意。 这不最近还登上了GitHub的Trending热榜。 https://github.com/LAION-AI/Open-Assistant 根据官方的介绍,Open Assistant也是一个对话式的大型语言模型项目&#xff…...

Redis基础

Redis6 1. NoSQL数据库简介 1.1 技术发展 技术的分类 1、解决功能性的问题:Java、Jsp、RDBMS、Tomcat、HTML、Linux、JDBC、SVN。 2、解决扩展性的问题:Struts、Spring、SpringMVC、Hibernate、Mybatis。 3、解决性能的问题:NoSQL、Jav…...

为什么重视安全的公司都在用SSL安全证书?

我们今天来讲一讲为什么重视安全的公司都在用SSL证书 SSL证书是什么? SSL安全证书是由权威认证机构颁发的,是CA机构将公钥和相关信息写入一个文件,CA机构用他们的私钥对我们的公钥和相关信息进行签名后,将签名信息也写入这个文件…...

嵌入式QT (使用 Qt Designer 开发)

一、使用 UI 设计器开发程序 1.1、 在 UI 文件添加一个按钮 1.2、在 UI 文件里连接信号与槽 所谓信号即是一个对象发出的信号,槽即是当这个对象发出这个信号时,对应连接的槽就发被执行或者触发。 UI 设计器里信号与槽的连接方法一: 在主窗…...

每日一个小技巧:今天告诉你拍照识别文字的软件有哪些

在现代社会里,手机已经成为了人们生活中必不可少的工具。它的功能众多,比如通讯、上网、拍照以及导航等,为我们的生活带来了许多便利。除此之外,手机还能帮助我们解决一些实际的问题,例如,当你需要识别图片…...

多版本VersionARXDBG

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、一级标题二级标题三级标题四级标题五级标题六级标题总结前言 提示:这里可以添加本文要记录的大概内容: VersionARXDBG,多版本,2023.4.22-4.23两天时间,分别研究了在多版本编译ARXDB…...

# 生成器

生成器 生成器是什么? 生成器(generator)是一种用来生成数据的对象。它们是普通函数的一种特殊形式,可以用来控制数据的生成过程。 生成器有什么优势? 使用生成器的优势在于它们可以在生成数据的同时控制数据的生成过程…...

Netty 源码解析(上)

序 Netty的影响力以及使用场景就不用多说了, 去年10月份后,就着手研究Netty源码,之前研究过Spring源码,MyBatis源码,java.util.concurrent源码,tomcat源码,发现一个特点,之前的源码都…...

Vue 消息订阅与发布

消息订阅与发布,也可以实现任意组件之间的通信。 订阅者:就相当于是我们,用于接收数据。 发布者:就相当于是媒体,用于传递数据。 安装消息订阅与发布插件: 在原生 JS 中 不太容易实现消息订阅与发布&…...

如何在你的云服务器/云主机上更新并使用最新版本的python(python3.11)

更新并使用最新版本的python3.11 第一步,登录云服务器,并更新系统包 打开您的终端(Terminal)或使用任意SSH客户端,输入如下命令来登录云主机: ssh 用户名IP地址 在输入密码后,您将成功登录到云…...

python学习——【第八弹】

前言 上篇文章 python学习——【第七弹】学习了python中的可变序列集合,自此python中的序列的学习就完成啦,这篇文章开始学习python中的函数。 函数 在学习其他编程语言的时候我们就了解过函数:函数就是执行特定任何以完成特定功能的一段代…...

铁路应答器传输系统介绍

应答器传输系统 应答器传输系统是安全点式信息传输系统,通过应答器实现地面设备向车载设备传输信息。 应答器可根据应用需求向车载设备传输固定的(通过无源应答器)或可变的(通过有源应答器)上行链路数据。 当天线单…...

RestClient

什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级&#xff…...

day52 ResNet18 CBAM

在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能,我们需要对它的功能特点进行分析: 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具: mysql:关系型数据库&am…...

django filter 统计数量 按属性去重

在Django中,如果你想要根据某个属性对查询集进行去重并统计数量,你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求: 方法1:使用annotate()和Count 假设你有一个模型Item,并且你想…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词

Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...

九天毕昇深度学习平台 | 如何安装库?

pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子: 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现企业微信功能

1. 开发环境准备 ​​安装DevEco Studio 3.1​​: 从华为开发者官网下载最新版DevEco Studio安装HarmonyOS 5.0 SDK ​​项目配置​​: // module.json5 {"module": {"requestPermissions": [{"name": "ohos.permis…...

ubuntu22.04有线网络无法连接,图标也没了

今天突然无法有线网络无法连接任何设备,并且图标都没了 错误案例 往上一顿搜索,试了很多博客都不行,比如 Ubuntu22.04右上角网络图标消失 最后解决的办法 下载网卡驱动,重新安装 操作步骤 查看自己网卡的型号 lspci | gre…...