TensorRT量化工具pytorch_quantization代码解析(一)
量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。
API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的统计信息。它与 QuantDescriptor 一起使用,后者描述了如何量化张量。在 TensorQuantizer 之上的是量化模块,这些模块被设计为 PyTorch 全精度模块的替代品。这些是使用 TensorQuantizer 对模块的权重和输入进行伪量化或收集统计信息的方便模块。
API 支持将 PyTorch 模块自动转换为其量化版本。转换也可以使用 API 手动完成,这允许在不想量化所有模块的情况下进行部分量化。例如,一些层可能对量化更敏感,并且使其未量化可提高任务精度。
量化第一步是将量化器模块添加到神经网络图中。该包提供了许多量化层模块,其中包含用于输入和权重的量化器。例如quant_nn.QuantLinear,它可以用来代替nn.Linear。这些量化层可以通过猴子修补或手动修改模型定义来自动替换。自动层替换是使用quant_module完成的。这应该在创建模型之前调用。
首先看以下代码:
from pytorch_quantization import quant_modules
quant_modules.initialize()
initialize()会动态地修改 PyTorch 代码,适用于每个模块的所有实例,将 torch.nn.module 的一些子类替换为对应的量化版本。如果不希望所有模块都量化,则应手动替换量化模块。独立量化器也可以添加到带有quant_nn.TensorQuantizer的模型中。
initialize()位于:tools\pytorch-quantization\pytorch_quantization\quant_modules.py,作用使用使用monkey patching进行动态模块更换为量化版本
什么是猴子补丁
- Python是一种典型的动态脚本语言。它不仅具有 动态类型(dynamic type) ,而且它的 对象模型(object model) 也是动态的。Python的类是可变的(mutable),方法(methods)只是类的属性(attributes);这允许我们在 运行时(run time) 修改其行为。这被称为猴子补丁(Monkey Patching), 它指的是偷偷地更改代码。
- Monkey Patching只是在 运行时(run time) 动态替换属性(attributes)。而在Python中,术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改。
def initialize(float_module_list=None, custom_quant_modules=None):"""用量化版本动态地替换模块。在内部,状态由helper类对象维护,该对象有助于将原始模块替换回去。参数:float_module_list:列表,用户提供的列表,其中指明哪些模块不可执行替换custom_quant_modules:一个字典。用户提供的映射,用于指示除torch.nn及其相应量化版本之外的任何其他模块。Returns:空"""# 准备monkey patching中使用的内部变量quant_map和orginal_func_map_quant_module_helper_object.prepare_state(float_module_list, custom_quant_modules)#执行量化模块替换_quant_module_helper_object.apply_quant_modules()def deactivate():"""动态模块更换,可逆转monkey patching使用维护状态的helper类对象动态地替换回先前在initialize()函数调用中被monkey patching的原始模块。"""_quant_module_helper_object.restore_float_modules()# 维护被替换模块状态的全局对象。
_quant_module_helper_object = QuantModuleReplacementHelper()
自定义量化模块使用示例:
# torch.nn模块定义不可执行替换列表
float_module_list = ["Linear"]
# torch.nn以外的模块自定义映射
custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]
# Monkey修补模块
pytorch_quantization.quant_modules.initialize(float_module_list, custom_modules)
# 使用量化模块
pytorch_quantization.quant_modules.deactivate()
继续看helper类QuantModuleReplacementHelper
class QuantModuleReplacementHelper():"""帮助量化版本替换torch.nn模块术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改该模块用工具内部实现或任何其他用户提供的自定义模块提供的量化版 替换(通过monkey patching)torch.nn模块属性:orginal_func_map:一个dict.维护原始torch.nn模块字典quant_support_list:列表,包含工具提供的量化版本的模块名称quant_map:一个字典,包含模块名称及其量化版本的字典quant_switch_opt:一个字典,用于指示哪些模块不能替换其量化版本。该dict由用户提供的列表更新,该列表指示在monkey patching中要忽略的模块"""def __init__(self):# 保留要更换的原始模块self.orginal_func_map = set()# 默认情况下,维护工具支持的量化模块列表self.default_quant_map = _DEFAULT_QUANT_MAP# 保存最终量化模块。self.quant_map = set()
_DEFAULT_QUANT_MAP是包含量化模块映射的文件的全局成员
_DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),_quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),_quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),_quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),_quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),_quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),_quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),_quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),_quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),_quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),_quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),_quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),_quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),_quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),_quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]
_quant_entry定义命名元组,用于存储量化模块映射,它拥有三个属性orig_mod mod_name replace_mod
_quant_entry = namedtuple('quant_entry', 'orig_mod mod_name replace_mod')
QuantModuleReplacementHelper类的属性方法:
prepare_state准备稍后在monkey patching机制中使用的量化模块的命名字典quant_map和更换为原始模块orginal_func_map- 设置torch.nn工具支持的量化模块列表
- 为torch.nn以外的模块设置自定义映射
- 使用float_module_list关闭用户指示模块的monkey patching替换
def prepare_state(self, float_module_list=None, custom_map=None):""""""# 对于支持的默认量化模块,生成quant_mapfor item in self.default_quant_map:if float_module_list is not None and item.mod_name in float_module_list:# 如果float_module_list中存在此模块,则跳过此模块continueelse:# 将模块追加到将在monkey patching中使用的变量中self.quant_map.add(item)# 存储要在反向monkey patching中使用的原始模块self.orginal_func_map.add(_quant_entry(item.orig_mod, item.mod_name,getattr(item.orig_mod, item.mod_name)))# 将自定义模块添加到quant_mapif custom_map is not None:for item in custom_map:# 将自定义模块附加到将在monkey补丁中使用的列表中# 将元组转换为命名元组self.quant_map.add(_quant_entry(item[0], item[1], item[2]))# 将原始模块存储在另一个列表中,该列表将用于反向monkey patchingself.orginal_func_map.add(_quant_entry(item[0], item[1], getattr(item[0], item[1])))
- apply_quant_modules:根据quant_map,执行替换为量化模块
def apply_quant_modules(self):for entry in self.quant_map:# 用于设置属性值,该属性不一定是存在的,对应函数 getattr()setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
- restore_float_modules:通过使用orginal_func_map替换回原始模块,反转monkey patch的效果
def restore_float_modules(self):for entry in self.orginal_func_map:setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
相关文章:
TensorRT量化工具pytorch_quantization代码解析(一)
量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。 API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的…...
【Kubernetes】第二十七篇 - 布署前端项(下)
一,前言 上一篇,介绍了前端项目的部署:项目的创建和 jenkins 配置; 本篇,创建 Deployment、Service,完成前端项目的部署; 二,创建 Deployment 创建 Deployment 配置文件ÿ…...
【MFC】两个ListBox控件数据交互
一.控件ID名称 界面如图下所示: 候选数据列表的ID为: 已选数据列表的ID为: 二.数据添加 可以使用以下代码往框中添加数据: ((CListBox *)GetDlgItem(IDC_LIST_TO_CHO))->AddString("测试数据"); 显示效果如下&#…...
sklearn库学习--SelectKBest 、f_regression
目录 一、SelectKBest 介绍、代码使用 介绍: 代码使用: 二、评分函数 【1】f_regression: (1)介绍: (2)F值和相关系数 【2】除了f_regression函数,还有一些适用于…...
蓝桥杯刷题第十三天
第一题:特殊日期问题描述对于一个日期,我们可以计算出年份的各个数位上的数字之和,也可以分别计算月和日的各位数字之和。请问从 1900 年 11 月 1 日至 9999 年 12 月 31 日,总共有多少天,年份的数位数字之和等于月的数…...
CPU 和带宽之间的时空权衡
在 从一道面试题看 TCP 的吞吐极限 一文的开始,我提到在环形域上两个数字比较大小的前提是在同一个半圆内,进而得到滑动窗口最大值被限定在一个环形域的一半。 现在来看更为基本的问题。如果序列号只有 2bit,甚至仅有 1bit,保序传…...
ES+Redis+MySQL,这个高可用架构设计太顶了!
一、背景 会员系统是一种基础系统,跟公司所有业务线的下单主流程密切相关。如果会员系统出故障,会导致用户无法下单,影响范围是全公司所有业务线。所以,会员系统必须保证高性能、高可用,提供稳定、高效的基础服务。 …...
【Maven】Maven的常用命令
目录 一、Maven的常用命令 1、compile 编译命令 2、test 测试命令 3 、clean 清理命令 4、package 打包命令 5、 install 安装命令 6、Maven 指令的生命周期 二、maven 的概念模型 💟 创作不易,不妨点赞💚评论❤️收藏💙一…...
python的循环结构
python中有for循环和while循环两种形式。 1. for 循环 可以用for循环来遍历不同类型的对象,如数组、列表、元组、字典、集合或字符串,并对每个元素执行一段代码。 1.1 数组的for循环 用for循环遍历一个数组,并打印出每个元素:…...
五种Python中字典的高级用法
1. 引言 Python中的字典是一种非常有用的数据结构,它允许大家存储键值对。通常来说,字典灵活、高效且易于使用,是Python中最常用的数据结构之一。字典通常被用于统计频率、映射值等任务,但在Python中使用字典也可以达到许多意想不…...
[蓝桥杯单片机]——八到十一届初赛决赛客观题
第八届初赛 一、填空题 采用外部12MHz晶振,经过系统12分频时定时器获得最大定时长度,此时定时器定时脉冲为1MHz,周期为1s,而定时器计时均为16位加法计数器,即计时长度为。 二、 选择题 ①带阻滤波器是指能通过大多数频…...
多线程(初阶)
文章目录一.初始线程(Thread)1.1.线程的概念1.2.线程的优势1.2.1.线程比进程更轻量1.2.2.并发编程1.3.线程和进程的区别二.Thread类方法2.1. java 中创建线程的方法2.1.1. 继承Thread,重写run2.1.2. 实现Ruuable接口2.1.3. 使用匿名内部类,继承Thread2.1.4.使用匿名内部类,实现…...
【Vue从入门到进阶】Node.js安装与配置
✅作者简介:CSDN一位小博主,正在学习前端,欢迎大家一起来交流学习🏆 📃个人主页:白月光777的CSDN博客 🔥系列专栏:Vue从入门到进阶 💬个人格言:但行好事&…...
python 正则使用详解
python 正则使用详解什么是正则在 python 中使用正则一些正则的定义python 正则的方法match 从字符串开头匹配正则返回的结果分析(重要)fullmatch 严格匹配整个字符串search 任意位置开始匹配sub 替换匹配内容subn 以元组方式返回替换结果split 正则切割…...
一个深度学习项目需要什么
DataLoader1.数据预处理在将数据提供给模型之前,DataLoader需要对数据进行预处理。预处理可以包括数据增强、归一化、裁剪、缩放等操作。这些操作可以提高模型的性能和准确度。在处理点云数据时,可以通过最远点下采样到固定的点数。2.读取标签文件我 1 2…...
【Java进阶篇】—— 常用类和基础API
一、String类 1.1 String的特性 java.lang.String 类代表字符串,由final关键字修饰,在赋值后不能改变(常量),不能继承String类String 对象的字符内容是存储在一个字符数组 value[]中的 我们来看一下String在JDK8中的…...
手敲Mybatis(六)-反射工具天花板
历时漫长的岁月,终于鼓起勇气继续研究Mybatis的反射工具类们,简直就是把反射玩出花,但是理解起来还是很有难度的,涉及的内容代码也颇多,所以花费时间也比较浩大,不过当了解套路每个类的功能也好,…...
内含18禁~~关于自学\跳槽\转行做网络安全行业的一些建议
作者:Eason_LYC 悲观者预言失败,十言九中。 乐观者创造奇迹,一次即可。 一个人的价值,在于他所拥有的。所以可以不学无术,但不能一无所有! 技术领域:WEB安全、网络攻防 关注WEB安全、网络攻防。…...
春分策划×运维老王主讲:CMDB数据运营精准化公开课启动报名啦!
『CMDB数据运营精准化』 公开直播课 要来了! 👆扫描海报二维码,预约直播 CMDB似乎是运维中永恒的老话题。 提到CMDB很多人都是又爱又恨,爱的是它给我们提供了一个美好的未来,有了CMDB我们可以解决诸多运维中的难题。…...
制作INCA和CANape通用的A2L
文章目录 前言制作A2LA2ML定义MOD_COMMON定义MOD_PAR定义MEMORY_SEGMENTTransportLayer定义PROTOCOL_LAYERDAQ总结前言 由于INCA和CANape是两个不同的公司对XCP协议的实现,所以A2L中也会有不一样的地方,但是在标定时若每次都用两个A2L,是非常不方便的,本文介绍如何设计A2L…...
Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
Map相关知识
数据结构 二叉树 二叉树,顾名思义,每个节点最多有两个“叉”,也就是两个子节点,分别是左子 节点和右子节点。不过,二叉树并不要求每个节点都有两个子节点,有的节点只 有左子节点,有的节点只有…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...
人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式
今天是关于AI如何在教学中增强学生的学习体验,我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育,这并非炒作,而是已经发生的巨大变革。教育机构和教育者不能忽视它,试图简单地禁止学生使…...
Java数值运算常见陷阱与规避方法
整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...
R 语言科研绘图第 55 期 --- 网络图-聚类
在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...
从面试角度回答Android中ContentProvider启动原理
Android中ContentProvider原理的面试角度解析,分为已启动和未启动两种场景: 一、ContentProvider已启动的情况 1. 核心流程 触发条件:当其他组件(如Activity、Service)通过ContentR…...
