机器翻译之Bahdanau注意力机制在Seq2Seq中的应用
目录
1.创建 添加了Bahdanau的decoder
2. 训练
3.定义评估函数BLEU
4.预测
5.知识点个人理解

1.创建 添加了Bahdanau的decoder
import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder): #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小, 相当于输入数据的特征数features, 也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据 (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层 (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播 (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2) #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], [] #创建空列表,用于存储数据for x in X: #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态 (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens) #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)
# print('query的shape:', query.shape) #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
# print('context的shape:', context.shape) #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# print('x的shape:', x.shape) #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
# print('out的shape:', out.shape) #out的shape=(1, batch_size, num_hiddens)
# print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
# print('outputs的shape:', outputs.shape) #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) query的shape: torch.Size([4, 1, 16]) context的shape: torch.Size([4, 1, 16]) x的shape: torch.Size([4, 1, 24]) out的shape: torch.Size([1, 4, 16]) hidden_state的shape: torch.Size([2, 4, 16]) outputs的shape: torch.Size([7, 4, 10])Out[11]:
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))
2. 训练
#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围, range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score
4.预测
import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000
5.知识点个人理解









相关文章:
机器翻译之Bahdanau注意力机制在Seq2Seq中的应用
目录 1.创建 添加了Bahdanau的decoder 2. 训练 3.定义评估函数BLEU 4.预测 5.知识点个人理解 1.创建 添加了Bahdanau的decoder import torch from torch import nn import dltools#定义注意力解码器基类 class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写…...
MyBatis 入门教程-搭建入门工程
Maven作为一个优秀的项目构建和管理工具,在日常的开发中被大多数开发者使用,后续的项目也是基于Maven来构建。 创建一个Maven项目 利用IDEA创建项目工具来创建一个Maven项目 添加MyBatis的依赖 这里可以从Maven仓库地址中进行查看, https://mvnrepository.com/ 从这里可…...
CVE-2024-2389 未经身份验证的命令注入
什么是 Progress Flowmon? Progress Flowmon 是一种网络监控和分析工具,可提供对网络流量、性能和安全性的全面洞察。Flowmon 将 Nette PHP 框架用于其 Web 应用程序。 未经身份验证的路由 我们开始在“AllowedModulesDecider.php”文件中枚举未经身份验证的端点,这是一个描…...
C++初阶-list用法总结
目录 1.迭代器的分类 2.算法举例 3.push_back/emplace_back 4.insert/erase函数介绍 5.splice函数介绍 5.1用法一:把一个链表里面的数据给另外一个链表 5.2 用法二:调整链表当前的节点数据 6.unique去重函数介绍 1.迭代器的分类 我们的这个迭代器…...
【智能大数据分析 | 实验一】MapReduce实验:单词计数
【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈智能大数据分析 ⌋ ⌋ ⌋ 智能大数据分析是指利用先进的技术和算法对大规模数据进行深入分析和挖掘,以提取有价值的信息和洞察。它结合了大数据技术、人工智能(AI)、机器学习(ML&a…...
Git 版本控制--git restore和git reset
git restore 和 git reset 是 Git 版本控制系统中两个用于撤销更改的命令,但它们的作用范围和用途有所不同。 git restore git restore 是 Git 版本控制系统中的一个命令,用于撤销工作目录中的更改,但不影响暂存区(staging area…...
DBAPI如何实现插入数据前先判断数据是否存在,存在就更新,不存在就插入
DBAPI实现数据不存在即插入、存在即更新 场景 往数据库插入数据的时候,需要先判断一下记录是否在数据库已经存在,如果已经存在就更新记录,如果不存在,才插入数据。 实现方案 采用存储过程实现,以mysql为例子 创建存储过…...
【渗透测试】-灵当CRM系统-sql注入漏洞复现
文章目录 概要 灵当CRM系统sql注入漏洞: 具体实例: 技术名词解释 小结 概要 近期灵当CRM系统爆出sql注入漏洞,我们来进行nday复现。 灵当CRM系统sql注入漏洞: Python sqlmap.py -u "http://0.0.0.0:0000/c…...
c语言练习题1(数组和循环)
1实现一个对整形数组的冒泡排序 冒泡排序(Bubble Sort)是一种简单的排序算法。它重复地遍历要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。遍历数列的工作是重复进行的,直到没有再需要交换的元…...
实验3 Hadoop集群运行环境搭建和使用
实验3 Hadoop集群运行环境搭建和使用 一、实验介绍 本节实验旨在引导学生通过实际操作搭建一个基本的Hadoop集群,并进行基本的使用验证。实验包括在集群节点上添加域名映射以实现节点间的相互识别,配置免密SSH登录以便无密码访问各节点,安装和配置JDK以满足Hadoop的运行需求…...
前端文件上传全过程
特别说明:ui框架使用的是蚂蚁的antd 这里主要是学习前端上传接口的传递参数包括前端上传之前对于代码的整理 一、第一步将前端页面画出来 源代码: /** 费用管理 - IT费用管理 - 费用数据上传 */ import { useState } from "react"; import {…...
MySQL中的函数简单总结,以及TCL语句的简单讲解
文章目录 一、函数1、ifnull2、if3、case4、exists 存在5、字符串函数(重点)6、数学函数7、日期函数 二、TCL语句1、创建用户2、赋予权限3、修改mysql允许远程登录 一、函数 1、ifnull 当前⾯的值是null的时候,使⽤后⾯的默认值 ifnull(字段…...
GPS在Linux下的使用(war driving的前置学习)
1.ls /dev/tty* 列出所有与 tty 相关的设备文件。这些设备文件通常对应终端设备 ttyUSB0是GPS端口 2.cat /dev/ttyUSB0 用于读取并显示连接到 /dev/ttyUSB0 串口设备发送的原始数据 这种是GPS定位不全的,要拿到更开阔的地方 这种是GPS定位全的 因为会持续输出…...
开发经验总结: 读写分离简单实现
背景 使用mysql的代理中间件,某些接口如果主从同步延迟大,容易出现逻辑问题。所以程序中没有直接使用这个中间件。 依赖程序逻辑,如果有一些接口可以走读库,需要一个可以显示指定读库的方式来连接读库,降低主库的压力…...
MySQL(面试题 - 同类型归纳面试题)
目录 一、MySQL 数据类型 1. 数据库存储日期格式时,如何考虑时区转换问题? 2. Blob和text有什么区别? 3. mysql里记录货币用什么字段类型比较好? 4. MySQL如何获取当前日期? 5. 你们数据库是否支持emoji表情存储&…...
【C++ Primer Plus习题】17.7
问题: 解答: #include <iostream> #include <vector> #include <string> #include <fstream> #include <algorithm>using namespace std;const int LIMIT 50;void ShowStr(const string& str); void GetStrs(ifstream& fin, vector<…...
vue3(整合版)
创建第一个vue项目 1.安装node.js cmd输入node查看是否安装成功 2.vscode开启一个终端,配置淘宝镜像 # 修改为淘宝镜像源 npm config set registry https://registry.npmmirror.com 输入如下命令创建第一个Vue项目 3.下载依赖,启动项目 访问5173端口 …...
复制他人 CSDN 文章到自己的博客
文章目录 0.前言步骤 0.前言 在复制别人文章发布时,记得表明转载哦 步骤 在需要复制的csdn 文章页面,打开浏览器开发者工具(F12)Ctrl F 查找"article_content"标签头 右键“Copy”->“Copy element”新建一个 tx…...
【算法——二分查找】
理论基础: 程序员面试经典题,二分搜索一个区间,区间查找 (LeetCode 34)_哔哩哔哩_bilibili 手把手带你撕出正确的二分法 | 二分查找法 | 二分搜索法 | LeetCode:704. 二分查找_哔哩哔哩_bilibili 这个是红蓝法,很牛…...
Cisco Packet Tracer的安装加汉化
这个工具学计算机网络的同学会用到 1.下载安装 网盘链接:https://pan.baidu.com/s/1CmnxAD9MkCtE7pc8Tjw0IA 提取码:frkb 点击第一个进行安装,按步骤来即可。 2.汉化 (1)复制chinese.ptl文件 (2&…...
AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例
使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...
2.Vue编写一个app
1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...
linux arm系统烧录
1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...
C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。
1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj,再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...
JVM虚拟机:内存结构、垃圾回收、性能优化
1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...
深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用
文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...
C#中的CLR属性、依赖属性与附加属性
CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...
Golang——9、反射和文件操作
反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一:使用Read()读取文件2.3、方式二:bufio读取文件2.4、方式三:os.ReadFile读取2.5、写…...
