当前位置: 首页 > news >正文

手撕Transformer -- Day7 -- Decoder

手撕Transformer – Day7 – Decoder

Transformer 网络结构图

目录

  • 手撕Transformer -- Day7 -- Decoder
    • Transformer 网络结构图
    • Decoder 代码
      • Part1 库函数
      • Part2 实现一个解码器Decoder,作为一个类
      • Part3 测试
    • 参考

在这里插入图片描述

Transformer 网络结构

Decoder 代码

Part1 库函数

# 该板块主要是对解码器进行串接,实现得到解码器部分
# 输入为x,还没嵌入的,但是PAD好的输入,输出需要对注意力值进行线性转化和softmax,最后得到一个单维向量,长度为词库大小。
'''
# Part1 导入库函数
'''
import torch
from torch import nn
from dataset import train_dataset, de_vocab, en_vocab, de_preprocess, en_preprocess,PAD_IDX
from encoder import Encoder
from decoder_block import DecoderBlock
from emb import EmbeddingWithPosition

Part2 实现一个解码器Decoder,作为一个类

'''
# Part2 设计解码器的类
'''class Decoder(nn.Module):def __init__(self, en_vocab_size, emd_size, nums_decoder_block, head, q_k_size, v_size, f_size):super().__init__()self.nums_decoder_block=nums_decoder_block# 首先对x进行编码self.emd = EmbeddingWithPosition(vocab_size=en_vocab_size, emd_size=emd_size)# 然后输入n个编码器self.decoder_list = nn.ModuleList()for _ in range(nums_decoder_block):self.decoder_list.append(DecoderBlock(head=head, emd_size=emd_size, q_k_size=q_k_size, v_size=v_size, f_size=f_size))# 然后需要线性化和softmax,目前是(batch_size,q_sqen_len,emd)# 得到(batch_size,vocab_size)self.linear1=nn.Linear(emd_size,en_vocab_size)self.softmax=nn.Softmax(-1)def forward(self, x, encoder_z,encoder_x): # encoder_x是编码器的输入(batch_size,q_seq_len)# x(batch_size,q_sqen_len)# 首先对解码器输入的padding位置进行掩码设置。mask1=(x==PAD_IDX).unsqueeze(1) # (batch_size,1,q_seq_len)mask1.expand(-1,x.size()[1],-1)  # (batch_size,q_seq_len,q_seq_len)# 然后要对解码器的输入的上半部分也取True然后和mask1或一下(也就是符号|),注意True表示需要隐藏的位置。# 注意:torch.tril 和 torch.triu 的区别就是决定矩阵的上半部分(不包含对角线)还是下半部分(不包含对角线)置为0,diagonal=1,表示置0的区域向上移动一行mask1=mask1 | torch.triu(torch.ones(mask1.size()[-1],mask1.size()[-1]),diagonal=1).bool().unsqueeze(0).expand(mask1.size()[0],-1,-1)# 然后对编码器的mask2进行掩码设置。在交叉注意力中,Padding 掩码的区域由K 和 V 的来源决定,# 而不是由Q 的来源决定。这确保了来自Q 的查询只关注K 中有效的信息位置。mask2 = (encoder_x == PAD_IDX).unsqueeze(1) # (batch_size,1,q_seq_len)mask2.expand(-1, encoder_x.size()[1], -1) # (batch_size,1,q_seq_len)x=self.emd(x)  # (batch_size,q_sqen_len,emd)# 进入解码器output=xfor i in range(self.nums_decoder_block):output = self.decoder_list[i](output,encoder_z,mask1,mask2)# 输出进行线性层和softmaxoutput=self.linear1(output)output=self.softmax(output)return output

Part3 测试

if __name__ == '__main__':# 取2个de句子转词ID序列,输入给encoderde_tokens1, de_ids1 = de_preprocess(train_dataset[0][0])de_tokens2, de_ids2 = de_preprocess(train_dataset[1][0])# 对应2个en句子转词ID序列,再做embedding,输入给decoderen_tokens1, en_ids1 = en_preprocess(train_dataset[0][1])en_tokens2, en_ids2 = en_preprocess(train_dataset[1][1])# de句子组成batch并padding对齐if len(de_ids1) < len(de_ids2):de_ids1.extend([PAD_IDX] * (len(de_ids2) - len(de_ids1)))elif len(de_ids1) > len(de_ids2):de_ids2.extend([PAD_IDX] * (len(de_ids1) - len(de_ids2)))enc_x_batch = torch.tensor([de_ids1, de_ids2], dtype=torch.long)print('enc_x_batch batch:', enc_x_batch.size())# en句子组成batch并padding对齐if len(en_ids1) < len(en_ids2):en_ids1.extend([PAD_IDX] * (len(en_ids2) - len(en_ids1)))elif len(en_ids1) > len(en_ids2):en_ids2.extend([PAD_IDX] * (len(en_ids1) - len(en_ids2)))dec_x_batch = torch.tensor([en_ids1, en_ids2], dtype=torch.long)print('dec_x_batch batch:', dec_x_batch.size())# Encoder编码,输出每个词的编码向量enc = Encoder(vocab_size=len(de_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_encoderblock=3)enc_outputs = enc(enc_x_batch)print('encoder outputs:', enc_outputs.size())# Decoder编码,输出每个词对应下一个词的概率dec = Decoder(en_vocab_size=len(en_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_decoder_block=3)enc_outputs = dec(dec_x_batch, enc_outputs, enc_x_batch)print(enc_outputs)print('decoder outputs:', enc_outputs.size())

参考

视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili

github代码库:github.com

相关文章:

手撕Transformer -- Day7 -- Decoder

手撕Transformer – Day7 – Decoder Transformer 网络结构图 目录 手撕Transformer -- Day7 -- DecoderTransformer 网络结构图Decoder 代码Part1 库函数Part2 实现一个解码器Decoder&#xff0c;作为一个类Part3 测试 参考 Transformer 网络结构 Decoder 代码 Part1 库函数…...

C#异步和多线程,Thread,Task和async/await关键字--12

目录 一.多线程和异步的区别 1.多线程 2.异步编程 多线程和异步的区别 二.Thread,Task和async/await关键字的区别 1.Thread 2.Task 3.async/await 三.Thread,Task和async/await关键字的详细对比 1.Thread和Task的详细对比 2.Task 与 async/await 的配合使用 3. asy…...

使用分割 Mask 和 K-means 聚类获取天空的颜色

引言 在计算机视觉领域&#xff0c;获取天空的颜色是一个常见任务&#xff0c;广泛应用于天气分析、环境感知和图像增强等场景。本篇博客将介绍如何通过已知的天空区域 Mask 提取天空像素&#xff0c;并使用 K-means 聚类分析天空颜色&#xff0c;最终根据颜色占比查表得到主导…...

145.《redis原生超详细使用》

文章目录 什么是redisredis 安装启动redis数据类型redis key操作key 的增key 的查key 的改key 的删key 是否存在key 查看所有key 「设置」过期时间key 「查看」过期时间key 「移除」过期时间key 「查看」数据类型key 「匹配」符合条件的keykey 「移动」到其他数据库 redis数据类…...

Pytorch基础教程:从零实现手写数字分类

文章目录 1.Pytorch简介2.理解tensor2.1 一维矩阵2.2 二维矩阵2.3 三维矩阵 3.创建tensor3.1 你可以直接从一个Python列表或NumPy数组创建一个tensor&#xff1a;3.2 创建特定形状的tensor3.3 创建三维tensor3.4 使用随机数填充tensor3.5 指定tensor的数据类型 4.tensor基本运算…...

【SH】Xiaomi9刷Windows10系统研发记录 、手机刷Windows系统教程、小米9重装win10系统

文章目录 参考资料云盘资料软硬件环境手机解锁刷机驱动绑定账号和设备解锁手机 Mindows工具箱安装工具箱和修复下载下载安卓和woa资源包第三方Recovery 一键安装Windows准备工作创建分区安装系统 效果展示Windows和Android一键互换Win切换安卓安卓切换Win 删除分区 参考资料 解…...

excel仅复制可见单元格,仅复制筛选后内容

背景 我们经常需要将内容分给不同的人&#xff0c;做完后需要合并 遇到情况如下 那是因为直接选择了整列&#xff0c;当然不可以了。 下面提供几种方法&#xff0c;应该都可以 直接选中要复制区域然后复制&#xff0c;不要选中最上面的列alt;选中可见单元格正常复制&#xff…...

HBASE学习(一)

1.HBASE基础架构&#xff0c; 1.1 参考&#xff1a; HBase集群架构与读写优化&#xff1a;理解核心机制与性能提升-CSDN博客 1.2问题&#xff1a; 1.FLUSH对hbase的影响 2. HLog和memstore的区别 hlog中存储的是操作记录&#xff0c;比如写、删除。而memstor中存储的是写入…...

element select 绑定一个对象{}

背景&#xff1a; select组件的使用&#xff0c;适用广泛的基础单选 v-model 的值为当前被选中的 el-option 的 value 属性值。但是我们这里想绑定一个对象&#xff0c;一个el-option对应的对象。 <el-select v-model"state.form.modelA" …...

Sprint Boot教程之五十八:动态启动/停止 Kafka 监听器

Spring Boot – 动态启动/停止 Kafka 监听器 当 Spring Boot 应用程序启动时&#xff0c;Kafka Listener 的默认行为是开始监听某个主题。但是&#xff0c;有些情况下我们不想在应用程序启动后立即启动它。 要动态启动或停止 Kafka Listener&#xff0c;我们需要三种主要方法…...

C:JSON-C简介

介绍 JSON-C是一个用于处理JSON格式数据的C语言库&#xff0c;提供了一系列操作JSON数据的函数。 一、json参数类型 typedef enum json_type { json_type_null, json_type_boolean, json_type_double, json_type_int, json_type_object, json_type_ar…...

业务幂等性技术架构体系之消息幂等深入剖析

在系统中当使用消息队列时&#xff0c;无论做哪种技术选型&#xff0c;有很多问题是无论如何也不能忽视的&#xff0c;如&#xff1a;消息必达、消息幂等等。本文以典型的RabbitMQ为例&#xff0c;讲解如何保证消息幂等的可实施解决方案&#xff0c;其他MQ选型均可参考。 一、…...

【Go】Go Gin框架初识(一)

1. 什么是Gin框架 Gin框架&#xff1a;是一个由 Golang 语言开发的 web 框架&#xff0c;能够极大提高开发 web 应用的效率&#xff01; 1.1 什么是web框架 web框架体系图&#xff08;前后端不分离&#xff09;如下图所示&#xff1a; 从上图中我们可以发现一个Web框架最重要…...

2024年合肥市科普日小学组市赛第一题题解

9304&#xff1a;数字加密&#xff08;encrypt&#xff09;(1) 【问题描述】 在信息科技课堂上&#xff0c;小肥正在思考“数字加密”实验项目。项目需要加密n个正整数&#xff0c;对每一个正整数x加密的规则是&#xff0c;将x的每一位数字都替换为x的最大数字。例如&#xff0…...

【MySQL实战】mysql_exporter+Prometheus+Grafana

要在Prometheus和Grafana中监控MySQL数据库&#xff0c;如下图&#xff1a; 可以使用mysql_exporter。 以下是一些步骤来设置和配置这个监控环境&#xff1a; 1. 安装和配置Prometheus&#xff1a; - 下载和安装Prometheus。 - 在prometheus.yml中配置MySQL通过添加以下内…...

Wireshark 使用教程:网络分析从入门到精通

一、引言 在网络技术的广阔领域中&#xff0c;网络协议分析是一项至关重要的技能。Wireshark 作为一款开源且功能强大的网络协议分析工具&#xff0c;被广泛应用于网络故障排查、网络安全检测以及网络协议研究等诸多方面。本文将深入且详细地介绍 Wireshark 的使用方法&#x…...

如何在前端给视频进行去除绿幕并替换背景?-----Vue3!!

最近在做这个这项目奇店桶装水小程序V1.3.9安装包骑手端V2.0.1小程序前端 最近&#xff0c;我在进行前端开发时&#xff0c;遇到了一个难题“如何给前端的视频进行去除绿幕并替换背景”。这是一个“数字人项目”所需&#xff0c;我一直在冥思苦想。终于有了一个解决方法…...

使用中间件自动化部署java应用

为了实现你在 IntelliJ IDEA 中打包项目并通过工具推送到两个 Docker 服务器&#xff08;172.168.0.1 和 172.168.0.12&#xff09;&#xff0c;并在推送后自动或手动重启容器&#xff0c;我们可以按照以下步骤进行操作&#xff1a; 在 IntelliJ IDEA 中配置 Maven 或 Gradle 打…...

pytorch张量分块投影示例代码

张量的投影操作 背景 张量投影 是深度学习中常见的操作,将输入张量通过线性变换映射到另一个空间。例如: Y=W⋅X+b 其中: X: 输入张量(形状可能为 (B,M,K),即批量维度、序列维度、特征维度)。W: 权重矩阵((K,N),将 K 维投影到 N 维)。b: 偏置向量(可选,(N,))。Y:…...

Visual Studio 同一解决方案 同时运行 多个项目

方案一 方案二...

零基础友好:快马AI为你定制专属visual studio code图文安装与上手教程

作为一名从零开始学习编程的新手&#xff0c;我深刻体会到安装开发环境是很多人遇到的第一个"拦路虎"。最近在InsCode(快马)平台上发现了一个特别适合新手的Visual Studio Code安装教程项目&#xff0c;它完全解决了我的困惑。下面分享我的学习笔记&#xff0c;希望能…...

从自动驾驶到AR眼镜:聊聊PSMNet这个双目立体匹配的‘老将’现在还能怎么用

PSMNet在2024年的技术重生&#xff1a;从经典立体匹配到轻量化落地的实战指南 六年前&#xff0c;当PSMNet在CVPR 2018上首次亮相时&#xff0c;其金字塔池化模块和堆叠沙漏3D CNN架构刷新了KITTI榜单的精度记录。如今&#xff0c;在Transformer大行其道的时代&#xff0c;这个…...

视觉问答技术全解析:从原理到实践的LAVIS框架应用指南

视觉问答技术全解析&#xff1a;从原理到实践的LAVIS框架应用指南 【免费下载链接】LAVIS LAVIS - A One-stop Library for Language-Vision Intelligence 项目地址: https://gitcode.com/gh_mirrors/la/LAVIS 技术原理&#xff1a;机器如何"看懂"并"回答…...

Qwen3-TTS多语言语音合成实测:一键部署,生成10种语言的逼真语音

Qwen3-TTS多语言语音合成实测&#xff1a;一键部署&#xff0c;生成10种语言的逼真语音 1. 开篇&#xff1a;语音合成新体验 想象一下&#xff0c;只需输入一段文字&#xff0c;就能让电脑用10种不同语言"开口说话"&#xff0c;而且声音自然得几乎分辨不出是机器生…...

HelixDB部署与运维:从本地开发到生产环境的完整流程

HelixDB部署与运维&#xff1a;从本地开发到生产环境的完整流程 【免费下载链接】helix-db HelixDB is a powerful, graph-vector database built entirely in Rust for millisecond query latency and ease of use. 项目地址: https://gitcode.com/gh_mirrors/he/helix-db …...

intv_ai_mk11行业落地:教育机构课件辅助生成、HR招聘文案批量产出案例

intv_ai_mk11行业落地&#xff1a;教育机构课件辅助生成、HR招聘文案批量产出案例 1. 模型能力与行业价值 intv_ai_mk11作为一款基于Llama架构的文本生成模型&#xff0c;在教育培训和人力资源领域展现出独特的实用价值。这个开箱即用的解决方案特别适合需要快速处理大量文本…...

避坑指南:QT5的QListView复选框居中/对齐问题解决方案(含TableView对比)

QT5复选框对齐终极指南&#xff1a;从QListView到TableView的完美排版方案 在QT5界面开发中&#xff0c;复选框控件的视觉对齐问题堪称"程序员强迫症终结者"——明明功能已经实现&#xff0c;却总在UI细节上栽跟头。本文将带您深入解决QListView和TableView中复选框居…...

洛雪音乐音源修复实战指南:从零开始的插件化解决方案

洛雪音乐音源修复实战指南&#xff1a;从零开始的插件化解决方案 【免费下载链接】New_lxmusic_source 六音音源修复版 项目地址: https://gitcode.com/gh_mirrors/ne/New_lxmusic_source 当你点击播放按钮却只看到加载动画无限循环&#xff0c;当搜索结果永远停留在&qu…...

Lingbot-Depth-Pretrain-ViTL-14 Anaconda环境搭建:创建隔离的Python开发与推理环境

Lingbot-Depth-Pretrain-ViTL-14 Anaconda环境搭建&#xff1a;创建隔离的Python开发与推理环境 你是不是也遇到过这种情况&#xff1a;好不容易跟着教程跑通了一个AI项目&#xff0c;结果过两天想跑另一个项目时&#xff0c;发现各种库版本冲突&#xff0c;报错满天飞&#x…...

Arduino智能小车避坑指南:从TB6612驱动到HC-05蓝牙,新手最容易搞错的5个硬件连接点

Arduino智能小车避坑实战&#xff1a;5个硬件连接致命细节与示波器级调试方案 刚拿到Arduino套件的新手们&#xff0c;总会在论坛里发出同样的灵魂拷问&#xff1a;"为什么我的小车要么瘫着不动&#xff0c;要么像醉汉一样乱撞&#xff1f;"这个问题背后&#xff0c;…...