当前位置: 首页 > news >正文

【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现

文章目录

      • 0. 前言
      • 1. 基础介绍
        • 1.1 基本参数
        • 1.2 可选参数
        • 1.3 属性
        • 1.4 PyTorch源码注释
      • 2. 实例演示
      • 3. `embedding_dim`的合理设定
      • 4. 结论

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在自然语言处理(NLP)中,torch.nn.Embedding是PyTorch框架中一个至关重要的模块,用于将离散的词汇转换成连续的向量空间表示。这种转换允许模型捕捉词汇之间的语义关系,并在诸如情感分析、文本分类和机器翻译等任务中发挥关键作用。

本文将深入探讨torch.nn.Embedding的工作原理,并通过示例代码演示其在PyTorch中的使用。

1. 基础介绍

torch.nn.Embedding的本质是一个映射表(Lookup table),它用于储存自然语言词典嵌入向量的映射关系。

1.1 基本参数

torch.nn.Embedding的初始化接受两个基本参数:num_embeddingsembedding_dim

  • num_embeddings:这个参数直观理解为“要嵌入的自然语言的词汇数量”,表示上面所述的自然语言词典的大小,即可能的唯一词汇数量。比如英语中的常用单词,从abandon开始一共有3000个,那num_embeddings就可以设定为3000;
  • embedding_dim:表示每个词汇映射的嵌入向量的维度。
1.2 可选参数
  • padding_idx:用于指定词汇表中的填充词汇索引,该位置的向量将被初始化为零。
  • max_norm:用于限制嵌入向量的L2范数。
  • norm_type:用于指定范数类型。
  • scale_grad_by_freq:如果设置为True,则将梯度按词汇频率缩放。
  • sparse:如果设置为True,则将嵌入梯度标记为稀疏。
1.3 属性

torch.nn.Embedding 模块只有一个属性 weight。这个属性代表了嵌入层要学习的权重,即存储所有嵌入向量的矩阵。这是嵌入层的学习权重,形状为 (num_embeddings, embedding_dim),也就是上文所说的lookup table映射表。这些权重代表实际的嵌入向量,它们是可学习的参数,并且在训练过程中会被优化算法更新。默认情况下,weight 是从标准正态分布 N(0, 1) 随机初始化的。这意味着每个元素都独立地从均值为 0、标准差为 1 的正态分布中采样。

1.4 PyTorch源码注释

以下是nn.Embedding的源码注释,用于上面说明的参考:

Args:num_embeddings (int): size of the dictionary of embeddingsembedding_dim (int): the size of each embedding vectorpadding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;therefore, the embedding vector at :attr:`padding_idx` is not updated during training,i.e. it remains as a fixed "pad". For a newly constructed Embedding,the embedding vector at :attr:`padding_idx` will default to all zeros,but can be updated to another value to be used as the padding vector.max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`is renormalized to have norm :attr:`max_norm`.norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency ofthe words in the mini-batch. Default ``False``.sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.See Notes for more details regarding sparse gradients.Attributes:weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)initialized from :math:`\mathcal{N}(0, 1)`Shape:- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`

2. 实例演示

这里我将给出一个简单的例子来说明如何使用 PyTorch 的 torch.nn.Embedding 模块创建一个嵌入层,并获取一些单词的嵌入向量。

假设我们有一个小型的词汇表,包含以下单词:

  • “the”
  • “cat”
  • “dog”
  • “sat”
  • “on”
  • “mat”

我们将这些单词映射到索引上,例如:

  • “the” -> 0
  • “cat” -> 1
  • “dog” -> 2
  • “sat” -> 3
  • “on” -> 4
  • “mat” -> 5

现在我们可以创建一个 torch.nn.Embedding 层,将这些单词映射到嵌入向量中。我们将使用一个 3 维的嵌入向量来表示每个单词。

下面是具体的代码示例:

import torch
import torch.nn as nn# 创建一个 Embedding 层
# num_embeddings: 词汇表的大小,这里是 6
# embedding_dim: 嵌入向量的维度,这里是 3
embedding = nn.Embedding(num_embeddings=6, embedding_dim=3)# 定义一些单词的索引
word_indices = torch.LongTensor([0, 1, 2, 3, 4, 5])  # "the", "cat", "dog", "sat", "on", "mat"# 通过索引获取嵌入向量
word_embeddings = embedding(word_indices)# 输出嵌入向量
print(word_embeddings)

运行上述代码后,word_embeddings 将是一个形状为 (6, 3) 的张量,其中每一行代表一个单词的嵌入向量。

tensor([[ 0.0439,  0.7314, -0.3546],[ 0.6975,  1.2725,  1.4042],[-1.7532, -2.0642, -0.1434],[ 0.2538,  1.1123, -0.8636],[-0.7238, -0.0585,  0.5242],[ 0.6485,  0.6885, -1.2045]], grad_fn=<EmbeddingBackward0>)

例如,word_embeddings[0] 对应于单词 “the” 的嵌入向量,word_embeddings[1] 对应于单词 “cat” 的嵌入向量,以此类推。

这就是一个简单的英语单词嵌入向量的例子。在实际应用中,词汇表会更大,嵌入向量的维度也会更高,而且通常会使用预训练的嵌入向量来初始化这些权重。

3. embedding_dim的合理设定

通过上文说明,我们可以轻松地掌握nn.Embedding模块的使用,但是这里有个问题:embedding_dim设定为多少比较合适呢?

这里首先要说明下嵌入向量:它应该是代表单词“语义”的向量,而不是像one-hot那样是简单的字母映射。

举个例子:meetmeat两个词,拼写十分接近,即它们的one-hot编码十分接近,但是它们的语义完全不同,也就是说嵌入向量应该相差很远。而hugeenormous情况刚好相反,它们的one-hot编码完全不同,而嵌入向量应该比较接近。

那回到embedding_dim的设定选择上来,我觉得可以参考以下3个方面来设定比较合理的embedding_dim

  1. 平衡信息量与过拟合风险
    • 信息量: 较高的 embedding_dim 可以捕获更多的信息和细微差别,从而提高模型的表达能力。然而,这也可能会导致过拟合,因为高维空间容易出现稀疏性问题。
    • 过拟合风险: 较低的 embedding_dim 可以减少参数数量,降低过拟合的风险,但可能会丢失一些信息。
  2. 考虑词汇表的大小
    • 较小的词汇表: 如果词汇表相对较小(例如几千个词),较低的 embedding_dim(如 50 或 100)可能就足够了。
    • 较大的词汇表: 对于较大的词汇表(例如几十万或更多),可以选择较高的 embedding_dim(如 200 至 500)以更好地捕捉语义信息。
  3. 实验验证
    • 交叉验证: 最终的选择通常需要通过实验来确定。使用交叉验证来评估不同 embedding_dim 下的模型性能,可以帮助找到最佳值。
    • 预训练嵌入: 如果有可用的预训练嵌入(如 Word2Vec、GloVe 或 FastText),可以考虑使用它们的维度作为参考。

一点点思考:在Embedding方法中,embedding_dim一般是要比num_embeddings小(很多)的,这会导致矩阵的秩不满,最终会导致Embedding方法中的单词可以通过线性变换变成另一个单词。比如把abandon的词向量×2得到get的词向量,而one-hot不会有这个问题,这是Embedding小小的局限性。

4. 结论

torch.nn.Embedding模块在PyTorch中为NLP任务提供了强大的工具,允许模型从词汇索引中学习有意义的向量表示。通过初始化和调用这个模块,我们可以轻松地将文本数据转换为适合深度学习模型的格式,从而挖掘文本数据中的丰富语义信息。

相关文章:

【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现

文章目录 0. 前言1. 基础介绍1.1 基本参数1.2 可选参数1.3 属性1.4 PyTorch源码注释 2. 实例演示3. embedding_dim的合理设定4. 结论 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我自己学习的理解&#xff0c;虽然参考了他人的宝贵见解及成果&#xff0c;但…...

Jedis 操作 Redis 数据结构全攻略

Jedis 操作 Redis 数据结构全攻略 一 . 认识 RESP二 . 前置操作2.1 创建项目2.2 关于开放 Redis 端口的问题2.2.1 端口转发?2.2.2 端口配置 2.3 连接到 Redis 服务器 三 . 通用命令3.1 set 和 get3.2 exists 和 del3.3 keys3.4 expire、ttl、type 三 . string 相关命令3.1 mse…...

ctf.show靶场ssrf攻略

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 web351 解析:post传入url参数他就会访问。 解法: hackbar传入url参数写入https://127.0.0.1/flag.php web352 解析:post传入url参数&#xff0c;不能是127.0.0.1和localhost 解法:缩写127.1传入 web353 解析…...

在 PyTorch 中,如何使用 `pack_padded_sequence` 来提高模型训练的效率?

在PyTorch中&#xff0c;pack_padded_sequence 是一个非常有用的函数&#xff0c;它可以用来提高模型训练的效率&#xff0c;特别是在处理变长序列数据时。这个函数的主要作用是将填充后的序列数据打包&#xff0c;以便循环神经网络&#xff08;RNN&#xff09;可以更高效地处理…...

Gossip协议

主要用在Redis Cluster 节点间通信 Gossip协议&#xff0c;也称为流行病协议&#xff08;Epidemic Protocol&#xff09;&#xff0c;是一种在分布式系统中用于信息传播和故障探测的算法。 一、工作原理 随机选择传播对象 每个节点会定期随机选择一些其他节点作为传播对象。这…...

数据结构————双向链表

内存泄漏&#xff1a; 内存泄漏&#xff08;Memory Leak&#xff09;是指程序中已动态分配的内存由于某种原因程序未释放或无法释放&#xff0c;导致系统内存的浪费&#xff0c;严重时会导致程序运行缓慢甚至崩溃。这种情况在长时间运行的程序或大型系统中尤为常见&#xff0c…...

55 - I. 二叉树的深度

comments: true difficulty: 简单 edit_url: https://github.com/doocs/leetcode/edit/main/lcof/%E9%9D%A2%E8%AF%95%E9%A2%9855%20-%20I.%20%E4%BA%8C%E5%8F%89%E6%A0%91%E7%9A%84%E6%B7%B1%E5%BA%A6/README.md 面试题 55 - I. 二叉树的深度 题目描述 输入一棵二叉树的根节点…...

Redis——初识Redis

初识Redis Redis认识Redis 分布式系统单机架构为什么要引入分布式理解负载均衡数据库的读写分离引入主从数据库 引入缓存数据库分库分表业务拆分——微服务常见概念了解 Redis背景介绍特性应用场景Redis不能做的事情Redis客户端redis客户端的多种形态 Redis 认识Redis 存储数…...

Xshell or Xftp提示“要继续使用此程序,您必须应用最新的更新或使用新版本”

Xshell提示“要继续使用此程序,您必须应用最新的更新或使用新版本”&#xff0c;笔者版本是xshell 6 方法一&#xff1a;更改系统时间 对于Windows 10用户&#xff0c;首先找到系统日期&#xff0c;右键点击并选择“调整时间/日期”。将日期设定为上一年。完成调整后&#x…...

table用position: sticky固定多层表头,滑动滚动条border边框透明解决方法

问题&#xff1a;我发现&#xff0c;这个上下滑动有内容经过就会出现如图的情况。 解决的方法&#xff1a;用outline&#xff08;轮廓&#xff09;替代border,以达到我们想要的样式。 outline主要是在元素边框的外围设置轮廓&#xff0c;outline不占据空间&#xff0c;绘制于…...

基于飞桨paddle2.6.1+cuda11.7+paddleRS开发版的目标提取-道路数据集训练和预测代码

基于飞桨paddle2.6.1cuda11.7paddleRS开发版的目标提取-道路数据集训练和预测代码 预测结果&#xff1a; 预测影像&#xff1a; &#xff08;一&#xff09;准备道路数据集 下载数据集地址&#xff1a; https://aistudio.baidu.com/datasetdetail/56961 mass_road.zip …...

数学建模笔记—— 整数规划和0-1规划

数学建模笔记—— 整数规划和0-1规划 整数规划和0-1规划1. 模型原理1.1 基本概念1.2 线性整数规划求解1.3 线性0-1规划的求解 2. 典型例题2.1 背包问题2.2 指派问题 3. matlab代码实现3.1 背包问题3.2 指派问题 整数规划和0-1规划 1. 模型原理 1.1 基本概念 在规划问题中&am…...

[001-03-007].第26节:分布式锁迭代3->优化基于setnx命令实现的分布式锁-防锁的误删

我的博客大纲 我的后端学习大纲 1、问题分析&#xff1a; 1.1.问题&#xff1a; 1.锁的超时释放&#xff0c;可能会释放其他服务器的锁 1.2.场景&#xff1a; 1.如果业务逻辑的执行时间是7s。执行流程如下 1.index1业务逻辑没执行完&#xff0c;3秒后锁被自动释放。2.index…...

【Unity踩坑】为什么有Rigidbody的物体运行时位置会变化

先上图&#xff0c;不知你有没有注意过这个现象呢&#xff1f; 一个物体加上了Rigidbody组件&#xff0c;当勾选上Use Gravity时&#xff0c;运行后&#xff0c;这个物体的位置的值会有变化。这是为什么呢&#xff1f; 刚体由物理系统处理&#xff0c;因此它会对重力、碰撞等做…...

NGINX开启HTTP3,给web应用提个速

环境说明 linuxdockernginx版本:1.27 HTTP3/QUIC介绍 HTTP3是由IETF于2022年发布的一个标准&#xff0c;文档地址为&#xff1a;https://datatracker.ietf.org/doc/html/rfc9114 如rfc9114所述&#xff0c;http3主要基于QUIC协议实现&#xff0c;在具备高性能的同时又兼备了…...

秋招季!别浮躁!

好久没写了&#xff0c;今天兴致来了&#xff0c;众所周知我一旦想说话&#xff0c;就来这里疯狂写。 最近&#xff0c;我去了一家国企的研究院&#xff0c;听着是不是贼高大上&#xff0c;呵——这玩意儿把我分配到三级机构&#xff0c;我一个学计算机的&#xff0c;它不把我…...

Java的时间复杂度和空间复杂度和常见排序

目录 一丶时间复杂度 二丶空间复杂度 三丶Java常见排序 1. 冒泡排序&#xff08;Bubble Sort&#xff09; 2.插入排序&#xff08;Insertion Sort&#xff09; 3.希尔排序&#xff08;Shell Sort&#xff09; 4.选择排序&#xff08;Selection Sort&#xff09; 5.堆排序&am…...

Qt 学习第十天:标准对话框 页面布局

系统标准对话框 错误对话框 //错误对话框connect(createPro, &QAction::triggered, this, []{//参数1 父亲 参数2 标题 参数3 对话框内显示文本内容 。。。QMessageBox::critical(this, "报错!", "没加头文件!");}); 【运行结果】 信息对话框 co…...

体育数据API纳米足球数据API:足球数据接口文档API示例⑩

纳米体育数据的数据接口通过JSON拉流方式获取200多个国家的体育赛事实时数据或历史数据的编程接口&#xff0c;无请求次数限制&#xff0c;可按需购买&#xff0c;接口稳定高效&#xff1b; 覆盖项目包括足球、篮球、网球、电子竞技、奥运等专题、数据内容。纳米数据API2.0版本…...

[数据集][目标检测]高铁受电弓检测数据集VOC+YOLO格式1245张2类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;1245 标注数量(xml文件个数)&#xff1a;1245 标注数量(txt文件个数)&#xff1a;1245 标注…...

数据库智能运维:利用PyTorch LSTM预测数据库性能瓶颈

数据库智能运维&#xff1a;利用PyTorch LSTM预测数据库性能瓶颈 1. 引言&#xff1a;当数据库遇上AI预测 凌晨三点&#xff0c;运维工程师小李被刺耳的报警声惊醒——核心数据库又崩溃了。这已经是本月第三次因为性能瓶颈导致的业务中断&#xff0c;每次损失都超过百万。传统…...

GD32F407定时器实战:1ms中断精准控制LED闪烁(附源码与调试技巧)

GD32F407定时器实战&#xff1a;1ms中断精准控制LED闪烁&#xff08;附源码与调试技巧&#xff09; 1. 嵌入式定时器的核心价值与应用场景 在嵌入式系统开发中&#xff0c;定时器如同系统的心跳&#xff0c;为各类周期性任务提供精准的时间基准。以智能家居中的温控系统为例&…...

DNS负载均衡的5个认知误区:为什么你的轮询总不生效?(附排查指南)

DNS负载均衡的5个认知误区&#xff1a;为什么你的轮询总不生效&#xff1f;&#xff08;附排查指南&#xff09; 当我们在讨论DNS负载均衡时&#xff0c;常常会遇到一些根深蒂固的误解。这些误解不仅会影响系统设计决策&#xff0c;还可能导致运维人员在排查问题时走弯路。本文…...

WubiUEFI终极指南:如何在Windows中零风险安装Ubuntu系统

WubiUEFI终极指南&#xff1a;如何在Windows中零风险安装Ubuntu系统 【免费下载链接】wubiuefi fork of Wubi (https://launchpad.net/wubi) for UEFI support and for support of recent Ubuntu releases 项目地址: https://gitcode.com/gh_mirrors/wu/wubiuefi 你是否…...

[Windows 驱动] 深入解析进程名获取的多种内核方法

1. Windows驱动开发中的进程名获取基础 在Windows内核驱动开发中&#xff0c;获取进程名是最基础但至关重要的操作之一。想象一下&#xff0c;你正在开发一个安全监控驱动&#xff0c;需要实时检查哪些进程正在运行&#xff1b;或者你在开发一个性能优化工具&#xff0c;需要针…...

vLLM-v0.17.1部署实战教程:3步启用OpenAI兼容API服务

vLLM-v0.17.1部署实战教程&#xff1a;3步启用OpenAI兼容API服务 1. vLLM框架简介 vLLM是一个专为大型语言模型(LLM)设计的高性能推理和服务库&#xff0c;以其出色的速度和易用性著称。这个项目最初由加州大学伯克利分校的天空计算实验室开发&#xff0c;现在已经发展成为一…...

5个核心功能让网盘用户彻底解决下载速度慢的问题

5个核心功能让网盘用户彻底解决下载速度慢的问题 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 &#xff0c;支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼云盘 / 迅雷云盘 …...

水墨江南模型效果对比:不同参数下的笔触与渲染风格

水墨江南模型效果对比&#xff1a;不同参数下的笔触与渲染风格 最近在尝试用AI生成水墨画&#xff0c;发现一个挺有意思的现象&#xff1a;同一个“水墨江南”模型&#xff0c;用不同的参数设置&#xff0c;画出来的效果天差地别。有时候是寥寥几笔的写意小品&#xff0c;有时…...

别再纠结了!.NET后台任务调度,Hangfire和Quartz.NET到底怎么选?

Hangfire与Quartz.NET深度抉择指南&#xff1a;从业务场景到技术实现的精准匹配 在.NET生态系统中&#xff0c;后台任务调度是几乎所有企业级应用都无法绕开的核心需求。无论是电商平台的订单状态更新、金融系统的日终批处理&#xff0c;还是内容管理系统的定时数据同步&#x…...

Ubuntu:无网络环境下Docker离线部署全攻略

1. 离线部署Docker的核心挑战与解决方案 在完全隔离网络的环境中部署Docker&#xff0c;就像要在荒岛上搭建一个现代化厨房——所有食材和工具都得提前准备好。我经历过三次企业级离线部署&#xff0c;最深刻的一次是在某金融机构数据中心&#xff0c;他们的服务器甚至不允许插…...