机器翻译之Bahdanau注意力机制在Seq2Seq中的应用
目录
1.创建 添加了Bahdanau的decoder
2. 训练
3.定义评估函数BLEU
4.预测
5.知识点个人理解
1.创建 添加了Bahdanau的decoder
import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder): #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小, 相当于输入数据的特征数features, 也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据 (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层 (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播 (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2) #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], [] #创建空列表,用于存储数据for x in X: #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态 (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens) #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)
# print('query的shape:', query.shape) #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
# print('context的shape:', context.shape) #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# print('x的shape:', x.shape) #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
# print('out的shape:', out.shape) #out的shape=(1, batch_size, num_hiddens)
# print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
# print('outputs的shape:', outputs.shape) #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) outputs的shape: torch.Size([7, 4, 10])Out[11]:
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))
2. 训练
#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
3.定义评估函数BLEU
def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围, range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score
4.预测
import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000 i lost . => ("j'ai perdu .", []), bleu 1.000 he's calm . => ('il est bon .', []), bleu 0.658 i'm home . => ('je suis chez moi .', []), bleu 1.000
5.知识点个人理解
相关文章:

机器翻译之Bahdanau注意力机制在Seq2Seq中的应用
目录 1.创建 添加了Bahdanau的decoder 2. 训练 3.定义评估函数BLEU 4.预测 5.知识点个人理解 1.创建 添加了Bahdanau的decoder import torch from torch import nn import dltools#定义注意力解码器基类 class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写…...

MyBatis 入门教程-搭建入门工程
Maven作为一个优秀的项目构建和管理工具,在日常的开发中被大多数开发者使用,后续的项目也是基于Maven来构建。 创建一个Maven项目 利用IDEA创建项目工具来创建一个Maven项目 添加MyBatis的依赖 这里可以从Maven仓库地址中进行查看, https://mvnrepository.com/ 从这里可…...

CVE-2024-2389 未经身份验证的命令注入
什么是 Progress Flowmon? Progress Flowmon 是一种网络监控和分析工具,可提供对网络流量、性能和安全性的全面洞察。Flowmon 将 Nette PHP 框架用于其 Web 应用程序。 未经身份验证的路由 我们开始在“AllowedModulesDecider.php”文件中枚举未经身份验证的端点,这是一个描…...

C++初阶-list用法总结
目录 1.迭代器的分类 2.算法举例 3.push_back/emplace_back 4.insert/erase函数介绍 5.splice函数介绍 5.1用法一:把一个链表里面的数据给另外一个链表 5.2 用法二:调整链表当前的节点数据 6.unique去重函数介绍 1.迭代器的分类 我们的这个迭代器…...

【智能大数据分析 | 实验一】MapReduce实验:单词计数
【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈智能大数据分析 ⌋ ⌋ ⌋ 智能大数据分析是指利用先进的技术和算法对大规模数据进行深入分析和挖掘,以提取有价值的信息和洞察。它结合了大数据技术、人工智能(AI)、机器学习(ML&a…...

Git 版本控制--git restore和git reset
git restore 和 git reset 是 Git 版本控制系统中两个用于撤销更改的命令,但它们的作用范围和用途有所不同。 git restore git restore 是 Git 版本控制系统中的一个命令,用于撤销工作目录中的更改,但不影响暂存区(staging area…...

DBAPI如何实现插入数据前先判断数据是否存在,存在就更新,不存在就插入
DBAPI实现数据不存在即插入、存在即更新 场景 往数据库插入数据的时候,需要先判断一下记录是否在数据库已经存在,如果已经存在就更新记录,如果不存在,才插入数据。 实现方案 采用存储过程实现,以mysql为例子 创建存储过…...

【渗透测试】-灵当CRM系统-sql注入漏洞复现
文章目录 概要 灵当CRM系统sql注入漏洞: 具体实例: 技术名词解释 小结 概要 近期灵当CRM系统爆出sql注入漏洞,我们来进行nday复现。 灵当CRM系统sql注入漏洞: Python sqlmap.py -u "http://0.0.0.0:0000/c…...

c语言练习题1(数组和循环)
1实现一个对整形数组的冒泡排序 冒泡排序(Bubble Sort)是一种简单的排序算法。它重复地遍历要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。遍历数列的工作是重复进行的,直到没有再需要交换的元…...

实验3 Hadoop集群运行环境搭建和使用
实验3 Hadoop集群运行环境搭建和使用 一、实验介绍 本节实验旨在引导学生通过实际操作搭建一个基本的Hadoop集群,并进行基本的使用验证。实验包括在集群节点上添加域名映射以实现节点间的相互识别,配置免密SSH登录以便无密码访问各节点,安装和配置JDK以满足Hadoop的运行需求…...

前端文件上传全过程
特别说明:ui框架使用的是蚂蚁的antd 这里主要是学习前端上传接口的传递参数包括前端上传之前对于代码的整理 一、第一步将前端页面画出来 源代码: /** 费用管理 - IT费用管理 - 费用数据上传 */ import { useState } from "react"; import {…...

MySQL中的函数简单总结,以及TCL语句的简单讲解
文章目录 一、函数1、ifnull2、if3、case4、exists 存在5、字符串函数(重点)6、数学函数7、日期函数 二、TCL语句1、创建用户2、赋予权限3、修改mysql允许远程登录 一、函数 1、ifnull 当前⾯的值是null的时候,使⽤后⾯的默认值 ifnull(字段…...

GPS在Linux下的使用(war driving的前置学习)
1.ls /dev/tty* 列出所有与 tty 相关的设备文件。这些设备文件通常对应终端设备 ttyUSB0是GPS端口 2.cat /dev/ttyUSB0 用于读取并显示连接到 /dev/ttyUSB0 串口设备发送的原始数据 这种是GPS定位不全的,要拿到更开阔的地方 这种是GPS定位全的 因为会持续输出…...

开发经验总结: 读写分离简单实现
背景 使用mysql的代理中间件,某些接口如果主从同步延迟大,容易出现逻辑问题。所以程序中没有直接使用这个中间件。 依赖程序逻辑,如果有一些接口可以走读库,需要一个可以显示指定读库的方式来连接读库,降低主库的压力…...

MySQL(面试题 - 同类型归纳面试题)
目录 一、MySQL 数据类型 1. 数据库存储日期格式时,如何考虑时区转换问题? 2. Blob和text有什么区别? 3. mysql里记录货币用什么字段类型比较好? 4. MySQL如何获取当前日期? 5. 你们数据库是否支持emoji表情存储&…...

【C++ Primer Plus习题】17.7
问题: 解答: #include <iostream> #include <vector> #include <string> #include <fstream> #include <algorithm>using namespace std;const int LIMIT 50;void ShowStr(const string& str); void GetStrs(ifstream& fin, vector<…...

vue3(整合版)
创建第一个vue项目 1.安装node.js cmd输入node查看是否安装成功 2.vscode开启一个终端,配置淘宝镜像 # 修改为淘宝镜像源 npm config set registry https://registry.npmmirror.com 输入如下命令创建第一个Vue项目 3.下载依赖,启动项目 访问5173端口 …...

复制他人 CSDN 文章到自己的博客
文章目录 0.前言步骤 0.前言 在复制别人文章发布时,记得表明转载哦 步骤 在需要复制的csdn 文章页面,打开浏览器开发者工具(F12)Ctrl F 查找"article_content"标签头 右键“Copy”->“Copy element”新建一个 tx…...

【算法——二分查找】
理论基础: 程序员面试经典题,二分搜索一个区间,区间查找 (LeetCode 34)_哔哩哔哩_bilibili 手把手带你撕出正确的二分法 | 二分查找法 | 二分搜索法 | LeetCode:704. 二分查找_哔哩哔哩_bilibili 这个是红蓝法,很牛…...

Cisco Packet Tracer的安装加汉化
这个工具学计算机网络的同学会用到 1.下载安装 网盘链接:https://pan.baidu.com/s/1CmnxAD9MkCtE7pc8Tjw0IA 提取码:frkb 点击第一个进行安装,按步骤来即可。 2.汉化 (1)复制chinese.ptl文件 (2&…...

MMain函数定义为WinMain函数看port1632.h和pwin32.h文件
编译win2k3的源代码的时候有时候看到MMain函数 ..//public/sdk/inc/port1632.h #if defined(WIN16) /* ---------------- Maps to windows 3.0 and 3.1 16-bit APIs ----------------*/ #include "ptypes16.h" #include "pwin16.h" #include "plan16.…...

单词搜索问题(涉及递归等)
目录 一题目: 二思路解释: 三解答代码: 一题目: newcode题目链接: 单词搜索_牛客题霸_牛客网 二思路解释: 思路:个人理解是找到word中的第一个元素,然后去递归的上下左右查找&am…...

Redis的一些通用指令
首先我们需要先连接客户端服务器,此时我们需要通过redis-cli和redis服务器进行交互,输入ping来确保通路的流畅 (一)get和set redis中最核心的两个命令就是get和set,get就是根据key来取出对应value,set就是把…...

C++中vector类的使用
目录 1.vector类常用接口说明 1.1默认成员函数 1.1.1构造函数(constructor) 1.1.2 赋值运算符重载(operator()) 2. vector对象的访问及遍历操作(Iterators and Element access) 3.vector类对象的容量操作(Capacity) 4. vector类对象的修改及相关操作(Modifiers and Stri…...

cmaklist流程控制——调试及发布
cmaklist流程控制 目前只会配置-编译调试-打包发布,并且不会workflow控制 后续学习配置-编译调试-测试-打包发布,workflow控制,理解整个流程,目前对流程控制理解也不够。 1.CMake Presets 先于Cmakelist文件,指导项…...

制作一个能对话能跳舞的otto机器人
OTTO机器人是一个开源外壳,硬件和软件的桌面机器人项目,非常适合新手研究和拓展。记住,他是一个能移动有表情能声音的机器人。 b站有很多演示和组装的视频,我就不多说了,照着做就好,因为硬件我也是刚入门&…...

git配置SSH
1 打开cmd窗口 2 在窗口中输入如下命令: 配置用户名: git config --global user.name “gyk” 配置邮箱: git config --global user.email “247929163qq.com” 继续在Git命令窗口中输入如下命令,即可生成SSH公钥和私钥 ss…...

mozilla/pdf.js view.html加载指定页码
mozilla/pdf.js view.html加载指定页码 在Mozilla’s PDF.js中,如果你想要在viewer.html加载时直接跳转到指定的页码,你可以通过修改URL来实现。 PDF.js使用查询参数来处理URL,其中page参数用于指定页码。你可以通过修改URL的查询字符串来设…...

Qt之QFuture理解
结构 #mermaid-svg-J9J683RG8QjtEqoM {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-J9J683RG8QjtEqoM .error-icon{fill:#552222;}#mermaid-svg-J9J683RG8QjtEqoM .error-text{fill:#552222;stroke:#552222;}#merm…...

求二叉树的高度(递归和非递归)
假设二叉树采用二叉链表存储结构,设计一个算法求二叉树的高度。 递归: int getTreeHight(BiTree T){if(TNULL){return 0;}else {int lh getTreeHight(T->lchild);int rh getTreeHight(T->rchild);return (lh>rh?lh:rh)1;}}时间复杂度O(n)&a…...