当前位置: 首页 > news >正文

NLP Seq2Seq模型

🍨 本文为[🔗365天深度学习训练营学习记录博客🍦 参考文章:365天深度学习训练营🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

Seq2Seq模型是一种深度学习模型,用于处理序列到序列的任务,它由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器(Encoder): 编码器负责将输入序列(例如源语言句子)编码成一个固定长度的向量,通常称为上下文向量或编码器的隐藏状态。编码器可以是循环神经网络(RNN)、长短期记忆网络(LSTM)或者变种如门控循环单元(GRU)等。编码器的目标是捕捉输入序列中的语义信息,并将其编码成一个固定维度的向量表示。

  2. 解码器(Decoder): 解码器接收编码器生成的上下文向量,并根据它来生成输出序列(例如目标语言句子)。解码器也可以是RNN、LSTM、GRU等。在训练阶段,解码器一次生成一个词或一个标记,并且其隐藏状态从一个时间步传递到下一个时间步。解码器的目标是根据上下文向量生成与输入序列对应的输出序列。

在训练阶段,Seq2Seq模型的目标是最大化目标序列的条件概率给定输入序列。为了实现这一点,通常使用了一种称为教师强制(Teacher Forcing)的技术,即将目标序列中的真实标记作为解码器的输入。但是,在推理阶段(即模型用于生成新的序列时),解码器则根据先前生成的标记生成下一个标记,直到生成一个特殊的终止标记或达到最大长度为止。

下面演示了如何使用PyTorch实现一个简单的Seq2Seq模型,用于将一个序列翻译成另一个序列。这里我们将使用一个虚构的数据集来进行简单的法语到英语翻译。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset# 定义数据集
class SimpleDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 定义Encoder
class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hidden_dim):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)def forward(self, src):embedded = self.embedding(src)outputs, hidden = self.rnn(embedded)return outputs, hidden# 定义Decoder
class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hidden_dim):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, output_dim)def forward(self, input, hidden):input = input.unsqueeze(0)embedded = self.embedding(input)output, hidden = self.rnn(embedded, hidden)prediction = self.fc_out(output.squeeze(0))return prediction, hidden# 定义Seq2Seq模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = trg.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)encoder_outputs, hidden = self.encoder(src)input = trg[0,:]for t in range(1, trg_len):output, hidden = self.decoder(input, hidden)outputs[t] = outputteacher_force = np.random.rand() < teacher_forcing_ratiotop1 = output.argmax(1) input = trg[t] if teacher_force else top1return outputs# 设置参数
INPUT_DIM = 10
OUTPUT_DIM = 10
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 64
N_LAYERS = 1
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5# 实例化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(enc, dec, device).to(device)# 打印模型结构
print(model)# 定义训练函数
def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)# 定义测试函数
def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)output = model(src, trg, 0) # 关闭teacher forcingoutput_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)# 示例数据
train_data = [(torch.tensor([1, 2, 3]), torch.tensor([3, 2, 1])),(torch.tensor([4, 5, 6]), torch.tensor([6, 5, 4])),(torch.tensor([7, 8, 9]), torch.tensor([9, 8, 7]))]# 超参数
BATCH_SIZE = 3
N_EPOCHS = 10
LEARNING_RATE = 0.001
CLIP = 1# 数据集与迭代器
train_dataset = SimpleDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)# 定义损失函数与优化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()# 训练模型
for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')

相关文章:

NLP Seq2Seq模型

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客&#x1f366; 参考文章&#xff1a;365天深度学习训练营&#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制]\n&#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/mi…...

如何在 Linux 上使用 dmesg 命令

文章目录 1. Overview2.ring buffer怎样工作&#xff1f;3.dmesg命令4.移除sudo需求5. 强制彩色输出6.使用人性化的时间戳7.使用dmesg的人性化可读时间戳8.观察实时event9.检索最后10条消息10.搜索特定术语11.使用Log Levels12.使用Facility Categories13.Combining Facility a…...

WPF的DataGrid设置标题头

要设置DataGrid标题头的分割线、背景色和前景色等属性&#xff0c;您可以使用DataGrid的样式和模板来自定义标题头的外观。下面是详细解释以及示例代码&#xff1a; 分割线设置&#xff1a; 您可以使用DataGrid.ColumnHeaderStyle样式中的BorderThickness和BorderBrush属性来设…...

【软考】UML中的图之通信图

目录 1. 说明2. 图示3. 特性4. 例题4.1 例题1 1. 说明 1.通信图强调收发消息的对象的结构组织2.早期版本叫做协作图3.通信图强调参加交互的对象和组织4.首先将参加交互的对象作为图的顶点&#xff0c;然后把连接这些对象的链表示为图的弧&#xff0c;最后用对象发送和接收的消…...

为什么ChatGPT预训练能非常好地捕捉语言的普遍特征和模式

ChatGPT能够非常好地捕捉语言的普遍特征和模式&#xff0c;主要得益于以下几个方面的原因&#xff1a; 大规模语料库&#xff1a;ChatGPT的预训练是在大规模文本语料库上进行的&#xff0c;这些语料库涵盖了来自互联网、书籍、文章、对话记录等多种来源的丰富数据。这种大规模的…...

如何安装ProtoBuf环境

1 &#x1f351;下载 ProtoBuf&#x1f351; 下载 ProtoBuf 前⼀定要安装依赖库&#xff1a;autoconf automake libtool curl make g unzip 如未安装&#xff0c;安装命令如下&#xff1a; Ubuntu ⽤⼾选择&#xff1a; sudo apt-get install autoconf automake libtool cur…...

C语言 vs Rust应该学习哪个?

C语言 vs Rust应该学习哪个&#xff1f; 在开始前我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「C语言的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&am…...

IT廉连看——Uniapp——配置文件pages

IT廉连看——Uniapp——配置文件pages [IT廉连看] 本堂课主要为大家介绍pages.json这个配置文件 一、打开官网查看pages.json可以配置哪些属性。 下面边写边讲解 新建一个home页面理解一下这句话。 以下一些页面的通用配置 通用设置里我们可以对导航栏和状态栏进行一些设…...

服务器上部署WEb服务方法

部署Web服务在服务器上是一个比较复杂的过程。这不仅仅涉及到配置环境、选择软件和设置端口&#xff0c;更有众多其它因素需要考虑。以下是在服务器上部署WEb服务的步骤&#xff1a; 1. 选择服务器&#xff1a;根据项目规模和预期访问量&#xff0c;选择合适的服务器类型和配置…...

设计模式:模版模式

模板模式&#xff08;Template Pattern&#xff09;是一种行为型设计模式&#xff0c;它定义了一个操作中的算法骨架&#xff0c;将一些步骤的具体实现延迟到子类中。模板模式使得子类可以在不改变算法结构的情况下重新定义算法的某些步骤。 在模板模式中&#xff0c;将算法的…...

pikachu之特殊注入之搜索型注入、xx型注入、insert/update注入、delete注入、宽字节注入

一步一脚印&#xff01;&#xff01;&#xff01; 补充&#xff1a;此处为什么不写http请求头注入&#xff0c;因为该注入类型只是换了注入点&#xff0c;语句其他根本没有什么变化 1.搜索型 先尝试输入常用payload&#xff1a; 1 or 11 #。 已经有回显 我们在查看提示 我们…...

docker构建hyperf环境

一&#xff0c;构建hyperf 镜像 官网git https://github.com/hyperf/hyperf-docker 使用dockerfile构建镜像 根据需要这里我使用8.1 swoole版本的镜像 在/home/hyperfdocker 目录中新建一个Dockerfile文件&#xff0c;将这个git上的Dockerfile内容复制粘贴进去 docker build…...

WPF常用mvvm开源框架介绍 vue的mvvm设计模式鼻祖

WPF&#xff08;Windows Presentation Foundation&#xff09;是一个用于构建桌面应用程序的.NET框架&#xff0c;它支持MVVM&#xff08;Model-View-ViewModel&#xff09;架构模式来分离UI逻辑和业务逻辑。以下是一些常用的WPF MVVM开源框架&#xff1a; Prism Prism是由微软…...

HTML <script>元素的10个属性

将javascrip插入HTML的主要方法是使用<script>元素&#xff0c;这个元素是网景公司&#xff08;Netscape&#xff09;创造出来的&#xff0c;script 元素所属类型因其用法而异。位于 head 元素中的 script 元素属于元数据元素&#xff0c;位于其他元素&#xff08;如 bod…...

NX二次开发:ListingWindow窗口的应用

一、概述 在NX二次开发的学习中&#xff0c;浏览博客时发现看到[社恐猫]和[王牌飞行员_里海]这两篇博客中写道有关信息窗口内容的打印和将窗口内容保存为txt,个人人为在二次开发项目很有必要&#xff0c;因此做以下记录。 ListingWindow信息窗口发送信息四种位置类型 设置Listi…...

设计模式-结构型模式-外观模式

外观模式&#xff08;Facade&#xff09;&#xff0c;为子系统中的一组接口提供一个一致的界面&#xff0c;此模式定义了一个高层接口&#xff0c;这个接口使得这一子系统更加容易使用。[DP] 首先&#xff0c;定义子系统的各个组件接口和具体实现类&#xff1a; // 子系统组件接…...

C++学习第四天(类与对象下)

1、构造函数的其他知识 构造函数体赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值 构造函数调用之后&#xff0c;对象中已经有了一个初始值&#xff0c;但是不能将其称为对对象中成员变量的初始化&#xff0c;构造函…...

【AI Agent系列】【MetaGPT多智能体学习】0. 环境准备 - 升级MetaGPT 0.7.2版本及遇到的坑

之前跟着《MetaGPT智能体开发入门课程》学了一些MetaGPT的知识和实践&#xff0c;主要关注在MetaGPT入门和单智能体部分&#xff08;系列文章附在文末&#xff0c;感兴趣的可以看下&#xff09;。现在新的教程来了&#xff0c;新教程主要关注多智能体部分。 本系列文章跟随《M…...

python自动化管理和zabbix监控网络设备(无线AC控制瘦ap配置部分)

目录 前言 拓扑 一、AC-SW1 二、Core-sw1 三、Core-sw2 四、汇聚层 五、AC1 六、SW1-6 七、DMZ区域 前言 具体原理和操作可以访问我的主页视频 白帽小丑的个人空间-白帽小丑个人主页-哔哩哔哩视频 拓扑 一、AC-SW1 sys sysname AC-SW1 vlan batch 100 200 210 220 2…...

XSS中级漏洞(靶场)

目录 一、环境 二、正式开始闯关 0x01 0x02 0x03 0x04 0x05 0x06 0x07 0x08 0x0B 0x0C 0x0D 0x0E ​ 0x0F 0x10 0x11 0x12 一、环境 在线环境&#xff08;gethub上面的&#xff09; alert(1) 二、正式开始闯关 0x01 源码&#xff1a; 思路&#xff1a;闭…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 &#xff08;FL&#xff09; 支持跨分布式客户端进行协作模型训练&#xff0c;而无需共享原始数据&#xff0c;这使其成为在互联和自动驾驶汽车 &#xff08;CAV&#xff09; 等领域保护隐私的机器学习的一种很有前途的方法。然而&#xff0c;最近的研究表明&…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06&#xff08;十亿美元&#xff09;。漏洞扫描服务市场行业预计将从 2024 年的 3.48&#xff08;十亿美元&#xff09;增长到 2032 年的 9.54&#xff08;十亿美元&#xff09;。预测期内漏洞扫描服务市场 CAGR&#xff08;增长率&…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)

推荐 github 项目:GeminiImageApp(图片生成方向&#xff0c;可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

Neko虚拟浏览器远程协作方案:Docker+内网穿透技术部署实践

前言&#xff1a;本文将向开发者介绍一款创新性协作工具——Neko虚拟浏览器。在数字化协作场景中&#xff0c;跨地域的团队常需面对实时共享屏幕、协同编辑文档等需求。通过本指南&#xff0c;你将掌握在Ubuntu系统中使用容器化技术部署该工具的具体方案&#xff0c;并结合内网…...