当前位置: 首页 > news >正文

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

相关文章:

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类 1.1创建编码器的基类 from torch import nn#构建编码器的基类 class Encoder(nn.Module): #继承父类nn.Moduledef __init__(self, **kwargs): #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args…...

锤炼核心技能以应对编程革命

一、引言  随着人工智能的快速发展,尤其是AIGC等大语言模型的涌现,AI辅助编程工具逐渐成为程序员的新伙伴。这一变革不仅引发了关于AI是否能取代部分编程工作的讨论,也促使程序员重新思考自己的职业发展和技能提升路径。在AI时代&#xff0…...

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…...

js跑马灯效果、横向、纵向滚动效果

比如横向滚动&#xff0c;则在li标签里设置 display: table-cell;滚动效果 transform: translateX(-200px); <div id"div1" ><ul><li><img src"imgs/Snipaste_2022-11-22_18-13-13.png"></li><li><img src"i…...

C#基础(14)冒泡排序

前言 其实到上一节结构体我们就已经将c#的基础知识点大概讲完&#xff0c;接下来我们会讲解一些关于算法相关的东西。 我们一样来问一下gpt吧&#xff1a; Q:解释算法 A: 算法是一组有序的逻辑步骤&#xff0c;用于解决特定问题或执行特定任务。它可以是一个计算过程、一个…...

喜报 | 众数信科荣获2024年“火炬瞪羚企业”称号

近日&#xff0c;厦门火炬高新区公布2024年“火炬瞪羚企业”名单&#xff0c;众数&#xff08;厦门&#xff09;信息科技有限公司凭借在AI领域的综合实力、技术创新及典型场景应用等方面的卓越表现&#xff0c;成功入选。 瞪羚企业 一般指高成长性科技型企业&#xff0c;是跨过…...

中央企业数智化薪酬信息系统建设如何实现穿透式监管?

近年来&#xff0c;深化国有企业改革成为推动高质量发展的重要抓手&#xff0c;薪酬管理作为其中的关键领域&#xff0c;备受关注。国资委于近日发布了《关于加强中央企业薪酬管理信息系统建设的通知》&#xff0c;并召开了中央企业薪酬管理信息系统建设工作部署会议&#xff0…...

110Redis 简明教程--Redis 数据类型

Redis strings 字符串是一种最基本、最常用的 Redis 值类型。 Redis 字符串是二进制安全的&#xff0c;这意味着一个 Redis 字符串能包含任意类型的数据&#xff0c;例如&#xff1a; 一张经过 base64 编码的图片或者一个序列化的 Ruby 对象。通过这样的方式&#xff0c;Redis …...

Spring Data Rest 远程命令执行命令(CVE-2017-8046)

&#xff08;1&#xff09;访问 http://your-ip:8080/customers/1&#xff0c;然后抓取数据包&#xff0c;使用PATCH请求来修改 PATCH /customers/1 HTTP/1.1 Host: Accept-Encoding: gzip, deflate Accept: */* Accept-Language: en User-Agent: Mozilla/5.0 (compatible; MS…...

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-18 1. The Application of Large Language Models in Primary Healthcare Services and the Challenges W YAN, J HU, H ZENG, M LIU, W LIANG - Chinese General Practice, 2024 人工智能大语言模型在基层医疗…...

搜索算法:Fibonacci查找

### 什么是Fibonacci查找 Fibonacci查找是一种搜索算法&#xff0c;它结合了Fibonacci数列和二分查找的思想&#xff0c;用于在有序数组中查找目标值。它的主要优点是在某些情况下可以比普通二分查找更高效。 ### Fibonacci数列 Fibonacci数列是一个递归定义的数列&#xff0…...

软件验收测试报告有什么作用?第三方验收测试报告包括哪些内容?

在现代软件开发中&#xff0c;软件验收测试报告占据了极为重要的地位&#xff0c;不仅是软件交付过程中的一环&#xff0c;更是软件质量保障的关键工具。 软件验收测试报告是指在软件开发过程中&#xff0c;针对软件的功能、性能、安全等方面进行的一系列测试后&#xff0c;形…...

AI大模型教程 Prompt提示词工程 AI原生应用开发零基础入门到实战【2024超细超全,建议收藏】

在AGI&#xff08;通用人工智能&#xff09;时代&#xff0c;那些既精通AI技术、又具备编程能力和业务洞察力的复合型人才将成为最宝贵的资源。为此&#xff0c;我们提出了‘AI全栈工程师’这一概念&#xff0c;旨在更精准地描述这一复合型人才群体&#xff0c;而非过分夸大其词…...

Pinia的快捷使用方法

安装Pinia npm install pinia 在main.js里面引入并注册挂载使用 在src下创建一个store inex.js // index.js import { defineStore } from pinia import { computed, ref } from vue //更简洁的的模块化 transferringValuesBetweenComponents simulationModule //简单定义了…...

一文搞懂C++继承

一文搞懂C继承 1.继承的概念及定义1.1继承的概念1.2 继承定义1.2.1定义格式1.2.2继承关系和访问限定符1.2.3继承基类成员访问方式的变化 2.基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数4.1 构造函数4.2 拷贝构造4.3 赋值重载4.4 析构函数 5.继承与友元6. 继…...

MFC -文件类控件

前言 各位师傅大家好&#xff0c;我是qmx_07&#xff0c;今天给大家讲解MFC中的文件类 MFC文件类 在MFC中&#xff0c;CFILE 是基本的文件操作类&#xff0c;提供了读取、写入、打开、关闭等操作方法主要成员函数:Open(用于打开文件&#xff0c;设置模式 例如 只读 只写 读…...

Hbase操作手册

一&#xff1a;Hbase 创建数据库表 1.进入hbase shell 2.创建数据库表的命令&#xff1a;create 表名, 列族名1,列族名2,列族名N 3.如果想查看所有数据库表&#xff0c;可以使用list 命令&#xff1a; 4.可以看到&#xff0c;刚创建的数据库表user 已经在数据库表的列表中&…...

vue组件($refs对象,动态组件,插槽,自定义指令)

一、ref 1.ref引用 每个vue组件实例上&#xff0c;都包含一个$refs对象&#xff0c;里面存储着对应dom元素或组件的引用。默认情况下&#xff0c;组件的$refs指向一个空对象。 2.使用ref获取dom元素的引用 <template><h3 ref"myh3">ref组件</h3&g…...

构建高可用和高防御力的云服务架构第五部分:PolarDB(5/5)

引言 云计算与数据库服务 云计算作为一种革命性的技术&#xff0c;已经深刻改变了信息技术行业的面貌。它通过提供按需分配的计算资源&#xff0c;使得数据存储、处理和分析变得更加灵活和高效。在云计算的众多服务中&#xff0c;数据库服务扮演着核心角色。数据库服务不仅负…...

QT窗口无法激活弹出问题排查记录

问题背景 问题环境 操作系统: 银河麒麟V10SP1qt版本 : 5.12.12 碰见了一个问题应用最小化,然后激活程序窗口无法弹出 这里描述一下代码的逻辑,使用QLocalServer实现一个单例进程,具体的功能就是在已存在一个程序A进程时,再启动这个程序A,新的程序A进程会被杀死,然后激活已存…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能&#xff0c;我们需要对它的功能特点进行分析&#xff1a; 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具&#xff1a; mysql&#xff1a;关系型数据库&am…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言&#xff1a; 在人工智能快速发展的浪潮中&#xff0c;快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型&#xff08;LLM&#xff09;。该模型代表着该领域的重大突破&#xff0c;通过独特方式融合思考与非思考…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个生活电费的缴纳和查询小程序

一、项目初始化与配置 1. 创建项目 ohpm init harmony/utility-payment-app 2. 配置权限 // module.json5 {"requestPermissions": [{"name": "ohos.permission.INTERNET"},{"name": "ohos.permission.GET_NETWORK_INFO"…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观&#xff0c;可持续性好高效率高复用&#xff0c;可移植性好高内聚&#xff0c;低耦合没有冗余规范性&#xff0c;代码有规可循&#xff0c;可以看出自己当时的思考过程特殊排版&#xff0c;特殊语法&#xff0c;特殊指令&#xff0c;必须…...

Web后端基础(基础知识)

BS架构&#xff1a;Browser/Server&#xff0c;浏览器/服务器架构模式。客户端只需要浏览器&#xff0c;应用程序的逻辑和数据都存储在服务端。 优点&#xff1a;维护方便缺点&#xff1a;体验一般 CS架构&#xff1a;Client/Server&#xff0c;客户端/服务器架构模式。需要单独…...