当前位置: 首页 > news >正文

TensorRT量化工具pytorch_quantization代码解析(二)

有些地方看的不是透彻,后续继续补充!

继续看张量量化函数,代码位于:tools\pytorch-quantization\pytorch_quantization\tensor_quant.py

ScaledQuantDescriptor

量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量定义了量化张量。

class ScaledQuantDescriptor():def __init__(self, num_bits=8, name=None, **kwargs):if not isinstance(num_bits, int):raise TypeError("num_bits must be an integer, not {}.".format(type(num_bits)))if num_bits < 0:raise ValueError("num_bits must be >= 0, not {}.".format(num_bits))if num_bits == 0:logging.error("num_bits is 0. This will result in the tensor being quantized to all zeros."" This mode should only be used for debugging purposes.")self._num_bits = num_bitsif not isinstance(name, str) and name is not None:raise TypeError("name must be a string or None, not {}.".format(type(name)))self._name = nameself._fake_quant = kwargs.pop('fake_quant', True)self._axis = kwargs.pop('axis', None)if self._axis is not None:logging.debug("Meaning of axis has changed since v2.0. Make sure to update.")self._learn_amax = kwargs.pop('learn_amax', False)if self._learn_amax and self._axis is not None:raise TypeError("axis is ignored and must be None when learn_amax is true, got {}.".format(type(self._axis)))amax = kwargs.pop('amax', None)if amax is not None:if not isinstance(amax, float) and not isinstance(amax, list) and not isinstance(amax, np.ndarray):raise TypeError("amax must be float, list or ndarray, not {}".format(type(amax)))# Make it single precision arrayself._amax = np.array(amax, dtype=np.float32)else:self._amax = amaxself._scale_amax = kwargs.pop('scale_amax', None)self._calib_method = kwargs.pop('calib_method', "max")self._unsigned = kwargs.pop('unsigned', False)self._narrow_range = kwargs.pop('narrow_range', False)if kwargs:raise TypeError("Unused keys: {}".format(kwargs.keys()))

参数:

  • num_bits:int,量化位数,用于计算比例因子。默认值8。
  • name:看起来很不错

关键字参数:

  • fake_quant:布尔值。如果为True,则使用fake量化模式。默认为True

  • axisNone, int或整数的tuple,轴将利用自己的最大值以计算缩放因子,默认None。

    • 如果None(默认值),则使用per tensor scale
      确保在范围[-rank(input_tensor),rank(输入_tensor))内。
      例如,对于KCRS权重张量,quant_axis=(0)将产生per channel scaling
  • amax:用户指定的绝对最大范围的floatlist/ndarray。如果提供,忽略quant_axis并使用它进行量化。如果learn_amax为True,将用于初始化可学习的amax。默认None

  • learn_amaxboolean,如果为True,学习amax。默认为False。

  • scale_amaxfloat,如果提供,将amax乘以scale_amax,默认无。

  • calib_methodstring[“max”,“histogram”]中的一个校准要使用的指标。除了
    max calibration,其他都是基于hisogram的。默认值“max”。

  • unsignedBoolean,如果为True,则使用无符号。默认为False

Raises:

  • TypeError:如果传入了不支持的类型。

Read-only properties:

  • fake_quant:
  • name:
  • learn_amax:
  • scale_amax:
  • axis:
  • calib_method:
  • num_bits:
  • amax:
  • unsigned:

QuantDescriptor定义了张量应该如何量化。预定义的QuantDescriptor张量描述符如下:

QuantDescriptor = ScaledQuantDescriptor# Predefined descriptors
QUANT_DESC_8BIT_PER_TENSOR = QuantDescriptor(num_bits=8)
QUANT_DESC_UNSIGNED_8BIT_PER_TENSOR = QuantDescriptor(num_bits=8, unsigned=True)
QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONV3D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE1D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE3D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))

如果在QuantDescriptor中给出最amaxTensorQuantizer将使用它进行量化。否则,TensorQuantizer将计算amax,然后进行量化。amax被计算通过指定的axis轴。注意QuantDescriptor将剩余轴指定与max()轴相反。

例子:

from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizerquant_desc = QuantDescriptor(num_bits=4, fake_quant=False, axis=(0), unsigned=True)

接下来看量化函数:pytorch_quantization提供3个自定义的张量量化函数算子,继承torch.autograd.function,实现函数的前向传播、反向传播

TensorQuantFunction

  • 通用的张量量化函数TensorQuantFunction
class TensorQuantFunction(Function):"""一个输入张量,输出一个量化张量。`scale`的粒度可以从amax的形状来解释"""

forward

在前向过程中,对浮点权重和激活进行伪量化,并使用这些伪量化的权重和激活来执行层的操作

    @staticmethoddef forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True):ctx.save_for_backward(inputs, amax)outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)# Check if scale overflows FP16if outputs.dtype == torch.half and scale.max() > 65504:raise ValueError("scale is too large for FP16 with amax={}".format(amax))return outputs, scale.to(inputs.dtype)

output_dtype指示量化值是以整数还是浮点形式存储。希望将其存储在浮点中的原因是pytorch函数接受量化值,它可能不接受整数输入,例如Conv2D
它使用2num_bits−12^{num\_bits-1}2num_bits1值,例如,对于num_bits=8,使用[-127,127]

遵循tensorflow约定,传入最大值并用于确定比例,而不是直接输入比例。尽管直接输入比例可能更自然。

参数:

  • ctx:一个用于向后存储张量的Context对象。

  • inputs:float32型张量。

  • amax:float32型张量。输入将在[-amax,amax]范围内量化,amax将广播到inputs tensor。

  • num_bits:用于计算缩放因子的整数,scale=(2num_bits−1−1)/maxscale=(2^{num\_bits-1}-1)/maxscale=(2num_bits11)/max。默认值8。

  • output_dtype:张量的一种类型。torch.int32或torch.float32。希望存储为float,pytorch函数接受float量化值,它可能不接受整数输入。
    unsigned:boolean,使用无符号整数范围。例如,对于num_bits=8,[0,255]。默认为False。

  • narrow_range:布尔值。使用对称整数范围进行有符号量化
    例如,对于num_bits=8,用[-127,127]代替[-128,127]。默认为True。

Returns:

  • outputsoutput_dtype类型的张量。

  • scalefloat32型张量。outputs / scale将对输出张量进行反量化。

Raises:

  • ValueError:

backward

通过clipping实现直通估计。对于-amax<=input<=amax,梯度直接通过,否则梯度为零。
参数:

  • ctx:一个上下文对象,其中保存了来自forward的张量。
  • grad_outputs:outputs梯度张量。
  • grad_scale:scale梯度张量。

Returns:

  • grad_inputs:梯度张量。
    @staticmethoddef backward(ctx, grad_outputs, grad_scale):"""Implements straight through estimation with clipping. For -amax <= input <= amaxthe gradient passes straight through, otherwise the gradient is zero.Args:ctx: A Context object with saved tensors from forward.grad_outputs: A tensor of gradient of outputs.grad_scale: A tensor of gradient of scale.Returns:grad_inputs: A tensor of gradient."""inputs, amax = ctx.saved_tensorszero = grad_outputs.new_zeros(1) # create a zero tensor with the same type and devicegrad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)return grad_inputs, None, None, None, None
tensor_quant = TensorQuantFunction.apply

TensorQuantFunction.apply赋予一个别名tensor_quant,这样可以直接调用tensor_quant进行量化,例如:

from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x = torch.rand(10)# quantize tensor x. quant_x will be
# tensor([126., 113., 127.,  59.,  11.,  23.,  47.,  73.,  43.,  27.])
# with scale=128.0057
quant_x, scale = tensor_quant.tensor_quant(x, x.abs().max())

FakeTensorQuantFunction

class FakeTensorQuantFunction(Function):"""Fake version of TensorQuantFunctionSee comments of TensorQuantFunction, arguments are the same."""@staticmethoddef forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True):ctx.save_for_backward(inputs, amax)outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)return outputs / scale.to(inputs.dtype)@staticmethoddef backward(ctx, grad_outputs):inputs, amax = ctx.saved_tensorszero = grad_outputs.new_zeros(1)grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)return grad_inputs, None, None, None, None

在向后过程中,使用权重的渐变来更新浮点权重。为了处理量化梯度,除了未定义的点之外,几乎所有地方都是零,可以使用 直通估计器 ( STE ),它通过伪量化操作符传递梯度。

fake_tensor_quant = FakeTensorQuantFunction.apply

TensorQuantFunction.apply赋予一个别名fake_tensor_quant,这样可以直接调用fake_tensor_quant进行量化,例如:

from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x = torch.rand(10)# fake quantize tensor x. fake_quant_x will be
# tensor([0.9843, 0.8828, 0.9921, 0.4609, 0.0859, 0.1797, 0.3672, 0.5703, 0.3359, 0.2109])
fake_quant_x = tensor_quant.fake_tensor_quant(x, x.abs().max())

_tensor_quant

def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):"""Shared function body between TensorQuantFunction and FakeTensorQuantFunction"""# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.if isinstance(amax, torch.Tensor) and inputs.dim() != amax.dim():logging.debug("amax %s has different shape than inputs %s. Make sure broadcast works as expected!",amax.size(), inputs.size())logging.debug("{} bits quantization on shape {} tensor.".format(num_bits, inputs.size()))if unsigned:if inputs.min() < 0.:raise TypeError("Negative values encountered in unsigned quantization.")# Computation must be in FP32 to prevent potential over flow.input_dtype = inputs.dtypeif inputs.dtype == torch.half:inputs = inputs.float()if amax.dtype == torch.half:amax = amax.float()min_amax = amax.min()if min_amax < 0:raise ValueError("Negative values in amax")max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)if unsigned:min_bound = 0elif narrow_range:min_bound = -max_boundelse:min_bound = -max_bound - 1scale = max_bound / amaxepsilon = 1. / (1<<24)if min_amax <= epsilon:  # Treat amax smaller than minimum representable of fp16 0zero_amax_mask = (amax <= epsilon)scale[zero_amax_mask] = 0  # Value quantized with amax=0 should all be 0outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)if min_amax <= epsilon:scale[zero_amax_mask] = 1.  # Return 1 makes more sense for values quantized to 0 with amax=0if input_dtype == torch.half:outputs = outputs.half()return outputs, scale

待梳理!!!

相关文章:

TensorRT量化工具pytorch_quantization代码解析(二)

有些地方看的不是透彻&#xff0c;后续继续补充&#xff01; 继续看张量量化函数&#xff0c;代码位于&#xff1a;tools\pytorch-quantization\pytorch_quantization\tensor_quant.py ScaledQuantDescriptor 量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量…...

buu [BJDCTF2020]easyrsa 1

题目描述 &#xff1a; from Crypto.Util.number import getPrime,bytes_to_long from sympy import Derivative from fractions import Fraction from secret import flagpgetPrime(1024) qgetPrime(1024) e65537 np*q zFraction(1,Derivative(arctan(p),p))-Fraction(1,Deri…...

taobao.user.openuid.getbyorder( 根据订单获取买家openuid )

&#xffe5;免费不需用户授权 根据订单获取买家openuid&#xff0c;最大查询30个 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 请求示例 TaobaoClient client new DefaultTaobaoClient(url, appkey, secret); UserOpenuidGetbyorderR…...

Mac iTerm2 rz sz

1、安装brew&#xff08;找了很多&#x1f517;&#xff0c;就这个博主的好用&#xff09; Mac如何安装brew&#xff1f;_行走的码农00的博客-CSDN博客_mac brew 2、安装lrzsz brew install lrzsz 检查是否安装成功 brew list 定位lrzsz的安装目录 brew list lrzsz 执…...

高通平台开发系列讲解(Sensor篇)Gsensor基础知识

文章目录 一、什么是SENSOR?二、Sensor的分类及作用三、Gsensor的工作原理及介绍3.1、常见Gsensor3.2、Gsensor的特性沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇文章将介绍 Sensor 基础 一、什么是SENSOR? 传感器(英文名称:sensor )是一种检测装置,能感…...

图像处理实战--Opencv实现人像迁移

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 今天来学习一下如何使用Opencv实现人像迁移&#xff0c;欢迎大家一起参与探讨交流~ 本文目录&#xff1a;一、实验要求二、实验环境三、实验原理及操作1.照片准备2.图像增强3.实现美颜功能4.背景虚化5.图像二值化处理6.人…...

OnlyOffice验证(二)在Centos7上部署OnlyOffice编译结果

在Centos7上部署OnlyOffice编译结果 此处将尝试将OnlyOffice验证&#xff08;一&#xff09;DocumentServer编译验证的结果部署到Centos7上。并且使用其它服务器现有的RabbitMq和Mysql。 安装Nginx 先安装Nginx需要的依赖环境&#xff1a; yum install openssl* -y yum insta…...

6.补充和总结【Java面试第三季】

6.补充和总结【Java面试第三季】前言推荐6.补充和总结69_总结闲聊回顾和总结继续学习最后前言 2023-2-4 19:08:01 以下内容源自 【尚硅谷Java大厂面试题第3季&#xff0c;跳槽必刷题目必扫技术盲点&#xff08;周阳主讲&#xff09;-哔哩哔哩】 仅供学习交流使用 推荐 Jav…...

基于ssm框架大学生社团管理系统(源码+数据库+文档)

一、项目简介 本项目是一套基于ssm框架大学生社团管理系统&#xff0c;主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目可以直接作为bishe使用。 项目都经过严格调试&#xff0c;确保可…...

vulnhub靶场NAPPING: 1.0.1教程

靶场搭建靶机下载地址&#xff1a;Napping: 1.0.1 ~ VulnHub直接解压双击ova文件即可使用软件&#xff1a;靶机VirtualBox&#xff0c;攻击机VMware攻击机&#xff1a;kali信息收集arp-scan -l上帝之眼直接来看看网站可以注册账号&#xff0c;那就先试试。注册完后登入哦。要输…...

Docker基本介绍

最近需要将项目做成一个web应用并部署到多台服务器上&#xff0c;于是就简单学习了一下docker&#xff0c;做一下小小的记录。 1、简单介绍一下docker 我们经常遇到这样一个问题&#xff0c;自己写的代码在自己的电脑上运行的很流畅&#xff0c;在其他人电脑上就各种bug&…...

可用于标记蛋白质216699-36-4,6-ROX,SE,6-羧基-X-罗丹明琥珀酰亚胺酯

一.6-ROX&#xff0c;SE产品描述&#xff1a;6-羧基-X-罗丹明琥珀酰亚胺酯&#xff08;6-ROX&#xff0c;SE&#xff09;是一种用于寡核苷酸标记和自动DNA测序的荧光染料&#xff0c;可用于标记蛋白质&#xff0c;寡核苷酸和其他含胺分子的伯胺&#xff08;-NH2&#xff09;。西…...

高数:极限的定义

目录 极限的定义&#xff1a; 数列极限的几何意义&#xff1a; 由极限的定义得出的极限的两个结论&#xff1a; ​编辑 极限的第三个结论&#xff1a; 例题 方法1&#xff1a; ​编辑 方法2&#xff1a; ​编辑 方法3&#xff1a; ​编辑 极限的定义&#xff1a; 如何理…...

大数据技术之Hadoop

第1章 Hadoop概述1.1 Hadoop是什么1.2 Hadoop发展历史&#xff08;了解&#xff09;1.3 Hadoop三大发行版本&#xff08;了解&#xff09;Hadoop三大发行版本&#xff1a;Apache、Cloudera、Hortonworks。Apache版本最原始&#xff08;最基础&#xff09;的版本&#xff0c;对于…...

一文带你搞懂Go语言函数选项模式,Go函数一等公民。

前言 通过这篇文章《为什么说Go的函数是”一等公民“》&#xff0c;我们了解到了什么是“一等公民”&#xff0c;以及都具备哪些特性&#xff0c;同时对函数的基本使用也更加深入。 本文重点介绍下Go设计模式之函数选项模式&#xff0c;它得益于Go的函数是“一等公民”&#…...

Window.location 详细介绍

如果你需要获取网站的 URL 信息&#xff0c;那么 window.location 对象就是为你准备的。使用它提供的属性来获取当前页面地址的信息&#xff0c;或使用其方法进行某些页面的重定向或刷新。 https://www.samanthaming.com/tidbits/?filterJS#2 window.location.origin → htt…...

js侧滑显示删除按钮

效果图&#xff1a; <!DOCTYPE html> <html><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0, maximum-scale1.0, user-scalableno"><title>js侧滑显示删…...

Python - DIY - 使用dump取json某些键值对合成新的json文件

Python - Json处理前言&#xff1a;应用场景&#xff1a;基本工具&#xff1a;文件操作&#xff1a;打开文件&#xff1a;写文件&#xff1a;读文件&#xff1a;关闭文件并刷新缓冲区&#xff1a;Json字符串和字典转换&#xff1a;json.loads()&#xff1a;json.dumps():Json文…...

深度剖析指针(中)——“C”

各位CSDN的uu们你们好呀&#xff0c;今天小雅兰的内容仍旧是深度剖析指针噢&#xff0c;在上一篇博客中&#xff0c;我已经写过了字符指针、数组指针、指针数组、数组传参和指针传参的知识点&#xff0c;那么这篇博客小雅兰会讲解一下函数指针、函数指针数组 、指向函数指针数组…...

论文阅读 | Video Frame Synthesis using Deep Voxel Flow

前言&#xff1a; 视频帧生成方法&#xff08;视频插帧/视频预测&#xff09;ICCV2017 oral Video Frame Synthesis using Deep Voxel Flow 引言 当下进行视频帧合成的方法分为两种&#xff0c;第一种是光流法&#xff0c;光流准确的话效果好&#xff0c;光流不准确的话则生…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题

分区配置 (ptab.json) img 属性介绍&#xff1a; img 属性指定分区存放的 image 名称&#xff0c;指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件&#xff0c;则以 proj_name:binary_name 格式指定文件名&#xff0c; proj_name 为工程 名&…...

嵌入式学习笔记DAY33(网络编程——TCP)

一、网络架构 C/S &#xff08;client/server 客户端/服务器&#xff09;&#xff1a;由客户端和服务器端两个部分组成。客户端通常是用户使用的应用程序&#xff0c;负责提供用户界面和交互逻辑 &#xff0c;接收用户输入&#xff0c;向服务器发送请求&#xff0c;并展示服务…...

20个超级好用的 CSS 动画库

分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码&#xff0c;而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库&#xff0c;可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画&#xff0c;可以包含在你的网页或应用项目中。 3.An…...

三分算法与DeepSeek辅助证明是单峰函数

前置 单峰函数有唯一的最大值&#xff0c;最大值左侧的数值严格单调递增&#xff0c;最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值&#xff0c;最小值左侧的数值严格单调递减&#xff0c;最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…...

(一)单例模式

一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...

Vue 模板语句的数据来源

&#x1f9e9; Vue 模板语句的数据来源&#xff1a;全方位解析 Vue 模板&#xff08;<template> 部分&#xff09;中的表达式、指令绑定&#xff08;如 v-bind, v-on&#xff09;和插值&#xff08;{{ }}&#xff09;都在一个特定的作用域内求值。这个作用域由当前 组件…...

海云安高敏捷信创白盒SCAP入选《中国网络安全细分领域产品名录》

近日&#xff0c;嘶吼安全产业研究院发布《中国网络安全细分领域产品名录》&#xff0c;海云安高敏捷信创白盒&#xff08;SCAP&#xff09;成功入选软件供应链安全领域产品名录。 在数字化转型加速的今天&#xff0c;网络安全已成为企业生存与发展的核心基石&#xff0c;为了解…...