当前位置: 首页 > article >正文

PyTorch嵌入层(nn.Embedding)

在 PyTorch 中,nn.Embedding 层(即 model.user_embedding)除了 .weight 这个核心属性外,还有其他属性和方法。以下是完整的解析:


1. 主要属性

(1) weight(核心参数)
  • 作用:存储所有嵌入向量的可训练权重矩阵。
  • 形状(num_embeddings, embedding_dim)
  • 示例
    print(model.user_embedding.weight.shape)  # 输出:torch.Size([3, 4])
    
(2) num_embeddings
  • 作用:返回嵌入向量的总数(即用户/物品的数量)。
  • 示例
    print(model.user_embedding.num_embeddings)  # 输出:3
    
(3) embedding_dim
  • 作用:返回每个嵌入向量的维度。
  • 示例
    print(model.user_embedding.embedding_dim)  # 输出:4
    
(4) padding_idx(可选)
  • 作用:如果设置了 padding_idx,则对应的嵌入向量会被强制设为 0 且不参与训练。
  • 示例
    # 初始化时设置 padding_idx=0
    self.user_embedding = nn.Embedding(3, 4, padding_idx=0)
    print(model.user_embedding.padding_idx)  # 输出:0
    print(model.user_embedding.weight[0])    # 输出:tensor([0., 0., 0., 0.], grad_fn=<SelectBackward>)
    

2. 主要方法

(1) forward(input)
  • 作用:根据输入的 ID 返回对应的嵌入向量。
  • 示例
    input_ids = torch.tensor([0, 1, 2])  # 查询用户 0、1、2 的向量
    embeddings = model.user_embedding(input_ids)  # 返回 shape (3, 4)
    
(2) reset_parameters()
  • 作用:重新随机初始化权重(通常在训练前调用)。
  • 内部逻辑:默认使用均匀分布 U ( − k , k ) U(-\sqrt{k}, \sqrt{k}) U(k ,k ),其中 k = 1 embedding_dim k = \frac{1}{\text{embedding\_dim}} k=embedding_dim1
  • 示例
    model.user_embedding.reset_parameters()
    
(3) extra_repr()
  • 作用:返回层的额外信息(用于 print 时显示)。
  • 示例
    print(model.user_embedding.extra_repr())  
    # 输出:'num_embeddings=3, embedding_dim=4'
    

3. 其他底层属性(一般无需直接操作)

  • _parameters:存储所有可训练参数(包括 weight)。
  • _buffers:存储非可训练参数(如 BatchNorm 的 running_mean)。
  • training:布尔值,表示是否处于训练模式。

4. 完整属性/方法列表

可以通过 dir() 查看所有属性和方法:

print(dir(model.user_embedding))

输出示例:

['__class__', '__delattr__', '__dir__', ..., 'weight', 'num_embeddings', 'embedding_dim', 'padding_idx', 'forward', 'reset_parameters']

5. 关键总结

属性/方法用途示例值/调用方式
.weight核心权重矩阵shape=(3, 4)
.num_embeddings嵌入向量的总数(用户数)3
.embedding_dim每个向量的维度4
.padding_idx指定填充索引(可选)None0
.forward(input)查询嵌入向量model.user_embedding([0, 1])
.reset_parameters()重新初始化权重model.user_embedding.reset_parameters()

6. 常见问题

Q:如何修改嵌入向量?
  • 直接操作 .weight
    # 将用户 0 的向量置零
    model.user_embedding.weight.data[0] = torch.zeros(4)
    
Q:如何冻结嵌入层?
  • 禁用梯度:
    model.user_embedding.weight.requires_grad = False
    
Q:padding_idx 和普通索引有什么区别?
  • padding_idx 对应的向量会固定为 0,且不参与梯度更新。

掌握这些属性和方法后,你可以更灵活地操作嵌入层! 🚀

相关文章:

PyTorch嵌入层(nn.Embedding)

在 PyTorch 中&#xff0c;nn.Embedding 层&#xff08;即 model.user_embedding&#xff09;除了 .weight 这个核心属性外&#xff0c;还有其他属性和方法。以下是完整的解析&#xff1a; 1. 主要属性 (1) weight&#xff08;核心参数&#xff09; 作用&#xff1a;存储所有…...

AIGC7——AIGC驱动的视听内容定制化革命:从Sora到商业化落地

引言&#xff1a;个性化视听时代的到来 2024年&#xff0c;OpenAI发布视频生成模型Sora&#xff0c;可生成60秒高清视频&#xff1b;中国团队推出的Vidu模型实现16秒镜头连贯生成。这些突破标志着AIGC正式进入高质量视听内容定制化阶段。据Gartner预测&#xff0c;到2027年&am…...

接上文,SpringBoot的线程池配置以及JVM监控

接上篇文章&#xff0c; 拿SpringBoot举个例 1.1 默认线程池的隐患 Spring Boot的Async默认使用SimpleAsyncTaskExecutor&#xff08;无复用线程&#xff09;&#xff0c;频繁创建/销毁线程易引发性能问题。 1.2 自定义线程池配置 Configuration EnableAsync public class A…...

《AI大模型应知应会100篇》加餐篇:LlamaIndex 与 LangChain 的无缝集成

加餐篇&#xff1a;LlamaIndex 与 LangChain 的无缝集成 问题背景&#xff1a;在实际应用中&#xff0c;开发者常常需要结合多个框架的优势。例如&#xff0c;使用 LangChain 管理复杂的业务逻辑链&#xff0c;同时利用 LlamaIndex 的高效索引和检索能力构建知识库。本文在基于…...

部署大模型实战:如何巧妙权衡效果、成本与延迟?

目录 部署大模型实战&#xff1a;如何巧妙权衡效果、成本与延迟&#xff1f; 一、为什么要进行权衡&#xff1f; 二、权衡的三个关键维度 三、如何进行有效权衡&#xff1f;&#xff08;实操策略&#xff09; &#xff08;一&#xff09;明确需求场景与优先级 &#xff08…...

元素三大等待

硬性等待&#xff08;强制等待&#xff09; 线程休眠&#xff0c;强制等待 Thread.sleep(long millis);这是最简单的等待方式&#xff0c;使用time.sleep()方法来实现。在代码中强制等待一定的时间&#xff0c;不论元素是否已经加载完成&#xff0c;都会等待指定的时间后才继…...

【DY】信息化集成化信号采集与处理系统;生物信号采集处理系统一体机

MD3000-C信息化一体机生物信号采集处理系统 实验平台技术指标 01、整机外形尺寸&#xff1a;1680mm(L)*750mm(w)*2260mm(H)&#xff1b; 02、实验台操作面积&#xff1a;750(w)*1340(L&#xff09;&#xff08;长*宽&#xff09;&#xff1b; 03、实验台面离地高度&#xf…...

康谋分享 | 仿真驱动、数据自造:巧用合成数据重构智能座舱

随着汽车向智能化、场景化加速演进&#xff0c;智能座舱已成为人车交互的核心承载。从驾驶员注意力监测到儿童遗留检测&#xff0c;从乘员识别到安全带状态判断&#xff0c;座舱内的每一次行为都蕴含着巨大的安全与体验价值。 然而&#xff0c;这些感知系统要在多样驾驶行为、…...

YOLO学习笔记 | 基于YOLOv5的车辆行人重识别算法研究(附matlab代码)

基于YOLOv5的车辆行人重识别算法研究 🥥🥥🥥🥥🥥🥥🥥🥥🥥🥥🥥🥥🥥🥥 摘要 本文提出了一种基于YOLOv5的车辆行人重识别(ReID)算法,结合目标检测与特征匹配技术,实现高效的多目标跟踪与识别。通过引入注意力机制、优化损失函数和轻量化网络结构…...

Vue 数据传递流程图指南

今天&#xff0c;我们探讨一下 Vue 中的组件传值问题。这不仅是我们在日常开发中经常遇到的核心问题&#xff0c;也是面试过程中经常被问到的重要知识点。无论你是初学者还是有一定经验的开发者&#xff0c;掌握这些传值方式都将帮助你更高效地构建和维护 Vue 应用 目录 1. 父…...

Node.js 与 MySQL:深入理解与高效实践

Node.js 与 MySQL:深入理解与高效实践 引言 随着互联网技术的飞速发展,Node.js 作为一种高性能的服务端JavaScript运行环境,因其轻量级、单线程和事件驱动等特点,受到了广大开发者的青睐。MySQL 作为一款开源的关系型数据库管理系统,以其稳定性和可靠性著称。本文将深入…...

鸿蒙NEXT开发缓存工具类(ArkTs)

import { ObjectUtil } from ./ObjectUtil;/*** 缓存工具类** 该类提供了一组静态方法&#xff0c;用于操作缓存数据。* 主要功能包括&#xff1a;获取缓存数据、存储缓存数据、删除缓存数据、检查键是否存在、判断缓存是否为空以及清空缓存。** author CSDN-鸿蒙布道师* since…...

【C语言】strstr查找字符串函数

一、函数介绍 strstr 是 C 语言标准库 <string.h> 中的字符串查找函数&#xff0c;用于在主字符串中查找子字符串的首次出现位置。若找到子串&#xff0c;返回其首次出现的地址&#xff1b;否则返回 NULL。它是处理字符串匹配问题的核心工具之一。 二、函数原型 char …...

使用pkexec 和其策略文件安全提权执行外部程序

‌一、pkexec 基本机制‌ pkexec 是 Linux 桌面环境下基于 ‌PolicyKit‌ 的安全提权工具&#xff0c;可通过交互式图形界面获取用户授权后&#xff0c;以 root 权限执行指定程序。其核心特点包括&#xff1a; ‌图形化密码输入‌&#xff1a;调用时自动弹出系统认证对话框&a…...

NVIDIA显卡

NVIDIA显卡作为全球GPU技术的标杆&#xff0c;其产品线覆盖消费级、专业级、数据中心、移动计算等多个领域&#xff0c;技术迭代贯穿架构创新、AI加速、光线追踪等核心方向。以下从技术演进、产品矩阵、核心技术、生态布局四个维度展开深度解析&#xff1a; 一、技术演进&…...

机器学习、深度学习和神经网络

机器学习、深度学习和神经网络 术语及相关概念 在深入了解人工智能&#xff08;AI&#xff09;的工作原理以及它的各种应用之前&#xff0c;让我们先区分一下与AI密切相关的一些术语和概念&#xff1a;人工智能、机器学习、深度学习和神经网络。这些术语有时会被交替使用&#…...

数字孪生在智慧城市中的前端呈现与 UI 设计思路

一、数字孪生技术在智慧城市中的应用与前端呈现 数字孪生技术通过创建城市的虚拟副本&#xff0c;实现了对城市运行状态的实时监控、分析与预测。在智慧城市中&#xff0c;数字孪生技术的应用包括交通流量监测、环境质量分析、基础设施管理等。其前端呈现主要依赖于Web3D技术、…...

黑莓手机有望回归:搭载 Android 15、支持 AI

据 3 月 31 日快科技消息&#xff0c;有博主称一家英国的初创公司正悄悄努力复活 BlackBerry Classic 及 OnwardMobility 未完成的产品。 从爆料的信息看&#xff0c;黑莓新手机将具备 5G、AMOLED 显示屏、12GB RAM 和 256GB 或 512GB 存储空间等高端配置&#xff0c;同时运行 …...

Android OpenGLES 360全景图片渲染(球体内部)

概述 360度全景图是一种虚拟现实技术&#xff0c;它通过对现实场景进行多角度拍摄后&#xff0c;利用计算机软件将这些照片拼接成一个完整的全景图像。这种技术能够让观看者在虚拟环境中以交互的方式查看整个周围环境&#xff0c;就好像他们真的站在那个位置一样。在Android设备…...

LETTERS(DFS)

【题目描述】 给出一个rowcolrowcol的大写字母矩阵&#xff0c;一开始的位置为左上角&#xff0c;你可以向上下左右四个方向移动&#xff0c;并且不能移向曾经经过的字母。问最多可以经过几个字母。 【输入】 第一行&#xff0c;输入字母矩阵行数RR和列数SS&#xff0c;1≤R,S≤…...

嵌入式海思Hi3861连接华为物联网平台操作方法

1.1 实验目的 快速演示 1、认识轻量级HarmonyOS——LiteOS-M 2、初步掌握华为云物联网平台的使用 3、快速驱动海思Hi3861 WIFI芯片,连接互联网并登录物联网平台...

CMDB平台(进阶篇):3D机房大屏全景解析

在数字化转型的浪潮中&#xff0c;数据中心作为企业信息架构的核心&#xff0c;其高效、智能的管理成为企业竞争力的关键因素之一&#xff0c;其运维管理方式也正经历着革命性的变革。传统基于二维平面图表的机房监控方式已难以满足现代企业对运维可视化、智能化的需求。乐维CM…...

NVM 多版本Node.js 管理全指南(Windows系统)

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家、全栈领域优质创作者、高级开发工程师、高级信息系统项目管理师、系统架构师&#xff0c;数学与应用数学专业&#xff0c;10年以上多种混合语言开发经验&#xff0c;从事DICOM医学影像开发领域多年&#xff0c;熟悉DICOM协议及…...

C,C++语言缓冲区溢出的产生和预防

缓冲区溢出的定义 缓冲区是内存中用于存储数据的一块连续区域&#xff0c;在 C 和 C 里&#xff0c;常使用数组、指针等方式来操作缓冲区。而缓冲区溢出指的是当程序向缓冲区写入的数据量超出了该缓冲区本身能够容纳的最大数据量时&#xff0c;额外的数据就会覆盖相邻的内存区…...

《Linux内存管理:实验驱动的深度探索》【附录】【实验环境搭建 2】【vscode搭建调试内核环境】

1. 如何调试我们的内核 1. GDB调试 安装gdb sudo apt-get install gdb-multiarchgdb-multiarch是多架构版本&#xff0c;可以通过set architecture aarch64指定架构 QEMU参数修改添加-s -S #!/usr/bin/shqemu-7.2.0-rc1/build/aarch64-softmmu/qemu-system-aarch64 \-nogr…...

Flutter项目之登录注册功能实现

目录&#xff1a; 1、页面效果2、登录两种状态界面3、中间按钮部分4、广告区域5、最新资讯6、登录注册页联调6.1、网络请求工具类6.2、注册页联调6.3、登录问题分析6.4、本地缓存6.5、共享token6.6、登录页联调6.7、退出登录 1、页面效果 import package:flutter/material.dart…...

mybatis 自带的几个插入接口的区别

研究这个的原由是应为需求对一张表新增了一个有默认值的字段&#xff0c;然后调用插入接口的时候发现这个字段没有传默认值但是还是以null值入库了&#xff0c;数据库中设置的默认值没有生效。 通过排查之后发现是使用了insertUseGeneratedKeys 方法进行插入&#xff0c;此方法…...

ctfshow VIP题目限免 源码泄露

根据题目提示是源代码泄露&#xff0c;右键查看页面源代码发现了 flag...

移动神器RAX3000M路由器变身家庭云之七:增加打印服务,电脑手机无线打印

系列文章目录&#xff1a; 移动神器RAX3000M路由器变身家庭云之一&#xff1a;开通SSH&#xff0c;安装新软件包 移动神器RAX3000M路由器变身家庭云之二&#xff1a;安装vsftpd 移动神器RAX3000M路由器变身家庭云之三&#xff1a;外网访问家庭云 移动神器RAX3000M路由器不刷固…...

《函数基础与内存机制深度剖析:从 return 语句到各类经典编程题详解》

一、问答题 &#xff08;1&#xff09;使用函数的好处是什么&#xff1f; 1.提升代码的复用性 2.提升代码的可维护性 3.增强代码的可读性 4.提高代码的灵活性 5.方便进行单元测试 &#xff08;2&#xff09;如何定义一个函数&#xff1f;如何调用一个函数&#xff1f; 在Pytho…...