AI学习指南自然语言处理篇-Transformer模型的实践
AI学习指南自然语言处理篇 - Transformer模型的实践
目录
- 引言
- Transformer模型概述
- 自注意力机制
- 编码器-解码器结构
- 环境准备
- Transformer模型的实现
- 编码器实现
- 解码器实现
- Transformer模型整体实现
- Transformer在NLP任务中的应用
- 文本分类
- 机器翻译
- 总结与展望
引言
在过去的数年里,深度学习为自然语言处理(NLP)领域注入了新的活力。特别是Transformer模型的提出,极大地改善了许多NLP任务的效果。本文将深入探讨Transformer模型的实现,以及其在NLP应用中的使用方法,并提供实际的Python代码示例。
Transformer模型概述
自注意力机制
自注意力机制(Self-Attention)是Transformer模型的核心。在处理序列数据时,这种机制允许模型关注序列中的不同部分,从而捕捉到长距离的依赖关系。
给定输入序列 ( X = [ x 1 , x 2 , … , x n ] ) ( X = [x_1, x_2, \ldots, x_n] ) (X=[x1,x2,…,xn]),自注意力计算过程如下:
-
生成Query、Key、Value:
- ( Q = X W Q ) ( Q = XW^Q ) (Q=XWQ)
- ( K = X W K ) ( K = XW^K ) (K=XWK)
- ( V = X W V ) ( V = XW^V ) (V=XWV)
-
计算注意力权重:
- ( Attention ( Q , K , V ) = softmax ( Q K T d k ) V ) ( \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ) (Attention(Q,K,V)=softmax(dkQKT)V)
-
输出:
- 最终输出与输入长度相同,捕捉到全局的上下文信息。
编码器-解码器结构
Transformer的架构主要分为编码器和解码器两部分。编码器对输入序列进行特征提取,而解码器负责生成目标序列。
- 编码器:由多个相同的层堆叠而成,每层包含自注意力机制和前馈神经网络。
- 解码器:同样由多个层堆叠而成,但每层包含掩蔽自注意力机制,以确保在生成序列时不会“看到”后续的token。
环境准备
在实现Transformer之前,我们需要设置好Python环境。推荐使用PyTorch
或TensorFlow
。以下是使用PyTorch
的环境准备步骤。
安装PyTorch
在命令行中运行以下命令以安装PyTorch:
pip install torch torchvision torchaudio
安装其他依赖
pip install numpy pandas matplotlib
Transformer模型的实现
编码器实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, d_model, nhead):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.nhead = nheadself.head_dim = d_model // nheadassert (self.head_dim * nhead == d_model), "d_model must be divisible by nhead"self.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out_linear = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)Q = self.q_linear(query).view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)K = self.k_linear(key).view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)V = self.v_linear(value).view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)attn_weights = F.softmax(Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1)if mask is not None:attn_weights = attn_weights.masked_fill(mask == 0, float("-inf"))output = (attn_weights @ V).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)return self.out_linear(output)class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):super(TransformerEncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, nhead)self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)def forward(self, src, src_mask=None):src2 = self.self_attn(src, src, src, mask=src_mask)src = self.norm1(src + src2)src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))src = self.norm2(src + src2)return srcclass TransformerEncoder(nn.Module):def __init__(self, num_layers, d_model, nhead, dim_feedforward, dropout=0.1):super(TransformerEncoder, self).__init__()self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])def forward(self, src, src_mask=None):for layer in self.layers:src = layer(src, src_mask)return src
解码器实现
class TransformerDecoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):super(TransformerDecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, nhead)self.cross_attn = MultiHeadAttention(d_model, nhead)self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):tgt2 = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)tgt = self.norm1(tgt + tgt2)tgt2 = self.cross_attn(tgt, memory, memory, mask=memory_mask)tgt = self.norm2(tgt + tgt2)tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))tgt = self.norm3(tgt + tgt2)return tgtclass TransformerDecoder(nn.Module):def __init__(self, num_layers, d_model, nhead, dim_feedforward, dropout=0.1):super(TransformerDecoder, self).__init__()self.layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):for layer in self.layers:tgt = layer(tgt, memory, tgt_mask, memory_mask)return tgt
Transformer模型整体实现
class Transformer(nn.Module):def __init__(self, num_encoder_layers, num_decoder_layers, d_model, nhead, dim_feedforward, dropout=0.1):super(Transformer, self).__init__()self.encoder = TransformerEncoder(num_encoder_layers, d_model, nhead, dim_feedforward, dropout)self.decoder = TransformerDecoder(num_decoder_layers, d_model, nhead, dim_feedforward, dropout)self.out_linear = nn.Linear(d_model, d_model)def forward(self, src, tgt, src_mask=None, tgt_mask=None):memory = self.encoder(src, src_mask)output = self.decoder(tgt, memory, tgt_mask)return self.out_linear(output)
Transformer在NLP任务中的应用
文本分类
在文本分类任务中,我们可以使用Transformer模型进行文本特征提取,然后将提取到的特征输入到全连接层进行分类。
实现文本分类模型
class TextClassifier(nn.Module):def __init__(self, num_classes, num_layers, d_model, nhead, dim_feedforward, dropout=0.1):super(TextClassifier, self).__init__()self.transformer = Transformer(num_layers, num_layers, d_model, nhead, dim_feedforward, dropout)self.fc = nn.Linear(d_model, num_classes)def forward(self, src):output = self.transformer(src, src) # src作为tgtoutput = output.mean(dim=1) # 全局平均池化return self.fc(output)# 实例化模型
model = TextClassifier(num_classes=3, num_layers=6, d_model=512, nhead=8, dim_feedforward=2048)
训练与评估
# 训练示例
import torch.optim as optim
from sklearn.metrics import accuracy_score# 假设有数据集train_loader和test_loader
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练过程
for epoch in range(10):model.train()for batch in train_loader:optimizer.zero_grad()inputs, targets = batchoutputs = model(inputs)loss = F.cross_entropy(outputs, targets)loss.backward()optimizer.step()# 评估过程
model.eval()
y_true, y_pred = [], []
with torch.no_grad():for batch in test_loader:inputs, targets = batchoutputs = model(inputs)preds = outputs.argmax(dim=1)y_true.extend(targets.numpy())y_pred.extend(preds.numpy())accuracy = accuracy_score(y_true, y_pred)
print(f"准确率: {accuracy:.4f}")
机器翻译
在机器翻译任务中,Transformer已经成为了最常用的架构之一,以下是机器翻译的实现步骤。
数据预处理
首先,我们需要处理并准备我们的翻译数据集,例如使用torchtext
库来处理。
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator# 定义源语和目标语
SRC = Field(tokenize="spacy", src_lang="de", lower=True)
TRG = Field(tokenize="spacy", src_lang="en", lower=True)# 下载中文-英文数据集
train_data, valid_data, test_data = Multi30k.splits(exts=(".de", ".en"), fields=(SRC, TRG))# 构建词汇表
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)# 创建数据迭代器
train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), batch_size=32, device=torch.device("cuda")
)
实现机器翻译模型
机器翻译模型利用Transformer的编码器-解码器结构。
class Translator(nn.Module):def __init__(self, num_layers, d_model, nhead, dim_feedforward, dropout=0.1):super(Translator, self).__init__()self.transformer = Transformer(num_layers, num_layers, d_model, nhead, dim_feedforward, dropout)def forward(self, src, tgt):return self.transformer(src, tgt)
训练机器翻译模型
model = Translator(num_layers=6, d_model=512, nhead=8, dim_feedforward=2048)optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练过程
for epoch in range(10):model.train()for batch in train_iterator:src, tgt = batch.src, batch.trgtgt_input = tgt[:-1, :]optimizer.zero_grad()output = model(src, tgt_input)# 转换输出的维度output_dim = output.shape[-1]output = output.view(-1, output_dim)tgt = tgt[1:, :].view(-1)loss = F.cross_entropy(output, tgt)loss.backward()optimizer.step()
评估机器翻译模型
可以使用如BLEU等指标来评估翻译质量。
from nltk.translate.bleu_score import sentence_bleu# 评估过程
model.eval()
with torch.no_grad():for batch in test_iterator:src, tgt = batch.src, batch.trgoutput = model(src, tgt) # tgt作为输入# 假设似乎实现了一个解码过程# 这里我们假设生成了一系列翻译句子references = [tgt[i][1:].tolist() for i in range(tgt.size(0))]predictions = [output[i].argmax(dim=-1).tolist() for i in range(output.size(0))]for reference, prediction in zip(references, predictions):print("BLEU Score:", sentence_bleu([reference], prediction))
总结与展望
本文深入探讨了Transformer模型的实现及在NLP任务中的应用,包括文本分类与机器翻译。借助于PyTorch,我们能够轻松地构建和训练Transformer模型。
未来,Transformer模型可能会与更多的技术结合,继续推动自然语言处理领域的发展。随着NLP领域的快速发展,研究者和工程师可以期待新的创新与应用。
希望本文能够为您深入理解Transformer模型及其应用提供帮助!
相关文章:
AI学习指南自然语言处理篇-Transformer模型的实践
AI学习指南自然语言处理篇 - Transformer模型的实践 目录 引言Transformer模型概述 自注意力机制编码器-解码器结构 环境准备Transformer模型的实现 编码器实现解码器实现Transformer模型整体实现 Transformer在NLP任务中的应用 文本分类机器翻译 总结与展望 引言 在过去的数…...

【LVGL速成】LVGL修改标签文本(GUI Guider生成的字库问题)
目录 前置篇章: 一.问题背景 二.失败方案 三.成功方案 1.Gui guider的源码结构 2.手动生成字体 3.Keil中配置相关文件 编辑 4.修改文字 四.字体样式函数说明 前置篇章: 【LVGL快速入门(二)】LVGL开源框架入门教程之框架使用(UI界面设计)_lvgl…...

C语言项目实践-贪吃蛇
⽬录: 1. 游戏背景 2. 游戏效果演⽰ 3. 实现的⽬标 4. 实现的定位 5. 技术要点 6. 贪吃蛇游戏设计与分析 7. 贪吃蛇游戏数据结构设计 8. 相关Win32API介绍 9. 参考代码 正文开始 1. 游戏背景 贪吃蛇是久负盛名的游戏,它也和俄罗斯⽅块…...

在kanzi 3.9.8里使用API创建自定义材质
1. kanzi studio设置 1.1 创建一个纹理贴图,起名Render Target Texture 1.2 创建一个Image节点,使用该贴图 2. 代码设置 2.1 创建一个自定义节点类 class mynode2d : public Node2D { public: virtual void renderOverride(Renderer3D& renderer…...

IDEA中通义灵码的使用技巧
大家好,我是 V 哥。在日常写代码的过程中,通过 AI 工具辅助开发已是当下程序员惯用的方式,V 哥在使用了众多的 AI 工具后,多数情况下,选择通义灵码来辅助开发,尤其是解释代码和生成单元测试功能甚是好用&am…...
JS中let var 和const区别
在JavaScript中,let、var 和 const 都是用来声明变量的关键字,但它们之间有几个关键的区别: 作用域(Scope): var 声明的变量拥有函数作用域(function scope),这意味着如果 var 变量在…...
ansible详细介绍和具体步骤
Ansible简介 1.1 Ansible的基本概念 Ansible是一款开源的自动化工具,旨在简化IT操作的复杂性。它由Michael DeHaan创建,并于2012年发布,随后在2015年被Red Hat收购。Ansible的核心理念是“简单即美”,它通过使用YAML(…...

利用LangChain与LLM打造个性化私有文档搜索系统
我们知道LLM(大语言模型)的底模是基于已经过期的公开数据训练出来的,对于新的知识或者私有化的数据LLM一般无法作答,此时LLM会出现“幻觉”。针对“幻觉”问题,一般的解决方案是采用RAG做检索增强。 但是我们不可能把…...

linux中的软、硬链接
目录 引言 简单介绍 如何理解软硬链接 链接的应用 环路问题 引言 在Linux操作系统的广阔天地中,文件管理是其核心功能之一。而软链接和硬链接作为Linux文件系统中的两种特殊链接方式,它们为用户提供了灵活的文件访问途径和高效的磁盘空间利用手段。…...

Ubuntu 系统、Docker配置、Docker的常用软件配置(下)
前言 书接上文,现在操作系统已经有了,作为程序的载体Docker也安装配置好了,接下来我们需要让Docker发挥它的法力了。 Docker常用软件的安装 1.Redis 缓存安装 1.1 下载 docker pull redis:7.4.1 #可改为自己需要的版本 1.2 创建本地目录存储…...
jdk,openjdk,oraclejdk
Java是开发语言,不是软件。JDK是软件,使用OpenJDK是免费的,一直免费。而且OpenJDK正儿巴经的Java社区推出来的JDK。 Oracle JDK主要是面向付费能力强的企业用户,收费已经好多年了,不是一两年的事,JDK8是JDK…...
Docker Hub 镜像加速器
零、参考资料 https://gist.github.com/y0ngb1n/7e8f16af3242c7815e7ca2f0833d3ea6Daemon proxy configuration | Docker Docs 一、解决方案 1、问题现象 Error response from daemon: Get "https://index.docker.io/v1/search?qcarlasim%2Fcarla&n25": dia…...

DevOps赋能:优化业务价值流的实战策略与路径(上)
上篇:价值流引领与可视化体系构建 一、前言 在快速迭代的软件项目和产品开发生态中,我们始终围绕两个核心目标:一是确保每一项工作都能为客户创造实际价值,这是产品团队的核心使命;二是确保这些有价值的工作能够高效…...

int的取值范围
原码(True form):原码是一种计算机中对数字的二进制表示方法,数码序列中最高位为符号位,符号位为0表示正数,符号位为1表示负数;其余有效值部分用二进制的绝对值表示。 反码…...
图文检索(16):IDC: Boost Text-to-Image Retrieval via Indirect and Direct Connections
IDC: Boost Text-to-Image Retrieval via Indirect and Direct Connections 摘要3 方法3.1 直接连接3.2 间接连接3.3 DLB 正则化 结论 发布时间(2024 LREC-COLING) 标题:IDC:通过间接和直接连接增强文本到图像的检索 摘要 本文&…...

企业数字化转型:重识、深思、重启新征程-亿发
在当下这个日新月异的时代,企业数字化转型已然成为众多企业竞相追逐的发展方向,可真正能将其领悟透彻并有效落地实施的企业,却并非比比皆是。此刻,亿发软件针对企业数字化转型展开一次更为深入的重识、全面的深思,进而…...
仓颉刷题录-字符串数字转换(一)
文章目录 背景题目:交换后字典序最小的字符串个人感受 这是双子专栏: Cangjie仓颉程序设计-个人总结 本专栏还在持续更新: 仓颉编程cangjie刷题录 背景 报名了一个仓颉的比赛,感觉条件要求挺低的,就想上。哈哈哈。但…...

SpringBoot【实用篇】- 配置高级
文章目录 目标:1.ConfigurationProperties2.宽松绑定/松散绑定3. 常用计量单位绑定4.数据校验 目标: ConfigurationProperties宽松绑定/松散绑定常用计量单位绑定数据校验 1.ConfigurationProperties ConfigurationProperties 在学习yml的时候我们了解…...

liunx CentOs7安装MQTT服务器(mosquitto)
查找 mosquitto 软件包 yum list all | grep mosquitto出现以上两个即可进行安装,如果没有出现则需要安装EPEL软件库。 yum install https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm查看 mosquitto 信息 yum info mosquitto安装 mosquitt…...

【银河麒麟高级服务器操作系统】虚拟机lvm分区丢失现象分析及解决建议
了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 环境描述 系统环境 物理机/虚拟机/云/容器 虚拟…...

从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...

HBuilderX安装(uni-app和小程序开发)
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
什么是EULA和DPA
文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...
Mysql8 忘记密码重置,以及问题解决
1.使用免密登录 找到配置MySQL文件,我的文件路径是/etc/mysql/my.cnf,有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...

C# 表达式和运算符(求值顺序)
求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如,已知表达式3*52,依照子表达式的求值顺序,有两种可能的结果,如图9-3所示。 如果乘法先执行,结果是17。如果5…...

群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...

android RelativeLayout布局
<?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android:gravity&…...