当前位置: 首页 > news >正文

PyTorch使用Transformer进行机器翻译

文章目录

    • 简介
    • 数据集
    • 环境要求
    • 实验代码
    • 实验结果
    • 参考来源

简介

本文使用PyTorch自带的transformer层进行机器翻译:从德语翻译为英语。从零开始实现Transformer请参阅PyTorch从零开始实现Transformer,以便于获得对Transfomer更深的理解。

数据集

Multi30k

环境要求

使用torch, torchtext,spacy,其中spacy用来分词。另外,spacy要求在虚拟环境中下载语言模型,以便于进行tokenize(分词)

# To install spacy languages do:
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm

实验代码

代码来源请参考下方的GitHub链接

transformer_translation.py文件

# Bleu score 32.02
import torch
import torch.nn as nn
import torch.optim as optim
import spacy
from utils import translate_sentence, bleu, save_checkpoint, load_checkpoint
from torch.utils.tensorboard import SummaryWriter
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator"""
To install spacy languages do:
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
"""
spacy_ger = spacy.load("de_core_news_sm")
spacy_eng = spacy.load("en_core_web_sm")def tokenize_ger(text):return [tok.text for tok in spacy_ger.tokenizer(text)]# 将英语进行分词
def tokenize_eng(text):return [tok.text for tok in spacy_eng.tokenizer(text)]german = Field(tokenize=tokenize_ger, lower=True, init_token="<sos>", eos_token="<eos>")english = Field(tokenize=tokenize_eng, lower=True, init_token="<sos>", eos_token="<eos>"
)train_data, valid_data, test_data = Multi30k.splits(exts=(".de", ".en"), fields=(german, english)
)german.build_vocab(train_data, max_size=10000, min_freq=2)
english.build_vocab(train_data, max_size=10000, min_freq=2)class Transformer(nn.Module):def __init__(self,embedding_size,src_vocab_size,trg_vocab_size,src_pad_idx,num_heads,num_encoder_layers,num_decoder_layers,forward_expansion,dropout,max_len,device,):super(Transformer, self).__init__()self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)self.src_position_embedding = nn.Embedding(max_len, embedding_size)self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)self.trg_position_embedding = nn.Embedding(max_len, embedding_size)self.device = deviceself.transformer = nn.Transformer(embedding_size,num_heads,num_encoder_layers,num_decoder_layers,forward_expansion,dropout,)self.fc_out = nn.Linear(embedding_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)self.src_pad_idx = src_pad_idxdef make_src_mask(self, src):src_mask = src.transpose(0, 1) == self.src_pad_idx# (N, src_len)return src_mask.to(self.device)def forward(self, src, trg):src_seq_length, N = src.shapetrg_seq_length, N = trg.shapesrc_positions = (torch.arange(0, src_seq_length).unsqueeze(1).expand(src_seq_length, N).to(self.device))trg_positions = (torch.arange(0, trg_seq_length).unsqueeze(1).expand(trg_seq_length, N).to(self.device))embed_src = self.dropout((self.src_word_embedding(src) + self.src_position_embedding(src_positions)))embed_trg = self.dropout((self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions)))src_padding_mask = self.make_src_mask(src)trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)out = self.transformer(embed_src,embed_trg,src_key_padding_mask=src_padding_mask,tgt_mask=trg_mask,)out = self.fc_out(out)return out# We're ready to define everything we need for training our Seq2Seq model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
load_model = False
save_model = True# Training hyperparameters
num_epochs = 10
learning_rate = 3e-4
batch_size = 32# Model hyperparameters
src_vocab_size = len(german.vocab)
trg_vocab_size = len(english.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 100
forward_expansion = 4
src_pad_idx = english.vocab.stoi["<pad>"]# Tensorboard to get nice loss plot
writer = SummaryWriter("runs/loss_plot")
step = 0train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data),batch_size=batch_size,sort_within_batch=True,sort_key=lambda x: len(x.src),device=device,
)model = Transformer(embedding_size,src_vocab_size,trg_vocab_size,src_pad_idx,num_heads,num_encoder_layers,num_decoder_layers,forward_expansion,dropout,max_len,device,
).to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True
)pad_idx = english.vocab.stoi["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)if load_model:load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)# 'a', 'horse', 'is', 'walking', 'under', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.
sentence = "ein pferd geht unter einer brücke neben einem boot."for epoch in range(num_epochs):print(f"[Epoch {epoch} / {num_epochs}]")if save_model:checkpoint = {"state_dict": model.state_dict(),"optimizer": optimizer.state_dict(),}save_checkpoint(checkpoint)model.eval()translated_sentence = translate_sentence(model, sentence, german, english, device, max_length=50)print(f"Translated example sentence: \n {translated_sentence}")model.train()losses = []for batch_idx, batch in enumerate(train_iterator):# Get input and targets and get to cudainp_data = batch.src.to(device)target = batch.trg.to(device)# Forward propoutput = model(inp_data, target[:-1, :])# Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss# doesn't take input in that form. For example if we have MNIST we want to have# output to be: (N, 10) and targets just (N). Here we can view it in a similar# way that we have output_words * batch_size that we want to send in into# our cost function, so we need to do some reshapin.# Let's also remove the start token while we're at itoutput = output.reshape(-1, output.shape[2])target = target[1:].reshape(-1)optimizer.zero_grad()loss = criterion(output, target)losses.append(loss.item())# Back proploss.backward()# Clip to avoid exploding gradient issues, makes sure grads are# within a healthy rangetorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)# Gradient descent stepoptimizer.step()# plot to tensorboardwriter.add_scalar("Training loss", loss, global_step=step)step += 1mean_loss = sum(losses) / len(losses)scheduler.step(mean_loss)# running on entire test data takes a while
score = bleu(test_data[1:100], model, german, english, device)
print(f"Bleu score {score * 100:.2f}")

utils.py文件

import torch
import spacy
from torchtext.data.metrics import bleu_score
import sysdef translate_sentence(model, sentence, german, english, device, max_length=50):# Load german tokenizerspacy_ger = spacy.load("de_core_news_sm")# Create tokens using spacy and everything in lower case (which is what our vocab is)if type(sentence) == str:tokens = [token.text.lower() for token in spacy_ger(sentence)]else:tokens = [token.lower() for token in sentence]# Add <SOS> and <EOS> in beginning and end respectivelytokens.insert(0, german.init_token)tokens.append(german.eos_token)# Go through each german token and convert to an indextext_to_indices = [german.vocab.stoi[token] for token in tokens]# Convert to Tensorsentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)outputs = [english.vocab.stoi["<sos>"]]for i in range(max_length):trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)with torch.no_grad():output = model(sentence_tensor, trg_tensor)best_guess = output.argmax(2)[-1, :].item()outputs.append(best_guess)if best_guess == english.vocab.stoi["<eos>"]:breaktranslated_sentence = [english.vocab.itos[idx] for idx in outputs]# remove start tokenreturn translated_sentence[1:]def bleu(data, model, german, english, device):targets = []outputs = []for example in data:src = vars(example)["src"]trg = vars(example)["trg"]prediction = translate_sentence(model, src, german, english, device)prediction = prediction[:-1]  # remove <eos> tokentargets.append([trg])outputs.append(prediction)return bleu_score(outputs, targets)def save_checkpoint(state, filename="my_checkpoint.pth.tar"):print("=> Saving checkpoint")torch.save(state, filename)def load_checkpoint(checkpoint, model, optimizer):print("=> Loading checkpoint")model.load_state_dict(checkpoint["state_dict"])optimizer.load_state_dict(checkpoint["optimizer"])

实验结果

对下面的这句德文进行翻译

sentence = "ein pferd geht unter einer brücke neben einem boot."

翻译结果为

['a', 'horse', 'walks', 'underneath', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']

Bleu score为31.73

跑了10个Epoch,结果如下所示:


# Result
=> Loading checkpoint
[Epoch 0 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'under', 'a', 'boat', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 1 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'underneath', 'a', 'bridge', 'beside', 'a', 'boat', '.', '<eos>']
[Epoch 2 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'is', 'walking', 'beside', 'a', 'boat', 'under', 'a', 'bridge', '.', '<eos>']
[Epoch 3 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'under', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 4 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'under', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 5 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'beside', 'a', 'boat', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 6 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'is', 'walking', 'underneath', 'a', 'bridge', 'under', 'a', 'boat', '.', '<eos>']
[Epoch 7 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'under', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 8 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'beneath', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']
[Epoch 9 / 10]
=> Saving checkpoint
Translated example sentence: ['a', 'horse', 'walks', 'underneath', 'a', 'bridge', 'next', 'to', 'a', 'boat', '.', '<eos>']
Bleu score 31.73

参考来源

[1] https://blog.csdn.net/weixin_43632501/article/details/98731800
[2] https://www.youtube.com/watch?v=M6adRGJe5cQ
[3] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
[4] https://blog.csdn.net/g11d111/article/details/100103208

相关文章:

PyTorch使用Transformer进行机器翻译

文章目录 简介数据集环境要求实验代码实验结果参考来源 简介 本文使用PyTorch自带的transformer层进行机器翻译&#xff1a;从德语翻译为英语。从零开始实现Transformer请参阅PyTorch从零开始实现Transformer&#xff0c;以便于获得对Transfomer更深的理解。 数据集 Multi30…...

LoadRunner使用教程

1. LoadRunner简介 LoadRunner是一款广泛使用的性能测试工具 可以对各种应用程序进行性能测试&#xff0c;包括Web应用程序、移动应用程序、企业级应用程序等。它提供了一个综合的性能测试解决方案&#xff0c;包括测试计划设计、脚本录制、测试执行、结果分析和报告生成等功…...

Zia和ChatGPT如何协同工作?

有没有集成ChatGPT的CRM系统推荐&#xff1f;Zoho CRM已经正式与ChatGPT集成。下面我们将从使用场景、使用价值和使用范围等方面切入讲述CRMAI的应用和作用。 Zia和ChatGPT如何协同工作&#xff1f; Zia和ChatGPT是不同的人工智能模型&#xff0c;在CRM中呈现出共生的关系。 …...

【位操作】——获取整数变量最低位为 1 的位置

获取整数变量最低位为 1 的位置 #define BIT_LOW_BIT(y) (((y)&BIT(0)) ? 0 : (((y)&BIT(1)) ? 1 : (((y)&BIT(2)) ? 2 : (((y)&BIT(3)) ? 3 : \(((y)&BIT(4)) ? 4 : (((y)&BIT(5)) ? 5 : (((y)&BIT(6)) ? 6 : (((y)&…...

gtest测试用例注册及自动化调度机制源代码流程分析

gtest的入门参见&#xff1a; 玩转Google开源C单元测试框架Google Test系列(gtest) gtest源码分析流程参见&#xff1a; gtest流程解析 测试用例注册流程分析要点&#xff1a;TEST_F宏替换、C静态成员的动态初始化。 自动化调度流程分析要点&#xff1a;UnitTest、UnitTestIm…...

IOS自动化测试环境搭建教程

目录 一、前言 二、环境依赖 1、环境依赖项 2、环境需求与支持 三、环境配置 1、xcode安装 2、Git安装 3、Homebrew安装&#xff08;用brew来安装依赖&#xff09; 4、npm和nodejs安装 5、libimobiledevice安装 6、idevicesinstaller安装 7、ios-deploy安装 8、Ca…...

常用API学习08(Java)

格式化 格式化指的是将数据按照指定的规则转化为指定的形式 。 那么为什么需要格式化&#xff1f;格式化有什么用&#xff1f; 以数字类为例&#xff0c;假设有一个比分牌&#xff0c;在无人得分的时候我们希望以&#xff1a;“00&#xff1a;00”的形式存在&#xff0c;那么…...

面试题-TS(八):什么是装饰器(decorators)?如何在 TypeScript 中使用它们?

面试题-TS(八)&#xff1a;什么是装饰器&#xff08;decorators&#xff09;&#xff1f;如何在 TypeScript 中使用它们&#xff1f; 在TypeScript中&#xff0c;装饰器&#xff08;Decorators&#xff09;是一种用于增强代码功能的特殊类型声明。装饰器提供了一种在类、方法、…...

Jenkins 还可以支持钉钉消息通知?一个插件带你搞定!

Jenkins 作为最流行的开源持续集成平台&#xff0c;其强大的拓展功能一直备受测试人员及开发人员的青睐。大家都知道我们可以在 Jenkins 中安装 Email 插件支持构建之后通过邮件将结果及时通知到相关人员。 但其实 Jenkins 还可以支持钉钉消息通知&#xff0c;其主要通过 Ding…...

7.ES使用

ES多条件查询 and , or这种的 ES模糊查询 like这种的 {"wildcard": {"title.keyword": {"value": "*宣讲*"}}}说明&#xff1a; title是要匹配的关键字段名称keyword是属性&#xff0c;表示匹配的是关键字信息&#xff0c;如果不用.ke…...

Web安全基础

1、HTML基础 什么是 HTML HTML 是用来描述网页的一种语言。 HTML 指的是超文本标记语言 (Hyper Text Markup Language) HTML 不是一种编程语言&#xff0c;而是一种标记语言 (Markup language) 标记语言是一套标记标签 (Markup tag) HTML 使用标记标签来描述网页 总的来说&…...

jQueryAPI

文章目录 1.jQuery 选择器1.1 jQuery 基础选择器1.2 jQuery 层级选择器1.3 隐式迭代1.4 jQuery 筛选选择器1.5 jQuery 筛选方法1.6 jQuery 里面的排他思想1.7 链式编程 2.jQuery 样式操作2.1 操作 css 方法2.2 设置类样式方法2.3 类操作与className区别 3.jQuery 效果3.1 显示隐…...

如何将路径字符串数组(string[])转成树结构(treeNode[])?

原文链接&#xff1a;如何将路径字符串数组(string[])转成树结构(treeNode[])&#xff1f; 需求 这里的UI使用的是Element-Plus。 将一个路径字符串数组&#xff08;当然也可能是其他目标字符串数组&#xff09;&#xff0c;渲染成树。 /*source:/a/b/c/d/e/a/b/e/f/g/a/b/h/a…...

中国工程院院士陈晓红一行莅临麒麟信安调研

7月20日下午&#xff0c;中国工程院院士、湘江实验室主任、湖南工商大学党委书记陈晓红&#xff0c;湘江实验室副主任、湖南工商大学副校长刘国权&#xff0c;湘江实验室副主任、湖南工商大学党委组织部统战部常务副部长胡春华等领导一行莅临麒麟信安调研。麒麟信安董事长杨涛&…...

解决Linux环境下启动idea服务,由于权限问题无法正常启动问题

问题&#xff1a; 在Linux环境下启动idea服务&#xff0c;一直提示&#xff1a; invalid registry store file /app/appuser/.dmf/dubbo,cause:failed to create directory /app/appuser! 原因&#xff1a;文件夹中没有操作权限。 解决&#xff1a; &#xff08;1&#xff0…...

Linux6.16 Docker consul的容器服务更新与发现

文章目录 计算机系统5G云计算第四章 LINUX Docker consul的容器服务更新与发现一、consul 概述1.什么是服务注册与发现2.什么是consul 二、consul 部署1.consul服务器2.registrator服务器3.consul-template4.consul 多节点 计算机系统 5G云计算 第四章 LINUX Docker consul的…...

Redis学习2--使用java操作Redis

1、java操作Redis库的比较 Redis有各种语言的客户端可以来操作redis数据库&#xff0c;其中java语言主要有Jedis与lettuce &#xff0c;Spring Data Redis封装了上边两个客户端&#xff0c;优缺点如下&#xff1a; 2、使用Jedis操作Redis Jedis使用的基本步骤&#xff1a; 引…...

[游戏数值] 常用刷新次数钻石消耗的设计

需满足要求 以一定规律增加能够在较少次数内增加到较大数值平滑增长 设计思路 增加值INT((当前序号-1)/X)*YZ X2&#xff0c;表示希望几个一组&#xff0c;通过INT()取整可获得0、0、1、1、2、2…这样的序列Y10&#xff0c;表示基础值&#xff0c;将上述序列变为0、0、10、1…...

rancher 2.5.7 证书过期处理方案

背景&#xff1a;现场搭建的单节点 rancher 问题&#xff1a;rancher的 ui 界面无法访问 排查&#xff1a;排查rancher容器日志报错 time“2021-12-29T08:27:32.616638402Z” levelinfo msg“Waiting for master node startup: resource name may not be empty” 2021-12-29 08…...

Tomcat中的缓存配置

Tomcat中的缓存配置通常是通过Web应用程序的context.xml文件或Tomcat的server.xml文件进行设置。下面提供一个简单的案例来说明如何在Tomcat中配置缓存。 假设您的Web应用程序名为"myapp"&#xff0c;我们将在context.xml中添加缓存配置。 打开Tomcat安装目录&…...

突破ComfyUI下载瓶颈:3大秘诀让开源工具效率提升300%实战指南

突破ComfyUI下载瓶颈&#xff1a;3大秘诀让开源工具效率提升300%实战指南 【免费下载链接】ComfyUI-Manager ComfyUI-Manager is an extension designed to enhance the usability of ComfyUI. It offers management functions to install, remove, disable, and enable variou…...

Harness Engineering(驾驭工程)

AI 模型已经能写出 100 万行代码。真正的挑战不再是"让它写得更好"&#xff0c;而是怎么驾驭它稳定、可靠、不失控地工作。这套围绕 AI 智能体构建约束、反馈与控制系统的方法论&#xff0c;就是 2026 年初迅速席卷工程圈的新范式——Harness Engineering&#xff08…...

证书创建方法说明

生成证书 方法一&#xff1a;合适&#xff08;临时测试&#xff0c;不需要管理&#xff09; 快速生成脚本&#xff08;一键完成&#xff09; 创建 create_lan_cert.sh&#xff1a; #!/bin/bash# 配置参数 IP_ADDR"192.168.1.100" # 修改为你的局域网IP DAYS365…...

终极指南:5步解锁MacBook Touch Bar在Windows系统的完整显示功能

终极指南&#xff1a;5步解锁MacBook Touch Bar在Windows系统的完整显示功能 【免费下载链接】DFRDisplayKm Windows infrastructure support for Apple DFR (Touch Bar) 项目地址: https://gitcode.com/gh_mirrors/df/DFRDisplayKm 还在为MacBook Pro的Touch Bar在Wind…...

Asian Beauty Z-Image Turbo实战:用nvidia-smi监控显存,小白也能轻松调优

Asian Beauty Z-Image Turbo实战&#xff1a;用nvidia-smi监控显存&#xff0c;小白也能轻松调优 如果你正在使用Asian Beauty Z-Image Turbo生成东方风格人像&#xff0c;是否遇到过生成过程中程序突然崩溃&#xff0c;或者生成速度越来越慢的情况&#xff1f;这些问题的罪魁…...

Awesome Rust核心库精选:异步编程与网络开发

Awesome Rust核心库精选&#xff1a;异步编程与网络开发 本文深入探讨了Rust生态系统中的核心库&#xff0c;重点分析了异步运行时&#xff08;Tokio与async-std&#xff09;、网络编程库、HTTP客户端/服务器框架、数据序列化工具链以及密码学与安全相关库。通过对比分析各库的…...

FreeCache内存管理终极指南:零GC开销的预分配机制详解

FreeCache内存管理终极指南&#xff1a;零GC开销的预分配机制详解 【免费下载链接】freecache A cache library for Go with zero GC overhead. 项目地址: https://gitcode.com/gh_mirrors/fr/freecache 在Go语言开发中&#xff0c;内存管理和垃圾回收&#xff08;GC&am…...

备战蓝桥杯效率翻倍:用快马平台一键生成算法测试脚手架

最近在备战蓝桥杯&#xff0c;发现很多时间都花在了重复搭建测试环境和编写输入输出代码上。为了提高效率&#xff0c;我用InsCode(快马)平台做了一个通用算法测试脚手架&#xff0c;分享下这个能提升备赛效率的实用工具。 项目设计思路 这个脚手架的核心目标是减少重复劳动。蓝…...

从RGB合并到多传感器融合:深入拆解AXI4-Stream Combiner IP在Zynq平台上的两种典型应用

从RGB合并到多传感器融合&#xff1a;深入拆解AXI4-Stream Combiner IP在Zynq平台上的两种典型应用 在FPGA开发中&#xff0c;数据流的高效处理一直是工程师面临的核心挑战之一。当系统需要同时处理多个并行数据源时&#xff0c;如何将这些数据流有序、高效地合并为单一数据流…...

新手开发者的第一课:用快马打造零基础的mc指令学习助手

作为一个刚接触《我的世界》指令系统的玩家&#xff0c;我最初完全搞不懂那些复杂的斜杠命令。直到自己动手做了一个指令查询工具&#xff0c;才发现原来理解指令可以这么简单。今天就来分享如何用InsCode(快马)平台快速打造一个零基础友好的MC指令助手。 为什么需要专门的指令…...