当前位置: 首页 > news >正文

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习

  • 深度学习-seq2seq模型
    • 什么是seq2seq模型
    • 应用场景
    • 架构
      • 编码器
      • 解码器
      • 训练 & 预测
      • 损失
      • 预测
      • 评估BLEU
        • BELU背后的数学意义
    • 模型参考论文

深度学习-seq2seq模型

本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014

什么是seq2seq模型

Sequence to sequence (seq2seq)是由encoder(编码器)和decoder(解码器)两个RNN组成的(注意本文中的RNN指代所有的循环神经网络,包括RNN、GRU、LSTM等)。
其中encoder负责对输入句子的理解,输出context vector(上下文变量)给decoder,decoder负责对理解后的句子的向量进行处理,解码,获得输出

应用场景

主要用来处理输入和输出序列长度不定的问题,在之前的RNN一文中,RNN的分类讲解过,其中就包括多对多结构,这个seq2seq模型就是典型的多对多,还是长度不一致的多对多,它的应用有很多场景,比如机器翻译,机器人问答,文章摘要,由关键字生成对话等等
例如翻译场景:
【hey took the little cat to the animal center】-> [他们把这只小猫送到了动物中心]
输入和输出长度没法一致

架构

在这里插入图片描述
整个架构是编码器-解码器结构

编码器

一般都是一个普通的RNN结构,不需要特殊的实现

class Encoder(nn.Module):"""用于序列到序列学习的循环神经网络 编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Encoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.gru = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):# input shape (batchsize, num_steps)-> (batchsize, num_steps, embedingdim)X = self.embedding(X)# 交换dim,pythorch要求batchsize位置X = X.permute(1, 0, 2)# encode编码# out的形状 (num_steps, batch_size, num_hiddens)# state的形状: (num_layers, batch_size, num_hiddens)output, state = self.gru(X)return output, state

解码器

在 (Sutskever et al., 2014)的设计:
输入序列的编码信息送入到解码器中来生成输出序列的。
(Cho et al., 2014)设计: 编码器最终的隐状态在每一个时间步都作为解码器的输入序列的一部分。
上面架构图中展示的正式这种设计
在解码器中,在训练的时候比较特殊,可以允许真实值(标签)成为原始的输出序列, 从源序列词元“”“Ils”“regardent”“.” 到新序列词元 “Ils”“regardent”“.”“”来移动预测的位置。
解码器

class Decoder(nn.Module):"""用于序列到序列学习的循环神经网络 解码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Decoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 与普通gru区别:input_size增加num_hiddens,用于input输入解码器encode的输出self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs):# 初始化decode的hidden, 使用enc_outputs[1],enc_outputs格式(output, hidden state)return enc_outputs[1]def forward(self, X, state):""":param X:     input,        shape is (num_steps, batch_size, embed_size):param state: hidden state, shape is( num_layers,batch_size, num_hiddens):return:"""# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X).permute(1, 0, 2)# 广播state的0维,使它与X具有相同的num_steps的维度,方便后续拼接,输出context的shape(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)# conect input and context (num_steps, batch_size, embed_size+num_hiddens)x_and_context = torch.cat((X, context), 2)# output的形状:(num_steps, batch_size, num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)output, state = self.rnn(x_and_context, state)# output的形状(batch_size,num_steps,vocab_size)output = self.dense(output).permute(1, 0, 2)return output, state

训练 & 预测

def train(net, data_iter, lr, num_epochs, tgt_vocab, device):net.to(device)loss = MaskedSoftmaxCELoss()optimizer = torch.optim.Adam(net.parameters(), lr=lr)for epoch in range(num_epochs):num_tokens = 0total_loss = 0for batch in data_iter:optimizer.zero_grad()X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学Y_hat, _ = net(X, dec_input)# Y_hat的形状(batch_size,num_steps,vocab_size)# Y的形状batch_size,num_steps# loss内部permute Y_hat = Y_hat.permute(0, 2, 1)l = loss(Y_hat, Y, Y_valid_len)# 损失函数的标量进行“反向传播”l.sum().backward()#梯度裁剪grad_clipping(net, 1)#梯度更新optimizer.step()num_tokens = Y_valid_len.sum()total_loss = l.sum()print('epoch{}, loss{:.3f}'.format(epoch, total_loss/num_tokens))

损失

这里特别的说明一下,NLP中的损失通常用的都是基于交叉熵损失的masksoftmax损失,它只是在交叉熵损失的基础上封装了一点,mask了pad填充的词元,这个损失函数的意思,举个例子说明一下:
假设解码器的lable是【they are watching】,通常会用unk等pad这些句子到一定的长度,这个长度是代码中由你自行指定的,也是decoder的num_steps,比如我们设置了10,那么此时整个输入会被pad成【they are wathing unk unk unk unk unk unk unk eos】,但是计算损失的时候,我们不需要计算这部分,对应的损失需要置为0

预测

预测与评估的过程相同,但是稍有不同的是,预测过程不知道真实的输出标签,所以都是用上一步的预测值来作为下一个时间步的输入的。这里不再复述

评估BLEU

与其他输出固定的评估不一样,这次是一个句子的评估,常用的方法是:BLEU(bilingual evaluation understudy),最早用于机器翻译,现在也是被广泛用于各种其他的领域

在这里插入图片描述
BLEU的评估都是n-grams词元是否出现在标签序列中
lenlable表示标签序列中的词元数和
lenlpred表示预测序列中的词元数
pn 预测序列与标签序列中匹配的n元词元的数量, 与 预测序列中
n元语法的数量的比率
BELU肯定是越大越好,最好的情况肯定是1,那就是完全匹配

举个例子:给定标签序列A B C D E F 和预测序列 A B B C D
lenlable是6
lenlpred是5
p1 1元词元在lable和 pred中匹配的数量 B C D 也就是4 与 预测序列中1元词元个数 5 也就是0.8
其他pi也是依次计算 i 从1取到预测长度 -1 (也就是4)分别计算出来是3/4 1/3和0
前面的)

BLUE实现简单,此处也不再展现代码了

BELU背后的数学意义

首先后面概率相加的这部分:
在这里插入图片描述

n元词法,当n越长则匹配难度越大, 所以BLEU为更长的元语法的精确度分配更大的权重,否则一个不完全匹配的句子可能会比全匹配的概率更大,这里就表现为,n越大,pn1/2n就越大
在这里插入图片描述
这一项是惩罚项,越短的句子就会降低belu分数,比如
给定标签序列A B C D E F 和预测序列 A B 和ABC 虽然p1 和p2 都是1,惩罚因此会降低短序列的分数
篇幅有限,代码无法一一展现,如果需要全部代码的小伙伴可以私信我

模型参考论文

Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. Advances in neural information processing systems (pp. 3104–3112).

Cho et al., 2014a
Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

Cho et al., 2014b
Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.

李沐 动手深度学习

相关文章:

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习 深度学习-seq2seq模型什么是seq2seq模型应用场景架构编码器解码器训练 & 预测损失预测评估BLEUBELU背后的数学意义 模型参考论文 深度学习-seq2seq模型 本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014 什么是seq2seq模型 Sequence to seq…...

videojs 实现自定义组件(视频画质/清晰度切换) React

前言 最近使用videojs作为视频处理第三方库&#xff0c;用来对接m3u8视频类型。这里总结一下自定义组件遇到的问题及实现&#xff0c;目前看了许多文章也不全&#xff0c;官方文档写的也不是很详细&#xff0c;自己摸索了一段时间陆陆续续完成了&#xff0c;这是实现后的效果.…...

python 模块urllib3 HTTP 客户端库

官网文档地址&#xff1a;https://urllib3.readthedocs.io/en/stable/reference/index.html 一、安装 pip install urlib3二、基本使用 import urllib3 import threadingimg_list ["https://pic.netbian.com/uploads/allimg/220211/004115-1644511275bc26.jpg",&…...

2023 CCPC 华为云计算挑战赛 D-塔

首先先来看第一轮的 假如有n个,每轮那k个 他们的高度的可能性分别为 n 1/C(n,k) n1 C(n-(k-11),1)/C(n,k) n2 C(n-(k-21),2)/C(n,k) ni C(n-(k-i1,i)/C(n,k) 通过概率和高度算出第一轮增加的期望 然后乘上m轮增加的高度加上初始高度&#xff0c;就是总共增加的高度 下面是…...

手搓大模型值just gru

这些类是构建神经网络模型的有用工具,并提供了一些关键功能: EmAdd类使文本输入数据嵌入成为可能,在自然语言处理任务中被广泛使用。通过屏蔽处理填充序列的能力对许多应用程序也很重要。 HeadLoss类是训练神经网络模型进行分类任务的常见损失函数。它计算损失和准确率的能力…...

eslint

什么是eslint ESLint 是一个根据方案识别并报告 ECMAScript/JavaScript 代码问题的工具&#xff0c;其目的是使代码风格更加一致并避免错误。 安装eslint npm init eslint/config执行后会有很多选项&#xff0c;按照自己的需求去选择就好&#xff0c;运行成功后会生成 .esli…...

node_modules.cache是什么东西

一开始没明白这是啥玩意&#xff0c;还以为是npm的属性&#xff0c;网上也没说过具体的来源出处 .cache文件的产生是由webpack4的插件cache-loader生成的&#xff0c;node_modules里下载了cache-loader插件&#xff0c;很多朋友都是vuecli工具生成的项目&#xff0c;内置了这部…...

Python 包管理(pip、conda)基本使用指南

Python 包管理 概述 介绍 Python 有丰富的开源的第三方库和包&#xff0c;可以帮助完成各种任务&#xff0c;扩展 Python 的功能&#xff0c;例如 NumPy 用于科学计算&#xff0c;Pandas 用于数据处理&#xff0c;Matplotlib 用于绘图等。在开始编写 Pytlhon 程序之前&#…...

系统级封装(SiP)技术如何助力智能化应用发展呢?

智能化时代&#xff0c;各种智能设备、智能互连的高速发展与跨界融合&#xff0c;需要高密度、高性能的微系统集成技术作为重要支撑。 例如&#xff0c;在系统级封装&#xff08;SiP&#xff09;技术的加持下&#xff0c;5G手机的射频电路面积更小&#xff0c;但支持的频段更多…...

git配置代理(github配置代理)

命令行配置代理方式一git config --global http.proxy http://代理服务器地址:端口号git config --global https.proxy https://代理服务器地址:端口号如果有用户名密码按照下面命令配置 git config --global http.proxy http://用户名:密码代理服务器地址:端口号git config --…...

【数据结构】详解环形队列

文章目录 &#x1f30f;引言&#x1f340;[循环队列](https://leetcode.cn/problems/design-circular-queue/description/)&#x1f431;‍&#x1f464;题目描述&#x1f431;‍&#x1f453;示例&#xff1a;&#x1f431;‍&#x1f409;提示&#x1f431;‍&#x1f3cd;思…...

Python爬取网页详细教程:从入门到进阶

【导言】&#xff1a; Python作为一门强大的编程语言&#xff0c;常常被用于编写网络爬虫程序。本篇文章将为大家详细介绍Python爬取网页的整个流程&#xff0c;从安装Python和必要的库开始&#xff0c;到发送HTTP请求、解析HTML页面&#xff0c;再到提取和处理数据&#xff0…...

linux安装JDK及hadoop运行环境搭建

1.linux中安装jdk &#xff08;1&#xff09;下载JDK至opt/install目录下&#xff0c;opt下创建目录soft&#xff0c;并解压至当前目录 tar xvf ./jdk-8u321-linux-x64.tar.gz -C /opt/soft/ &#xff08;2&#xff09;改名 &#xff08;3&#xff09;配置环境变量&#xf…...

使用ChatGPT一键生成思维导图

指令1&#xff1a;接下来你回复的所有内容&#xff0c;都放到Markdown代码框中。 指令2&#xff1a;作为一个Docker专家&#xff0c;为我编写一个详细全面的Docker学习大纲&#xff0c;包括基础知识、进阶知识、项目实践案例&#xff0c;学习书籍推荐、学习网站推荐等&#xf…...

极简Vim教程

2023年8月27日&#xff0c;周日上午 我不想学那么多命令和快捷键&#xff0c;够用就行... 所以就把我自己认为比较常用的命令和快捷键记录成博客 目录 预备知识Vim的工作模式保存内容退出Vim复制、粘贴和剪切选中一段内容复制粘贴剪切撤回和反撤回撤回反撤回查找替换删除删除…...

在线帮助中心也属于知识管理的一种吗?

在线帮助中心是企业或组织为了提供客户支持而建立的一个在线平台&#xff0c;它包含了各种类型的知识和信息&#xff0c;旨在帮助用户解决问题和获取相关的信息。从知识管理的角度来看&#xff0c;可以说在线帮助中心也属于知识管理的一种形式。下面将详细介绍在线帮助中心作为…...

《Linux从练气到飞升》No.18 进程终止

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…...

自动化运维工具——ansible安装及模块介绍

目录 一、ansible——自动化运维工具 1.1 Ansible 自动运维工具特点 1.2 Ansible 运维工具原理 二、安装ansible 三、ansible命令模块 3.1 command模块 3.2 shell模块 3.3 cron模块 3.4 user模块 3.5 group 模块 3.6 copy模块 3.7 file模块 3.8 ping模…...

Qt XML文件解析 QDomDocument

QtXml模块提供了一个读写XML文件的流&#xff0c;解析方法包含DOM和SAX,两者的区别是什么呢&#xff1f; DOM&#xff08;Document Object Model&#xff09;&#xff1a;将XML文件保存为树的形式&#xff0c;操作简单&#xff0c;便于访问。 SAX&#xff08;Simple API for …...

Vue2向Vue3过度Vuex状态管理工具快速入门

目录 1 Vuex概述1.是什么2.使用场景3.优势4.注意&#xff1a; 2 需求: 多组件共享数据1.创建项目2.创建三个组件, 目录如下3.源代码如下 3 vuex 的使用 - 创建仓库1.安装 vuex2.新建 store/index.js 专门存放 vuex3.创建仓库 store/index.js4 在 main.js 中导入挂载到 Vue 实例…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时&#xff0c;你可能需要保留重要的数据&#xff0c;例如通讯录。好在&#xff0c;将通讯录从 iPhone 转移到 Android 手机非常简单&#xff0c;你可以从本文中学习 6 种可靠的方法&#xff0c;确保随时保持连接&#xff0c;不错过任何信息。 第 1…...

大数据学习(132)-HIve数据分析

​​​​&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4…...

无人机侦测与反制技术的进展与应用

国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机&#xff08;无人驾驶飞行器&#xff0c;UAV&#xff09;技术的快速发展&#xff0c;其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统&#xff0c;无人机的“黑飞”&…...

【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案

目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后&#xff0c;迭代器会失效&#xff0c;因为顺序迭代器在内存中是连续存储的&#xff0c;元素删除后&#xff0c;后续元素会前移。 但一些场景中&#xff0c;我们又需要在执行删除操作…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

tomcat指定使用的jdk版本

说明 有时候需要对tomcat配置指定的jdk版本号&#xff0c;此时&#xff0c;我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...

自然语言处理——文本分类

文本分类 传统机器学习方法文本表示向量空间模型 特征选择文档频率互信息信息增益&#xff08;IG&#xff09; 分类器设计贝叶斯理论&#xff1a;线性判别函数 文本分类性能评估P-R曲线ROC曲线 将文本文档或句子分类为预定义的类或类别&#xff0c; 有单标签多类别文本分类和多…...

Spring Boot + MyBatis 集成支付宝支付流程

Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例&#xff08;电脑网站支付&#xff09; 1. 添加依赖 <!…...

机器学习的数学基础:线性模型

线性模型 线性模型的基本形式为&#xff1a; f ( x ) ω T x b f\left(\boldsymbol{x}\right)\boldsymbol{\omega}^\text{T}\boldsymbol{x}b f(x)ωTxb 回归问题 利用最小二乘法&#xff0c;得到 ω \boldsymbol{\omega} ω和 b b b的参数估计$ \boldsymbol{\hat{\omega}}…...