当前位置: 首页 > news >正文

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习

  • 深度学习-seq2seq模型
    • 什么是seq2seq模型
    • 应用场景
    • 架构
      • 编码器
      • 解码器
      • 训练 & 预测
      • 损失
      • 预测
      • 评估BLEU
        • BELU背后的数学意义
    • 模型参考论文

深度学习-seq2seq模型

本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014

什么是seq2seq模型

Sequence to sequence (seq2seq)是由encoder(编码器)和decoder(解码器)两个RNN组成的(注意本文中的RNN指代所有的循环神经网络,包括RNN、GRU、LSTM等)。
其中encoder负责对输入句子的理解,输出context vector(上下文变量)给decoder,decoder负责对理解后的句子的向量进行处理,解码,获得输出

应用场景

主要用来处理输入和输出序列长度不定的问题,在之前的RNN一文中,RNN的分类讲解过,其中就包括多对多结构,这个seq2seq模型就是典型的多对多,还是长度不一致的多对多,它的应用有很多场景,比如机器翻译,机器人问答,文章摘要,由关键字生成对话等等
例如翻译场景:
【hey took the little cat to the animal center】-> [他们把这只小猫送到了动物中心]
输入和输出长度没法一致

架构

在这里插入图片描述
整个架构是编码器-解码器结构

编码器

一般都是一个普通的RNN结构,不需要特殊的实现

class Encoder(nn.Module):"""用于序列到序列学习的循环神经网络 编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Encoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.gru = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):# input shape (batchsize, num_steps)-> (batchsize, num_steps, embedingdim)X = self.embedding(X)# 交换dim,pythorch要求batchsize位置X = X.permute(1, 0, 2)# encode编码# out的形状 (num_steps, batch_size, num_hiddens)# state的形状: (num_layers, batch_size, num_hiddens)output, state = self.gru(X)return output, state

解码器

在 (Sutskever et al., 2014)的设计:
输入序列的编码信息送入到解码器中来生成输出序列的。
(Cho et al., 2014)设计: 编码器最终的隐状态在每一个时间步都作为解码器的输入序列的一部分。
上面架构图中展示的正式这种设计
在解码器中,在训练的时候比较特殊,可以允许真实值(标签)成为原始的输出序列, 从源序列词元“”“Ils”“regardent”“.” 到新序列词元 “Ils”“regardent”“.”“”来移动预测的位置。
解码器

class Decoder(nn.Module):"""用于序列到序列学习的循环神经网络 解码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Decoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 与普通gru区别:input_size增加num_hiddens,用于input输入解码器encode的输出self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs):# 初始化decode的hidden, 使用enc_outputs[1],enc_outputs格式(output, hidden state)return enc_outputs[1]def forward(self, X, state):""":param X:     input,        shape is (num_steps, batch_size, embed_size):param state: hidden state, shape is( num_layers,batch_size, num_hiddens):return:"""# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X).permute(1, 0, 2)# 广播state的0维,使它与X具有相同的num_steps的维度,方便后续拼接,输出context的shape(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)# conect input and context (num_steps, batch_size, embed_size+num_hiddens)x_and_context = torch.cat((X, context), 2)# output的形状:(num_steps, batch_size, num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)output, state = self.rnn(x_and_context, state)# output的形状(batch_size,num_steps,vocab_size)output = self.dense(output).permute(1, 0, 2)return output, state

训练 & 预测

def train(net, data_iter, lr, num_epochs, tgt_vocab, device):net.to(device)loss = MaskedSoftmaxCELoss()optimizer = torch.optim.Adam(net.parameters(), lr=lr)for epoch in range(num_epochs):num_tokens = 0total_loss = 0for batch in data_iter:optimizer.zero_grad()X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学Y_hat, _ = net(X, dec_input)# Y_hat的形状(batch_size,num_steps,vocab_size)# Y的形状batch_size,num_steps# loss内部permute Y_hat = Y_hat.permute(0, 2, 1)l = loss(Y_hat, Y, Y_valid_len)# 损失函数的标量进行“反向传播”l.sum().backward()#梯度裁剪grad_clipping(net, 1)#梯度更新optimizer.step()num_tokens = Y_valid_len.sum()total_loss = l.sum()print('epoch{}, loss{:.3f}'.format(epoch, total_loss/num_tokens))

损失

这里特别的说明一下,NLP中的损失通常用的都是基于交叉熵损失的masksoftmax损失,它只是在交叉熵损失的基础上封装了一点,mask了pad填充的词元,这个损失函数的意思,举个例子说明一下:
假设解码器的lable是【they are watching】,通常会用unk等pad这些句子到一定的长度,这个长度是代码中由你自行指定的,也是decoder的num_steps,比如我们设置了10,那么此时整个输入会被pad成【they are wathing unk unk unk unk unk unk unk eos】,但是计算损失的时候,我们不需要计算这部分,对应的损失需要置为0

预测

预测与评估的过程相同,但是稍有不同的是,预测过程不知道真实的输出标签,所以都是用上一步的预测值来作为下一个时间步的输入的。这里不再复述

评估BLEU

与其他输出固定的评估不一样,这次是一个句子的评估,常用的方法是:BLEU(bilingual evaluation understudy),最早用于机器翻译,现在也是被广泛用于各种其他的领域

在这里插入图片描述
BLEU的评估都是n-grams词元是否出现在标签序列中
lenlable表示标签序列中的词元数和
lenlpred表示预测序列中的词元数
pn 预测序列与标签序列中匹配的n元词元的数量, 与 预测序列中
n元语法的数量的比率
BELU肯定是越大越好,最好的情况肯定是1,那就是完全匹配

举个例子:给定标签序列A B C D E F 和预测序列 A B B C D
lenlable是6
lenlpred是5
p1 1元词元在lable和 pred中匹配的数量 B C D 也就是4 与 预测序列中1元词元个数 5 也就是0.8
其他pi也是依次计算 i 从1取到预测长度 -1 (也就是4)分别计算出来是3/4 1/3和0
前面的)

BLUE实现简单,此处也不再展现代码了

BELU背后的数学意义

首先后面概率相加的这部分:
在这里插入图片描述

n元词法,当n越长则匹配难度越大, 所以BLEU为更长的元语法的精确度分配更大的权重,否则一个不完全匹配的句子可能会比全匹配的概率更大,这里就表现为,n越大,pn1/2n就越大
在这里插入图片描述
这一项是惩罚项,越短的句子就会降低belu分数,比如
给定标签序列A B C D E F 和预测序列 A B 和ABC 虽然p1 和p2 都是1,惩罚因此会降低短序列的分数
篇幅有限,代码无法一一展现,如果需要全部代码的小伙伴可以私信我

模型参考论文

Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. Advances in neural information processing systems (pp. 3104–3112).

Cho et al., 2014a
Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

Cho et al., 2014b
Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.

李沐 动手深度学习

相关文章:

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习 深度学习-seq2seq模型什么是seq2seq模型应用场景架构编码器解码器训练 & 预测损失预测评估BLEUBELU背后的数学意义 模型参考论文 深度学习-seq2seq模型 本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014 什么是seq2seq模型 Sequence to seq…...

videojs 实现自定义组件(视频画质/清晰度切换) React

前言 最近使用videojs作为视频处理第三方库&#xff0c;用来对接m3u8视频类型。这里总结一下自定义组件遇到的问题及实现&#xff0c;目前看了许多文章也不全&#xff0c;官方文档写的也不是很详细&#xff0c;自己摸索了一段时间陆陆续续完成了&#xff0c;这是实现后的效果.…...

python 模块urllib3 HTTP 客户端库

官网文档地址&#xff1a;https://urllib3.readthedocs.io/en/stable/reference/index.html 一、安装 pip install urlib3二、基本使用 import urllib3 import threadingimg_list ["https://pic.netbian.com/uploads/allimg/220211/004115-1644511275bc26.jpg",&…...

2023 CCPC 华为云计算挑战赛 D-塔

首先先来看第一轮的 假如有n个,每轮那k个 他们的高度的可能性分别为 n 1/C(n,k) n1 C(n-(k-11),1)/C(n,k) n2 C(n-(k-21),2)/C(n,k) ni C(n-(k-i1,i)/C(n,k) 通过概率和高度算出第一轮增加的期望 然后乘上m轮增加的高度加上初始高度&#xff0c;就是总共增加的高度 下面是…...

手搓大模型值just gru

这些类是构建神经网络模型的有用工具,并提供了一些关键功能: EmAdd类使文本输入数据嵌入成为可能,在自然语言处理任务中被广泛使用。通过屏蔽处理填充序列的能力对许多应用程序也很重要。 HeadLoss类是训练神经网络模型进行分类任务的常见损失函数。它计算损失和准确率的能力…...

eslint

什么是eslint ESLint 是一个根据方案识别并报告 ECMAScript/JavaScript 代码问题的工具&#xff0c;其目的是使代码风格更加一致并避免错误。 安装eslint npm init eslint/config执行后会有很多选项&#xff0c;按照自己的需求去选择就好&#xff0c;运行成功后会生成 .esli…...

node_modules.cache是什么东西

一开始没明白这是啥玩意&#xff0c;还以为是npm的属性&#xff0c;网上也没说过具体的来源出处 .cache文件的产生是由webpack4的插件cache-loader生成的&#xff0c;node_modules里下载了cache-loader插件&#xff0c;很多朋友都是vuecli工具生成的项目&#xff0c;内置了这部…...

Python 包管理(pip、conda)基本使用指南

Python 包管理 概述 介绍 Python 有丰富的开源的第三方库和包&#xff0c;可以帮助完成各种任务&#xff0c;扩展 Python 的功能&#xff0c;例如 NumPy 用于科学计算&#xff0c;Pandas 用于数据处理&#xff0c;Matplotlib 用于绘图等。在开始编写 Pytlhon 程序之前&#…...

系统级封装(SiP)技术如何助力智能化应用发展呢?

智能化时代&#xff0c;各种智能设备、智能互连的高速发展与跨界融合&#xff0c;需要高密度、高性能的微系统集成技术作为重要支撑。 例如&#xff0c;在系统级封装&#xff08;SiP&#xff09;技术的加持下&#xff0c;5G手机的射频电路面积更小&#xff0c;但支持的频段更多…...

git配置代理(github配置代理)

命令行配置代理方式一git config --global http.proxy http://代理服务器地址:端口号git config --global https.proxy https://代理服务器地址:端口号如果有用户名密码按照下面命令配置 git config --global http.proxy http://用户名:密码代理服务器地址:端口号git config --…...

【数据结构】详解环形队列

文章目录 &#x1f30f;引言&#x1f340;[循环队列](https://leetcode.cn/problems/design-circular-queue/description/)&#x1f431;‍&#x1f464;题目描述&#x1f431;‍&#x1f453;示例&#xff1a;&#x1f431;‍&#x1f409;提示&#x1f431;‍&#x1f3cd;思…...

Python爬取网页详细教程:从入门到进阶

【导言】&#xff1a; Python作为一门强大的编程语言&#xff0c;常常被用于编写网络爬虫程序。本篇文章将为大家详细介绍Python爬取网页的整个流程&#xff0c;从安装Python和必要的库开始&#xff0c;到发送HTTP请求、解析HTML页面&#xff0c;再到提取和处理数据&#xff0…...

linux安装JDK及hadoop运行环境搭建

1.linux中安装jdk &#xff08;1&#xff09;下载JDK至opt/install目录下&#xff0c;opt下创建目录soft&#xff0c;并解压至当前目录 tar xvf ./jdk-8u321-linux-x64.tar.gz -C /opt/soft/ &#xff08;2&#xff09;改名 &#xff08;3&#xff09;配置环境变量&#xf…...

使用ChatGPT一键生成思维导图

指令1&#xff1a;接下来你回复的所有内容&#xff0c;都放到Markdown代码框中。 指令2&#xff1a;作为一个Docker专家&#xff0c;为我编写一个详细全面的Docker学习大纲&#xff0c;包括基础知识、进阶知识、项目实践案例&#xff0c;学习书籍推荐、学习网站推荐等&#xf…...

极简Vim教程

2023年8月27日&#xff0c;周日上午 我不想学那么多命令和快捷键&#xff0c;够用就行... 所以就把我自己认为比较常用的命令和快捷键记录成博客 目录 预备知识Vim的工作模式保存内容退出Vim复制、粘贴和剪切选中一段内容复制粘贴剪切撤回和反撤回撤回反撤回查找替换删除删除…...

在线帮助中心也属于知识管理的一种吗?

在线帮助中心是企业或组织为了提供客户支持而建立的一个在线平台&#xff0c;它包含了各种类型的知识和信息&#xff0c;旨在帮助用户解决问题和获取相关的信息。从知识管理的角度来看&#xff0c;可以说在线帮助中心也属于知识管理的一种形式。下面将详细介绍在线帮助中心作为…...

《Linux从练气到飞升》No.18 进程终止

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…...

自动化运维工具——ansible安装及模块介绍

目录 一、ansible——自动化运维工具 1.1 Ansible 自动运维工具特点 1.2 Ansible 运维工具原理 二、安装ansible 三、ansible命令模块 3.1 command模块 3.2 shell模块 3.3 cron模块 3.4 user模块 3.5 group 模块 3.6 copy模块 3.7 file模块 3.8 ping模…...

Qt XML文件解析 QDomDocument

QtXml模块提供了一个读写XML文件的流&#xff0c;解析方法包含DOM和SAX,两者的区别是什么呢&#xff1f; DOM&#xff08;Document Object Model&#xff09;&#xff1a;将XML文件保存为树的形式&#xff0c;操作简单&#xff0c;便于访问。 SAX&#xff08;Simple API for …...

Vue2向Vue3过度Vuex状态管理工具快速入门

目录 1 Vuex概述1.是什么2.使用场景3.优势4.注意&#xff1a; 2 需求: 多组件共享数据1.创建项目2.创建三个组件, 目录如下3.源代码如下 3 vuex 的使用 - 创建仓库1.安装 vuex2.新建 store/index.js 专门存放 vuex3.创建仓库 store/index.js4 在 main.js 中导入挂载到 Vue 实例…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式&#xff08;Singleton Pattern&#…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

Linux离线(zip方式)安装docker

目录 基础信息操作系统信息docker信息 安装实例安装步骤示例 遇到的问题问题1&#xff1a;修改默认工作路径启动失败问题2 找不到对应组 基础信息 操作系统信息 OS版本&#xff1a;CentOS 7 64位 内核版本&#xff1a;3.10.0 相关命令&#xff1a; uname -rcat /etc/os-rele…...

鸿蒙(HarmonyOS5)实现跳一跳小游戏

下面我将介绍如何使用鸿蒙的ArkUI框架&#xff0c;实现一个简单的跳一跳小游戏。 1. 项目结构 src/main/ets/ ├── MainAbility │ ├── pages │ │ ├── Index.ets // 主页面 │ │ └── GamePage.ets // 游戏页面 │ └── model │ …...

数据分析六部曲?

引言 上一章我们说到了数据分析六部曲&#xff0c;何谓六部曲呢&#xff1f; 其实啊&#xff0c;数据分析没那么难&#xff0c;只要掌握了下面这六个步骤&#xff0c;也就是数据分析六部曲&#xff0c;就算你是个啥都不懂的小白&#xff0c;也能慢慢上手做数据分析啦。 第一…...