当前位置: 首页 > news >正文

Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.pyimport torchdef random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoreif __name__ == '__main__':x = torch.arange(64).reshape(1, 16, 4)random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4,  4,  4,  4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])

相关文章:

Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法 torch.gather() 就是在指定维度上收集value。 torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释: input (Tensor) – the source tensordim (int) – the axis along which to indexindex (Lo…...

用爬虫解决问题

使用Java进行网络爬虫开发是一种常见的做法,它可以帮助你从网站上自动抓取信息。Java语言因为其丰富的库支持(如Jsoup、HtmlUnit、Selenium等)和良好的跨平台性,成为实现爬虫的优选语言之一。下面我将简要介绍如何使用Java编写一个…...

机器学习-有监督学习

有监督学习是机器学习的一种主要范式,其基本思想是从有标签的训练数据中学习输入和输出之间的关系,然后利用学习到的模型对新的输入进行预测或分类。 有监督学习的过程如下: 1. 准备数据:首先,需要准备一组有标签的训练…...

【详细介绍下Visual Studio】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…...

【Golang】实现 Excel 文件下载功能

在当今的网络应用开发中,提供数据导出功能是一项常见的需求。Excel 作为一种广泛使用的电子表格格式,通常是数据导出的首选格式之一。在本教程中,我们将学习如何使用 Go 语言和 Gin Web 框架来创建一个 Excel 文件,并允许用户通过…...

设计模式2——原则篇:依赖倒转原则、单一职责原则、合成|聚合复用原则、开放-封闭原则、迪米特法则、里氏代换原则

设计模式2——设计原则篇 目录 一、依赖倒转原则 二、单一职责原则(SRP) 三、合成|聚合复用原则(CARP) 四、开放-封闭原则 五、迪米特法则(LoD) 六、里氏代换原则 七、接口隔离原则 八、总结 一、依赖…...

深入探讨布隆过滤器算法:高效的数据查找与去重工具

在处理海量数据时,我们经常需要快速地进行数据查找和去重操作。然而,传统的数据结构可能无法满足这些需求,特别是在数据量巨大的情况下。在这种情况下,布隆过滤器(Bloom Filter)算法就显得尤为重要和有效。…...

基于STC12C5A60S2系列1T 8051单片机实现一主单片机与一从单片机进行双向串口通信功能

基于STC12C5A60S2系列1T 8051单片机实现一主单片机与一从单片机进行双向串口通信功能 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机串口通信介绍STC12C5A60S2系列1T 8051单片机串口通信的结构基于STC12C5A60S2系列1T 8051单片机串口通信的特殊功能寄存器…...

ubuntu18.04安装docker容器

Ubuntu镜像下载 https://mirrors.huaweicloud.com/ubuntu-releases/ docker安装 # 第一步、卸载旧版本docker sudo apt-get remove docker docker-engine docker.io containerd runc# 第二步、更新及安装软件 luhost:~$ curl -fsSL https://get.docker.com -o get-docker.sh …...

202212青少年软件编程(Python)等级考试试卷(二级)

第 1 题 【单选题】 运行下列程序, 最终输出的结果是? ( ) info = {1:小明, 2:小黄,3:小兰}info[4] = 小红info[...

单播、组播、广播

​​​​​​ 概念 单播(Unicast) 单播是网络中最常用、最基本的通信方式。在单播通信中,数据包从一个节点发送到特定的另一个节点。换句话说,发送端和接收端之间建立一对一的连接,然后进行数据传输。 例如&#x…...

吴恩达深度学习笔记:深度学习的 实践层面 (Practical aspects of Deep Learning)1.13-1.14

目录 第二门课: 改善深层神经网络:超参数调试、正 则 化 以 及 优 化 (Improving Deep Neural Networks:Hyperparameter tuning, Regularization and Optimization)第一周:深度学习的 实践层面 (Practical aspects of Deep Learning)1.13 梯度检验&#…...

笔试强训未触及题目(个人向)

1.DP22 最长回文子序列 1.题目 2.解析 这是一个区间dp问题,我们让dp[i][j]表示在区间[i,j]内的最长子序列长度,如图: 3.代码 public class LongestArr {//DP22 最长回文子序列public static void main(String[] args) {Scanner…...

【YOLO改进】换遍MMDET主干网络之EfficientNet(基于MMYOLO)

EfficientNet EfficientNet是Google在2019年提出的一种新型卷积神经网络架构,其设计初衷是在保证模型性能的同时,尽可能地降低模型的复杂性和计算需求。EfficientNet的核心思想是通过均衡地调整网络的深度(层数)、宽度&#xff0…...

uniapp下拉选择组件

uniapp下拉选择组件 背景实现思路代码实现配置项使用尾巴 背景 最近遇到一个这样的需求,在输入框中输入关键字,通过接口查询到结果之后,以下拉框列表形式展现供用户选择。查询了下uni-app官网和项目中使用的uv-ui库,没找到符合条…...

高斯数据库创建函数的语法

CREATE FUNCTION 语法格式 •兼容PostgreSQL风格的创建自定义函数语法。 CREATE [ OR REPLACE ] FUNCTION function_name ( [ { argname [ argmode ] argtype [ { DEFAULT | : | } expression ]} [, …] ] ) [ RETURNS rettype [ DETERMINISTIC ] | RETURNS TABLE ( { column_…...

【.NET Core】你认识Attribute之CallerMemberName、CallerFilePath、CallerLineNumber三兄弟

你认识Attribute之CallerMemberName、CallerFilePath、CallerLineNumber三兄弟 文章目录 你认识Attribute之CallerMemberName、CallerFilePath、CallerLineNumber三兄弟一、概述二、CallerMemberNameAttribute类三、CallerFilePathAttribute 类四、CallerLineNumberAttribute 类…...

ubuntu删除opencv

要完全删除OpenCV 3.4.5版本,你可以按照以下步骤进行操作: 卸载OpenCV库: 首先,你需要卸载OpenCV 3.4.5版本。可以使用以下命令卸载OpenCV库: sudo apt-get purge libopencv*这将删除OpenCV库及其相关文件。 删除O…...

K8s源码分析(二)-K8s调度队列介绍

本文首发在个人博客上,欢迎来踩! 本次分析参考的K8s版本是 文章目录 调度队列简介调度队列源代码分析队列初始化QueuedPodInfo元素介绍ActiveQ源代码介绍UnschedulableQ源代码介绍**BackoffQ**源代码介绍队列弹出待调度的Pod队列增加新的待调度的Podpod调…...

OpenGL ES 面试高频知识点(二)

说说纹理常用的采样方式? 最邻近点采样(GL_NEAREST)和双线性采样(GL_LINEAR)。 GL_NEAREST 采样是 OpenGL 默认的纹理采样方式,OpenGL 会选择中心点最接近纹理坐标的那个像素,纹理放大的时候会有锯齿感或者颗粒感。 **GL_LINEAR 采样会基于纹理坐标附近的纹理像素,计…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本&#xff1a; 3.8.1 语言&#xff1a; JavaScript/TypeScript、C、Java 环境&#xff1a;Window 参考&#xff1a;Java原生反射机制 您好&#xff0c;我是鹤九日&#xff01; 回顾 在上篇文章中&#xff1a;CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

NFT模式:数字资产确权与链游经济系统构建

NFT模式&#xff1a;数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新&#xff1a;构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议&#xff1a;基于LayerZero协议实现以太坊、Solana等公链资产互通&#xff0c;通过零知…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作

一、上下文切换 即使单核CPU也可以进行多线程执行代码&#xff0c;CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短&#xff0c;所以CPU会不断地切换线程执行&#xff0c;从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...

CSS | transition 和 transform的用处和区别

省流总结&#xff1a; transform用于变换/变形&#xff0c;transition是动画控制器 transform 用来对元素进行变形&#xff0c;常见的操作如下&#xff0c;它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...