当前位置: 首页 > news >正文

【AI大模型】深入Transformer架构:解码器部分的实现与解析

目录

🍔 解码器介绍

🍔 解码器层

2.1 解码器层的作用

2.2 解码器层的代码实现

2.3 解码器层总结

🍔 解码器

3.1 解码器的作用

3.2 解码器的代码分析

3.3 解码器总结


学习目标

🍀 了解解码器中各个组成部分的作用.

🍀 掌握解码器中各个组成部分的实现过程.

🍔 解码器介绍

解码器部分:

  • 由N个解码器层堆叠而成
  • 每个解码器层由三个子层连接结构组成
  • 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
  • 第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
  • 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接


  • 说明:
  • 解码器层中的各个部分,如,多头注意力机制,规范化层,前馈全连接网络,子层连接结构都与编码器中的实现相同. 因此这里可以直接拿来构建解码器层.

🍔 解码器层

2.1 解码器层的作用

  • 作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.

2.2 解码器层的代码实现

# 使用DecoderLayer的类实现解码器层
class DecoderLayer(nn.Module):def __init__(self, size, self_attn, src_attn, feed_forward, dropout):"""初始化函数的参数有5个, 分别是size,代表词嵌入的维度大小, 同时也代表解码器层的尺寸,第二个是self_attn,多头自注意力对象,也就是说这个注意力机制需要Q=K=V, 第三个是src_attn,多头注意力对象,这里Q!=K=V, 第四个是前馈全连接层对象,最后就是droupout置0比率."""super(DecoderLayer, self).__init__()# 在初始化函数中, 主要就是将这些输入传到类中self.size = sizeself.self_attn = self_attnself.src_attn = src_attnself.feed_forward = feed_forward# 按照结构图使用clones函数克隆三个子层连接对象.self.sublayer = clones(SublayerConnection(size, dropout), 3)def forward(self, x, memory, source_mask, target_mask):"""forward函数中的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量mermory, 以及源数据掩码张量和目标数据掩码张量."""# 将memory表示成m方便之后使用m = memory# 将x传入第一个子层结构,第一个子层结构的输入分别是x和self-attn函数,因为是自注意力机制,所以Q,K,V都是x,# 最后一个参数是目标数据掩码张量,这时要对目标数据进行遮掩,因为此时模型可能还没有生成任何目标数据,# 比如在解码器准备生成第一个字符或词汇时,我们其实已经传入了第一个字符以便计算损失,# 但是我们不希望在生成第一个字符时模型能利用这个信息,因此我们会将其遮掩,同样生成第二个字符或词汇时,# 模型只能使用第一个字符或词汇信息,第二个字符以及之后的信息都不允许被模型使用.x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, target_mask))# 接着进入第二个子层,这个子层中常规的注意力机制,q是输入x; k,v是编码层输出memory, # 同样也传入source_mask,但是进行源数据遮掩的原因并非是抑制信息泄漏,而是遮蔽掉对结果没有意义的字符而产生的注意力值,# 以此提升模型效果和训练速度. 这样就完成了第二个子层的处理.x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, source_mask))# 最后一个子层就是前馈全连接子层,经过它的处理后就可以返回结果.这就是我们的解码器层结构.return self.sublayer[2](x, self.feed_forward)
  • 实例化参数:
# 类的实例化参数与解码器层类似, 相比多出了src_attn, 但是和self_attn是同一个类.
head = 8
size = 512
d_model = 512
d_ff = 64
dropout = 0.2
self_attn = src_attn = MultiHeadedAttention(head, d_model, dropout)# 前馈全连接层也和之前相同 
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
  • 输入参数:
# x是来自目标数据的词嵌入表示, 但形式和源数据的词嵌入表示相同, 这里使用per充当.
x = pe_result# memory是来自编码器的输出
memory = en_result# 实际中source_mask和target_mask并不相同, 这里为了方便计算使他们都为mask
mask = Variable(torch.zeros(8, 4, 4))
source_mask = target_mask = mask
  • 调用:
dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_result = dl(x, memory, source_mask, target_mask)
print(dl_result)
print(dl_result.shape)
  • 输出效果:
tensor([[[ 1.9604e+00,  3.9288e+01, -5.2422e+01,  ...,  2.1041e-01,-5.5063e+01,  1.5233e-01],[ 1.0135e-01, -3.7779e-01,  6.5491e+01,  ...,  2.8062e+01,-3.7780e+01, -3.9577e+01],[ 1.9526e+01, -2.5741e+01,  2.6926e-01,  ..., -1.5316e+01,1.4543e+00,  2.7714e+00],[-2.1528e+01,  2.0141e+01,  2.1999e+01,  ...,  2.2099e+00,-1.7267e+01, -1.6687e+01]],[[ 6.7259e+00, -2.6918e+01,  1.1807e+01,  ..., -3.6453e+01,-2.9231e+01,  1.1288e+01],[ 7.7484e+01, -5.0572e-01, -1.3096e+01,  ...,  3.6302e-01,1.9907e+01, -1.2160e+00],[ 2.6703e+01,  4.4737e+01, -3.1590e+01,  ...,  4.1540e-03,5.2587e+00,  5.2382e+00],[ 4.7435e+01, -3.7599e-01,  5.0898e+01,  ...,  5.6361e+00,3.5891e+01,  1.5697e+01]]], grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])

2.3 解码器层总结

  • 学习了解码器层的作用:

    • 作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.
  • 学习并实现了解码器层的类: DecoderLayer

    • 类的初始化函数的参数有5个, 分别是size,代表词嵌入的维度大小, 同时也代表解码器层的尺寸,第二个是self_attn,多头自注意力对象,也就是说这个注意力机制需要Q=K=V,第三个是src_attn,多头注意力对象,这里Q!=K=V, 第四个是前馈全连接层对象,最后就是droupout置0比率.
    • forward函数的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量mermory, 以及源数据掩码张量和目标数据掩码张量.
    • 最终输出了由编码器输入和目标数据一同作用的特征提取结果.

🍔 解码器

3.1 解码器的作用

  • 根据编码器的结果以及上一次预测的结果, 对下一次可能出现的'值'进行特征表示.

3.2 解码器的代码分析

# 使用类Decoder来实现解码器
class Decoder(nn.Module):def __init__(self, layer, N):"""初始化函数的参数有两个,第一个就是解码器层layer,第二个是解码器层的个数N."""super(Decoder, self).__init__()# 首先使用clones方法克隆了N个layer,然后实例化了一个规范化层. # 因为数据走过了所有的解码器层后最后要做规范化处理. self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, memory, source_mask, target_mask):"""forward函数中的参数有4个,x代表目标数据的嵌入表示,memory是编码器层的输出,source_mask, target_mask代表源数据和目标数据的掩码张量"""# 然后就是对每个层进行循环,当然这个循环就是变量x通过每一个层的处理,# 得出最后的结果,再进行一次规范化返回即可. for layer in self.layers:x = layer(x, memory, source_mask, target_mask)return self.norm(x)
  • 实例化参数:
# 分别是解码器层layer和解码器层的个数N
size = 512
d_model = 512
head = 8
d_ff = 64
dropout = 0.2
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
N = 8
  • 输入参数:
# 输入参数与解码器层的输入参数相同
x = pe_result
memory = en_result
mask = Variable(torch.zeros(8, 4, 4))
source_mask = target_mask = mask
  • 调用:
de = Decoder(layer, N)
de_result = de(x, memory, source_mask, target_mask)
print(de_result)
print(de_result.shape)
  • 输出效果:
tensor([[[ 0.9898, -0.3216, -1.2439,  ...,  0.7427, -0.0717, -0.0814],[-0.7432,  0.6985,  1.5551,  ...,  0.5232, -0.5685,  1.3387],[ 0.2149,  0.5274, -1.6414,  ...,  0.7476,  0.5082, -3.0132],[ 0.4408,  0.9416,  0.4522,  ..., -0.1506,  1.5591, -0.6453]],[[-0.9027,  0.5874,  0.6981,  ...,  2.2899,  0.2933, -0.7508],[ 1.2246, -1.0856, -0.2497,  ..., -1.2377,  0.0847, -0.0221],[ 3.4012, -0.4181, -2.0968,  ..., -1.5427,  0.1090, -0.3882],[-0.1050, -0.5140, -0.6494,  ..., -0.4358, -1.2173,  0.4161]]],grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])

3.3 解码器总结

  • 学习了解码器的作用:

    • 根据编码器的结果以及上一次预测的结果, 对下一次可能出现的'值'进行特征表示.
  • 学习并实现了解码器的类: Decoder

    • 类的初始化函数的参数有两个,第一个就是解码器层layer,第二个是解码器层的个数N.
    • forward函数中的参数有4个,x代表目标数据的嵌入表示,memory是编码器层的输出,src_mask, tgt_mask代表源数据和目标数据的掩码张量.
    • 输出解码过程的最终特征表示.

💘若能为您的学习之旅添一丝光亮,不胜荣幸💘

🐼期待您的宝贵意见,让我们共同进步共同成长🐼

相关文章:

【AI大模型】深入Transformer架构:解码器部分的实现与解析

目录 &#x1f354; 解码器介绍 &#x1f354; 解码器层 2.1 解码器层的作用 2.2 解码器层的代码实现 2.3 解码器层总结 &#x1f354; 解码器 3.1 解码器的作用 3.2 解码器的代码分析 3.3 解码器总结 学习目标 &#x1f340; 了解解码器中各个组成部分的作用. &#…...

前端html js css 基础巩固3

一个这样的首页 滑动显示 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title>&l…...

如在下载自己的需要的rmp包呢

下载地址&#xff1a;https://pkgs.org/和https://rpmfind.net/linux/rpm2html/search.php 根基自己的需要进行下载使用。...

Android TextView实现一串文字特定几个字改变颜色

遇到一个需求&#xff0c;让Android端实现给定一个字符串指定下标的几个字颜色与其他字颜色不一致。 主要是用ForegroundColorSpan这个API来传入颜色值&#xff0c;用SpannableString来设置指定索引下标的字的颜色值。 这里通过给定一个输入文字描述框&#xff0c;要求输入指定…...

桃子叶片病害分类检测数据集(猫脸码客 第221期)

桃子叶片病害分类检测数据集 一、引言 桃子作为世界上广泛种植的果树之一&#xff0c;其叶片的健康状况直接关系到果实的产量和品质。然而&#xff0c;桃子叶片易受多种病害的侵袭&#xff0c;这些病害不仅影响叶片的光合作用&#xff0c;还可能导致果实减产、品质下降&#…...

Vue--》掌握自定义依赖引入的最佳实践

在现代前端开发中&#xff0c;vue凭借其灵活性和高效性&#xff0c;已成为开发者们的宠儿&#xff0c;然而随着项目的复杂度提升&#xff0c;如何高效地管理和引入依赖&#xff0c;尤其是自定义引入依赖&#xff0c;成为了许多开发者面临的一大挑战。无论是为了优化加载速度&am…...

repo 命令大全详解(第十四篇 repo overview)

repo overview 命令用于显示当前项目的概览信息&#xff0c;帮助用户快速了解项目的状态和分支信息。 参数分类及解释 基本参数 [--current-branch]: 可选&#xff0c;仅考虑已检出的分支。 示例: repo overview --current-branch [<project>...]: 可选&#xff0c;指定…...

【设计模式】深入理解Python中的抽象工厂设计模式

深入理解Python中的抽象工厂设计模式 设计模式是软件开发中解决常见问题的经典方案&#xff0c;而**抽象工厂模式&#xff08;Abstract Factory Pattern&#xff09;**是其中非常重要的一种创建型模式。抽象工厂模式的主要作用是提供一个接口&#xff0c;创建一系列相关或依赖…...

网站建设完成后,多久需要升级迭代一次

网站建设完成后&#xff0c;一般每隔几个月就会进行一次迭代升级。以下是关于网站迭代周期和原因的具体分析&#xff1a; 更新频率&#xff1a;网站在建设完成后&#xff0c;一般每隔几个月就会进行一次迭代升级。这种周期性的更新有助于保持网站的现代感和竞争力。更新目的&a…...

一个整型数组里除了两个数字之外,其他的数字都出现了两次。请写程序找出这两个只出现一次的数字

这里写目录标题 问题详情分析问题代码展示 问题详情 剑指 Offer 56&#xff1a; 一个整型数组 nums 里除两个数字之外&#xff0c;其他数字都出现了两次。请写程序找出这两个只出现一次的数字。要求时间复杂度是O(n)&#xff0c;空间复杂度是O(1)。 示例&#xff1a; 输入&a…...

Vue基本学习2

Vue使用方法 <script src"js/vue.js"></script><script>/*** Mode1:数据模型&#xff0c;负责数据存储(后台业务逻辑/数据库)* View:视图层&#xff0c;负责页面展示(HTML)* View Model(Vue):负责业务逻辑处理(比如Ajax请求等)* view 与 Model 数…...

创作者等级权益说明

创作者等级权益说明 一、如何查看创作者等级权益二、等级权益对照表 一、如何查看创作者等级权益 step1&#xff1a;鼠标移动至头像&#xff0c;显示如下图的浮窗 step2&#xff1a;点击我的等级&#xff0c;即跳转到创作者等级权益页面 图1.1 我的等级 图1.2 创作者等级权益…...

基于SpringBoot+Vue+uniapp微信小程序的校园反诈骗微信小程序的详细设计和实现(源码+lw+部署文档+讲解等)

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…...

统一修改UI库样式的几种方式

统一修改element组件库样式的几种方式。主题 | Element Plus 通过css变量设置 【CSS扩展】VUE如何使用或修改element plus中自带的CSS全局变量来定义样式:root {--hc-text-color-placeholder: #5f84a2;--hc-text-color-regular: #fff;--hc-text-color-primary: #fff;--hc-bg-c…...

ICM20948 DMP代码详解(88)

接前一篇文章:ICM20948 DMP代码详解(87) 本回继续对inv_convert_androidSensor_to_control函数进行解析。为了便于理解和回顾,再次贴出inv_convert_androidSensor_to_control函数源码,在EMD-Core\sources\Invn\Devices\Drivers\ICM20948\Icm20948DataBaseControl.c中,如下…...

字节跳动实习生投毒自家大模型细节曝光 影响到底有多大?

10月19日&#xff0c;字节跳动大模型训练遭实习生攻击一事引发广泛关注。据多位知情人士透露&#xff0c;字节跳动某技术团队在今年6月遭遇了一起内部技术袭击事件&#xff0c;一名实习生因对团队资源分配不满&#xff0c;使用攻击代码破坏了团队的模型训练任务。 据悉&#xf…...

【路径规划】蚁群算法优化bp神经网络回归预测

摘要 本文提出了一种基于蚁群算法&#xff08;ACO&#xff09;优化 BP 神经网络的回归预测方法&#xff0c;用于路径规划中的预测问题。通过蚁群算法优化神经网络的初始权值和阈值&#xff0c;提高了神经网络的训练效率和预测精度。实验结果表明&#xff0c;该方法能够有效提升…...

如何在OceanBase中新增系统变量及应用实践

因为系统变量涉及复杂的工程文件&#xff0c;为防止新增变量操作对软件系统的潜在影响&#xff0c;OceanBase为多数开发者设计了一套高效的编程框架。此框架允许开发者在新增及使用系统变量时&#xff0c;仅需专注于变量定义的细节。具体来说&#xff0c;通过运行一个Python脚本…...

Olap数据处理

一、OLAP 是什么 1. OLAP的定义 OLAP&#xff08;Online Analytical Processing&#xff0c;联机分析处理&#xff09;是一种软件技术&#xff0c;它主要专注于复杂的分析操作&#xff0c;帮助分析人员、管理人员或执行人员从多角度对信息进行快速、一致、交互地存取&#xf…...

Tailwind Starter Kit 一款极简的前端快速启动模板

Tailwind Starter Kit 是基于TailwindCSS实现的一款开源的、使用简单的极简模板扩展。会用Tailwincss就可以快速入手使用。Tailwind Starter Kit 是免费开源的。它不会在原始的TailwindCSS框架中更改或添加任何CSS。它具有多个HTML元素&#xff0c;并附带了ReactJS、Vue和Angul…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06&#xff08;十亿美元&#xff09;。漏洞扫描服务市场行业预计将从 2024 年的 3.48&#xff08;十亿美元&#xff09;增长到 2032 年的 9.54&#xff08;十亿美元&#xff09;。预测期内漏洞扫描服务市场 CAGR&#xff08;增长率&…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

工程地质软件市场:发展现状、趋势与策略建议

一、引言 在工程建设领域&#xff0c;准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具&#xff0c;正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...

(二)原型模式

原型的功能是将一个已经存在的对象作为源目标,其余对象都是通过这个源目标创建。发挥复制的作用就是原型模式的核心思想。 一、源型模式的定义 原型模式是指第二次创建对象可以通过复制已经存在的原型对象来实现,忽略对象创建过程中的其它细节。 📌 核心特点: 避免重复初…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。

1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj&#xff0c;再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

Netty从入门到进阶(二)

二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架&#xff0c;用于…...