深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)
前言
《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》里面我只是提到了对conv1
层进行剪枝,只是为了验证这个剪枝的整个过程,但是后面也有提到:仅裁剪 conv1
层的影响极大,原因如下:
- 底层特征的重要性 :
conv1
输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。 - 结构连锁反应 :
conv1
的输出通道减少会触发bn1
、layer1.0.conv1
、downsample
等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
虽然,在例子中,我们只是简单的进行了验证,发现效果也不是很差,但是如果具体到自己的数据,或者更加复杂的特征或者模型,可能就会影响到了整体的性能,因此,我们在原有的基础上做了如下的改动:
- 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
- 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
- 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。
这些修改可显著提升剪枝后模型的稳定性和准确率。建议运行时观察微调阶段的Loss是否持续下降,若下降缓慢可进一步降低学习率(如0.00001)。
所有代码都在这:https://gitee.com/NOON47/model_prune
详细改动
- 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
layer_to_prune = 'layer2.0.conv1' # 显式定义要剪枝的层名pruned_model = prune_conv_layer(model, layer_to_prune, amount=0.2)
- 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device) # 模拟 CIFAR10 输入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0] # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3)) # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False # 生成剪枝掩码
- 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。
print("开始微调剪枝后的模型")# 新增:根据剪枝层动态解冻相关层(假设剪枝层为layer2.0.conv1)pruned_layer_prefix = layer_to_prune.rpartition('.')[0] # 例如 'layer2.0'for name, param in pruned_model.named_parameters():if (pruned_layer_prefix in name) or ('fc' in name) or ('bn' in name): # 解冻剪枝层、BN层和fc层param.requires_grad = Trueelse:param.requires_grad = Falseoptimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.0001) # 微调学习率降低pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=10) # 增加微调轮次
完整的裁剪函数:
def prune_conv_layer(model, layer_name, amount=0.2):device = next(model.parameters()).devicelayer = dict(model.named_modules())[layer_name]weight = layer.weight.data# 基于激活值的通道重要性评估model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device) # 模拟 CIFAR10 输入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0] # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3)) # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False # 生成剪枝掩码# 创建并替换新卷积层new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None).to(device)new_conv.weight.data = layer.weight.data[mask] # 应用掩码剪枝权重if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替换原始卷积层parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)# 仅处理首层 conv1 的特殊逻辑if layer_name == 'conv1':# 更新首层 BN 层(bn1)bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 处理 layer1.0 的 downsample 层(若不存在则创建)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 创建 1x1 卷积 + BN 用于通道匹配downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,out_channels=block.conv2.out_channels, # 与主路径输出通道一致(ResNet18 为 64)kernel_size=1,stride=1,bias=False).to(device)# 初始化权重(使用原卷积层的统计量)with torch.no_grad():downsample_conv.weight.data = layer.weight.data.mean(dim=(2,3), keepdim=True) # 原卷积核均值初始化downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():downsample_bn.weight.data.fill_(1.0)downsample_bn.bias.data.zero_()downsample_bn.running_mean.data.zero_()downsample_bn.running_var.data.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("✅ 为 layer1.0 添加新的 downsample 层")else:# 调整已有 downsample 层的输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channelsdownsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)# 更新对应的 BN 层downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 同步 layer1.0.conv1 的输入通道target_conv = model.layer1[0].conv1if target_conv.in_channels != new_conv.out_channels:print(f"同步 layer1.0.conv1 输入通道: {target_conv.in_channels} → {new_conv.out_channels}")target_conv.in_channels = new_conv.out_channelstarget_conv.weight = nn.Parameter(target_conv.weight.data[:, mask, :, :].clone()).to(device)else:# 中间层剪枝逻辑(如 layer2.0.conv1)block_prefix = layer_name.rsplit('.', 1)[0] # 提取 block 前缀(如 'layer2.0')block = model.get_submodule(block_prefix) # 获取对应的 block(如 layer2.0)# 更新当前 block 内的 BN 层(conv1 对应 bn1,conv2 对应 bn2)target_bn_name = f"{block_prefix}.bn1" if 'conv1' in layer_name else f"{block_prefix}.bn2"try:target_bn = model.get_submodule(target_bn_name)new_bn = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn.weight.data = target_bn.weight.data[mask].clone()new_bn.bias.data = target_bn.bias.data[mask].clone()new_bn.running_mean.data = target_bn.running_mean.data[mask].clone()new_bn.running_var.data = target_bn.running_var.data[mask].clone()setattr(block, target_bn_name.split('.')[-1], new_bn) # 替换原 BN 层print(f"✅ 更新剪枝层 {layer_name} 对应的 BN 层 {target_bn_name}")except AttributeError:print(f"⚠️ 未找到剪枝层 {layer_name} 对应的 BN 层,跳过 BN 更新")# 新增:同步后续卷积层的输入通道(如 conv1 后调整 conv2)if 'conv1' in layer_name:next_conv = block.conv2if next_conv.in_channels != new_conv.out_channels:print(f"同步 {block_prefix}.conv2 输入通道: {next_conv.in_channels} → {new_conv.out_channels}")next_conv.in_channels = new_conv.out_channelsnext_conv.weight = nn.Parameter(next_conv.weight.data[:, mask, :, :].clone()).to(device) # 按剪枝掩码筛选输入通道权重# 可选:如果存在 downsample 层,调整其输入通道(根据实际需求启用)# if hasattr(block, 'downsample') and block.downsample is not None:# downsample_conv = block.downsample[0]# downsample_conv.in_channels = new_conv.out_channels# downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)# print(f"✅ 调整剪枝层 {layer_name} 关联的 downsample 层输入通道")# 验证前向传播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raisereturn model
改动后结果
经过改动后, 增加微调轮次,得到的结果如下:
剪枝前模型大小信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
原始模型准确率: 81.42%剪枝后模型大小信息:
==========================================================================================
Total params: 11,138,392
Trainable params: 11,138,392
Non-trainable params: 0
Total mult-adds (M): 36.33
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.80
Params size (MB): 44.55
Estimated Total Size (MB): 45.37
==========================================================================================
剪枝后模型准确率: 83.28%
个人认为,这个才是比较符合实际应用的。
相关文章:
深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)
前言 《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》里面我只是提到了对conv1层进行剪枝,只是为了验证这个剪枝的整个过程,但是后面也有提到:仅裁剪 conv1层的影响极大,原因如…...

综采工作面电控4X型铜头连接器 conm/4x100s
综采工作面作为现代化煤矿生产的核心区域,其设备运行的稳定性和安全性直接关系到整个矿井的生产效率。在综采工作面的电气控制系统中,电控连接器扮演着至关重要的角色,而4X型铜头连接器CONM/4X100S作为其中的关键部件,其性能优劣直…...

用ApiFox MCP一键生成接口文档,做接口测试
日常开发过程中,尤其是针对长期维护的老旧项目,许多开发者都会遇到一系列相同的困扰:由于项目早期缺乏严格的开发规范和接口管理策略,导致接口文档缺失,甚至连基本的接口说明都难以找到。此外,由于缺乏规范…...

在compose中的Canvas用kotlin显示多数据波形闪烁的问题
在compose中的Canvas显示多数据波形闪烁的问题:当在Canvas多组记录波形数组时,从第一组开始记录多次显示,如图,当再次回到第一次记录位置再显示时,波形出现闪烁。 原码如下: data class DcWaveForm(var b…...

【学习笔记】MIME
文章目录 1. 引言2. MIME 构成Content-Type(内容类型)Content-Transfer-Encoding(传输编码)Multipart(多部分) 3. 常见 MIME 类型 1. 引言 早期的电子邮件只能发送 ASCII 文本,无法直接传输二进…...
【深尚想】OPA855QDSGRQ1运算放大器IC德州仪器TI汽车级高速8GHz增益带宽的全面解析
1. 元器件定义与核心特性 OPA855QDSGRQ1 是德州仪器(TI)推出的一款 汽车级高速运算放大器,专为宽带跨阻放大(TIA)和电压放大应用优化。核心特性包括: 超高速性能:增益带宽积(GBWP&a…...

单北斗定位芯片AT9880B
AT9880B 是面向北斗卫星导航系统的单模接收机单芯片(SOC),内部集成射频前端、数字基带处理单元、北斗多频信号处理引擎及电源管理模块,支持北斗二号与三号系统的 B1I、B1C、B2I、B3I、B2a、B2b 频点信号接收。 主要特征 支持北斗二…...

旅游微信小程序制作指南
想创建旅游微信小程序吗?知道旅游业企业怎么打造自己的小程序吗?这里有零基础小白也能学会的教程,教你快速制作旅游类微信小程序! 旅游行业能不能开发微信小程序呢?答案是肯定的。微信小程序对旅游企业来说可是个宝&am…...

Ubuntu ifconfig 查不到ens33网卡
BUG:ifconfig查看网络配置信息: 终端输入以下命令: sudo service network-manager stop sudo rm /var/lib/NetworkManager/NetworkManager.state sudo service network-manager start - service network - manager stop :停止…...
zookeeper 学习
Zookeeper 简介 github:https://github.com/apache/zookeeper 官网:https://zookeeper.apache.org/ 什么是 Zookeeper Zookeeper 是一个开源的分布式协调服务,用于管理分布式应用程序的配置、命名服务、分布式同步和组服务。其核心是通过…...

【python深度学习】Day 45 Tensorboard使用介绍
知识点: tensorboard的发展历史和原理tensorboard的常见操作tensorboard在cifar上的实战:MLP和CNN模型 效果展示如下,很适合拿去组会汇报撑页数: 作业:对resnet18在cifar10上采用微调策略下,用tensorboard监…...

【图像处理入门】5. 形态学处理:腐蚀、膨胀与图像的形状雕琢
摘要 形态学处理是基于图像形状特征的处理技术,在图像分析中扮演着关键角色。本文将深入讲解腐蚀、膨胀、开闭运算等形态学操作的原理,结合OpenCV代码展示其在去除噪声、提取边缘、分割图像等场景的应用,带你掌握通过结构元素雕琢图像形状的核心技巧。 一、形态学处理:基…...

并行智算MaaS云平台:打造你的专属AI助手,开启智能生活新纪元
目录 引言:AI助手,未来生活的必备伙伴 并行智算云:大模型API的卓越平台 实战指南:调用并行智算云API打造个人AI助手 3.1 准备工作 3.2 API调用示例 3.3 本地智能AI系统搭建 3.4 高级功能实现 并行智算云的优势 4.1 性能卓越…...
在 SpringBoot+Tomcat 环境中 线程安全问题的根本原因以及哪些变量会存在线程安全的问题。
文章目录 前言Tomcat SpringBoot单例加载结果分析多例加载:结果分析: 哪些变量存在线程安全的问题?线程不安全线程安全 总结 前言 本文带你去深入理解为什么在web环境中(Tomcat SpringBoot)会存在多线程的问题以及哪些变量会存在线程安全的…...

Day45 Python打卡训练营
知识点回顾: 1. tensorboard的发展历史和原理 2. tensorboard的常见操作 3. tensorboard在cifar上的实战:MLP和CNN模型 一、tensorboard的基本操作 1.1 发展历史 TensorBoard 是 TensorFlow 生态中的官方可视化工具(也可无缝集成 PyTorch&…...

2025年目前最新版本Android Studio自定义xml预览的屏幕分辨率
一、前言 在实际开发项目当中,我们的设备的分辨率可能会比较特殊,AS并没有自带这种屏幕分辨率的设备,但是我们又想一边编写XML界面,一边实时看到较为真实的预览效果,该怎么办呢?在早期的AS版本中ÿ…...

黑马Java面试笔记之 并发编程篇(线程池+使用场景)
一. 线程池的核心参数(线程池的执行原理) 线程池核心参数主要参考ThreadPoolExecutor这个类的7个参数的构造函数 corePoolSize 核心线程数目 maximumPoolSize 最大线程数目 (核心线程救急线程的最大数目) keepAliveTime 生存时间 - 救急线程的生存时…...
float和float32有什么区别
这是一个非常重要但容易混淆的问题。我们来一步一步讲清楚: ✅ 简答:float 和 float32 的区别 项目float(通用名称)float32(精确定义)含义通常指“浮点数”,具体精度由语言/平台决定明确指 32 …...

【AI学习】KV-cache和page attention
目录 带着问题学AI KV-cache KV-cache是什么? 之前每个token生成的K V矩阵给缓存起来有什么用? 为啥缓存K、V,没有缓存Q? KV-cache为啥在训练阶段不需要,只在推理阶段需要? KV cache的过程图解 阶段一:KV cac…...

七彩喜智慧养老平台:科技赋能下的市场蓝海,满足多样化养老服务需求
在人口老龄化加速与科技快速发展的双重驱动下,七彩喜智慧养老平台正成为破解养老服务供需矛盾、激活银发经济的核心引擎。 这一领域依托物联网、人工智能、大数据等技术,构建起覆盖居家、社区、机构的多层次服务体系。 既满足老年人多样化需求…...

《Pytorch深度学习实践》ch8-多分类
------B站《刘二大人》 1.Softmax Layer 在多分类问题中,输出的是每类的概率: 计算公式:保证了每类概率大于 0 ,又由保证了概率之和为 1; 举例如下: 2.Cross Entropy 计算损失: y np.array…...

国产录播一体机:科技赋能智慧教育信息化
在数字化时代,教育正经历着前所未有的变革。国产工控机作为信息化教育的核心载体,正在重新定义学习方式,赋能教师与学生,打造高效、互动、智能的教学环境,让我们一起感受科技与教育的深度融合!高能计算机推…...

关于逻辑回归的见解
逻辑回归通过将线性回归的输出映射到 [ 0 , 1 ] \left[0,1\right] [0,1]区间,来表示某个类别的概率。也就是其本质是先通过线性回归的预测值 y \boldsymbol{y} y输入到映射函数,既将线性回归的输出通过映射函数映射到 [ 0 , 1 ] \left[0,1\right] [0,1].常用的映射函数是sigm…...

Amazon Augmented AI:人类智慧与AI协作,破解机器学习审核难题
在人工智能日益渗透业务核心的今天,你是否遭遇过这样的困境:自动化AI处理海量数据时,面对模糊、复杂或高风险的场景频频“卡壳”?人工审核团队则被低效、重复的任务压得喘不过气?Amazon Augmented AI (A2I) 的诞生&…...
CMake入门:3、变量操作 set 和 list
在 CMake 中,set 和 list 是两个核心命令,用于变量管理和列表操作。理解它们的用法对于编写高效的 CMakeLists.txt 文件至关重要。下面详细介绍这两个命令的功能和常见用法: 一、set 命令:变量定义与赋值 set 命令用于创建、修改…...
聊聊FlaUI:让Windows UI自动化测试优雅起飞!
你还在为手动点点点测试Windows应用而感到膝盖疼?更愁于自动化测试工具价格贵得让钱包瑟瑟发抖?今天,我要给你安利一款“野路子有余,正经事儿也能干”的.NET UI自动化神器——FlaUI!别眨眼,看完你能少加三个…...

VIN码车辆识别码解析接口如何用C#进行调用?
一、什么是VIN码车辆识别码解析接口 输入17位vin码,获取到车辆的品牌、型号、出厂日期、发动机类型、驱动类型、车型、年份等信息。无论是汽车电商平台、二手车商、维修厂,还是保险公司、金融机构,都能通过接入该API实现信息自动化、决策智能…...
[论文阅读] 人工智能 | 用大语言模型解决软件元数据“身份谜题”:科研软件的“认脸”新方案
用大语言模型解决软件元数据“身份谜题”:科研软件的“认脸”新方案 论文信息 作者: Eva Martn del Pico, Josep Llus Gelp, Salvador Capella-Gutirrez 标题: Identity resolution of software metadata using Large Language Models 年份: 2025 来源: arX…...
gorm多租户插件的使用
一、关于gorm多租户插件的使用 1、安装依赖 go get -u github.com/kuangshp/gorm-tenant2、创建一个mysql数据表 DROP TABLE IF EXISTS user; CREATE TABLE user (id int(11) NOT NULL AUTO_INCREMENT primary key COMMENT 主键id,name varchar(50) not null comment 名称,ten…...

Playwright 测试框架 - Java
🚀【Playwright + Java 实战教程】从零到一掌握自动化测试利器! 🔧 本文专为 Java 开发者量身打造,通过详尽示例带你快速掌握 Playwright 自动化测试。涵盖基础操作、表单交互、测试框架集成、高阶功能及常见实战技巧,适用于企业 UI 测试与 CI/CD 场景。 🛠️ 一、环境…...