当前位置: 首页 > news >正文

Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )print(pruning_group.details())  # or print(pruning_group)# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):# handle groups in sequential orderidxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True)# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

相关文章:

Yolov8模型用torch_pruning剪枝

目录 &#x1f680;&#x1f680;&#x1f680;订阅专栏&#xff0c;更新及时查看不迷路&#x1f680;&#x1f680;&#x1f680; 原理 遍历所有分组 高级剪枝器 &#x1f680;&#x1f680;&#x1f680;订阅专栏&#xff0c;更新及时查看不迷路&#x1f680;&#x1f680…...

C++字符串操作【超详细】

零.前言 本文将重点围绕C的字符串来展开描述。 其中&#xff0c;对于C/C中字符串的一些区别也做出了回答&#xff0c;并对于C的&#xff08;string库&#xff09;进行了讲解&#xff0c;最后我们给出字符串的不同表达形式。 开发环境&#xff1a; VS2022 一.字符串常量跟字…...

Ps:画笔工具

画笔工具 Brush Tool是 Photoshop 中最常用的工具&#xff0c;可广泛地用于绘画与修饰工作之中。 快捷键&#xff1a;B ◆ ◆ ◆ 常用操作方法与技巧 1、熟练掌握画笔工具的操作对于使用其他工具也非常有益&#xff0c;因为 Photoshop 中许多与笔刷相关的工具有类似的选项和操…...

【鸿蒙 HarmonyOS 4.0】弹性布局(Flex)

一、介绍 弹性布局&#xff08;Flex&#xff09;提供更加有效的方式对容器中的子元素进行排列、对齐和分配剩余空间。容器默认存在主轴与交叉轴&#xff0c;子元素默认沿主轴排列&#xff0c;子元素在主轴方向的尺寸称为主轴尺寸&#xff0c;在交叉轴方向的尺寸称为交叉轴尺寸…...

Java 客户端向服务端上传文件(TCP通信)

一、实验内容 编写一个客户端向服务端上传文件的程序&#xff0c;要求使用TCP通信的的知识&#xff0c;完成将本地机器输入的路径下的文件上传到D盘中名称为upload的文件夹中。并把客户端的IP地址加上count标识作为上传后文件的文件名&#xff0c;即IP&#xff08;count&#…...

问题:前端获取long型数值精度丢失,后面几位都为0

文章目录 问题分析解决 问题 通过接口获取到的数据和 Postman 获取到的数据不一样&#xff0c;仔细看 data 的第17位之后 分析 该字段类型是long类型问题&#xff1a;前端接收到数据后&#xff0c;发现精度丢失&#xff0c;当返回的结果超过17位的时候&#xff0c;后面的全…...

Day26:安全开发-PHP应用模版引用Smarty渲染MVC模型数据联动RCE安全

目录 新闻列表 自写模版引用 Smarty模版引用 代码RCE安全测试 思维导图 PHP知识点&#xff1a; 功能&#xff1a;新闻列表&#xff0c;会员中心&#xff0c;资源下载&#xff0c;留言版&#xff0c;后台模块&#xff0c;模版引用&#xff0c;框架开发等 技术&#xff1a;输…...

LVS集群(Linux Virtual server)介绍----及LVS的NAT模式部署(一)

群集的含义 ●Cluster&#xff0c;集群、群集由多台主机构成&#xff0c;但对外只表现为一个整体&#xff0c;只提供访问入口(域名或IP地址)&#xff0c;相当于一台大型计算机 问题&#xff1a; 互联网应用中&#xff0c;随着站点对硬件性能、响应速度、服务稳定性、数据可靠…...

海外媒体宣发套餐如何利用3种方式洞察市场-华媒舍

在当今数字化时代&#xff0c;媒体宣发成为了企业推广产品和品牌的重要手段之一。其中&#xff0c;7FT媒体宣发套餐是一种常用而有效的宣传方式。本文将介绍这种媒体宣发套餐&#xff0c;以及如何利用它来洞察市场。 一、关键概念 在深入讨论7FT媒体宣发套餐之前&#xff0c;让…...

开发知识点-Apache Struts2框架

Apache Struts2 介绍S2-001S2CVE-2023-22530 介绍 Apache Struts2是一个基于MVC&#xff08;模型-视图-控制器&#xff09;设计模式的Web应用程序框架&#xff0c;它是Apache旗下的一个开源项目&#xff0c;并且是Struts1的下一代产品。Struts2是在Struts1和WebWork的技术基础…...

【Spring高级】第3讲 Bean的生命周期

目录 基本的生命周期后处理器总结 基本的生命周期 为了演示生命周期的过程&#xff0c;我们直接使用 SpringApplication.run()方法&#xff0c;他会直接诶返回一个容器对象。 import org.springframework.boot.SpringApplication; import org.springframework.context.Config…...

【C语言】linux内核tcp_write_xmit和tcp_write_queue_purge

tcp_write_xmit 一、讲解 这个函数 tcp_write_xmit 是Linux内核TCP协议栈中的一部分&#xff0c;其基本作用是发送数据包到网络。这个函数会根据不同情况推进发送队列的头部&#xff0c;确保只要远程窗口有空间&#xff0c;就可以发送数据。 下面是对该函数的一些主要逻辑的中…...

opencv实现视频人脸识别

一. 实现指定图像的人脸识别 注意&#xff1a; 以下实例参考《OpenCV轻松入门面向Python》李立宗著&#xff0c;使用python语言&#xff0c;编辑器为PyCharm&#xff0c;且都运行成功。 1.dface3.jpg图片文件和当前代码放在同一级目录下。 2.级联分类器文件和当前代码文件放在…...

【今日面经】24/3/9 广州Java某小厂电话面经

面经来源&#xff1a;https://www.nowcoder.com/?type818_1 目录 1、 和equals()有什么区别&#xff1f;2、String变量直接赋值和构造函数赋值比较相等吗&#xff1f;3、String一些方法&#xff1f;4、抽象类和接口有什么区别&#xff1f;5、Java容器有哪些&#xff1f;6、Lis…...

日期问题---算法精讲

前言 今天讲讲日期问题&#xff0c;所谓日期问题&#xff0c;在蓝桥杯中出现众多&#xff0c;但是解法比较固定。 一般有判断日期合法性&#xff0c;判断是否闰年&#xff0c;判断日期的特殊形式&#xff08;回文或abababab型等&#xff09; 目录 例题 题2 题三 总结 …...

倒计时35天

dp预备(来源&#xff1a;b站acm刘春英老师) 1. 2. 3. 4. 5. 6. 7....

JAVA后端开发面试基础知识(七)——多线程

1. 线程池原理 优点 降低资源消耗。通过重复利用已创建的线程降低线程创建和销毁造成的消耗。提高响应速度。当任务到达时,任务可以不需要等到线程创建就能立即执行。提高线程的可管理性。线程是稀缺资源,如果无限制的创建,不仅会消耗系统资源,还会降低系统的稳定性,使用线…...

Apache的安装与目录结构详细解说

1. Apache安装步骤 Apache是一款开源的Web服务器软件&#xff0c;常用于搭建网站和服务。以下是Apache的安装步骤&#xff1a; 在官方网站&#xff08;https://httpd.apache.org/&#xff09;下载最新版本的Apache软件包。解压下载的软件包到指定目录。运行安装程序&#xff…...

axios的详细使用

目录 axios&#xff1a;现代前端开发的HTTP客户端王者 一、axios简介 二、axios的基本用法 1. 安装axios 2. 发起GET请求 3. 发起POST请求 三、axios的高级特性 1. 拦截器 2. 取消请求 3. 自动转换JSON数据 四、axios在前端开发中的应用 五、总结 axios&#xff1a…...

空间复杂度的OJ练习——轮转数组

旋转数组OJ链接&#xff1a;https://leetcode-cn.com/problems/rotate-array/ 题目&#xff1a; 思路&#xff1a; 通过题目我们可以知道这是一个无序数组&#xff0c;只需要将数组中的数按给定条件重新排列&#xff0c;因此我们可以想到以下几种方法&#xff1a; 1.暴力求解法…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具&#xff0c;相比原生 Python 生态&#xff08;如 pip 虚拟环境&#xff09;有许多独特优势&#xff0c;尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处&#xff1a; 一、一站式环境管理&#xff1a…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

C++八股 —— 单例模式

文章目录 1. 基本概念2. 设计要点3. 实现方式4. 详解懒汉模式 1. 基本概念 线程安全&#xff08;Thread Safety&#xff09; 线程安全是指在多线程环境下&#xff0c;某个函数、类或代码片段能够被多个线程同时调用时&#xff0c;仍能保证数据的一致性和逻辑的正确性&#xf…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...

基于SpringBoot在线拍卖系统的设计和实现

摘 要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统&#xff0c;主要的模块包括管理员&#xff1b;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

Windows安装Miniconda

一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)

目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 &#xff08;1&#xff09;输入单引号 &#xff08;2&#xff09;万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一&#xff0c;能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时&#xff0c;需要添加Git仓库地址和凭证&#xff0c;设置构建触发器&#xff08;如GitHub…...

SQL Server 触发器调用存储过程实现发送 HTTP 请求

文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...