当前位置: 首页 > news >正文

【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现

文章目录

      • 0. 前言
      • 1. 基础介绍
        • 1.1 基本参数
        • 1.2 可选参数
        • 1.3 属性
        • 1.4 PyTorch源码注释
      • 2. 实例演示
      • 3. `embedding_dim`的合理设定
      • 4. 结论

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在自然语言处理(NLP)中,torch.nn.Embedding是PyTorch框架中一个至关重要的模块,用于将离散的词汇转换成连续的向量空间表示。这种转换允许模型捕捉词汇之间的语义关系,并在诸如情感分析、文本分类和机器翻译等任务中发挥关键作用。

本文将深入探讨torch.nn.Embedding的工作原理,并通过示例代码演示其在PyTorch中的使用。

1. 基础介绍

torch.nn.Embedding的本质是一个映射表(Lookup table),它用于储存自然语言词典嵌入向量的映射关系。

1.1 基本参数

torch.nn.Embedding的初始化接受两个基本参数:num_embeddingsembedding_dim

  • num_embeddings:这个参数直观理解为“要嵌入的自然语言的词汇数量”,表示上面所述的自然语言词典的大小,即可能的唯一词汇数量。比如英语中的常用单词,从abandon开始一共有3000个,那num_embeddings就可以设定为3000;
  • embedding_dim:表示每个词汇映射的嵌入向量的维度。
1.2 可选参数
  • padding_idx:用于指定词汇表中的填充词汇索引,该位置的向量将被初始化为零。
  • max_norm:用于限制嵌入向量的L2范数。
  • norm_type:用于指定范数类型。
  • scale_grad_by_freq:如果设置为True,则将梯度按词汇频率缩放。
  • sparse:如果设置为True,则将嵌入梯度标记为稀疏。
1.3 属性

torch.nn.Embedding 模块只有一个属性 weight。这个属性代表了嵌入层要学习的权重,即存储所有嵌入向量的矩阵。这是嵌入层的学习权重,形状为 (num_embeddings, embedding_dim),也就是上文所说的lookup table映射表。这些权重代表实际的嵌入向量,它们是可学习的参数,并且在训练过程中会被优化算法更新。默认情况下,weight 是从标准正态分布 N(0, 1) 随机初始化的。这意味着每个元素都独立地从均值为 0、标准差为 1 的正态分布中采样。

1.4 PyTorch源码注释

以下是nn.Embedding的源码注释,用于上面说明的参考:

Args:num_embeddings (int): size of the dictionary of embeddingsembedding_dim (int): the size of each embedding vectorpadding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;therefore, the embedding vector at :attr:`padding_idx` is not updated during training,i.e. it remains as a fixed "pad". For a newly constructed Embedding,the embedding vector at :attr:`padding_idx` will default to all zeros,but can be updated to another value to be used as the padding vector.max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`is renormalized to have norm :attr:`max_norm`.norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency ofthe words in the mini-batch. Default ``False``.sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.See Notes for more details regarding sparse gradients.Attributes:weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)initialized from :math:`\mathcal{N}(0, 1)`Shape:- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`

2. 实例演示

这里我将给出一个简单的例子来说明如何使用 PyTorch 的 torch.nn.Embedding 模块创建一个嵌入层,并获取一些单词的嵌入向量。

假设我们有一个小型的词汇表,包含以下单词:

  • “the”
  • “cat”
  • “dog”
  • “sat”
  • “on”
  • “mat”

我们将这些单词映射到索引上,例如:

  • “the” -> 0
  • “cat” -> 1
  • “dog” -> 2
  • “sat” -> 3
  • “on” -> 4
  • “mat” -> 5

现在我们可以创建一个 torch.nn.Embedding 层,将这些单词映射到嵌入向量中。我们将使用一个 3 维的嵌入向量来表示每个单词。

下面是具体的代码示例:

import torch
import torch.nn as nn# 创建一个 Embedding 层
# num_embeddings: 词汇表的大小,这里是 6
# embedding_dim: 嵌入向量的维度,这里是 3
embedding = nn.Embedding(num_embeddings=6, embedding_dim=3)# 定义一些单词的索引
word_indices = torch.LongTensor([0, 1, 2, 3, 4, 5])  # "the", "cat", "dog", "sat", "on", "mat"# 通过索引获取嵌入向量
word_embeddings = embedding(word_indices)# 输出嵌入向量
print(word_embeddings)

运行上述代码后,word_embeddings 将是一个形状为 (6, 3) 的张量,其中每一行代表一个单词的嵌入向量。

tensor([[ 0.0439,  0.7314, -0.3546],[ 0.6975,  1.2725,  1.4042],[-1.7532, -2.0642, -0.1434],[ 0.2538,  1.1123, -0.8636],[-0.7238, -0.0585,  0.5242],[ 0.6485,  0.6885, -1.2045]], grad_fn=<EmbeddingBackward0>)

例如,word_embeddings[0] 对应于单词 “the” 的嵌入向量,word_embeddings[1] 对应于单词 “cat” 的嵌入向量,以此类推。

这就是一个简单的英语单词嵌入向量的例子。在实际应用中,词汇表会更大,嵌入向量的维度也会更高,而且通常会使用预训练的嵌入向量来初始化这些权重。

3. embedding_dim的合理设定

通过上文说明,我们可以轻松地掌握nn.Embedding模块的使用,但是这里有个问题:embedding_dim设定为多少比较合适呢?

这里首先要说明下嵌入向量:它应该是代表单词“语义”的向量,而不是像one-hot那样是简单的字母映射。

举个例子:meetmeat两个词,拼写十分接近,即它们的one-hot编码十分接近,但是它们的语义完全不同,也就是说嵌入向量应该相差很远。而hugeenormous情况刚好相反,它们的one-hot编码完全不同,而嵌入向量应该比较接近。

那回到embedding_dim的设定选择上来,我觉得可以参考以下3个方面来设定比较合理的embedding_dim

  1. 平衡信息量与过拟合风险
    • 信息量: 较高的 embedding_dim 可以捕获更多的信息和细微差别,从而提高模型的表达能力。然而,这也可能会导致过拟合,因为高维空间容易出现稀疏性问题。
    • 过拟合风险: 较低的 embedding_dim 可以减少参数数量,降低过拟合的风险,但可能会丢失一些信息。
  2. 考虑词汇表的大小
    • 较小的词汇表: 如果词汇表相对较小(例如几千个词),较低的 embedding_dim(如 50 或 100)可能就足够了。
    • 较大的词汇表: 对于较大的词汇表(例如几十万或更多),可以选择较高的 embedding_dim(如 200 至 500)以更好地捕捉语义信息。
  3. 实验验证
    • 交叉验证: 最终的选择通常需要通过实验来确定。使用交叉验证来评估不同 embedding_dim 下的模型性能,可以帮助找到最佳值。
    • 预训练嵌入: 如果有可用的预训练嵌入(如 Word2Vec、GloVe 或 FastText),可以考虑使用它们的维度作为参考。

一点点思考:在Embedding方法中,embedding_dim一般是要比num_embeddings小(很多)的,这会导致矩阵的秩不满,最终会导致Embedding方法中的单词可以通过线性变换变成另一个单词。比如把abandon的词向量×2得到get的词向量,而one-hot不会有这个问题,这是Embedding小小的局限性。

4. 结论

torch.nn.Embedding模块在PyTorch中为NLP任务提供了强大的工具,允许模型从词汇索引中学习有意义的向量表示。通过初始化和调用这个模块,我们可以轻松地将文本数据转换为适合深度学习模型的格式,从而挖掘文本数据中的丰富语义信息。

相关文章:

【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现

文章目录 0. 前言1. 基础介绍1.1 基本参数1.2 可选参数1.3 属性1.4 PyTorch源码注释 2. 实例演示3. embedding_dim的合理设定4. 结论 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我自己学习的理解&#xff0c;虽然参考了他人的宝贵见解及成果&#xff0c;但…...

Jedis 操作 Redis 数据结构全攻略

Jedis 操作 Redis 数据结构全攻略 一 . 认识 RESP二 . 前置操作2.1 创建项目2.2 关于开放 Redis 端口的问题2.2.1 端口转发?2.2.2 端口配置 2.3 连接到 Redis 服务器 三 . 通用命令3.1 set 和 get3.2 exists 和 del3.3 keys3.4 expire、ttl、type 三 . string 相关命令3.1 mse…...

ctf.show靶场ssrf攻略

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 web351 解析:post传入url参数他就会访问。 解法: hackbar传入url参数写入https://127.0.0.1/flag.php web352 解析:post传入url参数&#xff0c;不能是127.0.0.1和localhost 解法:缩写127.1传入 web353 解析…...

在 PyTorch 中,如何使用 `pack_padded_sequence` 来提高模型训练的效率?

在PyTorch中&#xff0c;pack_padded_sequence 是一个非常有用的函数&#xff0c;它可以用来提高模型训练的效率&#xff0c;特别是在处理变长序列数据时。这个函数的主要作用是将填充后的序列数据打包&#xff0c;以便循环神经网络&#xff08;RNN&#xff09;可以更高效地处理…...

Gossip协议

主要用在Redis Cluster 节点间通信 Gossip协议&#xff0c;也称为流行病协议&#xff08;Epidemic Protocol&#xff09;&#xff0c;是一种在分布式系统中用于信息传播和故障探测的算法。 一、工作原理 随机选择传播对象 每个节点会定期随机选择一些其他节点作为传播对象。这…...

数据结构————双向链表

内存泄漏&#xff1a; 内存泄漏&#xff08;Memory Leak&#xff09;是指程序中已动态分配的内存由于某种原因程序未释放或无法释放&#xff0c;导致系统内存的浪费&#xff0c;严重时会导致程序运行缓慢甚至崩溃。这种情况在长时间运行的程序或大型系统中尤为常见&#xff0c…...

55 - I. 二叉树的深度

comments: true difficulty: 简单 edit_url: https://github.com/doocs/leetcode/edit/main/lcof/%E9%9D%A2%E8%AF%95%E9%A2%9855%20-%20I.%20%E4%BA%8C%E5%8F%89%E6%A0%91%E7%9A%84%E6%B7%B1%E5%BA%A6/README.md 面试题 55 - I. 二叉树的深度 题目描述 输入一棵二叉树的根节点…...

Redis——初识Redis

初识Redis Redis认识Redis 分布式系统单机架构为什么要引入分布式理解负载均衡数据库的读写分离引入主从数据库 引入缓存数据库分库分表业务拆分——微服务常见概念了解 Redis背景介绍特性应用场景Redis不能做的事情Redis客户端redis客户端的多种形态 Redis 认识Redis 存储数…...

Xshell or Xftp提示“要继续使用此程序,您必须应用最新的更新或使用新版本”

Xshell提示“要继续使用此程序,您必须应用最新的更新或使用新版本”&#xff0c;笔者版本是xshell 6 方法一&#xff1a;更改系统时间 对于Windows 10用户&#xff0c;首先找到系统日期&#xff0c;右键点击并选择“调整时间/日期”。将日期设定为上一年。完成调整后&#x…...

table用position: sticky固定多层表头,滑动滚动条border边框透明解决方法

问题&#xff1a;我发现&#xff0c;这个上下滑动有内容经过就会出现如图的情况。 解决的方法&#xff1a;用outline&#xff08;轮廓&#xff09;替代border,以达到我们想要的样式。 outline主要是在元素边框的外围设置轮廓&#xff0c;outline不占据空间&#xff0c;绘制于…...

基于飞桨paddle2.6.1+cuda11.7+paddleRS开发版的目标提取-道路数据集训练和预测代码

基于飞桨paddle2.6.1cuda11.7paddleRS开发版的目标提取-道路数据集训练和预测代码 预测结果&#xff1a; 预测影像&#xff1a; &#xff08;一&#xff09;准备道路数据集 下载数据集地址&#xff1a; https://aistudio.baidu.com/datasetdetail/56961 mass_road.zip …...

数学建模笔记—— 整数规划和0-1规划

数学建模笔记—— 整数规划和0-1规划 整数规划和0-1规划1. 模型原理1.1 基本概念1.2 线性整数规划求解1.3 线性0-1规划的求解 2. 典型例题2.1 背包问题2.2 指派问题 3. matlab代码实现3.1 背包问题3.2 指派问题 整数规划和0-1规划 1. 模型原理 1.1 基本概念 在规划问题中&am…...

[001-03-007].第26节:分布式锁迭代3->优化基于setnx命令实现的分布式锁-防锁的误删

我的博客大纲 我的后端学习大纲 1、问题分析&#xff1a; 1.1.问题&#xff1a; 1.锁的超时释放&#xff0c;可能会释放其他服务器的锁 1.2.场景&#xff1a; 1.如果业务逻辑的执行时间是7s。执行流程如下 1.index1业务逻辑没执行完&#xff0c;3秒后锁被自动释放。2.index…...

【Unity踩坑】为什么有Rigidbody的物体运行时位置会变化

先上图&#xff0c;不知你有没有注意过这个现象呢&#xff1f; 一个物体加上了Rigidbody组件&#xff0c;当勾选上Use Gravity时&#xff0c;运行后&#xff0c;这个物体的位置的值会有变化。这是为什么呢&#xff1f; 刚体由物理系统处理&#xff0c;因此它会对重力、碰撞等做…...

NGINX开启HTTP3,给web应用提个速

环境说明 linuxdockernginx版本:1.27 HTTP3/QUIC介绍 HTTP3是由IETF于2022年发布的一个标准&#xff0c;文档地址为&#xff1a;https://datatracker.ietf.org/doc/html/rfc9114 如rfc9114所述&#xff0c;http3主要基于QUIC协议实现&#xff0c;在具备高性能的同时又兼备了…...

秋招季!别浮躁!

好久没写了&#xff0c;今天兴致来了&#xff0c;众所周知我一旦想说话&#xff0c;就来这里疯狂写。 最近&#xff0c;我去了一家国企的研究院&#xff0c;听着是不是贼高大上&#xff0c;呵——这玩意儿把我分配到三级机构&#xff0c;我一个学计算机的&#xff0c;它不把我…...

Java的时间复杂度和空间复杂度和常见排序

目录 一丶时间复杂度 二丶空间复杂度 三丶Java常见排序 1. 冒泡排序&#xff08;Bubble Sort&#xff09; 2.插入排序&#xff08;Insertion Sort&#xff09; 3.希尔排序&#xff08;Shell Sort&#xff09; 4.选择排序&#xff08;Selection Sort&#xff09; 5.堆排序&am…...

Qt 学习第十天:标准对话框 页面布局

系统标准对话框 错误对话框 //错误对话框connect(createPro, &QAction::triggered, this, []{//参数1 父亲 参数2 标题 参数3 对话框内显示文本内容 。。。QMessageBox::critical(this, "报错!", "没加头文件!");}); 【运行结果】 信息对话框 co…...

体育数据API纳米足球数据API:足球数据接口文档API示例⑩

纳米体育数据的数据接口通过JSON拉流方式获取200多个国家的体育赛事实时数据或历史数据的编程接口&#xff0c;无请求次数限制&#xff0c;可按需购买&#xff0c;接口稳定高效&#xff1b; 覆盖项目包括足球、篮球、网球、电子竞技、奥运等专题、数据内容。纳米数据API2.0版本…...

[数据集][目标检测]高铁受电弓检测数据集VOC+YOLO格式1245张2类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;1245 标注数量(xml文件个数)&#xff1a;1245 标注数量(txt文件个数)&#xff1a;1245 标注…...

visual studio 2022更改主题为深色

visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中&#xff0c;选择 环境 -> 常规 &#xff0c;将其中的颜色主题改成深色 点击确定&#xff0c;更改完成...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

2025季度云服务器排行榜

在全球云服务器市场&#xff0c;各厂商的排名和地位并非一成不变&#xff0c;而是由其独特的优势、战略布局和市场适应性共同决定的。以下是根据2025年市场趋势&#xff0c;对主要云服务器厂商在排行榜中占据重要位置的原因和优势进行深度分析&#xff1a; 一、全球“三巨头”…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程&#xff0c;代码下载&#xff1a;这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中&#xff0c;**知识蒸馏&#xff08;Knowledge Distillation&#xff09;**被广泛应用&#xff0c;作为提升模型…...

Vite中定义@软链接

在webpack中可以直接通过符号表示src路径&#xff0c;但是vite中默认不可以。 如何实现&#xff1a; vite中提供了resolve.alias&#xff1a;通过别名在指向一个具体的路径 在vite.config.js中 import { join } from pathexport default defineConfig({plugins: [vue()],//…...

Python 训练营打卡 Day 47

注意力热力图可视化 在day 46代码的基础上&#xff0c;对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...

五、jmeter脚本参数化

目录 1、脚本参数化 1.1 用户定义的变量 1.1.1 添加及引用方式 1.1.2 测试得出用户定义变量的特点 1.2 用户参数 1.2.1 概念 1.2.2 位置不同效果不同 1.2.3、用户参数的勾选框 - 每次迭代更新一次 总结用户定义的变量、用户参数 1.3 csv数据文件参数化 1、脚本参数化 …...