TensorRT量化工具pytorch_quantization代码解析(一)
量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。
API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的统计信息。它与 QuantDescriptor 一起使用,后者描述了如何量化张量。在 TensorQuantizer 之上的是量化模块,这些模块被设计为 PyTorch 全精度模块的替代品。这些是使用 TensorQuantizer 对模块的权重和输入进行伪量化或收集统计信息的方便模块。
API 支持将 PyTorch 模块自动转换为其量化版本。转换也可以使用 API 手动完成,这允许在不想量化所有模块的情况下进行部分量化。例如,一些层可能对量化更敏感,并且使其未量化可提高任务精度。
量化第一步是将量化器模块添加到神经网络图中。该包提供了许多量化层模块,其中包含用于输入和权重的量化器。例如quant_nn.QuantLinear,它可以用来代替nn.Linear。这些量化层可以通过猴子修补或手动修改模型定义来自动替换。自动层替换是使用quant_module完成的。这应该在创建模型之前调用。
首先看以下代码:
from pytorch_quantization import quant_modules
quant_modules.initialize()
initialize()
会动态地修改 PyTorch 代码,适用于每个模块的所有实例,将 torch.nn.module 的一些子类替换为对应的量化版本。如果不希望所有模块都量化,则应手动替换量化模块。独立量化器也可以添加到带有quant_nn.TensorQuantizer的模型中。
initialize()
位于:tools\pytorch-quantization\pytorch_quantization\quant_modules.py
,作用使用使用monkey patching进行动态模块更换为量化版本
什么是猴子补丁
- Python是一种典型的动态脚本语言。它不仅具有 动态类型(dynamic type) ,而且它的 对象模型(object model) 也是动态的。Python的类是可变的(mutable),方法(methods)只是类的属性(attributes);这允许我们在 运行时(run time) 修改其行为。这被称为猴子补丁(Monkey Patching), 它指的是偷偷地更改代码。
- Monkey Patching只是在 运行时(run time) 动态替换属性(attributes)。而在Python中,术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改。
def initialize(float_module_list=None, custom_quant_modules=None):"""用量化版本动态地替换模块。在内部,状态由helper类对象维护,该对象有助于将原始模块替换回去。参数:float_module_list:列表,用户提供的列表,其中指明哪些模块不可执行替换custom_quant_modules:一个字典。用户提供的映射,用于指示除torch.nn及其相应量化版本之外的任何其他模块。Returns:空"""# 准备monkey patching中使用的内部变量quant_map和orginal_func_map_quant_module_helper_object.prepare_state(float_module_list, custom_quant_modules)#执行量化模块替换_quant_module_helper_object.apply_quant_modules()def deactivate():"""动态模块更换,可逆转monkey patching使用维护状态的helper类对象动态地替换回先前在initialize()函数调用中被monkey patching的原始模块。"""_quant_module_helper_object.restore_float_modules()# 维护被替换模块状态的全局对象。
_quant_module_helper_object = QuantModuleReplacementHelper()
自定义量化模块使用示例:
# torch.nn模块定义不可执行替换列表
float_module_list = ["Linear"]
# torch.nn以外的模块自定义映射
custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]
# Monkey修补模块
pytorch_quantization.quant_modules.initialize(float_module_list, custom_modules)
# 使用量化模块
pytorch_quantization.quant_modules.deactivate()
继续看helper
类QuantModuleReplacementHelper
class QuantModuleReplacementHelper():"""帮助量化版本替换torch.nn模块术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改该模块用工具内部实现或任何其他用户提供的自定义模块提供的量化版 替换(通过monkey patching)torch.nn模块属性:orginal_func_map:一个dict.维护原始torch.nn模块字典quant_support_list:列表,包含工具提供的量化版本的模块名称quant_map:一个字典,包含模块名称及其量化版本的字典quant_switch_opt:一个字典,用于指示哪些模块不能替换其量化版本。该dict由用户提供的列表更新,该列表指示在monkey patching中要忽略的模块"""def __init__(self):# 保留要更换的原始模块self.orginal_func_map = set()# 默认情况下,维护工具支持的量化模块列表self.default_quant_map = _DEFAULT_QUANT_MAP# 保存最终量化模块。self.quant_map = set()
_DEFAULT_QUANT_MAP
是包含量化模块映射的文件的全局成员
_DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),_quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),_quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),_quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),_quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),_quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),_quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),_quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),_quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),_quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),_quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),_quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),_quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),_quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),_quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]
_quant_entry
定义命名元组,用于存储量化模块映射,它拥有三个属性orig_mod mod_name replace_mod
_quant_entry = namedtuple('quant_entry', 'orig_mod mod_name replace_mod')
QuantModuleReplacementHelper
类的属性方法:
prepare_state
准备稍后在monkey patching机制中使用的量化模块的命名字典quant_map
和更换为原始模块orginal_func_map
- 设置torch.nn工具支持的量化模块列表
- 为torch.nn以外的模块设置自定义映射
- 使用float_module_list关闭用户指示模块的monkey patching替换
def prepare_state(self, float_module_list=None, custom_map=None):""""""# 对于支持的默认量化模块,生成quant_mapfor item in self.default_quant_map:if float_module_list is not None and item.mod_name in float_module_list:# 如果float_module_list中存在此模块,则跳过此模块continueelse:# 将模块追加到将在monkey patching中使用的变量中self.quant_map.add(item)# 存储要在反向monkey patching中使用的原始模块self.orginal_func_map.add(_quant_entry(item.orig_mod, item.mod_name,getattr(item.orig_mod, item.mod_name)))# 将自定义模块添加到quant_mapif custom_map is not None:for item in custom_map:# 将自定义模块附加到将在monkey补丁中使用的列表中# 将元组转换为命名元组self.quant_map.add(_quant_entry(item[0], item[1], item[2]))# 将原始模块存储在另一个列表中,该列表将用于反向monkey patchingself.orginal_func_map.add(_quant_entry(item[0], item[1], getattr(item[0], item[1])))
- apply_quant_modules:根据quant_map,执行替换为量化模块
def apply_quant_modules(self):for entry in self.quant_map:# 用于设置属性值,该属性不一定是存在的,对应函数 getattr()setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
- restore_float_modules:通过使用orginal_func_map替换回原始模块,反转monkey patch的效果
def restore_float_modules(self):for entry in self.orginal_func_map:setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
相关文章:
TensorRT量化工具pytorch_quantization代码解析(一)
量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。 API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的…...

【Kubernetes】第二十七篇 - 布署前端项(下)
一,前言 上一篇,介绍了前端项目的部署:项目的创建和 jenkins 配置; 本篇,创建 Deployment、Service,完成前端项目的部署; 二,创建 Deployment 创建 Deployment 配置文件ÿ…...

【MFC】两个ListBox控件数据交互
一.控件ID名称 界面如图下所示: 候选数据列表的ID为: 已选数据列表的ID为: 二.数据添加 可以使用以下代码往框中添加数据: ((CListBox *)GetDlgItem(IDC_LIST_TO_CHO))->AddString("测试数据"); 显示效果如下&#…...
sklearn库学习--SelectKBest 、f_regression
目录 一、SelectKBest 介绍、代码使用 介绍: 代码使用: 二、评分函数 【1】f_regression: (1)介绍: (2)F值和相关系数 【2】除了f_regression函数,还有一些适用于…...
蓝桥杯刷题第十三天
第一题:特殊日期问题描述对于一个日期,我们可以计算出年份的各个数位上的数字之和,也可以分别计算月和日的各位数字之和。请问从 1900 年 11 月 1 日至 9999 年 12 月 31 日,总共有多少天,年份的数位数字之和等于月的数…...
CPU 和带宽之间的时空权衡
在 从一道面试题看 TCP 的吞吐极限 一文的开始,我提到在环形域上两个数字比较大小的前提是在同一个半圆内,进而得到滑动窗口最大值被限定在一个环形域的一半。 现在来看更为基本的问题。如果序列号只有 2bit,甚至仅有 1bit,保序传…...

ES+Redis+MySQL,这个高可用架构设计太顶了!
一、背景 会员系统是一种基础系统,跟公司所有业务线的下单主流程密切相关。如果会员系统出故障,会导致用户无法下单,影响范围是全公司所有业务线。所以,会员系统必须保证高性能、高可用,提供稳定、高效的基础服务。 …...

【Maven】Maven的常用命令
目录 一、Maven的常用命令 1、compile 编译命令 2、test 测试命令 3 、clean 清理命令 4、package 打包命令 5、 install 安装命令 6、Maven 指令的生命周期 二、maven 的概念模型 💟 创作不易,不妨点赞💚评论❤️收藏💙一…...
python的循环结构
python中有for循环和while循环两种形式。 1. for 循环 可以用for循环来遍历不同类型的对象,如数组、列表、元组、字典、集合或字符串,并对每个元素执行一段代码。 1.1 数组的for循环 用for循环遍历一个数组,并打印出每个元素:…...
五种Python中字典的高级用法
1. 引言 Python中的字典是一种非常有用的数据结构,它允许大家存储键值对。通常来说,字典灵活、高效且易于使用,是Python中最常用的数据结构之一。字典通常被用于统计频率、映射值等任务,但在Python中使用字典也可以达到许多意想不…...

[蓝桥杯单片机]——八到十一届初赛决赛客观题
第八届初赛 一、填空题 采用外部12MHz晶振,经过系统12分频时定时器获得最大定时长度,此时定时器定时脉冲为1MHz,周期为1s,而定时器计时均为16位加法计数器,即计时长度为。 二、 选择题 ①带阻滤波器是指能通过大多数频…...

多线程(初阶)
文章目录一.初始线程(Thread)1.1.线程的概念1.2.线程的优势1.2.1.线程比进程更轻量1.2.2.并发编程1.3.线程和进程的区别二.Thread类方法2.1. java 中创建线程的方法2.1.1. 继承Thread,重写run2.1.2. 实现Ruuable接口2.1.3. 使用匿名内部类,继承Thread2.1.4.使用匿名内部类,实现…...

【Vue从入门到进阶】Node.js安装与配置
✅作者简介:CSDN一位小博主,正在学习前端,欢迎大家一起来交流学习🏆 📃个人主页:白月光777的CSDN博客 🔥系列专栏:Vue从入门到进阶 💬个人格言:但行好事&…...

python 正则使用详解
python 正则使用详解什么是正则在 python 中使用正则一些正则的定义python 正则的方法match 从字符串开头匹配正则返回的结果分析(重要)fullmatch 严格匹配整个字符串search 任意位置开始匹配sub 替换匹配内容subn 以元组方式返回替换结果split 正则切割…...
一个深度学习项目需要什么
DataLoader1.数据预处理在将数据提供给模型之前,DataLoader需要对数据进行预处理。预处理可以包括数据增强、归一化、裁剪、缩放等操作。这些操作可以提高模型的性能和准确度。在处理点云数据时,可以通过最远点下采样到固定的点数。2.读取标签文件我 1 2…...

【Java进阶篇】—— 常用类和基础API
一、String类 1.1 String的特性 java.lang.String 类代表字符串,由final关键字修饰,在赋值后不能改变(常量),不能继承String类String 对象的字符内容是存储在一个字符数组 value[]中的 我们来看一下String在JDK8中的…...

手敲Mybatis(六)-反射工具天花板
历时漫长的岁月,终于鼓起勇气继续研究Mybatis的反射工具类们,简直就是把反射玩出花,但是理解起来还是很有难度的,涉及的内容代码也颇多,所以花费时间也比较浩大,不过当了解套路每个类的功能也好,…...

内含18禁~~关于自学\跳槽\转行做网络安全行业的一些建议
作者:Eason_LYC 悲观者预言失败,十言九中。 乐观者创造奇迹,一次即可。 一个人的价值,在于他所拥有的。所以可以不学无术,但不能一无所有! 技术领域:WEB安全、网络攻防 关注WEB安全、网络攻防。…...

春分策划×运维老王主讲:CMDB数据运营精准化公开课启动报名啦!
『CMDB数据运营精准化』 公开直播课 要来了! 👆扫描海报二维码,预约直播 CMDB似乎是运维中永恒的老话题。 提到CMDB很多人都是又爱又恨,爱的是它给我们提供了一个美好的未来,有了CMDB我们可以解决诸多运维中的难题。…...
制作INCA和CANape通用的A2L
文章目录 前言制作A2LA2ML定义MOD_COMMON定义MOD_PAR定义MEMORY_SEGMENTTransportLayer定义PROTOCOL_LAYERDAQ总结前言 由于INCA和CANape是两个不同的公司对XCP协议的实现,所以A2L中也会有不一样的地方,但是在标定时若每次都用两个A2L,是非常不方便的,本文介绍如何设计A2L…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...

51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
DockerHub与私有镜像仓库在容器化中的应用与管理
哈喽,大家好,我是左手python! Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库,用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

如何在看板中体现优先级变化
在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...

Docker 本地安装 mysql 数据库
Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker ;并安装。 基础操作不再赘述。 打开 macOS 终端,开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...
虚拟电厂发展三大趋势:市场化、技术主导、车网互联
市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦࿰…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】,分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
【Elasticsearch】Elasticsearch 在大数据生态圈的地位 实践经验
Elasticsearch 在大数据生态圈的地位 & 实践经验 1.Elasticsearch 的优势1.1 Elasticsearch 解决的核心问题1.1.1 传统方案的短板1.1.2 Elasticsearch 的解决方案 1.2 与大数据组件的对比优势1.3 关键优势技术支撑1.4 Elasticsearch 的竞品1.4.1 全文搜索领域1.4.2 日志分析…...
Python 训练营打卡 Day 47
注意力热力图可视化 在day 46代码的基础上,对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...