当前位置: 首页 > news >正文

Pytorch深度解析:Transformer嵌入层源码逐行解读

前言

本部分博客需要先阅读博客:
《Transformer实现以及Pytorch源码解读(一)-数据输入篇》
作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):super(Transformer, self).__init__()# 词嵌入层self.embedding_dim = embedding_dimself.embeddings = nn.Embedding(vocab_size, embedding_dim)self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)# 编码层:使用Transformerencoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)# 输出层self.output = nn.Linear(hidden_dim, num_class)def forward(self, inputs, lengths):inputs = torch.transpose(inputs, 0, 1)hidden_states = self.embeddings(inputs)hidden_states = self.position_embedding(hidden_states)attention_mask = length_to_mask(lengths) == Falsehidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)logits = self.output(hidden_states)log_probs = F.log_softmax(logits, dim=-1)return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。

我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_
init
,rest_parameters,forward:

Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,sparse: bool = False, _weight: Optional[Tensor] = None,device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super(Embedding, self).__init__()self.num_embeddings = num_embeddingsself.embedding_dim = embedding_dimif padding_idx is not None:if padding_idx > 0:assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'elif padding_idx < 0:assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'padding_idx = self.num_embeddings + padding_idxself.padding_idx = padding_idxself.max_norm = max_normself.norm_type = norm_typeself.scale_grad_by_freq = scale_grad_by_freqif _weight is None:self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))# print("===========================================1")# print(self.weight)#将self.weight进行nornal归一化self.reset_parameters()print("===========================================2")print(self.weight)else:assert list(_weight.shape) == [num_embeddings, embedding_dim], \'Shape of weight does not match num_embeddings and embedding_dim'self.weight = Parameter(_weight)self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:init.normal_(self.weight)self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。

N

(

mean

,

std

2

)

.

\mathcal{N}(\text{mean}, \text{std}^2).

N

(

mean

,

std

2

)

.

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:r"""Fills the input Tensor with values drawn from the normaldistribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.Args:tensor: an n-dimensional `torch.Tensor`mean: the mean of the normal distributionstd: the standard deviation of the normal distributionExamples:>>> w = torch.empty(3, 5)>>> nn.init.normal_(w)"""return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {explicit Fan(Tensor& tensor) {const auto dimensions = tensor.ndimension();TORCH_CHECK(dimensions >= 2,"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");if (dimensions == 2) {in = tensor.size(1);out = tensor.size(0);} else {in = tensor.size(1) * tensor[0][0].numel();out = tensor.size(0) * tensor[0][0].numel();}}int64_t in;int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {NoGradGuard guard;return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {return F::detail::embedding(input,weight,options.padding_idx(),options.max_norm(),options.norm_type(),options.scale_grad_by_freq(),options.sparse());
}

继续追踪,发现weight中的每个变量被下面的c++代码填充了正太分布的随机数。

void normal_kernel(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());templates::cpu::normal_kernel(self, mean, std, generator);
}

随机数的生成调用如下的代码,首先询问:目前代码是在什么设备上运行,并调用cpu或者gup上的随机数生成方法。

template <typename T>
static inline T * check_generator(c10::optional<Generator> gen) {TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");return gen->get<T>();
}/*** Utility function used in tensor implementations, which* supplies the default generator to tensors, if an input generator* is not supplied. The input Generator* is also static casted to* the backend generator type (CPU/CUDAGeneratorImpl etc.)*/
template <typename T>
static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) {return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}

至此,embedding的每个随机数的生成过程都清楚了。

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

相关文章:

Pytorch深度解析:Transformer嵌入层源码逐行解读

前言 本部分博客需要先阅读博客&#xff1a; 《Transformer实现以及Pytorch源码解读&#xff08;一&#xff09;-数据输入篇》 作为知识储备。 Embedding使用方式 如下面的代码中所示&#xff0c;embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的…...

HSP_10章 Python面向对象编程oop_基础部分

文章目录 P107 类与实例的关系1.类与实例的关系示意图2.类与实例的代码分析 P109 对象形式和传参机制1. 类与对象的区别和联系2. 属性/成员变量3. 类的定义和使用4. 对象的传递机制 P110 对象的布尔值P111 成员方法1. 基本介绍2. 成员方法的定义和基本使用3.注意事项和使用细节…...

JavaWeb系列十七: jQuery选择器 上

jQuery选择器 jQuery基本选择器jquery层次选择器基础过滤选择器内容过滤选择器可见度过滤选择器 选择器是jQuery的核心, 在jQuery中, 对事件处理, 遍历 DOM和Ajax 操作都依赖于选择器jQuery选择器的优点 $(“#id”) 等价于 document.getElementById(“id”);$(“tagName”) 等价…...

Gone框架介绍30 - 使用`goner/gin`提供Web服务

gone是可以高效开发Web服务的Golang依赖注入框架 github地址&#xff1a;https://github.com/gone-io/gone 文档地址&#xff1a;https://goner.fun/zh/ 使用goner/gin提供Web服务 文章目录 使用goner/gin提供Web服务注册相关的Goners编写Controller挂载路由路由处理函数io.Rea…...

Lipowerline5.0 雷达电力应用软件下载使用

1.配网数据处理分析 针对配网线路点云数据&#xff0c;优化了分类算法&#xff0c;支持杆塔、导线、交跨线、建筑物、地面点和其他线路的自动分类&#xff1b;一键生成危险点报告和交跨报告&#xff1b;还能生成点云数据采集航线和自主巡检航线。 获取软件安装包联系邮箱:289…...

STM32学习之一:什么是STM32

目录 1.什么是STM32 2.STM32命名规则 3.STM32外设资源 4. STM32的系统架构 5. 从0到1搭建一个STM32工程 学习stm32已经很久了&#xff0c;因为种种原因&#xff0c;也有很久一段时间没接触过stm32了。等我捡起来的时候&#xff0c;发现很多都已经忘记了&#xff0c;重新捡…...

AI绘画Stable Diffusion 超强一键去除图片中的物体,免费使用!

大家好&#xff0c;我是设计师阿威 在生成图像时总有一些不完美的小瑕疵&#xff0c;比如多余的物体或碍眼的水印&#xff0c;它们破坏了图片的美感。但别担心&#xff0c;今天我们将介绍一款神奇的工具——sd-webui-cleaner&#xff0c;它可以帮助我们使用Stable Diffusion轻…...

零基础STM32单片机编程入门(一)初识STM32单片机

文章目录 一.概要二.单片机型号命名规则三.STM32F103系统架构四.STM32F103C8T6单片机启动流程五.STM32F103C8T6单片机主要外设资源六.编程过程中芯片数据手册的作用1.单片机外设资源情况2.STM32单片机内部框图3.STM32单片机管脚图4.STM32单片机每个管脚可配功能5.单片机功耗数据…...

Github上前十大开源Rust项目

在github上排名前十的Rust开源项目整理出来与大家共享&#xff0c;以当前的Star数为准。 Deno Deno 是 V8 上的安全 TypeScript 运行时。Deno 是一个建立在V8、Rust和Tokio之上的 JavaScript、TypeScript 和 WebAssembly 的运行时环境&#xff0c;具有自带安全的设置和出色的开…...

硬件开发笔记(二十):AD21导入外部下载的元器件原理图库、封装库和3D模型

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/139707771 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…...

竞赛选题 python opencv 深度学习 指纹识别算法实现

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; python opencv 深度学习 指纹识别算法实现 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;4分创新点&#xff1a;4分 该项目较为新颖…...

RK3568开发笔记(三):瑞芯微RK3588芯片介绍,入手开发板的核心板介绍

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/139905873 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…...

EtherCAT主站IgH解析(二)-- 如何将Igh移植到Linux/Windows/RTOS等多操作系统

版权声明&#xff1a;本文为本文为博主原创文章&#xff0c;转载请注明出处 https://www.cnblogs.com/wsg1100 如有错误&#xff0c;欢迎指正。 本文简单介绍如何将 igh 移植到 zephyr、freertos、rtems、rtthread等RTOS &#xff0c;甚至 windows 上。 ##前言 目前&#xff0…...

ansible copy模块参选选项

copy模块用于将文件从ansible控制节点&#xff08;管理主机&#xff09;或者远程主机复制到远程主机上。其操作类似于scp&#xff08;secure copy protocol&#xff09;。 关键参数标红。 参数&#xff1a; src:&#xff08;source&#xff1a;源&#xff09; 要复制到远程…...

展厅设计主要的六大要素

1、从创意开始 展示设计的开始必须创意在先。根据整体的风格思路进行创意&#xff0c;首先要考虑的是主体的造型、大小高度位置以及它和周围展厅的关系。另外其他道具设计制作与运作方式也必须在创意中有明确的体现。 2、平面感 平面感是指对展示艺术设计平面图纸审美和功能两个…...

【数据结构与算法】最小生成树,Prim算法,Kruskal算法 详解

最小生成树的实际应用背景。 最节省经费的前提下&#xff0c;在n个城市之间建立通信联络网。 Kruskal算法&#xff08;基于并查集&#xff09; void init() {for (int i 1; i < n; i) {pre[i] i;} }ll root(ll a) {ll i a;while (pre[i] ! i) {i pre[i];}return i p…...

【启明智显产品分享】Model3工业级HMI芯片详解系列专题(三):安全、稳定、高防护

芯片作为电子设备的核心部件&#xff0c;&#xff0c;根据不同的应用领域被分为不同等级。工业级芯片适用于工业自动化、控制系统和仪器仪表等领域&#xff0c;对芯片的安全、稳定、防护能力等等有着较高的要求。这些芯片往往需要具备更宽的工业温度范围&#xff0c;能够在更恶…...

【面试干货】Java中的四种引用类型:强引用、软引用、弱引用和虚引用

【面试干货】Java中的四种引用类型&#xff1a;强引用、软引用、弱引用和虚引用 1、强引用&#xff08;Strong Reference&#xff09;2、软引用&#xff08;Soft Reference&#xff09;3、弱引用&#xff08;Weak Reference&#xff09;4、虚引用&#xff08;Phantom Reference…...

对称/非对称加密

对称加密和非对称加密是两种主要的加密方式&#xff0c;用于保护数据的机密性和完整性。它们在密钥的使用和管理上有着显著的不同。 对称加密 原理 对称加密&#xff08;Symmetric Encryption&#xff09;使用相同的密钥进行加密和解密。这意味着发送方和接收方必须共享相同…...

DDei在线设计器-API-DDeiSheet

DDeiSheet DDeiSheet是代表一个页签&#xff0c;一个页签含有一个DDeiStage用于显示图形。   DDeiSheet实例包含了一个页签的所有数据&#xff0c;在获取后可以通过它访问其他内容。DDeiFile中的sheets属性记录了当前文件的页签列表。 一个DDeiFile实例至少包含一个DDeiSheet…...

【WiFi帧结构】

文章目录 帧结构MAC头部管理帧 帧结构 Wi-Fi的帧分为三部分组成&#xff1a;MAC头部frame bodyFCS&#xff0c;其中MAC是固定格式的&#xff0c;frame body是可变长度。 MAC头部有frame control&#xff0c;duration&#xff0c;address1&#xff0c;address2&#xff0c;addre…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

工程地质软件市场:发展现状、趋势与策略建议

一、引言 在工程建设领域&#xff0c;准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具&#xff0c;正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

写一个shell脚本,把局域网内,把能ping通的IP和不能ping通的IP分类,并保存到两个文本文件里

写一个shell脚本&#xff0c;把局域网内&#xff0c;把能ping通的IP和不能ping通的IP分类&#xff0c;并保存到两个文本文件里 脚本1 #!/bin/bash #定义变量 ip10.1.1 #循环去ping主机的IP for ((i1;i<10;i)) doping -c1 $ip.$i &>/dev/null[ $? -eq 0 ] &&am…...

前端工具库lodash与lodash-es区别详解

lodash 和 lodash-es 是同一工具库的两个不同版本&#xff0c;核心功能完全一致&#xff0c;主要区别在于模块化格式和优化方式&#xff0c;适合不同的开发环境。以下是详细对比&#xff1a; 1. 模块化格式 lodash 使用 CommonJS 模块格式&#xff08;require/module.exports&a…...

Oracle实用参考(13)——Oracle for Linux物理DG环境搭建(2)

13.2. Oracle for Linux物理DG环境搭建 Oracle 数据库的DataGuard技术方案,业界也称为DG,其在数据库高可用、容灾及负载分离等方面,都有着非常广泛的应用,对此,前面相关章节已做过较为详尽的讲解,此处不再赘述。 需要说明的是, DG方案又分为物理DG和逻辑DG,两者的搭建…...

Unity-ECS详解

今天我们来了解Unity最先进的技术——ECS架构&#xff08;EntityComponentSystem&#xff09;。 Unity官方下有源码&#xff0c;我们下载源码后来学习。 ECS 与OOP&#xff08;Object-Oriented Programming&#xff09;对应&#xff0c;ECS是一种完全不同的编程范式与数据架构…...

数据库优化实战指南:提升性能的黄金法则

在现代软件系统中&#xff0c;数据库性能直接影响应用的响应速度和用户体验。面对数据量激增、访问压力增大&#xff0c;数据库性能瓶颈经常成为项目痛点。如何科学有效地优化数据库&#xff0c;提升查询效率和系统稳定性&#xff0c;是每位开发与运维人员必备的技能。 本文结…...

触发DMA传输错误中断问题排查

在STM32项目中&#xff0c;集成BLE模块后触发DMA传输错误中断&#xff08;DMA2_Stream1_IRQHandler进入错误流程&#xff09;&#xff0c;但单独运行BLE模块时正常&#xff0c;表明问题可能源于原有线程与BLE模块的交互冲突。以下是逐步排查与解决方案&#xff1a; 一、问题根源…...