当前位置: 首页 > news >正文

机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property    #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层  (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播   (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], []  #创建空列表,用于存储数据for x in X:  #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围,  range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score

 4.预测

import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

相关文章:

机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录 1.创建 添加了Bahdanau的decoder 2. 训练 3.定义评估函数BLEU 4.预测 5.知识点个人理解 1.创建 添加了Bahdanau的decoder import torch from torch import nn import dltools#定义注意力解码器基类 class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写…...

MyBatis 入门教程-搭建入门工程

Maven作为一个优秀的项目构建和管理工具,在日常的开发中被大多数开发者使用,后续的项目也是基于Maven来构建。 创建一个Maven项目 利用IDEA创建项目工具来创建一个Maven项目 添加MyBatis的依赖 这里可以从Maven仓库地址中进行查看, https://mvnrepository.com/ 从这里可…...

CVE-2024-2389 未经身份验证的命令注入

什么是 Progress Flowmon? Progress Flowmon 是一种网络监控和分析工具,可提供对网络流量、性能和安全性的全面洞察。Flowmon 将 Nette PHP 框架用于其 Web 应用程序。 未经身份验证的路由 我们开始在“AllowedModulesDecider.php”文件中枚举未经身份验证的端点,这是一个描…...

C++初阶-list用法总结

目录 1.迭代器的分类 2.算法举例 3.push_back/emplace_back 4.insert/erase函数介绍 5.splice函数介绍 5.1用法一:把一个链表里面的数据给另外一个链表 5.2 用法二:调整链表当前的节点数据 6.unique去重函数介绍 1.迭代器的分类 我们的这个迭代器…...

【智能大数据分析 | 实验一】MapReduce实验:单词计数

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈智能大数据分析 ⌋ ⌋ ⌋ 智能大数据分析是指利用先进的技术和算法对大规模数据进行深入分析和挖掘,以提取有价值的信息和洞察。它结合了大数据技术、人工智能(AI)、机器学习(ML&a…...

Git 版本控制--git restore和git reset

git restore 和 git reset 是 Git 版本控制系统中两个用于撤销更改的命令,但它们的作用范围和用途有所不同。 git restore git restore 是 Git 版本控制系统中的一个命令,用于撤销工作目录中的更改,但不影响暂存区(staging area…...

DBAPI如何实现插入数据前先判断数据是否存在,存在就更新,不存在就插入

DBAPI实现数据不存在即插入、存在即更新 场景 往数据库插入数据的时候,需要先判断一下记录是否在数据库已经存在,如果已经存在就更新记录,如果不存在,才插入数据。 实现方案 采用存储过程实现,以mysql为例子 创建存储过…...

【渗透测试】-灵当CRM系统-sql注入漏洞复现

文章目录 概要   灵当CRM系统sql注入漏洞:   具体实例:  技术名词解释  小结 概要 近期灵当CRM系统爆出sql注入漏洞,我们来进行nday复现。 灵当CRM系统sql注入漏洞: Python sqlmap.py -u "http://0.0.0.0:0000/c…...

c语言练习题1(数组和循环)

1实现一个对整形数组的冒泡排序 冒泡排序(Bubble Sort)是一种简单的排序算法。它重复地遍历要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。遍历数列的工作是重复进行的,直到没有再需要交换的元…...

实验3 Hadoop集群运行环境搭建和使用

实验3 Hadoop集群运行环境搭建和使用 一、实验介绍 本节实验旨在引导学生通过实际操作搭建一个基本的Hadoop集群,并进行基本的使用验证。实验包括在集群节点上添加域名映射以实现节点间的相互识别,配置免密SSH登录以便无密码访问各节点,安装和配置JDK以满足Hadoop的运行需求…...

前端文件上传全过程

特别说明:ui框架使用的是蚂蚁的antd 这里主要是学习前端上传接口的传递参数包括前端上传之前对于代码的整理 一、第一步将前端页面画出来 源代码: /** 费用管理 - IT费用管理 - 费用数据上传 */ import { useState } from "react"; import {…...

MySQL中的函数简单总结,以及TCL语句的简单讲解

文章目录 一、函数1、ifnull2、if3、case4、exists 存在5、字符串函数(重点)6、数学函数7、日期函数 二、TCL语句1、创建用户2、赋予权限3、修改mysql允许远程登录 一、函数 1、ifnull 当前⾯的值是null的时候,使⽤后⾯的默认值 ifnull(字段…...

GPS在Linux下的使用(war driving的前置学习)

1.ls /dev/tty* 列出所有与 tty 相关的设备文件。这些设备文件通常对应终端设备 ttyUSB0是GPS端口 2.cat /dev/ttyUSB0 用于读取并显示连接到 /dev/ttyUSB0 串口设备发送的原始数据 这种是GPS定位不全的,要拿到更开阔的地方 这种是GPS定位全的 因为会持续输出…...

开发经验总结: 读写分离简单实现

背景 使用mysql的代理中间件,某些接口如果主从同步延迟大,容易出现逻辑问题。所以程序中没有直接使用这个中间件。 依赖程序逻辑,如果有一些接口可以走读库,需要一个可以显示指定读库的方式来连接读库,降低主库的压力…...

MySQL(面试题 - 同类型归纳面试题)

目录 一、MySQL 数据类型 1. 数据库存储日期格式时,如何考虑时区转换问题? 2. Blob和text有什么区别? 3. mysql里记录货币用什么字段类型比较好? 4. MySQL如何获取当前日期? 5. 你们数据库是否支持emoji表情存储&…...

【C++ Primer Plus习题】17.7

问题: 解答: #include <iostream> #include <vector> #include <string> #include <fstream> #include <algorithm>using namespace std;const int LIMIT 50;void ShowStr(const string& str); void GetStrs(ifstream& fin, vector<…...

vue3(整合版)

创建第一个vue项目 1.安装node.js cmd输入node查看是否安装成功 2.vscode开启一个终端&#xff0c;配置淘宝镜像 # 修改为淘宝镜像源 npm config set registry https://registry.npmmirror.com 输入如下命令创建第一个Vue项目 3.下载依赖&#xff0c;启动项目 访问5173端口 …...

复制他人 CSDN 文章到自己的博客

文章目录 0.前言步骤 0.前言 在复制别人文章发布时&#xff0c;记得表明转载哦 步骤 在需要复制的csdn 文章页面&#xff0c;打开浏览器开发者工具&#xff08;F12&#xff09;Ctrl F 查找"article_content"标签头 右键“Copy”->“Copy element”新建一个 tx…...

【算法——二分查找】

理论基础&#xff1a; 程序员面试经典题&#xff0c;二分搜索一个区间&#xff0c;区间查找 (LeetCode 34)_哔哩哔哩_bilibili 手把手带你撕出正确的二分法 | 二分查找法 | 二分搜索法 | LeetCode&#xff1a;704. 二分查找_哔哩哔哩_bilibili 这个是红蓝法&#xff0c;很牛…...

Cisco Packet Tracer的安装加汉化

这个工具学计算机网络的同学会用到 1.下载安装 网盘链接&#xff1a;https://pan.baidu.com/s/1CmnxAD9MkCtE7pc8Tjw0IA 提取码&#xff1a;frkb 点击第一个进行安装&#xff0c;按步骤来即可。 2.汉化 &#xff08;1&#xff09;复制chinese.ptl文件 &#xff08;2&…...

XCTF-web-easyupload

试了试php&#xff0c;php7&#xff0c;pht&#xff0c;phtml等&#xff0c;都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接&#xff0c;得到flag...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

Java多线程实现之Callable接口深度解析

Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

HDFS分布式存储 zookeeper

hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架&#xff0c;允许使用简单的变成模型跨计算机对大型集群进行分布式处理&#xff08;1.海量的数据存储 2.海量数据的计算&#xff09;Hadoop核心组件 hdfs&#xff08;分布式文件存储系统&#xff09;&a…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf

FTP 客服管理系统 实现kefu123登录&#xff0c;不允许匿名访问&#xff0c;kefu只能访问/data/kefu目录&#xff0c;不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...

华为OD机试-最短木板长度-二分法(A卷,100分)

此题是一个最大化最小值的典型例题&#xff0c; 因为搜索范围是有界的&#xff0c;上界最大木板长度补充的全部木料长度&#xff0c;下界最小木板长度&#xff1b; 即left0,right10^6; 我们可以设置一个候选值x(mid)&#xff0c;将木板的长度全部都补充到x&#xff0c;如果成功…...

Spring Boot + MyBatis 集成支付宝支付流程

Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例&#xff08;电脑网站支付&#xff09; 1. 添加依赖 <!…...