当前位置: 首页 > news >正文

Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )print(pruning_group.details())  # or print(pruning_group)# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):# handle groups in sequential orderidxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True)# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

相关文章:

Yolov8模型用torch_pruning剪枝

目录 &#x1f680;&#x1f680;&#x1f680;订阅专栏&#xff0c;更新及时查看不迷路&#x1f680;&#x1f680;&#x1f680; 原理 遍历所有分组 高级剪枝器 &#x1f680;&#x1f680;&#x1f680;订阅专栏&#xff0c;更新及时查看不迷路&#x1f680;&#x1f680…...

C++字符串操作【超详细】

零.前言 本文将重点围绕C的字符串来展开描述。 其中&#xff0c;对于C/C中字符串的一些区别也做出了回答&#xff0c;并对于C的&#xff08;string库&#xff09;进行了讲解&#xff0c;最后我们给出字符串的不同表达形式。 开发环境&#xff1a; VS2022 一.字符串常量跟字…...

Ps:画笔工具

画笔工具 Brush Tool是 Photoshop 中最常用的工具&#xff0c;可广泛地用于绘画与修饰工作之中。 快捷键&#xff1a;B ◆ ◆ ◆ 常用操作方法与技巧 1、熟练掌握画笔工具的操作对于使用其他工具也非常有益&#xff0c;因为 Photoshop 中许多与笔刷相关的工具有类似的选项和操…...

【鸿蒙 HarmonyOS 4.0】弹性布局(Flex)

一、介绍 弹性布局&#xff08;Flex&#xff09;提供更加有效的方式对容器中的子元素进行排列、对齐和分配剩余空间。容器默认存在主轴与交叉轴&#xff0c;子元素默认沿主轴排列&#xff0c;子元素在主轴方向的尺寸称为主轴尺寸&#xff0c;在交叉轴方向的尺寸称为交叉轴尺寸…...

Java 客户端向服务端上传文件(TCP通信)

一、实验内容 编写一个客户端向服务端上传文件的程序&#xff0c;要求使用TCP通信的的知识&#xff0c;完成将本地机器输入的路径下的文件上传到D盘中名称为upload的文件夹中。并把客户端的IP地址加上count标识作为上传后文件的文件名&#xff0c;即IP&#xff08;count&#…...

问题:前端获取long型数值精度丢失,后面几位都为0

文章目录 问题分析解决 问题 通过接口获取到的数据和 Postman 获取到的数据不一样&#xff0c;仔细看 data 的第17位之后 分析 该字段类型是long类型问题&#xff1a;前端接收到数据后&#xff0c;发现精度丢失&#xff0c;当返回的结果超过17位的时候&#xff0c;后面的全…...

Day26:安全开发-PHP应用模版引用Smarty渲染MVC模型数据联动RCE安全

目录 新闻列表 自写模版引用 Smarty模版引用 代码RCE安全测试 思维导图 PHP知识点&#xff1a; 功能&#xff1a;新闻列表&#xff0c;会员中心&#xff0c;资源下载&#xff0c;留言版&#xff0c;后台模块&#xff0c;模版引用&#xff0c;框架开发等 技术&#xff1a;输…...

LVS集群(Linux Virtual server)介绍----及LVS的NAT模式部署(一)

群集的含义 ●Cluster&#xff0c;集群、群集由多台主机构成&#xff0c;但对外只表现为一个整体&#xff0c;只提供访问入口(域名或IP地址)&#xff0c;相当于一台大型计算机 问题&#xff1a; 互联网应用中&#xff0c;随着站点对硬件性能、响应速度、服务稳定性、数据可靠…...

海外媒体宣发套餐如何利用3种方式洞察市场-华媒舍

在当今数字化时代&#xff0c;媒体宣发成为了企业推广产品和品牌的重要手段之一。其中&#xff0c;7FT媒体宣发套餐是一种常用而有效的宣传方式。本文将介绍这种媒体宣发套餐&#xff0c;以及如何利用它来洞察市场。 一、关键概念 在深入讨论7FT媒体宣发套餐之前&#xff0c;让…...

开发知识点-Apache Struts2框架

Apache Struts2 介绍S2-001S2CVE-2023-22530 介绍 Apache Struts2是一个基于MVC&#xff08;模型-视图-控制器&#xff09;设计模式的Web应用程序框架&#xff0c;它是Apache旗下的一个开源项目&#xff0c;并且是Struts1的下一代产品。Struts2是在Struts1和WebWork的技术基础…...

【Spring高级】第3讲 Bean的生命周期

目录 基本的生命周期后处理器总结 基本的生命周期 为了演示生命周期的过程&#xff0c;我们直接使用 SpringApplication.run()方法&#xff0c;他会直接诶返回一个容器对象。 import org.springframework.boot.SpringApplication; import org.springframework.context.Config…...

【C语言】linux内核tcp_write_xmit和tcp_write_queue_purge

tcp_write_xmit 一、讲解 这个函数 tcp_write_xmit 是Linux内核TCP协议栈中的一部分&#xff0c;其基本作用是发送数据包到网络。这个函数会根据不同情况推进发送队列的头部&#xff0c;确保只要远程窗口有空间&#xff0c;就可以发送数据。 下面是对该函数的一些主要逻辑的中…...

opencv实现视频人脸识别

一. 实现指定图像的人脸识别 注意&#xff1a; 以下实例参考《OpenCV轻松入门面向Python》李立宗著&#xff0c;使用python语言&#xff0c;编辑器为PyCharm&#xff0c;且都运行成功。 1.dface3.jpg图片文件和当前代码放在同一级目录下。 2.级联分类器文件和当前代码文件放在…...

【今日面经】24/3/9 广州Java某小厂电话面经

面经来源&#xff1a;https://www.nowcoder.com/?type818_1 目录 1、 和equals()有什么区别&#xff1f;2、String变量直接赋值和构造函数赋值比较相等吗&#xff1f;3、String一些方法&#xff1f;4、抽象类和接口有什么区别&#xff1f;5、Java容器有哪些&#xff1f;6、Lis…...

日期问题---算法精讲

前言 今天讲讲日期问题&#xff0c;所谓日期问题&#xff0c;在蓝桥杯中出现众多&#xff0c;但是解法比较固定。 一般有判断日期合法性&#xff0c;判断是否闰年&#xff0c;判断日期的特殊形式&#xff08;回文或abababab型等&#xff09; 目录 例题 题2 题三 总结 …...

倒计时35天

dp预备(来源&#xff1a;b站acm刘春英老师) 1. 2. 3. 4. 5. 6. 7....

JAVA后端开发面试基础知识(七)——多线程

1. 线程池原理 优点 降低资源消耗。通过重复利用已创建的线程降低线程创建和销毁造成的消耗。提高响应速度。当任务到达时,任务可以不需要等到线程创建就能立即执行。提高线程的可管理性。线程是稀缺资源,如果无限制的创建,不仅会消耗系统资源,还会降低系统的稳定性,使用线…...

Apache的安装与目录结构详细解说

1. Apache安装步骤 Apache是一款开源的Web服务器软件&#xff0c;常用于搭建网站和服务。以下是Apache的安装步骤&#xff1a; 在官方网站&#xff08;https://httpd.apache.org/&#xff09;下载最新版本的Apache软件包。解压下载的软件包到指定目录。运行安装程序&#xff…...

axios的详细使用

目录 axios&#xff1a;现代前端开发的HTTP客户端王者 一、axios简介 二、axios的基本用法 1. 安装axios 2. 发起GET请求 3. 发起POST请求 三、axios的高级特性 1. 拦截器 2. 取消请求 3. 自动转换JSON数据 四、axios在前端开发中的应用 五、总结 axios&#xff1a…...

空间复杂度的OJ练习——轮转数组

旋转数组OJ链接&#xff1a;https://leetcode-cn.com/problems/rotate-array/ 题目&#xff1a; 思路&#xff1a; 通过题目我们可以知道这是一个无序数组&#xff0c;只需要将数组中的数按给定条件重新排列&#xff0c;因此我们可以想到以下几种方法&#xff1a; 1.暴力求解法…...

dedecms 织梦自定义表单留言增加ajax验证码功能

增加ajax功能模块&#xff0c;用户不点击提交按钮&#xff0c;只要输入框失去焦点&#xff0c;就会提前提示验证码是否正确。 一&#xff0c;模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

Linux --进程控制

本文从以下五个方面来初步认识进程控制&#xff1a; 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程&#xff0c;创建出来的进程就是子进程&#xff0c;原来的进程为父进程。…...

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列&#xff0c;以便知晓哪些列包含有价值的数据&#xff0c;…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

C#学习第29天:表达式树(Expression Trees)

目录 什么是表达式树&#xff1f; 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持&#xff1a; 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...