Transformer模型:WordEmbedding实现
前言
最近在学Transformer,学了理论的部分之后就开始学代码的实现,这里是跟着b站的up主的视频记的笔记,视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili
正文
首先导入所需要的包:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
关于Word Embedding,这里以序列建模为例,考虑source sentence、target sentence,构建序列,序列的字符以其在词表中的索引的形式表示。
首先使用定义batch_size的大小,并且使用torch.randint()函数随机生成序列长度,这里的src是生成原本的序列,tgt是生成目标的序列。
以机器翻译实现英文翻译为中文来说,src就是英文句子,tgt就是中文句子,这也就是规定了要翻译的英文句子的长度和翻译出来的句子长度。(举个例子而已,不用纠结为什么翻译要限制句子的长度)
batch_size = 2src_len=torch.randint(2,5,(batch_size,))
tgt_len=torch.randint(2,5,(batch_size,))
将生成的src_len、tgt_len输出:
tensor([2, 3]) 生成的原序列第一个句子长度为2,第二个句子长度为3 tensor([4, 4]) 生成的目标序列第一个句子长度为4,第二个句子长度为4
因为随机生成的,所以每次运行都会有新的结果,也就是生成的src和tgt两个序列,其子句的长度每次都是随机的,这里改成生成固定长度的序列:
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
将生成的src_len、tgt_len输出,此时就固定好了序列长度了:
tensor([11, 9], dtype=torch.int32) tensor([10, 11], dtype=torch.int32)
接着是要实现单词索引构成的句子,首先定义单词表的大小和序列的最大长度。
# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
以生成原序列为例,使用torch.randint()生成第一个句子和第二个句子,然后放到列表中:
src_seq = [torch.randint(1, max_num_src_words, (L,)) for L in src_len]
[tensor([5, 3, 7, 5, 6, 3, 4, 3]), tensor([1, 6, 3, 1, 1, 7, 4])]
可以发现生成的两个序列长度不一样(因为我们自己定义的时候就是不一样的),在这里需要使用F.pad()函数进行padding保证序列长度一致:
src_seq = [F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len-L)) for L in src_len]
[tensor([8, 5, 2, 4, 6, 8, 1, 4, 0, 0, 0, 0]), tensor([5, 5, 5, 3, 7, 9, 3, 0, 0, 0, 0, 0])]
此时已经填充为同样的长度了,但是不同的句子各为一个张量,需要使用torch.cat()函数把不同句子的tensor转化为二维的tensor,在此之前需要先把每个张量变成二维的,使用torch.unsqueeze()函数:
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tensor([[9, 7, 7, 4, 7, 3, 9, 4, 7, 8, 8, 0],[1, 1, 5, 9, 5, 6, 2, 7, 4, 0, 0, 0]]) tensor([[3, 3, 2, 8, 3, 4, 1, 2, 9, 4, 0, 0],[1, 6, 3, 8, 5, 1, 5, 5, 1, 5, 3, 0]])
这里把tgt的也补充了,得到的就是src和tgt的内容各自在一个二维张量里(batch_size,max_seg_len),batch_size也就是句子数,max_seg_len也就是句子的单词数(分为src的长度跟tgt两种)。
补充:可以看到上面三次运行出来的结果都不一样,因为三次运行的时候,每次都是随机生成,所以结果肯定不一样,第三次为什么有两个二维的tensor是因为第三次把tgt的部分也补上去了,所以就有两个二维的tensor。
接下来就是构造embedding了,这里nn.Embedding()传入了两个参数,第一个是embedding的长度,也就是单词个数+1,+1的原因是因为有个0是作为填充的,第二个参数就是embedding的维度,也就是一个单词会被映射为多少维度的向量。
然后调用forward,得到我们的src和tgt的embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight) # 每一行代表一个embedding向量,第0行让给pad,从第1行到第行分配给各个单词,单词的索引是多少就取对应的行位置的向量
print(src_embedding) # 根据src_seq,从src_embedding_table获取得到的embedding vector,三维张量:batch_size、max_seq_len、model_dim
print(tgt_embedding)
此时src_embedding_table.weight的输出内容如下,第一行为填充(0)的向量:
tensor([[-0.3412, 1.5198, -1.7252, 0.6905, -0.3832, -0.8586, -2.0788, 0.3269],
[-0.5613, 0.3953, 1.6818, -2.0385, 1.1072, 0.2145, -0.9349, -0.7091],
[ 1.5881, -0.2389, -0.0347, 0.3808, 0.5261, 0.7253, 0.8557, -1.0020],
[-0.2725, 1.3238, -0.4087, 1.0758, 0.5321, -0.3466, -0.9051, -0.8938],
[-1.5393, 0.4966, -1.4887, 0.2795, -1.6751, -0.8635, -0.4689, -0.0827],
[ 0.6798, 0.1168, -0.5410, 0.5363, -0.0503, 0.4518, -0.3134, -0.6160],
[-1.1223, 0.3817, -0.6903, 0.0479, -0.6894, 0.7666, 0.9695, -1.0962],
[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],
[-0.9118, -1.4221, -2.4675, -0.1321, 0.7458, -0.8015, 0.5114, -0.5023],
[-1.7504, 0.0824, 2.2088, -0.4486, 0.7324, 1.8790, 1.7644, 1.2731],
[-0.3791, 1.9915, -1.0117, 0.8238, -2.1784, -1.2824, -0.4275, 0.3202]],
requires_grad=True)
src_embedding的输出结果如下所示,往前看src_seq的第一个句子前三个为9 7 7,往前看第9+1行与第7+1行的向量,就是现在输出的前3个向量:
tensor([[[-1.7504, 0.0824, 2.2088, -0.4486, 0.7324, 1.8790, 1.7644, 1.2731],[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],[-1.5393, 0.4966, -1.4887, 0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],[-0.2725, 1.3238, -0.4087, 1.0758, 0.5321, -0.3466, -0.9051, -0.8938],[-1.7504, 0.0824, 2.2088, -0.4486, 0.7324, 1.8790, 1.7644, 1.2731],[-1.5393, 0.4966, -1.4887, 0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],[-0.9118, -1.4221, -2.4675, -0.1321, 0.7458, -0.8015, 0.5114, -0.5023],[-0.9118, -1.4221, -2.4675, -0.1321, 0.7458, -0.8015, 0.5114, -0.5023],[-0.3412, 1.5198, -1.7252, 0.6905, -0.3832, -0.8586, -2.0788, 0.3269]],[[-0.5613, 0.3953, 1.6818, -2.0385, 1.1072, 0.2145, -0.9349, -0.7091],[-0.5613, 0.3953, 1.6818, -2.0385, 1.1072, 0.2145, -0.9349, -0.7091],[ 0.6798, 0.1168, -0.5410, 0.5363, -0.0503, 0.4518, -0.3134, -0.6160],[-1.7504, 0.0824, 2.2088, -0.4486, 0.7324, 1.8790, 1.7644, 1.2731],[ 0.6798, 0.1168, -0.5410, 0.5363, -0.0503, 0.4518, -0.3134, -0.6160],[-1.1223, 0.3817, -0.6903, 0.0479, -0.6894, 0.7666, 0.9695, -1.0962],[ 1.5881, -0.2389, -0.0347, 0.3808, 0.5261, 0.7253, 0.8557, -1.0020],[ 0.9608, 0.0764, 0.0914, 1.1949, -1.3853, 1.1089, -0.9282, -0.9793],[-1.5393, 0.4966, -1.4887, 0.2795, -1.6751, -0.8635, -0.4689, -0.0827],[-0.3412, 1.5198, -1.7252, 0.6905, -0.3832, -0.8586, -2.0788, 0.3269],[-0.3412, 1.5198, -1.7252, 0.6905, -0.3832, -0.8586, -2.0788, 0.3269],[-0.3412, 1.5198, -1.7252, 0.6905, -0.3832, -0.8586, -2.0788, 0.3269]]], grad_fn=<EmbeddingBackward>)
同理tgt_embedding的输出结果如下所示:
tensor([[[-1.3681, -0.1619, -0.3676, 0.4312, -1.3842, -0.6180, 0.3685, 1.6281],[-1.3681, -0.1619, -0.3676, 0.4312, -1.3842, -0.6180, 0.3685, 1.6281],[-2.6519, -0.8566, 1.2268, 2.6479, -0.2011, -0.1394, -0.2449, 1.0309],[-0.8919, 0.5235, -3.1833, 0.9388, -0.6213, -0.5146, 0.7913, 0.5126],[-1.3681, -0.1619, -0.3676, 0.4312, -1.3842, -0.6180, 0.3685, 1.6281],[-0.4984, 0.2948, -0.2804, -1.1943, -0.4495, 0.3793, -0.1562, -1.0122],[ 0.8976, 0.5226, 0.0286, 0.1434, -0.2600, -0.7661, 0.1225, -0.7869],[-2.6519, -0.8566, 1.2268, 2.6479, -0.2011, -0.1394, -0.2449, 1.0309],[ 2.2026, 1.8504, -0.6285, -0.0996, -0.0994, -0.0828, 0.6004, -0.3173],[-0.4984, 0.2948, -0.2804, -1.1943, -0.4495, 0.3793, -0.1562, -1.0122],[ 0.3637, 0.4256, 0.7674, 1.4321, -0.1164, -0.6032, -0.8182, -0.6119],[ 0.3637, 0.4256, 0.7674, 1.4321, -0.1164, -0.6032, -0.8182, -0.6119]],[[ 0.8976, 0.5226, 0.0286, 0.1434, -0.2600, -0.7661, 0.1225, -0.7869],[-1.0356, 0.8212, 1.0538, 0.4510, 0.2734, 0.3254, 0.4503, 0.1694],[-1.3681, -0.1619, -0.3676, 0.4312, -1.3842, -0.6180, 0.3685, 1.6281],[-0.8919, 0.5235, -3.1833, 0.9388, -0.6213, -0.5146, 0.7913, 0.5126],[-0.4783, -1.5936, 0.5033, 0.3483, -1.3354, 1.4553, -1.1344, -1.9280],[ 0.8976, 0.5226, 0.0286, 0.1434, -0.2600, -0.7661, 0.1225, -0.7869],[-0.4783, -1.5936, 0.5033, 0.3483, -1.3354, 1.4553, -1.1344, -1.9280],[-0.4783, -1.5936, 0.5033, 0.3483, -1.3354, 1.4553, -1.1344, -1.9280],[ 0.8976, 0.5226, 0.0286, 0.1434, -0.2600, -0.7661, 0.1225, -0.7869],[-0.4783, -1.5936, 0.5033, 0.3483, -1.3354, 1.4553, -1.1344, -1.9280],[-1.3681, -0.1619, -0.3676, 0.4312, -1.3842, -0.6180, 0.3685, 1.6281],[ 0.3637, 0.4256, 0.7674, 1.4321, -0.1164, -0.6032, -0.8182, -0.6119]]], grad_fn=<EmbeddingBackward>)
实际想要把文本句子嵌入到Embedding中,需要先根据自己的词典,将文本信息转化为每个词在词典中的位置,然后第0个位置依旧要让给Padding,得到索引然后构建Batch再去构造Embedding,以索引为输入得到每个样本的Embedding。
代码
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F# 句子数
batch_size = 2# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12# 模型的维度
model_dim = 8# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)# 构造embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)
print(src_embedding)
print(tgt_embedding)
参考
torch.randint — PyTorch 2.3 documentation
torch.nn.functional.pad — PyTorch 2.3 文档
F.pad 的理解_domain:luyixian.cn-CSDN博客
嵌入 — PyTorch 2.3 文档
相关文章:
Transformer模型:WordEmbedding实现
前言 最近在学Transformer,学了理论的部分之后就开始学代码的实现,这里是跟着b站的up主的视频记的笔记,视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili 正文 首先导入所需要的包:…...
如何压缩pdf文件大小,怎么压缩pdf文件大小
在数字化时代,pdf文件因其稳定的格式和跨平台兼容性,成为了工作与学习中不可或缺的一部分。然而,随着pdf文件内容的丰富,pdf文件的体积也随之增大,给传输和存储带来了不少挑战。本文将深入探讨如何高效压缩pdf文件大小…...
Spring Boot集成Atomix快速入门Demo
1.什么是Atomix? Atomix是一个能用的Java框架,用来构建高可用的分布式系统。它是基于RAFT协议的实现,为用户提供了各种原子数据结构,比如map/set/integer等,这些数据结构都可以在整个集群中共享并保证一致性ÿ…...
Go语言map并发安全,互斥锁和读写锁谁更优?
并发编程是 Go 语言的一大特色,合理地使用锁对于保证数据一致性和提高程序性能至关重要。 在处理并发控制时,sync.Mutex(互斥锁)和 sync.RWMutex(读写锁)是两个常用的工具。理解它们各自的优劣及擅长的场景…...
Java多线程性能调优
Synchronized同步锁优化方法 1.6之前比较重量级,1.6后经过优化性能大大提升 使用Synchronized实现同步锁住要是两种方式:方法、代码块。 1.代码块 Synchronized在修饰同步代码块时,是由 monitorenter和monitorexit指令来实现同步的。进入mo…...
MacOS 通过Docker安装宝塔面板搭建PHP开发环境
1、docker拉取ubuntu系统 docker pull ubuntu2、运行容器 docker run -i -t -d --name bt -p 20:20 -p 21:21 -p 80:80 -p 443:443 -p 888:888 -p 8888:8888 -p 3306:3306 -p 6379:6379 --privilegedtrue -v /Users/oi/Sites:/www/wwwroot ubuntu-v 后的 /Users/oi/Sites 代表…...
Unity发布webgl之后修改StreamingAssets 内的配置文件读取到的还是之前的配置文件的解决方案
问题描述 unity发布webgl之后,修改在StreamingAssets 中的配置信息,修改之后读取的还是之前的配置信息 读取配置文件的代码IEnumerator IE_WebGL_LoadWebSocketServerCopnfig(){var uri new System.Uri(Path.Combine(Application.streamingAssetsPath…...
离线语音识别芯片在智能生活中的应用
离线语音识别芯片,这一技术正逐渐渗透到我们日常生活的每一个角落,为众多产品带来前所未有的智能体验。它能够应用到多种产品中,包括但不限于: 1、智能音箱:语音识别芯片作为智能音箱的核心,使用户…...
替换:show-overflow-tooltip=“true“ ,使用插槽tooltip,达到内容可复制
原生的show-overflow-tooltip“true” 不能满足条件,使用插槽自定义编辑; 旧code <el-table-column prop"reason" label"原因" align"center" :show-overflow-tooltip"true" /> <el-table-column pro…...
219.贪心算法:柠檬水找零(力扣)
代码解决 class Solution { public:bool lemonadeChange(vector<int>& bills) {int num50, num100; // 初始化5美元和10美元的计数器for(int i0; i < bills.size(); i) // 遍历所有账单{if(bills[i]5) // 如果账单是5美元{num5; // 增加5美元的计数continue; // …...
通过 Azure OpenAI 服务使用 GPT-35-Turbo and GPT-4(win版)
官方文档 Azure OpenAI 是微软提供的一项云服务,旨在将 OpenAI 的先进人工智能模型与 Azure 的基础设施和服务相结合。通过 Azure OpenAI,开发者和企业可以访问 OpenAI 的各种模型,如 GPT-3、Codex 和 DALL-E 等,并将其集成到自己…...
MySQL 面试真题(带答案)
MySQL 场景面试题 目录 场景1:用户注册和登录系统 1.1 数据库设计1.2 用户注册1.3 用户登录 场景2:订单管理系统 2.1 数据库设计2.2 创建订单2.3 查询订单 场景3:博客系统 3.1 数据库设计3.2 发布文章3.3 评论功能 场景1:用户…...
《A++ 敏捷开发》- 10 二八原则
团队成员协作,利用项目数据,分析根本原因,制定纠正措施,并立马尝试,判断是否有效,是改善的“基本功”。10-12章会探索里面的注意事项,13章会看两家公司的实施情况和常见问题。 如果已经获得高层…...
Spring Boot 框架知识汇总
1、什么是SpringBoot? 通过Spring Boot,可以轻松地创建独立的,基于生产级别的Spring的应用程序,您可以“运行"它们。大多数Spring Boot应用程序需要最少的Spring配置,集成了大量常用的第三方库配置,使…...
国产麒麟、uos在线编辑word文件并控制编辑区域(局部编辑)
windows系统也适用,该插件可同时支持windows和国产系统 在实际项目开发中,以下场景可能会用到Word局部编辑功能: 合同审批公文流转策划设计报告汇签单招投标(标书文件)其他,有模板且需要不同人员协作编辑…...
Go:基本变量与数据类型
目录 前言 前期准备 Hello World! 一、基本变量 1.1 声明变量 1.2 初始化变量 1.3 变量声明到初始化的过程 1.4 变量值交换 1.5 匿名变量 1.6 变量的作用域 二、数据类型 1.1 整型 1.2 浮点型 1.3 字符串 1.4 布尔类型 1.5 数据类型判断 1.6 数据类型转换 1.…...
计算器原生js
目录 1.HTML 2.CSS 2.JS 4.资源 5.运行截图 6.下载连接 7.注意事项 1.HTML <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-s…...
怎样将aac转换mp3格式?推荐四个aac转MP3的方法
怎样将aac转换mp3格式?当需要将aac格式音频转换为MP3格式时,有几种方法可以轻松实现这一目标。MP3是一种广泛支持的音频格式,几乎所有设备和平台都能播放MP3文件,包括各种音乐播放器、手机、平板电脑和汽车音响系统。而且它也提供…...
MongoDB - 查询操作符:比较查询、逻辑查询、元素查询、数组查询
文章目录 1. 构造数据2. MongoDB 比较查询操作符1. $eq 等于1.1 等于指定值1.2 嵌入式文档中的字段等于某个值1.3 数组元素等于某个值1.4 数组元素等于数组值 2. $ne 不等于3. $gt 大于3.1 匹配文档字段3.2 根据嵌入式文档字段执行更新 4. $gte 大于等于5. $lt 小于6. $lte 小于…...
html5——CSS高级选择器
目录 属性选择器 E[att^"value"] E[att$"http"] E[att*"http"] 关系选择器 子代: 相邻兄弟: 普通兄弟: 结构伪类选择器 链接伪类选择器 伪元素选择器 CSS的继承与层叠 CSS的继承性 CSS的层叠性 …...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...
多模态2025:技术路线“神仙打架”,视频生成冲上云霄
文|魏琳华 编|王一粟 一场大会,聚集了中国多模态大模型的“半壁江山”。 智源大会2025为期两天的论坛中,汇集了学界、创业公司和大厂等三方的热门选手,关于多模态的集中讨论达到了前所未有的热度。其中,…...
为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?
在建筑行业,项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升,传统的管理模式已经难以满足现代工程的需求。过去,许多企业依赖手工记录、口头沟通和分散的信息管理,导致效率低下、成本失控、风险频发。例如&#…...
大语言模型如何处理长文本?常用文本分割技术详解
为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配
AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
CMake控制VS2022项目文件分组
我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...
