当前位置: 首页 > news >正文

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

相关文章:

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类 1.1创建编码器的基类 from torch import nn#构建编码器的基类 class Encoder(nn.Module): #继承父类nn.Moduledef __init__(self, **kwargs): #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args…...

锤炼核心技能以应对编程革命

一、引言  随着人工智能的快速发展,尤其是AIGC等大语言模型的涌现,AI辅助编程工具逐渐成为程序员的新伙伴。这一变革不仅引发了关于AI是否能取代部分编程工作的讨论,也促使程序员重新思考自己的职业发展和技能提升路径。在AI时代&#xff0…...

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…...

js跑马灯效果、横向、纵向滚动效果

比如横向滚动&#xff0c;则在li标签里设置 display: table-cell;滚动效果 transform: translateX(-200px); <div id"div1" ><ul><li><img src"imgs/Snipaste_2022-11-22_18-13-13.png"></li><li><img src"i…...

C#基础(14)冒泡排序

前言 其实到上一节结构体我们就已经将c#的基础知识点大概讲完&#xff0c;接下来我们会讲解一些关于算法相关的东西。 我们一样来问一下gpt吧&#xff1a; Q:解释算法 A: 算法是一组有序的逻辑步骤&#xff0c;用于解决特定问题或执行特定任务。它可以是一个计算过程、一个…...

喜报 | 众数信科荣获2024年“火炬瞪羚企业”称号

近日&#xff0c;厦门火炬高新区公布2024年“火炬瞪羚企业”名单&#xff0c;众数&#xff08;厦门&#xff09;信息科技有限公司凭借在AI领域的综合实力、技术创新及典型场景应用等方面的卓越表现&#xff0c;成功入选。 瞪羚企业 一般指高成长性科技型企业&#xff0c;是跨过…...

中央企业数智化薪酬信息系统建设如何实现穿透式监管?

近年来&#xff0c;深化国有企业改革成为推动高质量发展的重要抓手&#xff0c;薪酬管理作为其中的关键领域&#xff0c;备受关注。国资委于近日发布了《关于加强中央企业薪酬管理信息系统建设的通知》&#xff0c;并召开了中央企业薪酬管理信息系统建设工作部署会议&#xff0…...

110Redis 简明教程--Redis 数据类型

Redis strings 字符串是一种最基本、最常用的 Redis 值类型。 Redis 字符串是二进制安全的&#xff0c;这意味着一个 Redis 字符串能包含任意类型的数据&#xff0c;例如&#xff1a; 一张经过 base64 编码的图片或者一个序列化的 Ruby 对象。通过这样的方式&#xff0c;Redis …...

Spring Data Rest 远程命令执行命令(CVE-2017-8046)

&#xff08;1&#xff09;访问 http://your-ip:8080/customers/1&#xff0c;然后抓取数据包&#xff0c;使用PATCH请求来修改 PATCH /customers/1 HTTP/1.1 Host: Accept-Encoding: gzip, deflate Accept: */* Accept-Language: en User-Agent: Mozilla/5.0 (compatible; MS…...

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18 1. The Application of Large Language Models in Primary Healthcare Services and the Challenges W YAN, J HU, H ZENG, M LIU, W LIANG - Chinese General Practice, 2024 人工智能大语言模型在基层医疗…...

搜索算法:Fibonacci查找

### 什么是Fibonacci查找 Fibonacci查找是一种搜索算法&#xff0c;它结合了Fibonacci数列和二分查找的思想&#xff0c;用于在有序数组中查找目标值。它的主要优点是在某些情况下可以比普通二分查找更高效。 ### Fibonacci数列 Fibonacci数列是一个递归定义的数列&#xff0…...

软件验收测试报告有什么作用?第三方验收测试报告包括哪些内容?

在现代软件开发中&#xff0c;软件验收测试报告占据了极为重要的地位&#xff0c;不仅是软件交付过程中的一环&#xff0c;更是软件质量保障的关键工具。 软件验收测试报告是指在软件开发过程中&#xff0c;针对软件的功能、性能、安全等方面进行的一系列测试后&#xff0c;形…...

AI大模型教程 Prompt提示词工程 AI原生应用开发零基础入门到实战【2024超细超全,建议收藏】

在AGI&#xff08;通用人工智能&#xff09;时代&#xff0c;那些既精通AI技术、又具备编程能力和业务洞察力的复合型人才将成为最宝贵的资源。为此&#xff0c;我们提出了‘AI全栈工程师’这一概念&#xff0c;旨在更精准地描述这一复合型人才群体&#xff0c;而非过分夸大其词…...

Pinia的快捷使用方法

安装Pinia npm install pinia 在main.js里面引入并注册挂载使用 在src下创建一个store inex.js // index.js import { defineStore } from pinia import { computed, ref } from vue //更简洁的的模块化 transferringValuesBetweenComponents simulationModule //简单定义了…...

一文搞懂C++继承

一文搞懂C继承 1.继承的概念及定义1.1继承的概念1.2 继承定义1.2.1定义格式1.2.2继承关系和访问限定符1.2.3继承基类成员访问方式的变化 2.基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数4.1 构造函数4.2 拷贝构造4.3 赋值重载4.4 析构函数 5.继承与友元6. 继…...

MFC -文件类控件

前言 各位师傅大家好&#xff0c;我是qmx_07&#xff0c;今天给大家讲解MFC中的文件类 MFC文件类 在MFC中&#xff0c;CFILE 是基本的文件操作类&#xff0c;提供了读取、写入、打开、关闭等操作方法主要成员函数:Open(用于打开文件&#xff0c;设置模式 例如 只读 只写 读…...

Hbase操作手册

一&#xff1a;Hbase 创建数据库表 1.进入hbase shell 2.创建数据库表的命令&#xff1a;create 表名, 列族名1,列族名2,列族名N 3.如果想查看所有数据库表&#xff0c;可以使用list 命令&#xff1a; 4.可以看到&#xff0c;刚创建的数据库表user 已经在数据库表的列表中&…...

vue组件($refs对象,动态组件,插槽,自定义指令)

一、ref 1.ref引用 每个vue组件实例上&#xff0c;都包含一个$refs对象&#xff0c;里面存储着对应dom元素或组件的引用。默认情况下&#xff0c;组件的$refs指向一个空对象。 2.使用ref获取dom元素的引用 <template><h3 ref"myh3">ref组件</h3&g…...

构建高可用和高防御力的云服务架构第五部分:PolarDB(5/5)

引言 云计算与数据库服务 云计算作为一种革命性的技术&#xff0c;已经深刻改变了信息技术行业的面貌。它通过提供按需分配的计算资源&#xff0c;使得数据存储、处理和分析变得更加灵活和高效。在云计算的众多服务中&#xff0c;数据库服务扮演着核心角色。数据库服务不仅负…...

QT窗口无法激活弹出问题排查记录

问题背景 问题环境 操作系统: 银河麒麟V10SP1qt版本 : 5.12.12 碰见了一个问题应用最小化,然后激活程序窗口无法弹出 这里描述一下代码的逻辑,使用QLocalServer实现一个单例进程,具体的功能就是在已存在一个程序A进程时,再启动这个程序A,新的程序A进程会被杀死,然后激活已存…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法

树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作&#xff0c;无需更改相机配置。但是&#xff0c;一…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

基于TurtleBot3在Gazebo地图实现机器人远程控制

1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...

基于 TAPD 进行项目管理

起因 自己写了个小工具&#xff0c;仓库用的Github。之前在用markdown进行需求管理&#xff0c;现在随着功能的增加&#xff0c;感觉有点难以管理了&#xff0c;所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD&#xff0c;需要提供一个企业名新建一个项目&#…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

Python+ZeroMQ实战:智能车辆状态监控与模拟模式自动切换

目录 关键点 技术实现1 技术实现2 摘要&#xff1a; 本文将介绍如何利用Python和ZeroMQ消息队列构建一个智能车辆状态监控系统。系统能够根据时间策略自动切换驾驶模式&#xff08;自动驾驶、人工驾驶、远程驾驶、主动安全&#xff09;&#xff0c;并通过实时消息推送更新车…...

在 Spring Boot 项目里,MYSQL中json类型字段使用

前言&#xff1a; 因为程序特殊需求导致&#xff0c;需要mysql数据库存储json类型数据&#xff0c;因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...

【堆垛策略】设计方法

堆垛策略的设计是积木堆叠系统的核心&#xff0c;直接影响堆叠的稳定性、效率和容错能力。以下是分层次的堆垛策略设计方法&#xff0c;涵盖基础规则、优化算法和容错机制&#xff1a; 1. 基础堆垛规则 (1) 物理稳定性优先 重心原则&#xff1a; 大尺寸/重量积木在下&#xf…...