模型优化之剪枝
文章目录
- 什么是神经网络剪枝
- 剪枝的好处
- 不同粒度的剪枝
- 剪枝的分类
- 非结构化剪枝
- 结构化剪枝
- 哪些层的参数更容易被剪掉
- 剪枝效果
什么是神经网络剪枝
神经网络剪枝
- 在训练期间删除连接
- 密集张量将变得稀疏(用零填充)
- 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接
剪枝的好处
- 减少过拟合
- 稀疏性优势
- 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
- 模型更小,可以减少内存带宽消耗量。
- 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。
不同粒度的剪枝
什么时候做剪枝?
one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练
剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。
再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。
iterative pruning: 迭代式训练,特点如下:
初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。
剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。
再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。
迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。
automated gradual pruning: 自动化渐进剪枝,特点如下:
剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。
渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。
无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。
自动化:整个过程高度自动化,可以减少人为干预的需求
剪枝的分类
结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。
非结构化剪枝
不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。
基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:
# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0,end_step=end_step)
}model_for_pruning = prune_low_magnitude(model, **pruning_params)# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model_for_pruning.summary()logdir = "./logs/mnist_pruning"callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)
结构化剪枝
结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。
注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。
model = keras.Sequential([prune_low_magnitude(keras.layers.Conv2D(32, 5, padding='same', activation='relu',input_shape=(28, 28, 1),name="pruning_sparsity_0_5"),**pruning_params_sparsity_0_5),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),prune_low_magnitude(keras.layers.Conv2D(64, 5, padding='same',name="structural_pruning"),**pruning_params_2_by_4),keras.layers.BatchNormalization(),keras.layers.ReLU(),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),keras.layers.Flatten(),prune_low_magnitude(keras.layers.Dense(1024, activation='relu',name="structural_pruning_dense"),**pruning_params_2_by_4),keras.layers.Dropout(0.4),keras.layers.Dense(10)
])
哪些层的参数更容易被剪掉
因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。
剪枝效果
- 一般50%-70%左右的稀疏性,准确率降低幅度并不大
- 剪枝是独立于量化技巧,通常与量化配合效果不错
- 可以通过微调尝试不同的参数组合
相关文章:

模型优化之剪枝
文章目录 什么是神经网络剪枝剪枝的好处不同粒度的剪枝剪枝的分类非结构化剪枝结构化剪枝 哪些层的参数更容易被剪掉剪枝效果 什么是神经网络剪枝 神经网络剪枝 在训练期间删除连接密集张量将变得稀疏(用零填充)可以通过结构化块( n m nm nm&…...

JVM的组成
JVM 运行在操作系统之上 java二进制字节码文件的运行环境 JVM的组成部分 java代码在编写完成后编译成字节码文件通过类加载器 来到运行数据区,主要作用是加载字节码到内存 包含 方法区/元空间 堆 程序计数器,虚拟机栈,本地方法栈等等 随后来到执行引擎,主要作用是翻译字…...
快速上手 iOS Protocol Buffer
快速上手 iOS Protocol Buffer | 来自缤纷多彩的灰 本文主要介绍在 iOS 开发中如何快速上手使用 Protobuf。更多关于 Protobuf 的介绍和相关的功能 api,读者可自行查阅官网。 Protocol Buffer(简称 Protobuf)是一种由Google开发的语言中立、…...
每天一个数据分析题(四百八十)- 线性回归建模
关于线性回归建模,线性回归模型假设说法不正确的是? A. 因变量和自变量要有因果关系 B. 残差均值为0 C. 残差服从正态分布 D. 自变量不存在共线性 数据分析认证考试介绍:点击进入 题目来源于CDA模拟题库 点击此处获取答案 数据分析专…...

电动汽车和混动汽车DC-DC转换器的创新设计与测试方法
汽车 DC-DC 转换器市场规模将达到187亿美元,年复合增长率为10%。 DC-DC 转换器是汽车的重要组成部分,它可以通过电压转换为各种车载系统供电,例如日益复杂的车载信息娱乐系统、使用驾驶辅助系统(ADAS)实现的增强安全功…...

OriginPro快速上手指南:数据可视化与分析的利器
目录 OriginLab - Origin and OriginPro - Data Analysis and Graphing Softwarehttps://www.originlab.com/编辑 一、安装与界面概览 安装 界面概览 二、基础操作 数据输入 创建图表 三、高级功能 数据分析 自动化与脚本 Origin 提供了几个小工具 四、技巧与提示…...

缓存学习
缓存基本概念 概念 对于缓存,最普遍的理解是能让打开某些页面速度更快的工具。从技术角度来看,其本质上是因为缓存是基于内存建立的,而内存的读写速度相比之于硬盘快了xx倍,因此用内存来代替硬盘作为读写的介质当然能大大提高访…...

亚世光电:消费电子年度表演
机圈风云再起,消费电子乘风而起? 今天我们来聊——亚世光电 最近,华为mate60突然降价,被大家怀疑是为新品上市做准备,算算时间,下半年的消费电子大战也即将拉开帷幕,而亚世光电所在的光电显示领…...

AI 工程应用 建筑表面检测及修复
文章目录 1 项目概述(必写):2 技术方案与实施步骤2.1 模型选择(必写):2.2 数据的构建:2.3 功能整合(进阶): 3 实施步骤:3.1 环境搭建(…...

Qt-Qt中的小事项(7)
目录 命名风格 快捷键 查询文档 坐标系 代码理解 move 命名风格 这个也是老生常谈的问题了,入乡随俗就好啦 快捷键 这里是一些常用的快捷键,用多了自然就熟悉了 • 注释:ctrl/ • 运行:ctrlR • 编译:ctrlB …...
Android MediaRecorder 视频录制及报错解决
目录 一、start failed: -19 二、使用MediaRecorder录制视频 2.1 申请权限 2.2 布局文件 2.3 MediaRecordActivity 2.4 运行结果 三、拓展 3.1 录制视频模糊(解决) 3.2 阿里云OSS上传文件 3.2.1 权限(刚需) 3.2.2 安装SDK 3.2.3 使用 相关链接 一、start failed…...

HarmonyOS应用程序访问控制探究
关于作者 白晓明 宁夏图尔科技有限公司董事长兼CEO、坚果派联合创始人 华为HDE、润和软件HiHope社区专家、鸿蒙KOL、仓颉KOL 华为开发者学堂/51CTO学堂/CSDN学堂认证讲师 开放原子开源基金会2023开源贡献之星 一、引言 随着信息技术的飞速发展,移动应用程序已经成为…...

董卫民赴考拉悠然等企业调研,强调加快发展人工智能产业
8月14日,按照省政府重点产业链协同推进机制有关工作安排,省委常委、常务副省长董卫民在成都市调研人工智能产业发展情况,并召开座谈会。他强调,要坚决落实党的二十届三中全会精神和省委省政府决策部署,充分把握人工智能…...
MFC将类A中的事件在类B中处理采用回调函数实现
需求: 在类A的界面上有一个tab控件。tab控件上面有那个页面。在MFC编程中一个tab的一个页面就应该是一个新的类。在tab的一个页面上有一个list控件。现在需要将list控件的点击事件,双击事件等在类A里面处理。 解决: 在类B里面给控件list添加…...
公众号 微信登录
export function getWxCode(that, localhostUrl) { // localhostUrl 当前页面的路径 传这个也可以this.$route.fullPath// console.log(that.$store.state.wxSessionData)// console.log(that.$store.state.wxSessionData.openId)//openId为undefine执行获取openid判断是否没有…...

sanic + webSocket:股票实时行情推送服务实现
💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storm…...
Unity动态给按钮各个状态下的图片赋值
Unity动态给按钮各个状态下的图片赋值 using UnityEngine; using UnityEngine.UI; public class ButtonOnClickTest : MonoBehaviour {public Button btn;public Sprite _highlighterSprite;public Sprite _pressedSprite;public Sprite _selectesdSprite;public Sprite _disa…...

xiaomi pad 6PRO 小米平板6 pro hyperOS降级 澎湃os 降级MIUI 14 教程 免解锁BL 降级,168小时解锁绑定
小米平板 6 Pro 机型代号 :liuqin 降级MIUI 14 小米澎湃 OS 正式版 澎湃OS安卓发布日期卡刷包线刷包OS1.0.7.0.UMYCNXM14.02024-07-13miui_LIUQIN_OS1.0.7.0.UMYCNXM_d618a5c980_14.0.zipliuqin_images_OS1.0.7.0.UMYCNXM_20240705.0000.00_14.0_cn_8cbf5920be.…...
MySQL 备份一个表
语法(创建一个与table1结构相同的新表table2,并且将table1的数据复制到table2): create table table2 as select * from table1 举例(备份tb_log表到tb_log_20240815中去): create table tb_log_20240815 as select * from tb_log...

鸿蒙开发入门day10-组件导航
(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,还请三连支持一波哇ヾ(@^∇^@)ノ) 目录 组件导航 (Navigation) 设置页面显示模式 设置标题栏模式 设置菜…...
Python爬虫实战:研究MechanicalSoup库相关技术
一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...

Keil 中设置 STM32 Flash 和 RAM 地址详解
文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
Go语言多线程问题
打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...
从实验室到产业:IndexTTS 在六大核心场景的落地实践
一、内容创作:重构数字内容生产范式 在短视频创作领域,IndexTTS 的语音克隆技术彻底改变了配音流程。B 站 UP 主通过 5 秒参考音频即可克隆出郭老师音色,生成的 “各位吴彦祖们大家好” 语音相似度达 97%,单条视频播放量突破百万…...
Django RBAC项目后端实战 - 03 DRF权限控制实现
项目背景 在上一篇文章中,我们完成了JWT认证系统的集成。本篇文章将实现基于Redis的RBAC权限控制系统,为系统提供细粒度的权限控制。 开发目标 实现基于Redis的权限缓存机制开发DRF权限控制类实现权限管理API配置权限白名单 前置配置 在开始开发权限…...
LangChain【6】之输出解析器:结构化LLM响应的关键工具
文章目录 一 LangChain输出解析器概述1.1 什么是输出解析器?1.2 主要功能与工作原理1.3 常用解析器类型 二 主要输出解析器类型2.1 Pydantic/Json输出解析器2.2 结构化输出解析器2.3 列表解析器2.4 日期解析器2.5 Json输出解析器2.6 xml输出解析器 三 高级使用技巧3…...

如何做好一份技术文档?从规划到实践的完整指南
如何做好一份技术文档?从规划到实践的完整指南 🌟 嗨,我是IRpickstars! 🌌 总有一行代码,能点亮万千星辰。 🔍 在技术的宇宙中,我愿做永不停歇的探索者。 ✨ 用代码丈量世界&…...