0基础跟德姆(dom)一起学AI 自然语言处理18-解码器部分实现
1 解码器介绍
解码器部分:
- 由N个解码器层堆叠而成
- 每个解码器层由三个子层连接结构组成
- 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
- 第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
- 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
- 说明:
- 解码器层中的各个部分,如,多头注意力机制,规范化层,前馈全连接网络,子层连接结构都与编码器中的实现相同. 因此这里可以直接拿来构建解码器层.
2 解码器层
2.1 解码器层的作用
- 作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.
2.2 解码器层的代码实现
# 解码器层类 DecoderLayer 实现思路分析
# init函数 (self, size, self_attn, src_attn, feed_forward, dropout)# 词嵌入维度尺寸大小size 自注意力机制层对象self_attn 一般注意力机制层对象src_attn 前馈全连接层对象feed_forward# clones3子层连接结构 self.sublayer = clones(SublayerConnection(size,dropout),3)
# forward函数 (self, x, memory, source_mask, target_mask)# 数据经过子层连接结构1 self.sublayer[0](x, lambda x:self.self_attn(x, x, x, target_mask))# 数据经过子层连接结构2 self.sublayer[1](x, lambda x:self.src_attn(x, m, m, source_mask))# 数据经过子层连接结构3 self.sublayer[2](x, self.feed_forward)class DecoderLayer(nn.Module):def __init__(self, size, self_attn, src_attn, feed_forward, dropout):super(DecoderLayer, self).__init__()# 词嵌入维度尺寸大小self.size = size# 自注意力机制层对象 q=k=vself.self_attn = self_attn# 一遍注意力机制对象 q!=k=vself.src_attn = src_attn# 前馈全连接层对象self.feed_forward = feed_forward# clones3子层连接结构self.sublayer = clones(SublayerConnection(size, dropout), 3)def forward(self, x, memory, source_mask, target_mask):m = memory# 数据经过子层连接结构1x = self.sublayer[0](x, lambda x:self.self_attn(x, x, x, target_mask))# 数据经过子层连接结构2x = self.sublayer[1](x, lambda x:self.src_attn (x, m, m, source_mask))# 数据经过子层连接结构3x = self.sublayer[2](x, self.feed_forward)return x
- 函数调用
def dm_test_DecoderLayer():d_model = 512vocab = 1000 # 词表大小是1000# 输入x 是一个使用Variable封装的长整型张量, 形状是2 x 4x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))emb = Embeddings(d_model, vocab)embr = emb(x)dropout = 0.2max_len = 60 # 句子最大长度x = embr # [2, 4, 512]pe = PositionalEncoding(d_model, dropout, max_len)pe_result = pe(x)x = pe_result # 获取位置编码器层 编码以后的结果# 类的实例化参数与解码器层类似, 相比多出了src_attn, 但是和self_attn是同一个类.head = 8d_ff = 64size = 512self_attn = src_attn = MultiHeadedAttention(head, d_model, dropout)# 前馈全连接层也和之前相同ff = PositionwiseFeedForward(d_model, d_ff, dropout)x = pe_result# 产生编码器结果 # 注意此函数返回编码以后的结果 要有返回值en_result = dm_test_Encoder()memory = en_resultmask = Variable(torch.zeros(8, 4, 4))source_mask = target_mask = mask# 实例化解码器层 对象dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)# 对象调用dl_result = dl(x, memory, source_mask, target_mask)print(dl_result.shape)print(dl_result)
- 输出效果
torch.Size([2, 4, 512])
tensor([[[-27.4382, 0.6516, 6.6735, ..., -42.2930, -44.9728, 0.1264],[-28.7835, 26.4919, -0.5608, ..., 0.5652, -2.9634, 9.7438],[-19.6998, 13.5164, 45.8216, ..., 23.9127, 22.0259, 34.0195],[ -0.1647, 0.2331, -36.4173, ..., -20.0557, 29.4576, 2.5048]],[[ 29.1466, 50.7677, 26.4624, ..., -39.1015, -27.9200, 19.6819],[-10.7069, 28.0897, -0.4107, ..., -35.7795, 9.6881, 0.3228],[ -6.9027, -16.0590, -0.8897, ..., 4.0253, 2.5961, 37.4659],[ 9.8892, 32.7008, -6.6772, ..., -11.4273, -21.4676, 32.5692]]],grad_fn=<AddBackward0>)
2.3 解码器层总结¶
-
学习了解码器层的作用:
- 作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.
-
学习并实现了解码器层的类: DecoderLayer
- 类的初始化函数的参数有5个, 分别是size,代表词嵌入的维度大小, 同时也代表解码器层的尺寸,第二个是self_attn,多头自注意力对象,也就是说这个注意力机制需要Q=K=V,第三个是src_attn,多头注意力对象,这里Q!=K=V, 第四个是前馈全连接层对象,最后就是droupout置0比率.
- forward函数的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量mermory, 以及源数据掩码张量和目标数据掩码张量.
- 最终输出了由编码器输入和目标数据一同作用的特征提取结果.
3 解码器
3.1 解码器的作用
- 根据编码器的结果以及上一次预测的结果, 对下一次可能出现的'值'进行特征表示.
3.2 解码器的代码分析
# 解码器类 Decoder 实现思路分析
# init函数 (self, layer, N):# self.layers clones N个解码器层clones(layer, N)# self.norm 定义规范化层 LayerNorm(layer.size)
# forward函数 (self, x, memory, source_mask, target_mask)# 数据以此经过各个子层 x = layer(x, memory, source_mask, target_mask)# 数据最后经过规范化层 return self.norm(x)# 返回处理好的数据class Decoder(nn.Module):def __init__(self, layer, N):# 参数layer 解码器层对象# 参数N 解码器层对象的个数super(Decoder, self).__init__()# clones N个解码器层self.layers = clones(layer, N)# 定义规范化层self.norm = LayerNorm(layer.size)def forward(self, x, memory, source_mask, target_mask):# 数据以此经过各个子层for layer in self.layers:x = layer(x, memory, source_mask, target_mask)# 数据最后经过规范化层return self.norm(x)
- 函数调用
# 测试 解码器
def dm_test_Decoder():d_model = 512vocab = 1000 # 词表大小是1000# 输入x 是一个使用Variable封装的长整型张量, 形状是2 x 4x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))emb = Embeddings(d_model, vocab)embr = emb(x)dropout = 0.2max_len = 60 # 句子最大长度x = embr # [2, 4, 512]pe = PositionalEncoding(d_model, dropout, max_len)pe_result = pe(x)x = pe_result # 获取位置编码器层 编码以后的结果# 分别是解码器层layer和解码器层的个数Nsize = 512d_model = 512head = 8d_ff = 64dropout = 0.2c = copy.deepcopy# 多头注意力对象attn = MultiHeadedAttention(head, d_model)# 前馈全连接层ff = PositionwiseFeedForward(d_model, d_ff, dropout)# 解码器层layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)N = 6# 输入参数与解码器层的输入参数相同x = pe_result# 产生编码器结果en_result = demo238_test_Encoder()memory = en_result# 掩码对象mask = Variable(torch.zeros(8, 4, 4))# sorce掩码 target掩码source_mask = target_mask = mask# 创建 解码器 对象de = Decoder(layer, N)# 解码器对象 解码de_result = de(x, memory, source_mask, target_mask)print(de_result)print(de_result.shape)
- 输出结果
tensor([[[ 0.1853, -0.8858, -0.0393, ..., -1.4989, -1.4008, 0.8456],[-1.0841, -0.0777, 0.0836, ..., -1.5568, 1.4074, -0.0848],[-0.4107, -0.1306, -0.0069, ..., -0.2370, -0.1259, 0.7591],[ 1.2895, 0.2655, 1.1799, ..., -0.2413, 0.9087, 0.4055]],[[ 0.3645, -0.3991, -1.2862, ..., -0.7078, -0.1457, -1.0457],[ 0.0146, -0.0639, -1.2143, ..., -0.7865, -0.1270, 0.5623],[ 0.0685, -0.1465, -0.1354, ..., 0.0738, -0.9769, -1.4295],[ 0.3168, 0.6305, -0.1549, ..., 1.0969, 1.8775, -0.5154]]],grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])
相关文章:

0基础跟德姆(dom)一起学AI 自然语言处理18-解码器部分实现
1 解码器介绍 解码器部分: 由N个解码器层堆叠而成每个解码器层由三个子层连接结构组成第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接第三个子层连接结构包括一个前馈全连接子层…...

我的创作纪念日——我与CSDN一起走过的365天
目录 一、机缘:旅程的开始 二、收获:沿路的花朵 三、日常:不断前行中 四、成就:一点小确幸 五、憧憬:梦中的重点 一、机缘:旅程的开始 最开始开始写博客是在今年一二月份的时候,也就是上一…...

C++:bfs解决多源最短路与拓扑排序问题习题
1. 多源最短路 其实就是将所有源头都加入队列, 01矩阵 LCR 107. 01 矩阵 - 力扣(LeetCode) 思路 求每个元素到离其最近0的距离如果我们将1当做源头加入队列的话,无法处理多个连续1的距离存储,我们反其道而行之&…...
【面试题】JVM部分[2025/1/13 ~ 2025/1/19]
JVM部分[2025/1/13 ~ 2025/1/19] 1. JVM 由哪些部分组成?2. Java 的类加载过程是怎样的?3. 请你介绍下 JVM 内存模型,分为哪些区域?各区域的作用是什么?4. JVM 垃圾回收调优的主要目标是什么?5. 如何对 Jav…...
文献综述相关ChatGPT提示词分享
文献综述 ChatGPT 可以帮助提高文献综述的有效性和全面性。ChatGPT可以高效搜索和审查与宝子们课题研究相关的文献资料来源。一些给力的插件工具还可以帮助您总结复杂的研究论文并提取信息以更快更好地消化信息。合理的运用ChatGPT和GPTs可以提高文献综述的清晰度和质量&#…...

Excel 技巧14 - 如何批量删除表格中的空行(★)
本文讲如何批量删除表格中的空行。 1,如何批量删除表格中的空行 要点就是按下F5,然后选择空值条件以定位所有空行,然后删除即可。 按下F5 点 定位条件 选 空值,点确认 这样就选中了空行 然后点右键,选 删除 选中 下方…...
图片生成Prompt编写技巧
1. 图片情绪(场景氛围) 一张图片一般都会有一个情绪基调,因为作画本质上也是在传达一些情绪,一般都会借助图片的氛围去转达。例如:比如家庭聚会一般是欢乐、喜乐融融。断壁残垣一般是悲凉。还有萧瑟、孤寂等。 2.补充细…...

【STM32-学习笔记-4-】PWM、输入捕获(PWMI)
文章目录 1、PWMPWM配置 2、输入捕获配置3、编码器 1、PWM PWM配置 配置时基单元配置输出比较单元配置输出PWM波的端口 #include "stm32f10x.h" // Device headervoid PWM_Init(void) { //**配置输出PWM波的端口**********************************…...

TOSUN同星TsMaster使用入门——3、使用系统变量及c小程序结合panel面板发送报文
本篇内容将介绍TsMaster中常用的Panel面板控件以及使用Panel控件通过系统变量以及c小程序来修改信号的值,控制报文的发送等。 目录 一、常用的Panel控件介绍 1.1系统——启动停止按钮 1.2 显示控件——文本框 1.3 显示控件——分组框 1.4 读写控件——按钮 1.…...

【Web】2025-SUCTF个人wp
目录 SU_blog SU_photogallery SU_POP SU_blog 先是注册功能覆盖admin账号 以admin身份登录,拿到读文件的权限 ./article?filearticles/..././..././..././..././..././..././etc/passwd ./article?filearticles/..././..././..././..././..././..././proc/1…...

React进阶之react.js、jsx模板语法及babel编译
React React介绍React官网初识React学习MVCMVVM JSX外部的元素props和内部的状态statepropsstate 生命周期constructorgetDerivedStateFromPropsrendercomponentDidMount()shouldComponentUpdategetSnapshotBeforeUpdate(prevProps, prevState) 创建项目CRA:create-…...

在Linux上如何让ollama在GPU上运行模型
之前一直在 Mac 上使用 ollama 所以没注意,最近在 Ubuntu 上运行发现一直在 CPU 上跑。我一开始以为是超显存了,因为 Mac 上如果超内存的话,那么就只用 CPU,但是我发现 Llama3.2 3B 只占用 3GB,这远没有超。看了一下命…...

R 语言科研绘图第 20 期 --- 箱线图-配对
在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...

suctf2025
Suctf2025 --2标识为看的wp,没环境复现了 所有参考资料将在文本末尾标明 WEB SU_photogallery 思路👇 构造一个压缩包,解压出我们想解压的部分,然后其他部分是损坏的,这样是不是就可以让整个解压过程是出错的从而…...

Quinlan C4.5剪枝U(0,6)U(1,16)等置信上限如何计算?
之前看到Quinlan中关于C4.5决策树算法剪枝环节中,关于错误率e置信区间估计,为啥 当E=0时,U(0,1)=0.75,U(0,6)=0.206,U(0,9)=0.143? 而当E不为0时,比如U(1,16)=0.157,如图: 关于C4.5决策树,Quinlan写了一本书,如下: J. Ross Quinlan (Auth.) - C4.5. Programs f…...

计算机组成原理--笔记二
目录 一.计算机系统的工作原理 二.计算机的性能指标 1.存储器的性能指标 2.CPU的性能指标 3.系统整体的性能指标(静态) 4.系统整体的性能指标(动态) 三.进制计算 1.任意进制 > 十进制 2.二进制 <> 八、十六进制…...
麒麟系统中删除权限不够的文件方法
在麒麟系统中删除权限不够的文件,可以尝试以下几种方法: 通过修改文件权限删除 打开终端:点击左下角的“终端”图标,或者通过搜索功能找到并打开终端 。定位文件:使用cd命令切换到文件所在的目录 。修改文件权限&…...

自定义提示确认弹窗-vue
最初可运行代码 弹窗组件代码: (后来发现以下代码可运行,但打包 typescript 类型检查出错,可打包的代码在文末) <template><div v-if"isVisible" class"dialog"><div class&quo…...

运行fastGPT 第五步 配置FastGPT和上传知识库 打造AI客服
运行fastGPT 第五步 配置FastGPT和上传知识库 打造AI客服 根据上一步的步骤,已经调试了ONE API的接口,下面,我们就登陆fastGPT吧 http://xxx.xxx.xxx.xxx:3000/ 这个就是你的fastGPT后台地址,可以在configer文件中找到。 账号是…...

CSS 合法颜色值
CSS 颜色 CSS 中的颜色可以通过以下方法指定: 十六进制颜色带透明度的十六进制颜色RGB 颜色RGBA 颜色HSL 颜色HSLA 颜色预定义/跨浏览器的颜色名称使用 currentcolor 关键字 十六进制颜色 用 #RRGGBB 规定十六进制颜色,其中 RR(红色&…...
设计模式和设计原则回顾
设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

【入坑系列】TiDB 强制索引在不同库下不生效问题
文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...

Psychopy音频的使用
Psychopy音频的使用 本文主要解决以下问题: 指定音频引擎与设备;播放音频文件 本文所使用的环境: Python3.10 numpy2.2.6 psychopy2025.1.1 psychtoolbox3.0.19.14 一、音频配置 Psychopy文档链接为Sound - for audio playback — Psy…...

DBAPI如何优雅的获取单条数据
API如何优雅的获取单条数据 案例一 对于查询类API,查询的是单条数据,比如根据主键ID查询用户信息,sql如下: select id, name, age from user where id #{id}API默认返回的数据格式是多条的,如下: {&qu…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...

GC1808高性能24位立体声音频ADC芯片解析
1. 芯片概述 GC1808是一款24位立体声音频模数转换器(ADC),支持8kHz~96kHz采样率,集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器,适用于高保真音频采集场景。 2. 核心特性 高精度:24位分辨率,…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...