【python函数】torch.nn.Embedding函数用法图解
学习SAM模型的时候,第一次看见了nn.Embedding函数,以前接触CV比较多,很少学习词嵌入方面的,找了一些资料一开始也不是很理解,多看了两遍后,突然顿悟,特此记录。
SAM中PromptEncoder中运用nn.Embedding:
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
torch.nn.Embedding官方页面
1. torch.nn.Embedding介绍
(1)词嵌入简介
关于词嵌入,这篇文章讲的挺清楚的,相比于One-hot编码,Embedding方式更方便计算,例如在“就在江湖之上”整个词典中,要编码“江湖”两个字,One-hot编码需要 [ l e n g t h , w o r d _ c o u n t ] {[length, word\_count]} [length,word_count] 大小的张量,其中 w o r d _ c o u n t {word\_count} word_count 为词典中所有词的总数,而Embedding方式的嵌入维度 e m b e d d i n g _ d i m {embedding\_dim} embedding_dim 可远远小于 w o r d _ c o u n t {word\_count} word_count 。在运用Embedding方式编码的词典时,只需要词的索引,下图例子中: “江湖”——>[2, 3]

(2)重要参数介绍
nn.embedding就相当于一个词典嵌入表:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
常用参数:
① num_embeddings (int): 词典中词的总数
② embedding_dim (int): 词典中每个词的嵌入维度
③ padding_idx (int, optional): 填充索引,在padding_idx处的嵌入向量在训练过程中没有更新,即它是一个固定的“pad”。对于新构造的Embedding,在padding_idx处的嵌入向量将默认为全零,但可以更新为另一个值以用作填充向量。
输入: I n p u t ( ∗ ) {Input(∗)} Input(∗): IntTensor 或者 LongTensor,为任意size的张量,包含要提取的所有词索引。
输出: O u t p u t ( ∗ , H ) {Output(∗, H)} Output(∗,H): ∗ {∗} ∗ 为输入张量的size, H {H} H = embedding_dim
2. torch.nn.Embedding用法
(1)基本用法
官方例子如下:
import torch
import torch.nn as nnembedding = nn.Embedding(10, 3)
x = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])y = embedding(x)print('权重:\n', embedding.weight)
print('输出:')
print(y)
查看权重与输出,打印如下:
权重:Parameter containing:
tensor([[ 1.4212, 0.6127, -1.1126],[ 0.4294, -1.0121, -1.8348],[-0.0315, -1.2234, -0.4589],[ 0.6131, -0.4381, 0.1253],[-1.0621, -0.1466, 1.7412],[ 1.0708, -0.7888, -0.0177],[-0.5979, 0.6465, 0.6508],[-0.5608, -0.3802, -0.4206],[ 1.1516, 0.4091, 1.2477],[-0.5753, 0.1394, 2.3447]], requires_grad=True)
输出:
tensor([[[ 0.4294, -1.0121, -1.8348],[-0.0315, -1.2234, -0.4589],[-1.0621, -0.1466, 1.7412],[ 1.0708, -0.7888, -0.0177]],[[-1.0621, -0.1466, 1.7412],[ 0.6131, -0.4381, 0.1253],[-0.0315, -1.2234, -0.4589],[-0.5753, 0.1394, 2.3447]]], grad_fn=<EmbeddingBackward0>)
家人们,发现了什么,输入 x {x} x 的 s i z e {size} size 大小为 [ 2 , 4 ] {[2, 4]} [2,4] ,输出 y {y} y 的 s i z e {size} size 大小为 [ 2 , 4 , 3 ] {[2, 4, 3]} [2,4,3] ,下图清晰的展示出nn.Embedding干了个什么事儿:

nn.Embedding相当于是一本词典,本例中,词典中一共有10个词 X 0 {X_0} X0~ X 9 {X_9} X9,每个词的嵌入维度为3,输入 x {x} x 中记录词在词典中的索引,输出 y {y} y 为输入 x {x} x 经词典编码后的映射。
注意:此时存在一个问题,词索引是不能超出词典的最大容量的,即本例中,输入 x {x} x 中的数值取值范围为 [ 0 , 9 ] {[0, 9]} [0,9]。
(2)自定义词典权重
如上所示,在未定义时,nn.Embedding的自动初始化权重满足 N ( 0 , 1 ) {N(0,1)} N(0,1) 分布,此外,nn.Embedding的权重也可以通过from_pretrained来自定义:
import torch
import torch.nn as nnweight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
embedding = nn.Embedding.from_pretrained(weight)
x = torch.LongTensor([1, 0, 0])
y = embedding(x)
print(y)
输出为:
tensor([[4.0000, 5.1000, 6.3000],[1.0000, 2.3000, 3.0000],[1.0000, 2.3000, 3.0000]])
(3)padding_idx用法
padding_idx可用于指定词典中哪一个索引的词填充为0。
import torch
import torch.nn as nnembedding = nn.Embedding(10, 3, padding_idx=5)
x = torch.LongTensor([[5, 2, 0, 5]])
y = embedding(x)
print('权重:\n', embedding.weight)
print('输出:')
print(y)
输出为:
权重:Parameter containing:
tensor([[ 0.1831, -0.0200, 0.7023],[ 0.2751, -0.1189, -0.3325],[-0.5242, -0.2230, -1.1677],[-0.4078, -1.2141, 1.3185],[ 0.8973, -0.9650, 0.5420],[ 0.0000, 0.0000, 0.0000],[ 0.0597, 0.6810, -0.2595],[ 0.6543, -0.6242, 0.2337],[-0.0780, -0.9607, -0.0618],[ 0.2801, -0.6041, -1.4143]], requires_grad=True)
输出:
tensor([[[ 0.0000, 0.0000, 0.0000],[-0.5242, -0.2230, -1.1677],[ 0.1831, -0.0200, 0.7023],[ 0.0000, 0.0000, 0.0000]]], grad_fn=<EmbeddingBackward0>)
词典中,被padding_idx标定后的词嵌入向量可被重新定义:
import torch
import torch.nn as nnpadding_idx=2
embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
print('权重:\n', embedding.weight)with torch.no_grad():embedding.weight[padding_idx] = torch.tensor([1.1, 2.2, 3.3])
print('权重:\n', embedding.weight)
输出为:
权重:Parameter containing:
tensor([[ 0.7247, 0.7553, -1.8226],[-1.3304, -0.5025, 0.5237],[ 0.0000, 0.0000, 0.0000]], requires_grad=True)
权重:Parameter containing:
tensor([[ 0.7247, 0.7553, -1.8226],[-1.3304, -0.5025, 0.5237],[ 1.1000, 2.2000, 3.3000]], requires_grad=True)
相关文章:
【python函数】torch.nn.Embedding函数用法图解
学习SAM模型的时候,第一次看见了nn.Embedding函数,以前接触CV比较多,很少学习词嵌入方面的,找了一些资料一开始也不是很理解,多看了两遍后,突然顿悟,特此记录。 SAM中PromptEncoder中运用nn.Emb…...
with ldid... /opt/MonkeyDev/bin/md: line 326: ldid: command not found
吐槽傻逼xcode 根据提示 执行了这个脚本/opt/MonkeyDev/bin/md 往这里面添加你brew install 安装文件的目录即可...
[golang gui]fyne框架代码示例
1、下载GO Go语言中文网 golang安装包 - 阿里镜像站(镜像站使用方法:查找最新非rc版本的golang安装包) golang安装包 - 中科大镜像站 go二进制文件下载 - 南京大学开源镜像站 Go语言官网(Google中国) Go语言官网(Go团队) 截至目前(2023年9月17日&#x…...
2000-2018年各省能源消费和碳排放数据
2000-2018年各省能源消费和碳排放数据 1、时间:2000-2018年 2、范围:30个省市 3、指标:id、year、ENERGY、COAL、碳排放倒数*100 4、来源:能源年鉴 5、指标解释: 2018年碳排放和能源数据为插值法推算得到 碳排放…...
C# ref 学习1
ref 关键字用在四种不同的上下文中; 1.在方法签名和方法调用中,按引用将参数传递给方法。 2.在方法签名中,按引用将值返回给调用方。 3.在成员正文中,指示引用返回值是否作为调用方欲修改的引用被存储在本地,或在一般…...
MQ - 08 基础篇_消费者客户端SDK设计(下)
文章目录 导图Pre概述消费分组协调者消费分区分配策略轮询粘性自定义消费确认确认后删除数据确认后保存消费进度数据消费失败处理从服务端拉取数据失败本地业务数据处理失败提交位点信息失败总结导图 Pre...
Flutter层对于Android 13存储权限的适配问题
感觉很久没有写博客了,不对,的确是很久没有写博客了。原因我不怎么想说,玩物丧志了。后面渐渐要恢复之前的写作节奏。今天来聊聊我最近遇到的一个问题: Android 13版本对于storage权限的控制问题。 我们都知道,Andro…...
Android kotlin开源项目-功能标题目录
目录 一、BRVAH二、开源项目1、RV列表动效(标题目录)2、拖拽与侧滑(标题目录)3、数据库(标题目录)4、树形图(多级菜单)(标题目录)5、轮播图与头条(标题目录)6…...
Linux下,基于TCP与UDP协议,不同进程下单线程通信服务器
C语言实现Linux下,基于TCP与UDP协议,不同进程下单线程通信服务器 一、TCP单线程通信服务器 先运行server端,再运行client端输入"exit" 是退出 1.1 server_TCP.c **#include <my_head.h>#define PORT 6666 #define IP &qu…...
qt功能自己创作
按钮按下三秒禁用 void MainWindow::on_pushButton_5_clicked(){// 锁定界面setWidgetsEnabled(ui->centralwidget, false);// 创建一个定时器,等待3秒后解锁界面QTimer::singleShot(3000, this, []() {setWidgetsEnabled(ui->centralwidget, true);;//ui-&g…...
Linux网络编程:使用UDP和TCP协议实现网络通信
目录 一. 端口号的概念 二. 对于UDP和TCP协议的认识 三. 网络字节序 3.1 字节序的概念 3.2 网络通信中的字节序 3.3 本地地址格式和网络地址格式 四. socket编程的常用函数 4.1 sockaddr结构体 4.2 socket编程常见函数的功能和使用方法 五. UDP协议实现网络通信 5.…...
【后端速成 Vue】初识指令(上)
前言: Vue 会根据不同的指令,针对标签实现不同的功能。 在 Vue 中,指定就是带有 v- 前缀 的特殊 标签属性,比如: <div v-htmlstr> </div> 这里问题就来了,既然 Vue 会更具不同的指令&#…...
爬虫 — Scrapy-Redis
目录 一、背景1、数据库的发展历史2、NoSQL 和 SQL 数据库的比较 二、Redis1、特性2、作用3、应用场景4、用法5、安装及启动6、Redis 数据库简单使用7、Redis 常用五大数据类型7.1 Redis-String7.2 Redis-List (单值多value)7.3 Redis-Hash7.4 Redis-Set (不重复的)7.5 Redis-Z…...
tcpdump常用命令
需要安装 tcpdump wireshark ifconfig找到网卡名称 eth0, ens192... tcpdump需要root权限 网卡eth0 经过221.231.92.240:80的流量写入到http.cap tcpdump -i eth0 host 221.231.92.240 and port 80 -vvv -w http.cap ssh登录到主机查看排除ssh 22端口的报文 tcpdump -i …...
计算机网络运输层网络层补充
1 CDMA是码分多路复用技术 和CMSA不是一个东西 UPD是只确保发送 但是接收端收到之后(使用检验和校验 除了检验的部分相加 对比检验和是否相等。如果不相同就丢弃。 复用和分用是发生在上层和下层的问题。通过比如时分多路复用 频分多路复用等。TCP IP 应用层的IO多路复用。网…...
java CAS详解(深入源码剖析)
CAS是什么 CAS是compare and swap的缩写,即我们所说的比较交换。该操作的作用就是保证数据一致性、操作原子性。 cas是一种基于锁的操作,而且是乐观锁。在java中锁分为乐观锁和悲观锁。悲观锁是将资源锁住,等之前获得锁的线程释放锁之后&am…...
1786_MTALAB代码生成把通用函数生成独立文件
全部学习汇总: GitHub - GreyZhang/g_matlab: MATLAB once used to be my daily tool. After many years when I go back and read my old learning notes I felt maybe I still need it in the future. So, start this repo to keep some of my old learning notes…...
2023/09/19 qt day3
头文件 #ifndef WIDGET_H #define WIDGET_H #include <QWidget> #include <QDebug> #include <QTime> #include <QTimer> #include <QPushButton> #include <QTextEdit> #include <QLineEdit> #include <QLabel> #include &l…...
Docker 学习总结(78)—— Docker Rootless 让你的容器更安全
前言 在以 root 用户身份运行 Docker 会带来一些潜在的危害和安全风险,这些风险包括: 容器逃逸:如果一个容器以 root 权限运行,并且它包含了漏洞或者被攻击者滥用,那么攻击者可能会成功逃出容器,并在宿主系统上执行恶意操作。这会导致宿主系统的安全性受到威胁。 特权升…...
如何使用ArcGIS Pro将等高线转DEM
通常情况下,我们拿到的等高线数据一般都是CAD格式,如果要制作三维地形模型,使用栅格格式的DEM数据是更好的选择,这里就为大家介绍一下如何使用ArcGIS Pro将等高线转DEM,希望能对你有所帮助。 创建TIN 在工具箱中选择“…...
K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
EtherNet/IP转DeviceNet协议网关详解
一,设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络,本网关连接到EtherNet/IP总线中做为从站使用,连接到DeviceNet总线中做为从站使用。 在自动…...
QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...
Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...
Swagger和OpenApi的前世今生
Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章,二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑: 🔄 一、起源与初创期:Swagger的诞生(2010-2014) 核心…...
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...
技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...
初探Service服务发现机制
1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能:服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源…...
Netty从入门到进阶(二)
二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架,用于…...
