当前位置: 首页 > news >正文

LSTM 词语模型上的动态量化

原文链接 

(beta) Dynamic Quantization on an LSTM Word Language Model — PyTorch Tutorials 2.3.0+cu121 documentation

引言

量化涉及将模型的权重和激活值从浮点数转换为整数,这样可以缩小模型大小,加快推理速度,但对准确性的影响很小。
在本教程中,我们将把最简单的量化形式--动态量化--应用到基于 LSTM 的下一个单词预测模型中,这与 PyTorch 示例中的单词语言模型密切相关。

# imports
import os
from io import open
import timeimport torch
import torch.nn as nn
import torch.nn.functional as F

 定义模型

  在此,我们按照单词语言模型示例中的模型,定义 LSTM 模型架构。

class LSTMModel(nn.Module):"""Container module with an encoder, a recurrent module, and a decoder."""def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):super(LSTMModel, self).__init__()self.drop = nn.Dropout(dropout)self.encoder = nn.Embedding(ntoken, ninp)self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)self.decoder = nn.Linear(nhid, ntoken)self.init_weights()self.nhid = nhidself.nlayers = nlayersdef init_weights(self):initrange = 0.1self.encoder.weight.data.uniform_(-initrange, initrange)self.decoder.bias.data.zero_()self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, input, hidden):emb = self.drop(self.encoder(input))output, hidden = self.rnn(emb, hidden)output = self.drop(output)decoded = self.decoder(output)return decoded, hiddendef init_hidden(self, bsz):weight = next(self.parameters())return (weight.new_zeros(self.nlayers, bsz, self.nhid),weight.new_zeros(self.nlayers, bsz, self.nhid))

加载文本数据

 接下来,我们将 Wikitext-2 数据集加载到[Corpus]{.title-ref}中,同样按照单词语言模型示例进行预处理。

class Dictionary(object):def __init__(self):self.word2idx = {}self.idx2word = []def add_word(self, word):if word not in self.word2idx:self.idx2word.append(word)self.word2idx[word] = len(self.idx2word) - 1return self.word2idx[word]def __len__(self):return len(self.idx2word)class Corpus(object):def __init__(self, path):self.dictionary = Dictionary()self.train = self.tokenize(os.path.join(path, 'train.txt'))self.valid = self.tokenize(os.path.join(path, 'valid.txt'))self.test = self.tokenize(os.path.join(path, 'test.txt'))def tokenize(self, path):"""Tokenizes a text file."""print(path)assert os.path.exists(path), f"Error: The path {path} does not exist."# Add words to the dictionarywith open(path, 'r', encoding="utf8") as f:for line in f:words = line.split() + ['<eos>']for word in words:self.dictionary.add_word(word)# Tokenize file contentwith open(path, 'r', encoding="utf8") as f:idss = []for line in f:words = line.split() + ['<eos>']ids = []for word in words:ids.append(self.dictionary.word2idx[word])idss.append(torch.tensor(ids).type(torch.int64))ids = torch.cat(idss)return idsmodel_data_filepath = ".\data\\"corpus = Corpus(model_data_filepath + 'wikitext-2')

加载预训练模型

 这是一个关于动态量化的教程,一种在模型训练完成后应用的量化技术。因此,我们只需将一些预先训练好的权重加载到该模型架构中;这些权重是通过使用单词语言模型示例中的默认设置进行五次历时训练获得的。

ntokens = len(corpus.dictionary)model = LSTMModel(ntoken=ntokens,ninp=512,nhid=256,nlayers=5,
)# model.load_state_dict(
#     torch.load(
#         model_data_filepath + 'word_language_model_quantize.pth',
#         map_location=torch.device('cpu')
#     )
# )model.eval()
print(model)

现在让我们生成一些文本,以确保预训练模型正常工作 - 与之前类似,我们遵循此处

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000with open(model_data_filepath + 'out.txt', 'w') as outf:with torch.no_grad():  # no tracking historyfor i in range(num_words):output, hidden = model(input_, hidden)word_weights = output.squeeze().div(temperature).exp().cpu()word_idx = torch.multinomial(word_weights, 1)[0]input_.fill_(word_idx)word = corpus.dictionary.idx2word[word_idx]outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))if i % 100 == 0:print('| Generated {}/{} words'.format(i, 1000))with open(model_data_filepath + 'out.txt', 'r') as outf:all_output = outf.read()print(all_output)

虽然不是 GPT-2,但看起来模型已经开始学习语言结构了!
我们差不多可以演示动态量化了。我们只需要再定义几个辅助函数:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1# create test data set
def batchify(data, bsz):# Work out how cleanly we can divide the dataset into ``bsz`` parts.nbatch = data.size(0) // bsz# Trim off any extra elements that wouldn't cleanly fit (remainders).data = data.narrow(0, 0, nbatch * bsz)# Evenly divide the data across the ``bsz`` batches.return data.view(bsz, -1).t().contiguous()test_data = batchify(corpus.test, eval_batch_size)# Evaluation functions
def get_batch(source, i):seq_len = min(bptt, len(source) - 1 - i)data = source[i:i + seq_len]target = source[i + 1:i + 1 + seq_len].reshape(-1)return data, targetdef repackage_hidden(h):"""Wraps hidden states in new Tensors, to detach them from their history."""if isinstance(h, torch.Tensor):return h.detach()else:return tuple(repackage_hidden(v) for v in h)def evaluate(model_, data_source):# Turn on evaluation mode which disables dropout.model_.eval()total_loss = 0.hidden = model_.init_hidden(eval_batch_size)with torch.no_grad():for i in range(0, data_source.size(0) - 1, bptt):data, targets = get_batch(data_source, i)output, hidden = model_(data, hidden)hidden = repackage_hidden(hidden)output_flat = output.view(-1, ntokens)total_loss += len(data) * criterion(output_flat, targets).item()return total_loss / (len(data_source) - 1)

测试动态量化

最后,我们可以在模型上调用 torch.quantization.quantize_dynamic!具体来说就是
我们指定要对模型中的 nn.LSTM 和 nn.Linear 模块进行量化
我们指定要将权重转换为 int8 值

import torch.quantizationquantized_model = torch.quantization.quantize_dynamic(model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)# 模型看起来没有变化,这对我们有什么好处呢?首先,我们看到模型的尺寸大幅缩小:
def print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print('Size (MB):', os.path.getsize("temp.p") / 1e6)os.remove('temp.p')print_size_of_model(model)
print_size_of_model(quantized_model)

其次,我们看到推理时间更快,而评估损失没有区别:
注:我们将单线程比较的线程数设为一个,因为量化模型是单线程运行的。

torch.set_num_threads(1)def time_model_evaluation(model, test_data):s = time.time()loss = evaluate(model, test_data)elapsed = time.time() - sprint('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

在本地 MacBook Pro 上运行这个程序,在不进行量化的情况下,推理时间约为 200 秒,而在进行量化的情况下,推理时间仅为 100 秒左右。

 结论

动态量化是减少模型大小的一种简单方法,但对准确性的影响有限。
感谢您的阅读!我们一如既往地欢迎任何反馈,如果您有任何问题,请在此创建一个问题。

相关文章:

LSTM 词语模型上的动态量化

原文链接 (beta) Dynamic Quantization on an LSTM Word Language Model — PyTorch Tutorials 2.3.0cu121 documentation 引言 量化涉及将模型的权重和激活值从浮点数转换为整数&#xff0c;这样可以缩小模型大小&#xff0c;加快推理速度&#xff0c;但对准确性的影响很小…...

STM32 proteus + STM32Cubemx仿真教程(第一课LED教程)

文章目录 前言一、STM32点亮LED灯的原理1.1GPIO是什么1.2点亮LED灯的原理 二、STM32Cubemx创建工程三、proteus仿真电路图四、程序代码编写1.LED灯操作函数介绍HAL_GPIO_WritePin函数原型参数说明示例代码 HAL_GPIO_TogglePin函数原型参数说明示例代码 2.代码编写3.烧写程序 总…...

享元模式

前言 享元模式&#xff1a;运用共享技术有效地支持大量细粒度的对象。 在享元对象内部并且不会随环境改变而改变的共享部分&#xff0c;可以称为是享元对象的内部状态&#xff0c;而随环境改变而改变的、不可以共享的状态就是外部状态了。事实上&#xff0c;享元模式可以避免大…...

R语言数据分析16-针对芬兰污染指数的分析与考察

1. 研究背景及意义 近年来&#xff0c;随着我国科技和经济高速发展&#xff0c;人们生活质量也随之显著提高。但是&#xff0c; 环境污染问题也日趋严重&#xff0c;给人们的生活质量和社会生产的各个方面都造成了许多不 利的影响。空气污染作为环境污染主要方面&#xff0c;更…...

Search用法Python:深入探索搜索功能的应用与技巧

Search用法Python&#xff1a;深入探索搜索功能的应用与技巧 在Python编程中&#xff0c;搜索功能是一项至关重要的技能&#xff0c;它能够帮助我们快速定位并处理数据。然而&#xff0c;对于初学者来说&#xff0c;如何高效地使用搜索功能可能会带来一些困惑。本文将从四个方…...

STM32的FreeRtos的学习

首先就是去官网下载一个源文件&#xff1a;FreeRtos官网 下载下来的是一个zip文件&#xff0c;解压缩了。 然后再工程文件夹中创建个文件夹&#xff1a; 在这个文件夹中创建3个文件夹&#xff1a; 然后开始把下载下来的文件夹中的文件挑选出来放到我们的工程文件夹中&#xff1…...

从零入手人工智能(2)——搭建开发环境

1.前言 作为一名单片机工程师&#xff0c;想要转型到人工智能开发领域的道路确实充满了挑战与未知。记得当我刚开始这段旅程时&#xff0c;心中充满了迷茫和困惑。面对全新的领域&#xff0c;我既不清楚如何入手&#xff0c;也不知道能用人工智能干什么。正是这些迷茫和困惑&a…...

Web前端指南

前言 前端开发员主要负责网站的设计、外观和感觉。他们设计引人入胜的在线用户体验,激发用户兴趣,鼓励用户重复访问。他们与设计师密切合作,使网站美观、实用、快捷。 如果您喜欢创造性思维、打造更好的体验并对视觉设计感兴趣,这将是您的理想职业道路。 探讨前端、后端以…...

AI菜鸟向前飞 — LangChain系列之十七 - 剖析AgentExecutor

AgentExecutor 顾名思义&#xff0c;Agent执行器&#xff0c;本篇先简单看看LangChain是如何实现的。 先回顾 AI菜鸟向前飞 — LangChain系列之十四 - Agent系列&#xff1a;从现象看机制&#xff08;上篇&#xff09; AI菜鸟向前飞 — LangChain系列之十五 - Agent系列&#…...

nodejs 第三方库 exiftool-vendored

exiftool-vendored 是一款可以帮助你快捷修改图片信息的第三方库。如果你想要批量修改图片信息的话&#xff0c;那么它是一个不错的选择。 1.导入第三方库 在控制台中执行下面代码即可。 npm install exiftool-vendored --save2.获取信息 这里给出例子。 const { exiftool …...

docker部署redis实践

1.拉取redis镜像 # 拉取镜像 sudo docker pull redis2.创建映射持久化目录 # 创建目录 sudo mkdir -p $PWD/redis/{conf,data}3. 运行redis 容器&#xff0c;查看当前redis 版本号 # 运行 sudo docker run --name redis -d -p 6379:6379 redis # 查看版本号 sudo docker ex…...

Web前端学习之路:深入探索学习时长与技能进阶的奥秘

Web前端学习之路&#xff1a;深入探索学习时长与技能进阶的奥秘 在数字化时代&#xff0c;Web前端技术成为了连接用户与互联网世界的桥梁。对于初学者来说&#xff0c;学习Web前端究竟需要多久&#xff0c;以及如何高效掌握相关技能&#xff0c;一直是困扰他们的难题。本文将从…...

如何不用命令创建用户

都有哪些操作&#xff1a; 1、在/etc/passwd添加一行 2、在/etc/shadow添加一行 3、在/etc/group添加一行 4、创建用户家目录 5、创建用户邮件文件 例如&#xff1a; 创建用户jerry 要求&#xff1a; uid&#xff1a;777 主组&#xff1a;hadoop&#xff08;800&#xff09…...

基于Python实现可视化分析中国500强排行榜数据的设计与实现

基于Python实现可视化分析中国500强排行榜数据的设计与实现 “Design and Implementation of Visual Analysis for China’s Top 500 Companies Ranking Data using Python” 完整下载链接:基于Python实现可视化分析中国500强排行榜数据的设计与实现 文章目录 基于Python实现…...

VUE3 学习笔记(13):VUE3 下的Element-Plus基本使用

UI是页面的门面&#xff0c;一个好的UI自然令人赏心悦目&#xff1b;国人团队开发的ElementUI在众多UI中较为常见&#xff0c;因此通过介绍它的使用让大家更好的了解第三方UI的使用。 安装 Npm install element-plus --save 或 Cnpm install element-plus --save 配置 全局配置…...

MySql数据库安全加固

设置密码复杂度策略 查看密码策略 SHOW VARIABLES LIKE validate_password%; 设置密码策略 INSTALL PLUGIN validate_password SONAME validate_password.so; 设置登陆失败策略 安装插件&#xff08;谨慎操作&#xff0c;可能会导致数据库卡死&#xff09; install plug…...

Nginx(title小图标)修改方法

本章主要讲述Nginx如何上传网站图标。 操作系统&#xff1a; CentOS Stream 9 首先我们bing搜索ico网站图标在线设计&#xff0c;找到喜欢的设计分格并下载。 是一个压缩包 然后我们上传到nginx解压 [rootlocalhost html]# rz[rootlocalhost html]# unzip favicon_logosc.z…...

iOS 17.5中的一个漏洞

i0S 17.5中的一个漏洞 iOS 17.5中的一个漏洞会使已刚除的照片重新出现&#xff0c;并目此问题似乎会影响甚至已擦除并出售给他人的 iPhone 和 iPad. 在2023年9月&#xff0c;一位Reddit用户根据Apple的指南擦除了他的iPad&#xff0c;并将其卖给了一位朋友。然而&#xff0c;这…...

如何在 iPhone 上恢复已删除的短信

本文介绍如何检索已删除的短信和 iMessage 以及恢复丢失的消息。说明适用于 iOS 17 及更高版本。 如何在 iOS 17及更高版本中恢复文本 恢复已删除短信的最简单方法是使用 iOS 17。从删除短信到恢复它有 30 到 40 天的时间。 在“信息”的对话屏幕中&#xff0c;选择“过滤器”…...

矩阵练习1

73.矩阵置零 这道题相对简单。 首先我们需要标记需要置零的行和列&#xff0c;可以在遍历矩阵中的元素遇到0&#xff0c;则将其行首和列首元素置为0。在此过程中首行、首列会受影响&#xff0c;因此先用两个变量记录首行、首列是否需要被置0&#xff0c;接着遍历非首行、非首…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python&#xff5c;GIF 解析与构建&#xff08;5&#xff09;&#xff1a;手搓截屏和帧率控制 一、引言 二、技术实现&#xff1a;手搓截屏模块 2.1 核心原理 2.2 代码解析&#xff1a;ScreenshotData类 2.2.1 截图函数&#xff1a;capture_screen 三、技术实现&…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

Git常用命令完全指南:从入门到精通

Git常用命令完全指南&#xff1a;从入门到精通 一、基础配置命令 1. 用户信息配置 # 设置全局用户名 git config --global user.name "你的名字"# 设置全局邮箱 git config --global user.email "你的邮箱example.com"# 查看所有配置 git config --list…...