网络剪枝——network-slimming 项目复现
目录
文章目录
- 目录
- 网络剪枝——network-slimming 项目复现
- clone 存储库
- Baseline
- vgg
- 训练
- 结果
- resnet
- 训练
- 结果
- densenet
- 训练
- 结果
- Sparsity
- vgg
- 训练
- 结果
- resnet
- 训练
- 结果
- densenet
- 训练
- 结果
- Prune
- vgg
- 命令
- 结果
- resnet
- 命令
- 结果
- densenet
- 命令
- 结果
- Fine-tune
- vgg
- 训练
- 结果
- resnet
- 训练
- 结果
- densenet
- 训练
- 结果
- 模型大小计算脚本 param_counter.py
- 结果汇总
- CIFAR10
网络剪枝——network-slimming 项目复现
- 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
- 【作者复现项目】:
- 通过百度网盘分享的文件:network-slimming-regin.zip
链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
提取码:t1z2- 仅以 CIFAR-10 为例,CIFAR-100 同理.
- 提供中文README_zh-CN.md.
- 包含 CIFAR-10/100 数据集data.cifar10、data.cifar100.
- 解决了 main.py 运行报错问题.
- 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.
clone 存储库
注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

-
链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
提取码:ud89
Baseline
vgg
训练
- 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

- 这个报错通常出现在使用 Python 的
multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
- 为了解决这个问题,应 确保多进程代码(即main.py)位于
if __name__ == '__main__':保护块内.
# 导入部分
...def main():...if __name__ == '__main__':main()
- 再次运行命令,又报错:

- 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用
.item()方法.
- 为了解决这个问题,应该 使用
.item()方法来替换所有.data[0]的用法:
# 在 train 函数中
if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# 在 test 函数中
for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch losspred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)
- 再次运行命令就正常运行了:

结果
- Terminal:

- 在 ./logs 生成文件:checkpoint.pth.tar、model_best.pth.tar

resnet
训练
- 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet
训练
- 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果

Sparsity
vgg
训练
- 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet
训练
- 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet
训练
- 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果

Prune
vgg
命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

- 与main.py同理,为了解决这个问题,应 确保多进程代码位于
if __name__ == '__main__':保护块内:
# 导入部分
...def main():...if __name__ == '__main__':main()
- 之后就可以正常运行了.

结果
- Terminal:

- 在./prunes生成文件:prune.txt、pruned.pth.tar

- 在prune.txt中我们可以看到 Number of parameters、Test accuracy:

resnet
命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet
命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果

Fine-tune
vgg
训练
- 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet
训练
- 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet
训练
- 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果

模型大小计算脚本 param_counter.py
- 【路径】:./script/param_counter.py
import torchdef load_model(model_path):model = torch.load(model_path, map_location=torch.device('cpu'))return modeldef count_parameters(model_state_dict):total_params = sum(p.numel() for p in model_state_dict.values())return total_paramsdef get_model_parameters(model_path):# 加载模型状态字典model = load_model(model_path)# 模型状态字典存储在 'state_dict' 键下model_state_dict = model['state_dict'] if 'state_dict' in model else model# 计算参数总数total_params = count_parameters(model_state_dict)return total_params
- 在main.py中:
from script.param_counter import get_model_parametersdef main():...# 计算 Parametersmodel_path = 'logs/model_best.pth.tar'total_params = get_model_parameters(model_path)print(f'Total parameters in the model: {total_params}')
结果汇总
注:与原项目结果略有差别.
CIFAR10
| CIFAR10-Vgg | Baseline | Sparsity(1e-4) | Prune(70%) | Fine-tune-160(70%) |
|---|---|---|---|---|
| Top1 Accuracy(%) | 93.72 | 93.60 | 33.98 | 93.75 |
| Parameters | 20.05M | 20.05M | 2.22M | 2.23M |
| CIFAR10-Resnet-164 | Baseline | Sparsity(1e-5) | Prune(40%) | Fine-tune-160(40%) |
|---|---|---|---|---|
| Top1 Accuracy(%) | 94.99 | 95.00 | 94.59 | 95.27 |
| Parameters | 1.74M | 1.74M | 1.46M | 1.49M |
| CIFAR10-Densenet-40 | Baseline | Sparsity(1e-5) | Prune(40%) | Fine-tune-160(40%) |
|---|---|---|---|---|
| Top1 Accuracy(%) | 94.15 | 94.37 | 94.14 | 94.48 |
| Parameters | 1.09M | 1.09M | 0.70M | 0.72M |
相关文章:
网络剪枝——network-slimming 项目复现
目录 文章目录 目录网络剪枝——network-slimming 项目复现clone 存储库Baselinevgg训练结果 resnet训练结果 densenet训练结果 Sparsityvgg训练结果 resnet训练结果 densenet训练结果 Prunevgg命令结果 resnet命令结果 densenet命令结果 Fine-tunevgg训练结果 resnet训练结果 …...
Spring 懒加载的实际应用
引言 在 Spring 框架中,懒加载机制允许你在应用程序运行时延迟加载 Bean。这意味着 Bean 只会在第一次被请求时才实例化,而不是在应用程序启动时就立即创建。这种机制可以提高应用程序的启动速度,并节省内存资源。 Spring 的懒加载机制 懒…...
PyQT 串口改动每次点开时更新串口信息
class MainWindow(QWidget, Ui_Form):def __init__(self):super().__init__(parentNone)self.setupUi(self)self.comboBox.installEventFilter(self) # 加载事件过滤器self.comboBox.addItems(get_ports())def eventFilter(self, obj, event): # 定义事件过滤器if isinstance(o…...
三级_网络技术_19_路由器的配置及使用
1.在Cisco路由器上配置DHCP服务,使得客户端可以分配到的地址范围是222.28.71.2-222.28.71.200地址租用时间是2小时30分钟,不记录地址冲突日志默认路由是222.28.71.1,分配的dns服务器地址是222.28126.27和222.28.126.26。以下配置完全正确的是…...
【STM32 Blue Pill编程】-STM32CubeIDE开发环境搭建与点亮LED
开发环境搭建与点亮LED 文章目录 开发环境搭建与点亮LED1、STM32F103C8T6及STM32 Blue Pill 介绍2、下载并安装STM32CubeIDE3、编程并点亮LED3.1 在Stm32CubeIDE中编写第一个STM32程序3.1.1 创建项目3.1.2 设备配置3.1.2.1 系统时钟配置3.1.2.2 系统调试配置3.1.2.3 GPIO配置3.…...
【数据结构】六、图:4.图的遍历(深度优先算法DFS、广度优先算法BFS)
三、基本操作 文章目录 三、基本操作1.图的遍历1.1 深度优先遍历DFS1.1.1 DFS算法1.1.2 DFS算法的性能分析1.1.3 深度优先的生成树和生成森林 1.2 广度优先遍历BFS1.2.1 BFS算法1.2.2 BFS算法性能分析1.2.3 广度优先的生成树和生成森林 1.3 图的遍历与图的连通性 1.图的遍历 图…...
29、号外!号外!ERA5再分析数据下载方式更新啦
文章目录 1. 前言2. 账号注册与协议签署2.1 账号注册2.2 签署CDS-Beta使用条款2.3 更新.cdsapi文件 3. 常见问题与解决方法(持续更新中)3.1 问题1:更新完.cdsapi文件之后,原有下载代码不可以使用3.2 问题2: RuntimeError: 403 Cli…...
智能识别,2024年SD卡数据恢复软件的智能进化
除了手机之外现在有不少的设备还是依靠SD卡来存储数据,比如相机、摄像头、无人机等。有的时候会因为一些意外的情况导致数据丢失,那是真的丢失了吗?大部分情况还是可以依靠sd卡数据恢复工具来找回这些“消失”的数据哦。 1.福昕数据恢复 链…...
浙大数据结构慕课课后题(04-树5 Root of AVL Tree)
题目要求: AVL 树是一种自平衡的二叉搜索树。在 AVL 树中,任何节点的两个子子树的高度最多相差一;如果在任何时候它们相差不止一,则进行重新平衡以恢复此属性。图 1-4 说明了旋转规则。 图1 图2 图3 图4 现在给定一系列插入,您应该…...
Golang | Leetcode Golang题解之第331题验证二叉树的前序序列化
题目: 题解: func isValidSerialization(preorder string) bool {n : len(preorder)slots : 1for i : 0; i < n; {if slots 0 {return false}if preorder[i] , {i} else if preorder[i] # {slots--i} else {// 读一个数字for i < n &&…...
zdppy+vue3+onlyoffice文档管理系统项目实战 20240812上课笔记
遗留问题 1、增加新建和导入按钮,有按钮了,但是还没有完善,图标还不对,需要解决 2、登录功能 3、用户管理 4、角色管理 5、权限管理 6、分享功能 解决新建和导入的图标问题 解决代码: <a-button type"prim…...
怎么将mov视频转换成mp4?将mov视频转换成mp4的方法
怎么将mov视频转换成mp4?由于mov格式通常与苹果设备兼容性较好,而mp4则更广泛地支持于各种播放器和设备中,因此将mov转换为mp4可以确保视频在更多场景下能够流畅播放。通过这种转换,你可以确保视频在各种平台和设备上的兼容性&…...
大数据技术——实战项目:广告数仓(第五部分)
目录 第9章 广告数仓DIM层 9.1 广告信息维度表 9.2 平台信息维度表 9.3 数据装载脚本 第10章 广告数仓DWD层 10.1 广告事件事实表 10.1.1 建表语句 10.1.2 数据装载 10.1.2.1 初步解析日志 10.1.2.2 解析IP和UA 10.1.2.3 标注无效流量 10.2 数据装载脚本 第9章 广…...
计算机毕业设计 家电销售展示平台 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试
🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…...
C# 根据MySQL数据库中数据,批量删除OSS上的垃圾文件
protected void btndeleteTask_Click(object sender, EventArgs e){//获取标识为已删除数据,一次加载500条int countlocks _goodsItemsApplication.CountAllNeedExecuteTask();int totalPagelocks (countlocks 500 - 1) / 500;//分批次处理for (int curentpage …...
Vue3+Element-plus+setup使用vuemap/vue-amap实现高德地图API相关操作
首先要下载依赖并且引入 npm安装 // 安装核心库 npm install vuemap/vue-amap --save// 安装loca库 npm install vuemap/vue-amap-loca --save// 安装扩展库 npm install vuemap/vue-amap-extra --save cdn <script src"https://cdn.jsdelivr.net/npm/vuemap/vue-a…...
Windows配置开机直达桌面并跳过锁屏登录界面在 Windows 10 中添加在启动时自动运行的应用
目录 Win10开机直达桌面并跳过锁屏登录界面修改组策略修改注册表跳过登录界面 在 Windows 10 中添加在启动时自动运行的应用设置系统级别服务一、Windows下使用sc将应用程序设置为系统服务1. 什么是sc命令?2. sc命令的基本语法3. 创建Windows服务的步骤与示例创建服…...
pythonUI自动化007::pytest的组成以及运行
pytest组成: 测试模块:以“test”开头或结尾的py文件 测试用例:在测试模块里或测试类里,名称符合test_xxx函数或者示例函数。 测试类:测试模块里面命名符合Test_xxx的类 函数级: import pytestclass Test…...
开放式耳机哪个品牌好用又实惠?五大口碑精品分享
如今开放式耳机市场日益火爆,不少知名品牌都在对产品进行升级迭代,那么如何在一众品牌型号中选择到自己最满意的那一款呢?开放式耳机哪个品牌好用又实惠?这就需要更专业的选购攻略,因此笔者专门整理出了专业机构的开放…...
代码随想录算法训练营day39||动态规划07:多重背包+打家劫舍
多重背包理论 描述: 有N种物品和一个容量为V 的背包。 第i种物品最多有Mi件可用,每件耗费的空间是Ci ,价值是Wi 。 求解将哪些物品装入背包可使这些物品的耗费的空间 总和不超过背包容量,且价值总和最大。 本质: …...
Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...
基于TurtleBot3在Gazebo地图实现机器人远程控制
1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...
算法:模拟
1.替换所有的问号 1576. 替换所有的问号 - 力扣(LeetCode) 遍历字符串:通过外层循环逐一检查每个字符。遇到 ? 时处理: 内层循环遍历小写字母(a 到 z)。对每个字母检查是否满足: 与…...
【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看
文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...
人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent
安全大模型训练计划:基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标:为安全大模型创建高质量、去偏、符合伦理的训练数据集,涵盖安全相关任务(如有害内容检测、隐私保护、道德推理等)。 1.1 数据收集 描…...
永磁同步电机无速度算法--基于卡尔曼滤波器的滑模观测器
一、原理介绍 传统滑模观测器采用如下结构: 传统SMO中LPF会带来相位延迟和幅值衰减,并且需要额外的相位补偿。 采用扩展卡尔曼滤波器代替常用低通滤波器(LPF),可以去除高次谐波,并且不用相位补偿就可以获得一个误差较小的转子位…...
