当前位置: 首页 > news >正文

seq2seq翻译实战-Pytorch复现

🍨 本文为[🔗365天深度学习训练营学习记录博客 🍦 参考文章:365天深度学习训练营 🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

一、前期准备 

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import randomimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

1.1 搭建语言类
 

定义了两个常量 SOS_token 和 EOS_token,其分别代表序列的开始和结束。 Lang 类,用于方便对语料库进行操作:
●word2index 是一个字典,将单词映射到索引
●word2count 是一个字典,记录单词出现的次数
●index2word 是一个字典,将索引映射到单词
●n_words 是单词的数量,初始值为 2,因为序列开始和结束的单词已经被添加

SOS_token = 0
EOS_token = 1# 语言类,方便对语料库进行操作
class Lang:def __init__(self, name):self.name = nameself.word2index = {}self.word2count = {}self.index2word = {0: "SOS", 1: "EOS"}self.n_words    = 2  # Count SOS and EOSdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.n_wordsself.word2count[word] = 1self.index2word[self.n_words] = wordself.n_words += 1else:self.word2count[word] += 1

1.2 文本处理函数

def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')# 小写化,剔除标点与非字母符号
def normalizeString(s):s = unicodeToAscii(s.lower().strip())s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)return s

1.3 文件读取函数

def readLangs(lang1, lang2, reverse=False):print("Reading lines...")# 以行为单位读取文件lines = open('%s-%s.txt' % (lang1, lang2), encoding='utf-8'). \read().strip().split('\n')# 将每一行放入一个列表中# 一个列表中有两个元素,A语言文本与B语言文本pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]# 创建Lang实例,并确认是否反转语言顺序if reverse:pairs = [list(reversed(p)) for p in pairs]input_lang = Lang(lang2)output_lang = Lang(lang1)else:input_lang = Lang(lang1)output_lang = Lang(lang2)return input_lang, output_lang, pairsMAX_LENGTH = 10  # 定义语料最长长度eng_prefixes = ("i am ", "i m ","he is", "he s ","she is", "she s ","you are", "you re ","we are", "we re ","they are", "they re "
)def filterPair(p):return len(p[0].split(' ')) < MAX_LENGTH and \len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)def filterPairs(pairs):# 选取仅仅包含 eng_prefixes 开头的语料return [pair for pair in pairs if filterPair(pair)]def prepareData(lang1, lang2, reverse=False):# 读取文件中的数据input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)print("Read %s sentence pairs" % len(pairs))# 按条件选取语料pairs = filterPairs(pairs[:])print("Trimmed to %s sentence pairs" % len(pairs))print("Counting words...")# 将语料保存至相应的语言类for pair in pairs:input_lang.addSentence(pair[0])output_lang.addSentence(pair[1])# 打印语言类的信息print("Counted words:")print(input_lang.name, input_lang.n_words)print(output_lang.name, output_lang.n_words)return input_lang, output_lang, pairsinput_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

常量 MAX_LENGTH,表示语料中句子的最大长度。

元组 eng_prefixes,包含一些英语句子的前缀。这些前缀用于筛选语料,只选择以这些前缀开头的句子

filterPair 函数用于过滤语料对。它的返回值是一个布尔值,表示是否保留该语料对。这里的条件是:两个句子的长度都不超过 MAX_LENGTH,并且输出语句(第二个句子)以 eng_prefixes 中的某个前缀开头

filterPairs 函数接受一个语料对列表,然后调用 filterPair 函数过滤掉不符合条件的语料对,返回一个新的语料对列表。

prepareData 函数是主要的数据准备函数。它调用了之前定义的 readLangs 函数来读取语言对,然后使用 filterPairs 函数按条件过滤语料对。接着,它打印读取的句子对数、过滤后的句子对数,并统计语料中的词汇量。最后,它将语料保存到相应的语言类中,并返回这些语言类对象以及过滤后的语料对。

二、Seq2Seq 模型

 2.1 编码器(Encoder)

class EncoderRNN(nn.Module):def __init__(self, input_size, hidden_size):super(EncoderRNN, self).__init__()self.hidden_size = hidden_sizeself.embedding   = nn.Embedding(input_size, hidden_size)self.gru         = nn.GRU(hidden_size, hidden_size)def forward(self, input, hidden):embedded       = self.embedding(input).view(1, 1, -1)output         = embeddedoutput, hidden = self.gru(output, hidden)return output, hiddendef initHidden(self):return torch.zeros(1, 1, self.hidden_size, device=device)

2.2 解码器(Decoder)

class DecoderRNN(nn.Module):def __init__(self, hidden_size, output_size):super(DecoderRNN, self).__init__()self.hidden_size = hidden_sizeself.embedding   = nn.Embedding(output_size, hidden_size)self.gru         = nn.GRU(hidden_size, hidden_size)self.out         = nn.Linear(hidden_size, output_size)self.softmax     = nn.LogSoftmax(dim=1)def forward(self, input, hidden):output         = self.embedding(input).view(1, 1, -1)output         = F.relu(output)output, hidden = self.gru(output, hidden)output         = self.softmax(self.out(output[0]))return output, hiddendef initHidden(self):return torch.zeros(1, 1, self.hidden_size, device=device)

三、训练

3.1 数据预处理

def indexesFromSentence(lang, sentence):return [lang.word2index[word] for word in sentence.split(' ')]# 将数字化的文本,转化为tensor数据
def tensorFromSentence(lang, sentence):indexes = indexesFromSentence(lang, sentence)indexes.append(EOS_token)return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)# 输入pair文本,输出预处理好的数据
def tensorsFromPair(pair):input_tensor  = tensorFromSentence(input_lang, pair[0])target_tensor = tensorFromSentence(output_lang, pair[1])return (input_tensor, target_tensor)

3.2 训练函数

使用use_teacher_forcing 的目的是在训练过程中平衡解码器的预测能力和稳定性。以下是对两种策略的解释:
1. Teacher Forcing:在每个时间步(di循环中),解码器的输入都是目标序列中的真实标签。这样做的好处是,解码器可以直接获得正确的输入信息,加快训练速度,并且在训练早期提供更准确的梯度信号,帮助解码器更好地学习。然而,过度依赖目标序列可能会导致模型过于敏感,一旦目标序列中出现错误,可能会在解码器中产生累积的误差。
2. Without Teacher Forcing:在每个时间步,解码器的输入是前一个时间步的预测输出。这样做的好处是,解码器需要依靠自身的预测能力来生成下一个输入,从而更好地适应真实应用场景中可能出现的输入变化。这种策略可以提高模型的稳定性,但可能会导致训练过程更加困难,特别是在初始阶段。一般来说,Teacher Forcing策略在训练过程中可以帮助模型快速收敛,而Without Teacher Forcing策略则更接近真实应用中的生成场景。通常会使用一定比例的Teacher Forcing,在训练过程中逐渐减小这个比例,以便模型逐渐过渡到更自主的生成模式。
综上所述,通过使用use_teacher_forcing 来选择不同的策略,可以在训练解码器时平衡模型的预测能力和稳定性,同时也提供了更灵活的生成模式选择。

teacher_forcing_ratio = 0.5def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):# 编码器初始化encoder_hidden = encoder.initHidden()# grad属性归零encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_length  = input_tensor.size(0)target_length = target_tensor.size(0)# 用于创建一个指定大小的全零张量(tensor),用作默认编码器输出encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# 将处理好的语料送入编码器for ei in range(input_length):encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)encoder_outputs[ei]            = encoder_output[0, 0]# 解码器默认输出decoder_input  = torch.tensor([[SOS_token]], device=device)decoder_hidden = encoder_hiddenuse_teacher_forcing = True if random.random() < teacher_forcing_ratio else False# 将编码器处理好的输出送入解码器if use_teacher_forcing:# Teacher forcing: Feed the target as the next inputfor di in range(target_length):decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)loss         += criterion(decoder_output, target_tensor[di])decoder_input = target_tensor[di]  # Teacher forcingelse:# Without teacher forcing: use its own predictions as the next inputfor di in range(target_length):decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)topv, topi    = decoder_output.topk(1)decoder_input = topi.squeeze().detach()  # detach from history as inputloss         += criterion(decoder_output, target_tensor[di])if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_lengthimport time
import mathdef asMinutes(s):m = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)def timeSince(since, percent):now = time.time()s = now - sincees = s / (percent)rs = es - sreturn '%s (- %s)' % (asMinutes(s), asMinutes(rs))def trainIters(encoder,decoder,n_iters,print_every=1000,plot_every=100,learning_rate=0.01):start = time.time()plot_losses      = []print_loss_total = 0  # Reset every print_everyplot_loss_total  = 0  # Reset every plot_everyencoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)# 在 pairs 中随机选取 n_iters 条数据用作训练集training_pairs    = [tensorsFromPair(random.choice(pairs)) for i in range(n_iters)]criterion         = nn.NLLLoss()for iter in range(1, n_iters + 1):training_pair = training_pairs[iter - 1]input_tensor  = training_pair[0]target_tensor = training_pair[1]loss = train(input_tensor, target_tensor, encoder,decoder, encoder_optimizer, decoder_optimizer, criterion)print_loss_total += lossplot_loss_total  += lossif iter % print_every == 0:print_loss_avg   = print_loss_total / print_everyprint_loss_total = 0print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),iter, iter / n_iters * 100, print_loss_avg))if iter % plot_every == 0:plot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0return plot_losses

四、训练与评估

hidden_size   = 256
encoder1      = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)plot_losses = trainIters(encoder1, attn_decoder1, 100000, print_every=5000)

 

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               # 忽略警告信息
# plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        # 分辨率epochs_range = range(len(plot_losses))plt.figure(figsize=(8, 3))plt.subplot(1, 1, 1)
plt.plot(epochs_range, plot_losses, label='Training Loss')
plt.legend(loc='upper right')
plt.title('Training Loss')
plt.show()

相关文章:

seq2seq翻译实战-Pytorch复现

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客 &#x1f366; 参考文章&#xff1a;365天深度学习训练营 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制]\n&#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/…...

软考69-上午题-【面向对象技术2-UML】-关系

一、关系 UML中有4种关系&#xff1a; 依赖&#xff1b;关联&#xff1b;泛化&#xff1b;实现。 1-1、依赖 行为&#xff08;参数&#xff09;&#xff0c;参数就是被依赖的事物&#xff0c;即&#xff1a;独立事物。 当独立事物发生变化时&#xff0c;依赖事务行为的语义也…...

智慧文旅|AI数字人导览:让旅游体验不再局限于传统

AI数字人导览作为一种创新的展示方式&#xff0c;已经逐渐成为了VR全景领域的一大亮点&#xff0c;不仅可以很好的嵌入在VR全景中&#xff0c;更是能够随时随地为观众提供一种声情并茂的讲解介绍&#xff0c;结合VR场景的沉浸式体验&#xff0c;让观众仿佛置身于真实场景之中&a…...

spring boot 集成 mysql ,mybatisplus多数据源

1、需要的依赖&#xff0c;版本自行控制 <dependency><groupId>com.alibaba</groupId><artifactId>druid</artifactId> </dependency><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java<…...

CLion中常用快捷键(仍适用其他编译软件)

基本编辑操作&#xff1a; 复制&#xff1a;Ctrl C粘贴&#xff1a;Ctrl V剪切&#xff1a;Ctrl X撤销&#xff1a;Ctrl Z重做&#xff1a;Ctrl Shift Z &#xff08;不小心撤销了 需要返回之前的操作 相当于下一步&#xff09;全选&#xff1a;Ctrl A 导航&#xff1…...

考研复习c语言初阶(1)

本人准备考研&#xff0c;现在开始每天更新408的内容&#xff0c;目标这个月结束C语言和数据结构&#xff0c;每天更新~ 一.再次认识c语言 C语言是一门通用计算机编程语言&#xff0c;广泛应用于底层开发。C语言的设计目标是提供一种能以简易 的方式编译、处理低级存储器、产生…...

HTML—常用标签

常用标签&#xff1a; 标题标签&#xff1a;<h1></h1>......<h6></h6>段落标签&#xff1a;<p></p>换行标签&#xff1a;<br/>列表&#xff1a;无序列表<ul><li></li></ul> 有序列表<ol>&…...

Midjourney绘图欣赏系列(七)

Midjourney介绍 Midjourney 是生成式人工智能的一个很好的例子&#xff0c;它根据文本提示创建图像。它与 Dall-E 和 Stable Diffusion 一起成为最流行的 AI 艺术创作工具之一。与竞争对手不同&#xff0c;Midjourney 是自筹资金且闭源的&#xff0c;因此确切了解其幕后内容尚不…...

深度学习应该如何入门?

深度学习是一门令人着迷的领域&#xff0c;但初学者可能会感到有些困惑。让我们从头开始&#xff0c;用通俗易懂的语言来探讨深度学习的基础知识。 1. 基础知识 深度学习需要一些数学和编程基础。首先&#xff0c;我们要掌握一些数学知识&#xff0c;如线性代数、微积分和概率…...

FreeRtos Queue(五)

本篇主要分析在中断中向队列里发消息xQueueGenericSendFromISR和在中断里从队列中读取消息xQueueReceiveFromISR。 前言: xQueueGenericSendFromISR 和 xQueueReceiveFromISR都是在中断里调用的而不是任务里调用的&#xff0c;所以队列满了或者是队列为空的时候自然就没有把当…...

解决虚拟机静态网址设置后还是变动的的问题

源头就是我的虚拟机静态网址设置好了以后但是网址还是会变动 这是我虚拟机的配置 vi /etc/sysconfig/network-scripts/ifcfg-ens33 这是出现的问题 进入这里 cd /etc/sysconfig/network-scripts/ 然后我去把多余的ens33的文件都删了 然后还不行 后来按照这个图片进行了下 然后…...

【教程】Github环境配置新手指南(超详细)

写在前面&#xff1a; 如果文章对你有帮助&#xff0c;记得点赞关注加收藏一波&#xff0c;利于以后需要的时候复习&#xff0c;多谢支持&#xff01; 文章目录 一、Github初始设置&#xff08;一&#xff09;登入Github&#xff08;二&#xff09;新建仓库 二、本地Git配置&am…...

突然发现一个很炸裂的平台!

平时小孟会开发很多的项目&#xff0c;很多项目不仅开发的功能比较齐全&#xff0c;而且效果比较炸裂。 今天给大家介绍一个我常用的平台&#xff0c;因含低代码平台&#xff0c;开发相当的快。 1&#xff0c;什么是低代码 低代码包括两种&#xff0c;一种低代码&#xff0c;…...

安卓开发面试题

安卓开发面试题 解释一下 Android 中的四大组件。 答&#xff1a;Android 中的四大组件是 Activity、Service、BroadcastReceiver 和 ContentProvider。其中&#xff0c;Activity 负责界面展示和与用户交互&#xff1b;Service 负责后台服务处理&#xff1b;BroadcastReceiver …...

es6面试题

ES6面试题 var、let、const区别 共同点&#xff1a;都是可以声明变量 区别&#xff1a; 1、var具有变量提升机制&#xff0c;let和const没有 2、var 声明的变量是函数作用域或全局作用域&#xff0c;而 const 和 let 声明的变量是块级作用域。 3、var可以多次声明同一个变量&a…...

Kafka MQ 生产者和消费者

Kafka MQ 生产者和消费者 Kafka 的客户端就是 Kafka 系统的用户&#xff0c;它们被分为两种基本类型:生产者和消费者。除 此之外&#xff0c;还有其他高级客户端 API——用于数据集成的 Kafka Connect API 和用于流式处理 的 Kafka Streams。这些高级客户端 API 使用生产者和消…...

tomcat优化与部署(三)------nignx优化与nginx +tomcat 部署

在目前流行的互联网架构中&#xff0c;Tomcat在目前的网络编程中是举足轻重的&#xff0c;由于Tomcat的运行依赖于JVM&#xff0c;从虚拟机的角度把Tomcat的调整分为外部环境调优 JVM 和 Tomcat 自身调优两部分 Tomcat 是一个流行的开源 Java 服务器&#xff0c;用于托管 Java …...

一个用libcurl多线程下载断言错误问题的排查

某数据下载程序&#xff0c;相同版本的代码&#xff0c;在64位系统中运行正常&#xff0c;但在32位系统中概率性出现断言错误。一旦出现&#xff0c;程序无法正常继续&#xff0c;即使重启亦不行。从年前会上领导提出要追到根&#xff0c;跟到底&#xff0c;到年后的今天&#…...

Docker的安装及MySQL的部署(CentOS版)

目录 1 前言 2 Docker安装步骤 2.1 卸载可能存在的旧版Docker 2.2 配置Docker的yum库 2.2.1 安装yum工具 2.2.2 配置Docker的yum源 2.3 安装Docker 2.4 启动和校验 2.5 配置镜像加速(使用阿里云) 2.5.1 进入控制台 2.5.2 进入容器镜像服务 2.5.3 获取指令并粘贴到…...

css 背景图片居中显示

background 简写 background: #ffffff url(https://profile-avatar.csdnimg.cn/b9abdd57de464582860bf8ade52373b6_misnice.jpg) center center / 100% no-repeat;效果如图&#xff1a;...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

【Java学习笔记】Arrays类

Arrays 类 1. 导入包&#xff1a;import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序&#xff08;自然排序和定制排序&#xff09;Arrays.binarySearch()通过二分搜索法进行查找&#xff08;前提&#xff1a;数组是…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek

文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama&#xff08;有网络的电脑&#xff09;2.2.3 安装Ollama&#xff08;无网络的电脑&#xff09;2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...