当前位置: 首页 > news >正文

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习

  • 深度学习-seq2seq模型
    • 什么是seq2seq模型
    • 应用场景
    • 架构
      • 编码器
      • 解码器
      • 训练 & 预测
      • 损失
      • 预测
      • 评估BLEU
        • BELU背后的数学意义
    • 模型参考论文

深度学习-seq2seq模型

本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014

什么是seq2seq模型

Sequence to sequence (seq2seq)是由encoder(编码器)和decoder(解码器)两个RNN组成的(注意本文中的RNN指代所有的循环神经网络,包括RNN、GRU、LSTM等)。
其中encoder负责对输入句子的理解,输出context vector(上下文变量)给decoder,decoder负责对理解后的句子的向量进行处理,解码,获得输出

应用场景

主要用来处理输入和输出序列长度不定的问题,在之前的RNN一文中,RNN的分类讲解过,其中就包括多对多结构,这个seq2seq模型就是典型的多对多,还是长度不一致的多对多,它的应用有很多场景,比如机器翻译,机器人问答,文章摘要,由关键字生成对话等等
例如翻译场景:
【hey took the little cat to the animal center】-> [他们把这只小猫送到了动物中心]
输入和输出长度没法一致

架构

在这里插入图片描述
整个架构是编码器-解码器结构

编码器

一般都是一个普通的RNN结构,不需要特殊的实现

class Encoder(nn.Module):"""用于序列到序列学习的循环神经网络 编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Encoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.gru = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):# input shape (batchsize, num_steps)-> (batchsize, num_steps, embedingdim)X = self.embedding(X)# 交换dim,pythorch要求batchsize位置X = X.permute(1, 0, 2)# encode编码# out的形状 (num_steps, batch_size, num_hiddens)# state的形状: (num_layers, batch_size, num_hiddens)output, state = self.gru(X)return output, state

解码器

在 (Sutskever et al., 2014)的设计:
输入序列的编码信息送入到解码器中来生成输出序列的。
(Cho et al., 2014)设计: 编码器最终的隐状态在每一个时间步都作为解码器的输入序列的一部分。
上面架构图中展示的正式这种设计
在解码器中,在训练的时候比较特殊,可以允许真实值(标签)成为原始的输出序列, 从源序列词元“”“Ils”“regardent”“.” 到新序列词元 “Ils”“regardent”“.”“”来移动预测的位置。
解码器

class Decoder(nn.Module):"""用于序列到序列学习的循环神经网络 解码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Decoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 与普通gru区别:input_size增加num_hiddens,用于input输入解码器encode的输出self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs):# 初始化decode的hidden, 使用enc_outputs[1],enc_outputs格式(output, hidden state)return enc_outputs[1]def forward(self, X, state):""":param X:     input,        shape is (num_steps, batch_size, embed_size):param state: hidden state, shape is( num_layers,batch_size, num_hiddens):return:"""# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X).permute(1, 0, 2)# 广播state的0维,使它与X具有相同的num_steps的维度,方便后续拼接,输出context的shape(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)# conect input and context (num_steps, batch_size, embed_size+num_hiddens)x_and_context = torch.cat((X, context), 2)# output的形状:(num_steps, batch_size, num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)output, state = self.rnn(x_and_context, state)# output的形状(batch_size,num_steps,vocab_size)output = self.dense(output).permute(1, 0, 2)return output, state

训练 & 预测

def train(net, data_iter, lr, num_epochs, tgt_vocab, device):net.to(device)loss = MaskedSoftmaxCELoss()optimizer = torch.optim.Adam(net.parameters(), lr=lr)for epoch in range(num_epochs):num_tokens = 0total_loss = 0for batch in data_iter:optimizer.zero_grad()X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学Y_hat, _ = net(X, dec_input)# Y_hat的形状(batch_size,num_steps,vocab_size)# Y的形状batch_size,num_steps# loss内部permute Y_hat = Y_hat.permute(0, 2, 1)l = loss(Y_hat, Y, Y_valid_len)# 损失函数的标量进行“反向传播”l.sum().backward()#梯度裁剪grad_clipping(net, 1)#梯度更新optimizer.step()num_tokens = Y_valid_len.sum()total_loss = l.sum()print('epoch{}, loss{:.3f}'.format(epoch, total_loss/num_tokens))

损失

这里特别的说明一下,NLP中的损失通常用的都是基于交叉熵损失的masksoftmax损失,它只是在交叉熵损失的基础上封装了一点,mask了pad填充的词元,这个损失函数的意思,举个例子说明一下:
假设解码器的lable是【they are watching】,通常会用unk等pad这些句子到一定的长度,这个长度是代码中由你自行指定的,也是decoder的num_steps,比如我们设置了10,那么此时整个输入会被pad成【they are wathing unk unk unk unk unk unk unk eos】,但是计算损失的时候,我们不需要计算这部分,对应的损失需要置为0

预测

预测与评估的过程相同,但是稍有不同的是,预测过程不知道真实的输出标签,所以都是用上一步的预测值来作为下一个时间步的输入的。这里不再复述

评估BLEU

与其他输出固定的评估不一样,这次是一个句子的评估,常用的方法是:BLEU(bilingual evaluation understudy),最早用于机器翻译,现在也是被广泛用于各种其他的领域

在这里插入图片描述
BLEU的评估都是n-grams词元是否出现在标签序列中
lenlable表示标签序列中的词元数和
lenlpred表示预测序列中的词元数
pn 预测序列与标签序列中匹配的n元词元的数量, 与 预测序列中
n元语法的数量的比率
BELU肯定是越大越好,最好的情况肯定是1,那就是完全匹配

举个例子:给定标签序列A B C D E F 和预测序列 A B B C D
lenlable是6
lenlpred是5
p1 1元词元在lable和 pred中匹配的数量 B C D 也就是4 与 预测序列中1元词元个数 5 也就是0.8
其他pi也是依次计算 i 从1取到预测长度 -1 (也就是4)分别计算出来是3/4 1/3和0
前面的)

BLUE实现简单,此处也不再展现代码了

BELU背后的数学意义

首先后面概率相加的这部分:
在这里插入图片描述

n元词法,当n越长则匹配难度越大, 所以BLEU为更长的元语法的精确度分配更大的权重,否则一个不完全匹配的句子可能会比全匹配的概率更大,这里就表现为,n越大,pn1/2n就越大
在这里插入图片描述
这一项是惩罚项,越短的句子就会降低belu分数,比如
给定标签序列A B C D E F 和预测序列 A B 和ABC 虽然p1 和p2 都是1,惩罚因此会降低短序列的分数
篇幅有限,代码无法一一展现,如果需要全部代码的小伙伴可以私信我

模型参考论文

Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. Advances in neural information processing systems (pp. 3104–3112).

Cho et al., 2014a
Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

Cho et al., 2014b
Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.

李沐 动手深度学习

相关文章:

【深度学习-seq2seq模型-附核心encoder和decoder代码】

深度学习 深度学习-seq2seq模型什么是seq2seq模型应用场景架构编码器解码器训练 & 预测损失预测评估BLEUBELU背后的数学意义 模型参考论文 深度学习-seq2seq模型 本文的网络架构模型参考 Sutskever et al., 2014以及Cho et al., 2014 什么是seq2seq模型 Sequence to seq…...

videojs 实现自定义组件(视频画质/清晰度切换) React

前言 最近使用videojs作为视频处理第三方库&#xff0c;用来对接m3u8视频类型。这里总结一下自定义组件遇到的问题及实现&#xff0c;目前看了许多文章也不全&#xff0c;官方文档写的也不是很详细&#xff0c;自己摸索了一段时间陆陆续续完成了&#xff0c;这是实现后的效果.…...

python 模块urllib3 HTTP 客户端库

官网文档地址&#xff1a;https://urllib3.readthedocs.io/en/stable/reference/index.html 一、安装 pip install urlib3二、基本使用 import urllib3 import threadingimg_list ["https://pic.netbian.com/uploads/allimg/220211/004115-1644511275bc26.jpg",&…...

2023 CCPC 华为云计算挑战赛 D-塔

首先先来看第一轮的 假如有n个,每轮那k个 他们的高度的可能性分别为 n 1/C(n,k) n1 C(n-(k-11),1)/C(n,k) n2 C(n-(k-21),2)/C(n,k) ni C(n-(k-i1,i)/C(n,k) 通过概率和高度算出第一轮增加的期望 然后乘上m轮增加的高度加上初始高度&#xff0c;就是总共增加的高度 下面是…...

手搓大模型值just gru

这些类是构建神经网络模型的有用工具,并提供了一些关键功能: EmAdd类使文本输入数据嵌入成为可能,在自然语言处理任务中被广泛使用。通过屏蔽处理填充序列的能力对许多应用程序也很重要。 HeadLoss类是训练神经网络模型进行分类任务的常见损失函数。它计算损失和准确率的能力…...

eslint

什么是eslint ESLint 是一个根据方案识别并报告 ECMAScript/JavaScript 代码问题的工具&#xff0c;其目的是使代码风格更加一致并避免错误。 安装eslint npm init eslint/config执行后会有很多选项&#xff0c;按照自己的需求去选择就好&#xff0c;运行成功后会生成 .esli…...

node_modules.cache是什么东西

一开始没明白这是啥玩意&#xff0c;还以为是npm的属性&#xff0c;网上也没说过具体的来源出处 .cache文件的产生是由webpack4的插件cache-loader生成的&#xff0c;node_modules里下载了cache-loader插件&#xff0c;很多朋友都是vuecli工具生成的项目&#xff0c;内置了这部…...

Python 包管理(pip、conda)基本使用指南

Python 包管理 概述 介绍 Python 有丰富的开源的第三方库和包&#xff0c;可以帮助完成各种任务&#xff0c;扩展 Python 的功能&#xff0c;例如 NumPy 用于科学计算&#xff0c;Pandas 用于数据处理&#xff0c;Matplotlib 用于绘图等。在开始编写 Pytlhon 程序之前&#…...

系统级封装(SiP)技术如何助力智能化应用发展呢?

智能化时代&#xff0c;各种智能设备、智能互连的高速发展与跨界融合&#xff0c;需要高密度、高性能的微系统集成技术作为重要支撑。 例如&#xff0c;在系统级封装&#xff08;SiP&#xff09;技术的加持下&#xff0c;5G手机的射频电路面积更小&#xff0c;但支持的频段更多…...

git配置代理(github配置代理)

命令行配置代理方式一git config --global http.proxy http://代理服务器地址:端口号git config --global https.proxy https://代理服务器地址:端口号如果有用户名密码按照下面命令配置 git config --global http.proxy http://用户名:密码代理服务器地址:端口号git config --…...

【数据结构】详解环形队列

文章目录 &#x1f30f;引言&#x1f340;[循环队列](https://leetcode.cn/problems/design-circular-queue/description/)&#x1f431;‍&#x1f464;题目描述&#x1f431;‍&#x1f453;示例&#xff1a;&#x1f431;‍&#x1f409;提示&#x1f431;‍&#x1f3cd;思…...

Python爬取网页详细教程:从入门到进阶

【导言】&#xff1a; Python作为一门强大的编程语言&#xff0c;常常被用于编写网络爬虫程序。本篇文章将为大家详细介绍Python爬取网页的整个流程&#xff0c;从安装Python和必要的库开始&#xff0c;到发送HTTP请求、解析HTML页面&#xff0c;再到提取和处理数据&#xff0…...

linux安装JDK及hadoop运行环境搭建

1.linux中安装jdk &#xff08;1&#xff09;下载JDK至opt/install目录下&#xff0c;opt下创建目录soft&#xff0c;并解压至当前目录 tar xvf ./jdk-8u321-linux-x64.tar.gz -C /opt/soft/ &#xff08;2&#xff09;改名 &#xff08;3&#xff09;配置环境变量&#xf…...

使用ChatGPT一键生成思维导图

指令1&#xff1a;接下来你回复的所有内容&#xff0c;都放到Markdown代码框中。 指令2&#xff1a;作为一个Docker专家&#xff0c;为我编写一个详细全面的Docker学习大纲&#xff0c;包括基础知识、进阶知识、项目实践案例&#xff0c;学习书籍推荐、学习网站推荐等&#xf…...

极简Vim教程

2023年8月27日&#xff0c;周日上午 我不想学那么多命令和快捷键&#xff0c;够用就行... 所以就把我自己认为比较常用的命令和快捷键记录成博客 目录 预备知识Vim的工作模式保存内容退出Vim复制、粘贴和剪切选中一段内容复制粘贴剪切撤回和反撤回撤回反撤回查找替换删除删除…...

在线帮助中心也属于知识管理的一种吗?

在线帮助中心是企业或组织为了提供客户支持而建立的一个在线平台&#xff0c;它包含了各种类型的知识和信息&#xff0c;旨在帮助用户解决问题和获取相关的信息。从知识管理的角度来看&#xff0c;可以说在线帮助中心也属于知识管理的一种形式。下面将详细介绍在线帮助中心作为…...

《Linux从练气到飞升》No.18 进程终止

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…...

自动化运维工具——ansible安装及模块介绍

目录 一、ansible——自动化运维工具 1.1 Ansible 自动运维工具特点 1.2 Ansible 运维工具原理 二、安装ansible 三、ansible命令模块 3.1 command模块 3.2 shell模块 3.3 cron模块 3.4 user模块 3.5 group 模块 3.6 copy模块 3.7 file模块 3.8 ping模…...

Qt XML文件解析 QDomDocument

QtXml模块提供了一个读写XML文件的流&#xff0c;解析方法包含DOM和SAX,两者的区别是什么呢&#xff1f; DOM&#xff08;Document Object Model&#xff09;&#xff1a;将XML文件保存为树的形式&#xff0c;操作简单&#xff0c;便于访问。 SAX&#xff08;Simple API for …...

Vue2向Vue3过度Vuex状态管理工具快速入门

目录 1 Vuex概述1.是什么2.使用场景3.优势4.注意&#xff1a; 2 需求: 多组件共享数据1.创建项目2.创建三个组件, 目录如下3.源代码如下 3 vuex 的使用 - 创建仓库1.安装 vuex2.新建 store/index.js 专门存放 vuex3.创建仓库 store/index.js4 在 main.js 中导入挂载到 Vue 实例…...

多模态扩展:OpenClaw结合Qwen3.5-4B-Claude处理截图信息

多模态扩展&#xff1a;OpenClaw结合Qwen3.5-4B-Claude处理截图信息 1. 为什么需要多模态能力 作为一个长期依赖文本交互的技术爱好者&#xff0c;我最初对OpenClaw的理解停留在"能通过自然语言控制电脑的AI助手"层面。直到上个月需要处理大量产品截图中的文字信息…...

IE浏览器已成过去式?Win10用户必看的IE性能优化与安全设置

IE浏览器性能优化与安全设置指南&#xff1a;告别卡顿与劫持困扰 微软宣布放弃IE浏览器已经过去多年&#xff0c;但这款"古董级"浏览器依然顽固地存在于我们的Windows系统中。对于许多企业用户和特定行业从业者来说&#xff0c;完全卸载IE并非可行选项——某些老旧的…...

OpenCV手眼标定避坑指南:inner和outer内参到底怎么选?

OpenCV手眼标定避坑指南&#xff1a;inner和outer内参到底怎么选&#xff1f; 在工业自动化领域&#xff0c;手眼标定&#xff08;Eye to Hand&#xff09;是连接视觉系统与机械臂的关键技术环节。许多工程师在使用OpenCV进行标定时&#xff0c;常常对getOptimalNewCameraMatri…...

别再只用总基尼系数了!用Python实现Dagum分解,看清区域差距的‘里子’

用Python拆解经济差距&#xff1a;Dagum基尼系数分解实战指南 当一份区域经济报告只给出一个总的基尼系数时&#xff0c;就像医生只告诉你"体温偏高"却不说明是哪个器官发炎——数据研究者常陷入这种诊断困境。传统基尼系数虽能反映整体不平等程度&#xff0c;却无法…...

【AI智能体实战】基于Dify构建自然语言数据库查询系统的全流程解析

1. 为什么需要自然语言查询数据库&#xff1f; 想象一下这个场景&#xff1a;市场部的同事小王需要从公司数据库里找出"去年销售额超过100万且退货率低于5%的客户名单"。如果他不会写SQL&#xff0c;要么得找IT部门帮忙&#xff0c;要么得花半天时间导出Excel手动筛选…...

阿里云服务器+域名备案全流程避坑指南(附小程序开发必备配置)

阿里云服务器与域名备案实战指南&#xff1a;从小程序开发到前后端部署全解析 第一次在阿里云上配置服务器并完成域名备案的经历&#xff0c;就像新手司机独自上高速——既兴奋又忐忑。记得去年我们团队开发校园服务小程序时&#xff0c;原本计划两周完成的服务器部署&#xff…...

CodeBlocks-25.03 在 Windows 上的完整配置与避坑指南

1. 为什么选择CodeBlocks-25.03&#xff1f; 如果你刚开始学习C/C编程&#xff0c;CodeBlocks绝对是个不错的选择。作为一个开源的集成开发环境&#xff08;IDE&#xff09;&#xff0c;它轻量级、跨平台&#xff0c;最重要的是完全免费。我十年前刚开始写代码时用的就是CodeBl…...

别再只用ChatGPT了!用JavaScript的Web Speech API给你的网页加个‘嘴’(附完整代码)

用Web Speech API给你的网页装个"智能语音助手"&#xff1a;从基础到实战 当我们在讨论网页交互创新时&#xff0c;大多数人会立刻想到复杂的AI对话系统。但你可能不知道&#xff0c;浏览器原生就内置了一个被严重低估的语音合成神器——Web Speech API。想象一下&am…...

汽车ECU BootLoader开发:基于CAN总线与MPC57XX系列MCU

汽车ECU BootLoader开发基于CAN总线通信MPC57XX系列MCU bootloader开发 文档54页 在汽车电子领域&#xff0c;ECU&#xff08;Electronic Control Unit&#xff09;的重要性不言而喻&#xff0c;而BootLoader则是ECU中关键的一环。今天咱们就来聊聊基于CAN总线通信&#xff0c…...

Raspotify多用户环境配置终极指南:在家庭网络中共享Spotify音乐服务

Raspotify多用户环境配置终极指南&#xff1a;在家庭网络中共享Spotify音乐服务 【免费下载链接】raspotify A Spotify Connect client that mostly Just Works™ 项目地址: https://gitcode.com/gh_mirrors/ra/raspotify 想要在家庭网络中打造一个完美的音乐共享系统吗…...