当前位置: 首页 > news >正文

NLP(六十七)BERT模型训练后动态量化(PTDQ)

  本文将会介绍BERT模型训练后动态量化(Post Training Dynamic Quantization,PTDQ)。

量化

  在深度学习中,量化(Quantization)指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算。这么做的好处主要有如下几点:

  • 更少的模型体积,接近4倍的减少
  • 可以更快地计算,由于更少的内存访问和更快的int8计算,可以快2~4倍

  PyTorch中的模型参数默认以FP32精度储存。对于量化后的模型,其部分或者全部的tensor操作会使用int类型来计算,而不是使用量化之前的float类型。当然,量化还需要底层硬件支持,x86 CPU(支持AVX2)、ARM CPU、Google TPU、Nvidia Volta/Turing/Ampere、Qualcomm DSP这些主流硬件都对量化提供了支持。

模型量化示例图片

PTDQ

  PyTorch对量化的支持目前有如下三种方式:

  • Post Training Dynamic Quantization:模型训练完毕后的动态量化
  • Post Training Static Quantization:模型训练完毕后的静态量化
  • QAT (Quantization Aware Training):模型训练中开启量化

  本文仅介绍Post Training Dynamic Quantization(PTDQ)
  对训练后的模型权重执行动态量化,将浮点模型转换为动态量化模型,仅对模型权重进行量化,偏置不会量化。默认情况下,仅对Linear和RNN变体量化 (因为这些layer的参数量很大,收益更高)。

torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)

参数解释:

  • model:模型(默认为FP32)
  • qconfig_spec:
  1. 集合:比如: qconfig_spec={nn.LSTM, nn.Linear} 。列出要量化的神经网络模块。
  2. 字典: qconfig_spec = {nn.Linear: default_dynamic_qconfig, nn.LSTM: default_dynamic_qconfig}
  • dtype: float16 或 qint8
  • mapping:就地执行模型转换,原始模块发生变异
  • inplace:将子模块的类型映射到需要替换子模块的相应动态量化版本的类型

例子:

# -*- coding: utf-8 -*-
# 动态量化模型,只量化权重
import torch
from torch import nnclass DemoModel(torch.nn.Module):def __init__(self):super(DemoModel, self).__init__()self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)self.relu = nn.ReLU()self.fc = torch.nn.Linear(2, 2)def forward(self, x):x = self.conv(x)x = self.relu(x)x = self.fc(x)return xif __name__ == "__main__":model_fp32 = DemoModel()# 创建一个量化的模型实例model_int8 = torch.quantization.quantize_dynamic(model=model_fp32,  # 原始模型qconfig_spec={torch.nn.Linear},  # 要动态量化的算子dtype=torch.qint8)  # 将权重量化为:qint8print(model_fp32)print(model_int8)# 运行模型input_fp32 = torch.randn(1, 1, 2, 2)output_fp32 = model_fp32(input_fp32)print(output_fp32)output_int8 = model_int8(input_fp32)print(output_int8)

输出结果如下:

DemoModel((conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))(relu): ReLU()(fc): Linear(in_features=2, out_features=2, bias=True)
)
DemoModel((conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))(relu): ReLU()(fc): DynamicQuantizedLinear(in_features=2, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
tensor([[[[0.3120, 0.3042],[0.3120, 0.3042]]]], grad_fn=<AddBackward0>)
tensor([[[[0.3120, 0.3042],[0.3120, 0.3042]]]])

模型量化策略

  当前,由于量化算子的覆盖有限,因此,对于不同的深度学习模型,其量化策略不同,见下表:

模型量化策略原因
LSTM/RNNDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
BERT/TransformerDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
CNNStatic Quantization模型吞吐量由激活函数的内存带宽决定
CNNQuantization Aware Training模型准确率不能由Static Quantization获取的情况

   下面对BERT模型进行训练后动态量化,分析模型在量化前后,推理效果和推理性能的变化。

实验

   我们使用的训练后的模型为中文文本分类模型,其训练过程可以参考文章:NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调 。
   训练后的BERT模型动态量化实验的设置如下:

  1. base model: bert-base-chinese
  2. CPU info: x86-64, Intel® Core™ i5-10210U CPU @ 1.60GHz
  3. batch size: 1
  4. thread: 1

   具体的实验过程如下

  • 加载模型及tokenizer
import torch
from transformers import AutoModelForSequenceClassificationMAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)
from transformers import AutoTokenizer, DataCollatorWithPaddingtokenizer = AutoTokenizer.from_pretrained(checkpoint)
  • 测试数据集
import pandas as pdtest_df = pd.read_csv("./data/sougou/test.csv")test_df.head()
textlabel
0届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...0
1商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...0
2今天下午,沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多,因此这场比赛实际...0
3本报讯中国足协准备好了与特鲁西埃谈判的合同文本,也在北京给他预订好了房间,但特鲁西埃爽约了!...0
4网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日,2006年尤伯杯羽毛球赛在日...0
  • 量化前模型的推理时间及评估指标
import numpy as np
import times_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():row_s_time = time.time()true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)# print(encoded_text)logits = model(**encoded_text)label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]pred_labels.append(label_id)print(i, (time.time() - row_s_time)*1000, label_id)print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])
0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time:  352.44047810332944
from sklearn.metrics import classification_reportprint(classification_report(true_labels, pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9691    0.9495    0.9592        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9895    0.9495    0.9691        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 设置量化后端
# 模型量化
cpu_device = torch.device("cpu")
torch.backends.quantized.supported_engines
['none', 'onednn', 'x86', 'fbgemm']
torch.backends.quantized.engine = 'x86'
  • 量化后模型的推理时间及评估指标
# 8-bit 量化
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8
).to(cpu_device)
q_s_time = time.time()
q_true_labels, q_pred_labels = [], [] for i, row in test_df.iterrows():row_s_time = time.time()q_true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(cpu_device)logits = quantized_model(**encoded_text)label_id = np.argmax(logits[0].detach().numpy(), axis=1)[0]q_pred_labels.append(label_id)print(i, (time.time() - row_s_time) * 1000, label_id)print("avg time: ", (time.time() - q_s_time) * 1000 / test_df.shape[0])
0 195.47462463378906 0
100 247.33805656433105 1
200 219.41304206848145 2
300 206.44831657409668 3
400 187.4992847442627 4
avg time:  217.63229466447928
from sklearn.metrics import classification_reportprint(classification_report(q_true_labels, q_pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9688    0.9394    0.9538        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9896    0.9596    0.9744        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 量化前后模型大小对比
import osdef print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print("Size (MB): ", os.path.getsize("temp.p")/1e6)os.remove("temp.p")print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB):  409.155273
Size (MB):  152.627621

  量化后端(Quantization backend)取决于CPU架构,不同计算机的CPU架构不同,因此,默认的动态量化不一定在所有的CPU上都能生效,需根据自己计算机的CPU架构设置好对应的量化后端。另外,不同的量化后端也有些许差异。Linux服务器使用uname -a可查看CPU信息。
  重复上述实验过程,以模型的最大输入长度为变量,取值为128,256,384,每种情况各做3次实验,结果如下:

实验最大长度量化前平均推理时间(ms)量化前weighted F1值量化前平均推理时间(ms)量化前weighted F1值
实验138410660.97976860.9838
实验23841047.60.9899738.10.9879
实验33841020.90.9817714.00.9838
实验1256668.70.9717431.40.9718
实验2256675.10.9717449.90.9718
实验3256656.00.9717446.50.9718
实验1128335.80.9737200.50.9737
实验2128336.50.9737227.20.9737
实验3128352.40.9737217.60.9737

  综上所述,对于训练后的BERT模型(文本分类模型)进行动态量化,其结论如下:

  • 模型推理效果:量化前后基本相同,量化后略有下降
  • 模型推理时间:量化后平均提速约1.52倍

总结

  本文介绍了量化基本概念,PyTorch模型量化方式,以及对BERT模型训练后进行动态量化后在推理效果和推理性能上的实验。
  本文项目已开源至Github项目:https://github.com/percent4/dynamic_quantization_on_bert 。
  本人已开通个人博客网站,网址为:https://percent4.github.io/ ,欢迎大家访问~

相关文章:

NLP(六十七)BERT模型训练后动态量化(PTDQ)

本文将会介绍BERT模型训练后动态量化&#xff08;Post Training Dynamic Quantization&#xff0c;PTDQ&#xff09;。 量化 在深度学习中&#xff0c;量化&#xff08;Quantization&#xff09;指的是使用更少的bit来存储原本以浮点数存储的tensor&#xff0c;以及使用更少的…...

机器学习和数据挖掘04-PowerTransformer与 MinMaxScaler

概念 PowerTransformer&#xff08;幂变换器&#xff09; PowerTransformer 是用于对数据进行幂变换&#xff08;也称为Box-Cox变换&#xff09;的预处理工具。幂变换可以使数据更接近正态分布&#xff0c;这有助于某些机器学习算法的性能提升。它支持两种常用的幂变换&#x…...

1.15 自实现GetProcAddress

在正常情况下&#xff0c;要想使用GetProcAddress函数&#xff0c;需要首先调用LoadLibraryA函数获取到kernel32.dll动态链接库的内存地址&#xff0c;接着在调用GetProcAddress函数时传入模块基址以及模块中函数名即可动态获取到特定函数的内存地址&#xff0c;但在有时这个函…...

总结ADX指标交易的好处

股神巴菲特从一个穷小子变成世界富豪&#xff0c;而闻名世界。anzo capital昂首资本以为这辈子再也不会和巴菲特产生任何交集&#xff0c;直到我看了巴菲特的发家史&#xff0c;才发现原来我们都使用过ADX指标盈利过&#xff0c;下面anzo capital昂首资本就总结一下使用ADX指…...

ConsoleApplication815项目(直接加载+VEH Hook Load)

上线图 ConsoleApplication815.cpp #include <iostream> #include<Windows.h> #include "detours.h" #include "detver.h" #pragma comment(lib,"detours.lib")#pragma warning(disable:4996)LPVOID Beacon_address; SIZE_T Beacon…...

事务(SQL)

事务概述 事务是一组操作的集合&#xff0c;他是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向西永提交或撤销操作请求。这组操作&#xff0c;要么全部执行成功&#xff0c;要么全部执行失败。 事务操作 查看/设置事务提交方式 -- 查看/设置事务…...

原型,原型链,继承(圣杯模式)

经典模式和圣杯模式区别 经典模式和圣杯模式都是用于解决构造函数继承和原型继承的问题&#xff0c;但它们在实现继承的方式上有所不同。 经典模式是通过将子类的原型对象设置为父类的实例来实现继承&#xff0c;然后将子类的构造函数设置为子类本身。这样子类既可以继承父类…...

远程方法调用中间件Dubbo在spring项目中的使用

Dubbo是一个分布式服务框架&#xff0c;它可以帮助我们快速开发和提供高性能、高可靠性的分布式服务&#xff0c;同时提供服务治理、容错、负载均衡等功能。 使用Dubbo可以分为以下步骤&#xff1a; 引入Dubbo依赖&#xff1a;在项目的pom.xml文件中引入Dubbo的依赖。编写服务…...

MFC -- Date Time Picker 控件使用

当前环境&#xff1a;VS2015 Windows 10 //&#xff08;一&#xff09;使用普通函数&#xff0c; 获取当前时间CString strCurrentTime; COleDateTime m_time COleDateTime::GetCurrentTime(); strCurrentTime m_time.Format(_T("%Y-%m-%d %H:%M:%S")); SetDlgIt…...

zookeeper 集群

zookeeper 集群 1、zookeeper 集群说明 initLimit 是Zookeeper用它来限定集群中的Zookeeper服务器连接到Leader的时限 syncLimit 限制了follower服务器与leader服务器之间请求和应答之间的时限 服务器名称与地址&#xff1a;集群信息&#xff08;服务器编号&#xff0c;服务器…...

stable diffusion实践操作-随机种子seed

系列文章目录 stable diffusion实践操作 文章目录 系列文章目录前言一、seed是什么&#xff1f;二、使用步骤1.多批次随机生成多张图片2.提取图片seed3. 根据seed 再次培养4 seed使用4.1 复原别人图4.1 轻微修改4.2 固定某个人物-修改背景 三、差异随机种子1. webUI位置2. 什么…...

大数据可视化大屏实战项目(4)物流数据云看台(包括动态登陆页面)—HTML+CSS+JS【源码在文末】(可用于比赛项目或者作业参考中)

大数据可视化大屏实战项目&#xff08;4&#xff09;物流数据云看台&#xff08;包括动态登陆页面&#xff09;—HTMLCSSJS【源码在文末】&#xff08;可用于比赛项目或者作业参考中&#x1f415;&#x1f415;&#x1f415;&#xff09; 一&#xff0c;项目概览 ☞☞☞☞☞☞…...

在ubuntu下远程链接仓库gitte/github

后期适当加点图片&#xff0c;提高可读性。 本教程是最基础的连接教程&#xff0c;设计git的操作也仅仅局限于push/pull&#xff0c;如果想全面了解&#xff0c;可以参考廖雪峰git教程 在Ubuntu上初始化本地Git仓库并链接到远程Gitee仓库(github同理)&#xff0c;需要按照以下步…...

一些自己整理的工具实用参数

工具实用参数 sqlmap -u: 指定需要测试的目标URL&#xff08;格式&#xff1a;http://www.example.com/test.php?id1&#xff09; --cookie: 设置需要发送的 HTTP Cookie&#xff0c;例如&#xff1a;--cookie"sid123456;PHPSESSID654321" --threads&#xff1a;…...

C# Timer定时器

C# Timer定时器 Timer定时器定时器主要用到的就是Timer的Tick事件&#xff0c;另外还要设置时间间隔&#xff1a; 下面这个实力演示了每隔一秒&#xff0c;picturebox中的图片来回切换&#xff0c;每隔一秒&#xff0c;文本框中显示当前时间。 using System; using System.Co…...

oracle怎么删除表索引

Oracle是目前常用的企业级关系型数据库管理系统&#xff0c;用于存储和管理大量数据。在Oracle中&#xff0c;表索引是用于提高查询效率的重要组成部分&#xff0c;但也有时候需要删除表索引。本文将介绍如何在Oracle中删除表索引。 一、查看表索引 在删除表索引之前&#xff…...

【Tkinter系列13/15】标准化外观和选项数据库

27. 标准化外观和选项数据库 可以轻松地将颜色、字体和其他选项应用于 小部件&#xff0c;当您创建它们时。然而 如果您希望很多小部件具有相同的 背景颜色或字体&#xff0c;指定每个都很乏味 每次选项&#xff0c;以及 让用户覆盖您的选择是很好的 他们最喜欢的配色方案、字…...

springboot 集成dubbo

上一篇我们一起认识了Dubbo与RPC&#xff0c;今天我们就来一起学习如何使用Dubbo&#xff0c;并将Dubbo集成到Spring Boot的项目中。我们来看下今天要使用到的软件及版本&#xff1a; 软件 版本 说明 Java 11 Spring Boot 2.7.13 Spring Boot 3.0版本开始&#xff0c;最…...

基于YOLOV8模型和CCPD数据集的车牌目标检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要&#xff1a;基于YOLOV8模型和CCPD数据集的车牌目标检测系统可用于日常生活中检测与定位车牌目标&#xff0c;利用深度学习算法可实现图片、视频、摄像头等方式的目标检测&#xff0c;另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算…...

net user安全隐患

net user test 123456 /add #添加一个test账户,密码是123456net user test /delete #删除test账户windows administrator权限下删除其他账户&#xff0c;难道不需要知道该账户的密码吗&#xff1f; 以前没太注意&#xff0c;但是这算不算是一个漏洞呢&#xff1f; 另外&#…...

[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?

&#x1f9e0; 智能合约中的数据是如何在区块链中保持一致的&#xff1f; 为什么所有区块链节点都能得出相同结果&#xff1f;合约调用这么复杂&#xff0c;状态真能保持一致吗&#xff1f;本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里&#xf…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

Caliper 配置文件解析:config.yaml

Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

Java + Spring Boot + Mybatis 实现批量插入

在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法&#xff1a;使用 MyBatis 的 <foreach> 标签和批处理模式&#xff08;ExecutorType.BATCH&#xff09;。 方法一&#xff1a;使用 XML 的 <foreach> 标签&#xff…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

【UE5 C++】通过文件对话框获取选择文件的路径

目录 效果 步骤 源码 效果 步骤 1. 在“xxx.Build.cs”中添加需要使用的模块 &#xff0c;这里主要使用“DesktopPlatform”模块 2. 添加后闭UE编辑器&#xff0c;右键点击 .uproject 文件&#xff0c;选择 "Generate Visual Studio project files"&#xff0c;重…...

【笔记】AI Agent 项目 SUNA 部署 之 Docker 构建记录

#工作记录 构建过程记录 Microsoft Windows [Version 10.0.27871.1000] (c) Microsoft Corporation. All rights reserved.(suna-py3.12) F:\PythonProjects\suna>python setup.py --admin███████╗██╗ ██╗███╗ ██╗ █████╗ ██╔════╝…...

使用python进行图像处理—图像滤波(5)

图像滤波是图像处理中最基本和最重要的操作之一。它的目的是在空间域上修改图像的像素值&#xff0c;以达到平滑&#xff08;去噪&#xff09;、锐化、边缘检测等效果。滤波通常通过卷积操作实现。 5.1卷积(Convolution)原理 卷积是滤波的核心。它是一种数学运算&#xff0c;…...

2025 后端自学UNIAPP【项目实战:旅游项目】7、景点详情页面【完结】

1、获取景点详情的请求【my_api.js】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http(/login/getWXSessionKey, {code,avatar}); };//…...