当前位置: 首页 > article >正文

深入理解 PyTorch 的 nn.Embedding:词向量映射及变量 weight 的更新机制

文章目录

  • 前言
  • 一、直接使用 `nn.Embedding` 获得变量
    • 1、典型场景
    • 2、示例代码:
    • 3、特点
  • 二、使用 `iou_token = nn.Embedding(1, transformer_dim)` 并访问 `iou_token.weight`
    • 1、典型场景
    • 2、示例代码:
    • 3、特点
  • 三、第一种方法在模型更新中会更新其值吗?
    • 1、默认行为
    • 2、示例代码:
    • 3、控制权重更新的方法
      • 方法 1:设置 `requires_grad = False`
      • 方法 2:加载预训练权重并冻结
      • 方法 3:在优化器中排除某些参数
  • 四、总结

前言

在深度学习领域,特别是在自然语言处理(NLP)中,nn.Embedding 是一个非常重要的模块,用于将离散的词汇(如单词或标记)映射为连续的向量表示。本文详细讲解了 nn.Embedding 的使用方法、其权重是否会在模型更新过程中被更新的问题,以及如何控制这些权重是否参与训练。

一、直接使用 nn.Embedding 获得变量

1、典型场景

这种用法通常用于处理离散的词汇表(如单词、token等),将这些离散的 token 映射为连续的向量表示。例如,在 NLP 任务中,输入是一批句子或标记序列,每个标记都有一个唯一的索引(ID)。通过 nn.Embedding,可以将这些索引映射为对应的词向量。

2、示例代码:

import torch
import torch.nn as nn# 假设词汇表大小为 10,每个词嵌入维度为 5
vocab_size = 10
embedding_dim = 5# 创建 Embedding 层
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 输入是一个批次的 token 索引
input_tokens = torch.tensor([2, 3, 5])  # 示例输入索引
embedded_vectors = embedding_layer(input_tokens)  # 获取词向量print(embedded_vectors)

3、特点

  • nn.Embedding 是一个可训练的参数层。
  • 输入是离散的 token 索引,输出是对应的连续向量表示。
  • 这种用法适用于需要批量处理 token 的场景,比如文本分类、机器翻译等任务。

二、使用 iou_token = nn.Embedding(1, transformer_dim) 并访问 iou_token.weight

1、典型场景

这种用法通常用于定义一些特殊的、全局共享的向量,而不是处理整个词汇表中的 token。常见的例子包括在目标检测任务中,定义一个可学习的 “特殊 token” 来表示某些特定的对象或区域(如 IoU 预测中的 token)。

2、示例代码:

import torch
import torch.nn as nn# 定义一个特殊的 token,维度为 transformer_dim
transformer_dim = 64
iou_token = nn.Embedding(num_embeddings=1, embedding_dim=transformer_dim)# 访问这个特殊 token 的权重
special_token_vector = iou_token.weight  # 形状为 [1, transformer_dim]print("Special Token Vector:", special_token_vector)

3、特点

  • iou_token 是一个 nn.Embedding 实例,但它的词汇表大小为 1(即只有一个 token)。
  • iou_token.weight 是这个特殊 token 的实际值,形状为 [1, embedding_dim]
  • 这种用法适用于需要定义一个可学习的、全局共享的向量的场景,而不是处理多个离散 token。

三、第一种方法在模型更新中会更新其值吗?

1、默认行为

默认情况下,nn.Embedding 的权重(即词向量)是模型的可训练参数,默认情况下会被优化器更新。

2、示例代码:

import torch
import torch.nn as nn
import torch.optim as optim# 创建一个 Embedding 层
vocab_size = 10
embedding_dim = 5
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 定义输入和目标
input_tokens = torch.tensor([2, 3, 5])  # 输入 token 索引
target = torch.randn(3, embedding_dim)  # 假设的目标向量# 定义优化器
optimizer = optim.SGD(embedding_layer.parameters(), lr=0.01)# 前向传播
embedded_vectors = embedding_layer(input_tokens)# 计算损失
loss_fn = nn.MSELoss()
loss = loss_fn(embedded_vectors, target)# 反向传播和更新
optimizer.zero_grad()
loss.backward()
optimizer.step()# 查看更新后的权重
print("Updated Embedding Weights:", embedding_layer.weight)

3、控制权重更新的方法

有时我们希望固定某些权重,不让它们参与训练。这可以通过以下方式实现:

方法 1:设置 requires_grad = False

embedding_layer.weight.requires_grad 设置为 False,可以阻止这些权重被更新。

embedding_layer.weight.requires_grad = False

方法 2:加载预训练权重并冻结

如果我们使用预训练的词向量(如 GloVe 或 Word2Vec),可以选择加载这些权重并冻结它们。

# 加载预训练权重
pretrained_weights = torch.load('glove_embeddings.pth')# 创建 Embedding 层并加载权重
embedding_layer = nn.Embedding.from_pretrained(pretrained_weights, freeze=True)

方法 3:在优化器中排除某些参数

我们可以在定义优化器时,排除某些参数,从而避免更新它们。

# 排除 embedding_layer 的权重
optimizer = optim.SGD([param for param in model.parameters() if param is not embedding_layer.weight],lr=0.01
)

四、总结

  • 默认情况下nn.Embedding 的权重是可训练的,会在每次反向传播后被更新。
  • 如果需要固定权重,可以通过设置 requires_grad = False、使用 from_pretrained 并设置 freeze=True 或在优化器中排除这些参数来实现。
  • 选择是否更新权重取决于任务需求:如果你希望模型从头学习词向量(如随机初始化的场景),让权重可训练;如果你使用预训练的词向量并希望保持它们不变,则固定权重。

相关文章:

深入理解 PyTorch 的 nn.Embedding:词向量映射及变量 weight 的更新机制

文章目录 前言一、直接使用 nn.Embedding 获得变量1、典型场景2、示例代码:3、特点 二、使用 iou_token nn.Embedding(1, transformer_dim) 并访问 iou_token.weight1、典型场景2、示例代码:3、特点 三、第一种方法在模型更新中会更新其值吗&#xff1f…...

10min速通Linux文件传输

实验环境 在Linux中传输文件需要借助网络以及sshd,我们可通过systemctl status sshd来查看sshd状态 若服务未开启我们可通过systemctl enable --now sshd来开启sshd服务 将/etc/ssh/sshd_config中的PermitRootLogin 状态修改为yes 传输文件 scp scp (Sec…...

dify windos,linux下载安装部署,提供百度云盘地址

dify1.0.1 windos安装包百度云盘地址 通过网盘分享的文件:dify-1.0.1.zip 链接: 百度网盘 请输入提取码 提取码: 1234 dify安装包 linux安装包百度云盘地址 通过网盘分享的文件:dify-1.0.1.tar.gz 链接: 百度网盘 请输入提取码 提取码: 1234 1.安装…...

使用 TFIDF+分类器 范式进行企业级文本分类(二)

1.开场白 上一期讲了 TF-IDF 的底层原理,简单讲了一下它可以将文本转为向量形式,并搭配相应分类器做文本分类,且即便如今的企业实践中也十分常见。详情请见我的上一篇文章 从One-Hot到TF-IDF(点我跳转) 光说不练假把…...

《车辆人机工程-汽车驾驶操纵实验》

汽车操纵装置有哪几种,各有什么特点 汽车操纵装置是驾驶员直接控制车辆行驶状态的关键部件,主要包括以下几种,其特点如下: 一、方向盘(转向操纵装置) 作用:控制车辆行驶方向,通过转…...

[ABC400F] Happy Birthday! 3 题解

考虑正难则反。问题转化为: 一个环上有 n n n 个物品,颜色分别为 c o l i col_i coli​,每次操作选择两个数 i , j i, j i,j 使得 ∀ k ∈ [ i , j ] , c o l k c o l i ∨ c o l k 0 \forall k \in [i, j], col_k col_i \lor col_k …...

python高级编程一(生成器与高级编程)

@TOC 生成器 生成器使用 通过列表⽣成式,我们可以直接创建⼀个列表。但是,受到内存限制,列表容量肯定是有限的。⽽且,创建⼀个包含100万个元素的列表,不仅占⽤很⼤的存储空间,如果我们仅仅需要访问前⾯⼏个元素,那后⾯绝⼤多数元素占 ⽤的空间都⽩⽩浪费了。所以,如果…...

Go 字符串四种拼接方式的性能对比

简介 使用完整的基准测试代码文件,可以直接运行来比较四种字符串拼接方法的性能。 for 索引 的方式 for range 的方式 strings.Join 的方式 strings.Builder 的方式 写一个基准测试文件 echo_bench_test.go package mainimport ("os""stri…...

windows安装fastbev环境时,安装mmdetection3d出现的问题总结

出现的问题如下: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\include\crt/host_config.h(160): fatal error C1189: #error: -- unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2019 (inclusive) are supporte…...

单片机Day05---动态数码管显示01234567

一、原理图 数组索引段码值二进制显示内容00x3f0011 1111010x060000 0110120x5b0101 1011230x4f0100 1111340x660110 0110450x6d0110 1101560x7d0111 1101670x070000 0111780x7f0111 1111890x6f0110 11119100x770111 0111A110x7c0111 1100B120x390011 1001C130x5e0101 1110D140…...

【Python3教程】Python3基础篇之数据结构

博主介绍:✌全网粉丝22W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…...

muduo库源码分析: One Loop Per Thread

One Loop Per Thread的含义就是,一个EventLoop和一个线程唯一绑定,和这个EventLoop有关的,被这个EventLoop管辖的一切操作都必须在这个EventLoop绑定线程中执行 1.在MainEventLoop中,负责新连接建立的操作都要在MainEventLoop线程…...

使用Python解决Logistic方程

引言 在数学和计算机科学中,Logistic 方程是描述人口增长、传播过程等现象的一种常见模型。它通常用于表示一种有限资源下的增长过程,比如动物种群、疾病传播等。本文将带领大家通过 Python 实现 Logistic 方程的求解,帮助你更好地理解这一经典数学模型。 1.什么是 Logist…...

AI Agent工程师认证-学习笔记(3)——【多Agent】MetaGPT

学习链接:【多Agent】MetaGPT学习教程 源代码链接(觉得很好,star一下):GitHub - 基于MetaGPT的多智能体入门与开发教程 MetaGPT链接:GitHub - MetaGPT 前期准备 1、获取MetaGPT (1)使用pip获取MetaGPT pip install metagpt==0.6.6#或者在国内加速安装镜像 #pip in…...

MCP结合高德地图完成配置

文章目录 1.MCP到底是什么2.cursor配置2.1配置之后的效果2.2如何进行正确的配置2.3高德地图获取key2.4选择匹配的模型 1.MCP到底是什么 作为学生,我们应该如何认识MCP?最近看到了好多跟MCP相关的文章,我觉得我们不应该盲目的追求热点的技术&…...

重读《人件》Peopleware -(5)Ⅰ管理人力资源Ⅳ-质量—若时间允许

20世纪的心理学理论认为,人类的性格主要由少数几个基本本能所主导:生存、自尊、繁衍、领地等。这些本能直接嵌入大脑的“固件”中。我们可以在没有强烈情感的情况下理智地考虑这些本能(就像你现在正在做的那样),但当我…...

文献总结:AAAI2025-UniV2X-End-to-end autonomous driving through V2X cooperation

UniV2X 一、文章基本信息二、文章背景三、UniV2X框架1. 车路协同自动驾驶问题定义2. 稀疏-密集混合形态数据3. 交叉视图数据融合(智能体融合)4. 交叉视图数据融合(车道融合)5. 交叉视图数据融合(占用融合)6…...

制造一只电子喵 (qwen2.5:0.5b 微调 LoRA 使用 llama-factory)

AI (神经网络模型) 可以认为是计算机的一种新的 “编程” 方式. 为了充分利用计算机, 只学习传统的编程 (编程语言/代码) 是不够的, 我们还要掌握 AI. 本文以 qwen2.5 和 llama-factory 举栗, 介绍语言模型 (LLM) 的微调 (LoRA SFT). 为了方便上手, 此处选择使用小模型 (qwen2…...

如何查询node inode上限是多少?

在 Linux 系统中,inode 上限由文件系统的类型和格式化时的参数决定。不同文件系统(如 ext4、XFS)有不同的查询方法。以下是详细操作步骤: 1. 确认文件系统类型 首先确定目标磁盘分区的文件系统类型(如 ext4、XFS&…...

Redis核心功能实现

前言 学习是个输入的过程,在进行输入之后再进行一些输出,比如写写文章,笔记,或者做一些技术串讲,虽然需要花费不少时间,但是好处很多,首先是能通过输出给自己的输入带来一些动力,然…...

驱动学习专栏--字符设备驱动篇--1_chrdevbase

字符设备驱动简介 字符设备是 Linux 驱动中最基本的一类设备驱动,字符设备就是一个一个字节,按照字节 流进行读写操作的设备,读写数据是分先后顺序的。比如我们最常见的点灯、按键、 IIC 、 SPI , LCD 等等都是字符设备&…...

Python及C++中的列表

一、Python中的列表(List) Python的列表是动态数组,内置于语言中,功能强大且易用,非常适合算法竞赛。 1. 基本概念 定义:列表是一个有序、可变的序列,可以存储任意类型的元素(整数…...

Oracle DROP、TRUNCATE 和 DELETE 原理

在 Oracle 11g 中,DROP、TRUNCATE 和 DELETE 是三种不同的数据清理操作,它们的底层原理和适用场景有显著差异 1. DELETE 的原理 类型:DML(数据操作语言) 功能:逐行删除表中符合条件的数据,保留…...

ida 使用记录

文章目录 伪代码-汇编hexstring快捷键 伪代码-汇编 流程图界面——F5——伪代码界面——再点Tab——流程图界面——再按空格——汇编界面流程图界面——空格——汇编界面 hex view - open subviews - hex dump string view - open subviews - string快捷键: sh…...

【连载3】基础智能体的进展与挑战综述

基础智能体的进展与挑战综述 从类脑智能到具备可进化性、协作性和安全性的系统 【翻译团队】刘军(liujunbupt.edu.cn) 钱雨欣玥 冯梓哲 李正博 李冠谕 朱宇晗 张霄天 孙大壮 黄若溪 2. 认知 人类认知是一种复杂的信息处理系统,它通过多个专门的神经回路协调运行…...

MacOs java环境配置+maven环境配置踩坑实录

oracl官网下载jdk 1.8的安装包 注意可能需要注册!!! 下载链接:下载地址点击 注意晚上就不要下载了 报错400 !!! 1.点击安装嘛 2.配置环境变量 export JAVA_HOME/Library/Java/Java…...

【Git】--- 企业级开发流程

Welcome to 9ilks Code World (๑•́ ₃ •̀๑) 个人主页: 9ilk (๑•́ ₃ •̀๑) 文章专栏: Git 本篇博客我们讲解Git在企业开发中的整体流程,理解Git在实际企业开发中的高效设计。 🏠 企业级开发流程 一个软件从零开始到最…...

SAP系统客户可回收包材库存管理

问题:客户可回收包材库存管理 现象:回收瓶无库存管理,在库数量以及在客户的库存数量没有统计,管理混乱。 解决方法: 客户可回收包装材料在SAP有标准的解决方案,在集团尚未启用该业务,首先…...

蓝桥杯嵌入式历年省赛客观题

一.第十五届客观题 第十四届省赛 十三届 十二届...

JDK的卸载与安装

卸载JDK 删除java的1安装目录 卸载JAVA_HOME 删除path下关于java的路径 java -version查看 安装JDK 百度搜索JDK,找到下载地址 同意协议 下载电脑对应版本 双击安装 记住安装路径 配置环境变量 我的电脑–>右键–>属性–>高级系统设置 环境变…...