当前位置: 首页 > news >正文

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

相关文章:

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类 1.1创建编码器的基类 from torch import nn#构建编码器的基类 class Encoder(nn.Module): #继承父类nn.Moduledef __init__(self, **kwargs): #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args…...

锤炼核心技能以应对编程革命

一、引言  随着人工智能的快速发展,尤其是AIGC等大语言模型的涌现,AI辅助编程工具逐渐成为程序员的新伙伴。这一变革不仅引发了关于AI是否能取代部分编程工作的讨论,也促使程序员重新思考自己的职业发展和技能提升路径。在AI时代&#xff0…...

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…...

js跑马灯效果、横向、纵向滚动效果

比如横向滚动&#xff0c;则在li标签里设置 display: table-cell;滚动效果 transform: translateX(-200px); <div id"div1" ><ul><li><img src"imgs/Snipaste_2022-11-22_18-13-13.png"></li><li><img src"i…...

C#基础(14)冒泡排序

前言 其实到上一节结构体我们就已经将c#的基础知识点大概讲完&#xff0c;接下来我们会讲解一些关于算法相关的东西。 我们一样来问一下gpt吧&#xff1a; Q:解释算法 A: 算法是一组有序的逻辑步骤&#xff0c;用于解决特定问题或执行特定任务。它可以是一个计算过程、一个…...

喜报 | 众数信科荣获2024年“火炬瞪羚企业”称号

近日&#xff0c;厦门火炬高新区公布2024年“火炬瞪羚企业”名单&#xff0c;众数&#xff08;厦门&#xff09;信息科技有限公司凭借在AI领域的综合实力、技术创新及典型场景应用等方面的卓越表现&#xff0c;成功入选。 瞪羚企业 一般指高成长性科技型企业&#xff0c;是跨过…...

中央企业数智化薪酬信息系统建设如何实现穿透式监管?

近年来&#xff0c;深化国有企业改革成为推动高质量发展的重要抓手&#xff0c;薪酬管理作为其中的关键领域&#xff0c;备受关注。国资委于近日发布了《关于加强中央企业薪酬管理信息系统建设的通知》&#xff0c;并召开了中央企业薪酬管理信息系统建设工作部署会议&#xff0…...

110Redis 简明教程--Redis 数据类型

Redis strings 字符串是一种最基本、最常用的 Redis 值类型。 Redis 字符串是二进制安全的&#xff0c;这意味着一个 Redis 字符串能包含任意类型的数据&#xff0c;例如&#xff1a; 一张经过 base64 编码的图片或者一个序列化的 Ruby 对象。通过这样的方式&#xff0c;Redis …...

Spring Data Rest 远程命令执行命令(CVE-2017-8046)

&#xff08;1&#xff09;访问 http://your-ip:8080/customers/1&#xff0c;然后抓取数据包&#xff0c;使用PATCH请求来修改 PATCH /customers/1 HTTP/1.1 Host: Accept-Encoding: gzip, deflate Accept: */* Accept-Language: en User-Agent: Mozilla/5.0 (compatible; MS…...

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18 1. The Application of Large Language Models in Primary Healthcare Services and the Challenges W YAN, J HU, H ZENG, M LIU, W LIANG - Chinese General Practice, 2024 人工智能大语言模型在基层医疗…...

搜索算法:Fibonacci查找

### 什么是Fibonacci查找 Fibonacci查找是一种搜索算法&#xff0c;它结合了Fibonacci数列和二分查找的思想&#xff0c;用于在有序数组中查找目标值。它的主要优点是在某些情况下可以比普通二分查找更高效。 ### Fibonacci数列 Fibonacci数列是一个递归定义的数列&#xff0…...

软件验收测试报告有什么作用?第三方验收测试报告包括哪些内容?

在现代软件开发中&#xff0c;软件验收测试报告占据了极为重要的地位&#xff0c;不仅是软件交付过程中的一环&#xff0c;更是软件质量保障的关键工具。 软件验收测试报告是指在软件开发过程中&#xff0c;针对软件的功能、性能、安全等方面进行的一系列测试后&#xff0c;形…...

AI大模型教程 Prompt提示词工程 AI原生应用开发零基础入门到实战【2024超细超全,建议收藏】

在AGI&#xff08;通用人工智能&#xff09;时代&#xff0c;那些既精通AI技术、又具备编程能力和业务洞察力的复合型人才将成为最宝贵的资源。为此&#xff0c;我们提出了‘AI全栈工程师’这一概念&#xff0c;旨在更精准地描述这一复合型人才群体&#xff0c;而非过分夸大其词…...

Pinia的快捷使用方法

安装Pinia npm install pinia 在main.js里面引入并注册挂载使用 在src下创建一个store inex.js // index.js import { defineStore } from pinia import { computed, ref } from vue //更简洁的的模块化 transferringValuesBetweenComponents simulationModule //简单定义了…...

一文搞懂C++继承

一文搞懂C继承 1.继承的概念及定义1.1继承的概念1.2 继承定义1.2.1定义格式1.2.2继承关系和访问限定符1.2.3继承基类成员访问方式的变化 2.基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数4.1 构造函数4.2 拷贝构造4.3 赋值重载4.4 析构函数 5.继承与友元6. 继…...

MFC -文件类控件

前言 各位师傅大家好&#xff0c;我是qmx_07&#xff0c;今天给大家讲解MFC中的文件类 MFC文件类 在MFC中&#xff0c;CFILE 是基本的文件操作类&#xff0c;提供了读取、写入、打开、关闭等操作方法主要成员函数:Open(用于打开文件&#xff0c;设置模式 例如 只读 只写 读…...

Hbase操作手册

一&#xff1a;Hbase 创建数据库表 1.进入hbase shell 2.创建数据库表的命令&#xff1a;create 表名, 列族名1,列族名2,列族名N 3.如果想查看所有数据库表&#xff0c;可以使用list 命令&#xff1a; 4.可以看到&#xff0c;刚创建的数据库表user 已经在数据库表的列表中&…...

vue组件($refs对象,动态组件,插槽,自定义指令)

一、ref 1.ref引用 每个vue组件实例上&#xff0c;都包含一个$refs对象&#xff0c;里面存储着对应dom元素或组件的引用。默认情况下&#xff0c;组件的$refs指向一个空对象。 2.使用ref获取dom元素的引用 <template><h3 ref"myh3">ref组件</h3&g…...

构建高可用和高防御力的云服务架构第五部分:PolarDB(5/5)

引言 云计算与数据库服务 云计算作为一种革命性的技术&#xff0c;已经深刻改变了信息技术行业的面貌。它通过提供按需分配的计算资源&#xff0c;使得数据存储、处理和分析变得更加灵活和高效。在云计算的众多服务中&#xff0c;数据库服务扮演着核心角色。数据库服务不仅负…...

QT窗口无法激活弹出问题排查记录

问题背景 问题环境 操作系统: 银河麒麟V10SP1qt版本 : 5.12.12 碰见了一个问题应用最小化,然后激活程序窗口无法弹出 这里描述一下代码的逻辑,使用QLocalServer实现一个单例进程,具体的功能就是在已存在一个程序A进程时,再启动这个程序A,新的程序A进程会被杀死,然后激活已存…...

node.js 版本管理

在Node.js开发中&#xff0c;版本管理是一个非常重要的环节&#xff0c;特别是当你需要同时维护多个项目&#xff0c;而这些项目又依赖于不同版本的Node.js时。以下是一些常用的Node.js版本管理工具和方法&#xff1a; 1. NVM (Node Version Manager) NVM是Node.js版本管理的…...

使用Python实现图形学曲线和曲面的NURBS算法

目录 使用Python实现图形学曲线和曲面的NURBS算法引言NURBS曲线的数学原理1. NURBS曲线定义2. 权重的作用 NURBS曲线的Python实现1. 类结构设计2. 代码实现3. 代码详解使用示例 NURBS曲面的扩展NURBS曲面类实现 总结 使用Python实现图形学曲线和曲面的NURBS算法 引言 NURBS&a…...

SpringBoot3

文章目录 一、为什么要学习SpringBoot二、SpringBoot介绍2.1 约定优于配置2.2 SpringBoot中的约定三、SpringBoot快速入门3.1 快速构建SpringBoot3.1.1 选择构建项目的类型3.1.2 项目的描述3.1.3 指定SpringBoot版本和需要的依赖3.1.4 导入依赖3.1.5 编写了Controller3.1.6 测试…...

【Text2SQL】领域优质论文分享

解读论文&#xff1a;Enhancing Few-shot Text-to-SQL Capabilities of Large Language Models: A Study on Prompt Design Strategies 1. 重要贡献 这篇论文的主要贡献在于提出了一种新的方法来增强大型语言模型&#xff08;LLMs&#xff09;在少量样本&#xff08;Few-shot…...

2024全国研究生数学建模竞赛(数学建模研赛)ABCDEF题深度建模+全解全析+完整文章

全国研究生数学建模竞赛&#xff08;数学建模研赛&#xff09;于9月21日8时正式开赛&#xff0c;赛程4天半&#xff0c;咱这边会在开赛后第一时间给出对今年的6道赛题的评价、分析和解答。包括ABCDEF题深度建模全解全析完整文章&#xff0c;详情可以点击底部的卡片来获取哦。 …...

Java项目中异常处理的最佳实践

1. 异常分类 首先&#xff0c;理解异常的不同类型是合理处理异常的基础。Java中的异常大致可以分为两大类&#xff1a; 受检异常&#xff08;Checked Exceptions&#xff09;&#xff1a;这些异常必须被捕获或声明抛出&#xff0c;例如IOException。非受检异常&#xff08;Un…...

CSS基本概念以及CSS的多种引入方式

CSS基本概念 CSS是层叠样式表&#xff0c;又叫级联样式表&#xff0c;简称样式表。CSS的文件后缀为.css&#xff0c;CSS用于HTML文档中元素样式的定义。 CSS的基本语法 CSS的规则由2个主要的部分构成&#xff1a;选择器以及一条或者多条声明。 选测器通常是你血药改变样式的…...

TiDB 简单集群部署拓扑文件

TiDB集群部署 服务器环境部署拓扑 都2024了还在为分库分表烦恼吗&#x1f618;&#xff0c;用分布式数据库TiDB、OceanBase、华为 GaussDB&#xff0c;你就使劲往里存数据。 早下班、少脱发、脱单&#xff01; &#x1f64f;&#x1f3fb;&#x1f64f;&#x1f3fb;&#x1f6…...

十三 系统架构设计(考点篇)

1 软件架构的概念 一个程序和计算系统软件体系结构是指系统的一个或者多个结构。结构中包括软件的构件&#xff0c;构件 的外部可见属性以及它们之间的相互关系。 体系结构并非可运行软件。确切地说&#xff0c;它是一种表达&#xff0c;使软件工程师能够&#xff1a; (1)分…...

Java-数据结构-二叉树-习题(三)  ̄へ ̄

文本目录&#xff1a; ❄️一、习题一(前序遍历非递归)&#xff1a; ▶ 思路&#xff1a; ▶ 代码&#xff1a; ❄️二、习题二(中序遍历非递归)&#xff1a; ▶ 思路&#xff1a; ▶ 代码&#xff1a; ❄️三、习题三(后序遍历非递归)&#xff1a; ▶ 思路&#xff1a; …...