当前位置: 首页 > news >正文

pytorch花式索引提取topk的张量

文章目录

  • pytorch花式索引提取topk的张量
    • 问题设定
    • 代码实现
      • 索引方法
      • gather方法
      • 验证
    • 补充知识
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的张量

问题设定

在这里插入图片描述
或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(bs, dim, X)的reduced向量。我们在进行topk操作(以减少计算量)的时候经常碰到这种情况。
给出如下两种实现方法,分别使用花式索引(参考informer的代码)以及pytorch的gather方法

代码实现

索引方法

参考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 结果和indices一致,说明在第二个channel上,每个样本的索引是一样的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在这里插入图片描述
在这里插入图片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

验证

两种方法得到的结果完全相同
在这里插入图片描述

补充知识

expand方法

在 PyTorch 中,expand() 方法用于扩展张量的大小。它会在不实际复制数据的情况下,重复张量的元素以填充新的形状。这个方法可以用于广播操作,以便在执行一些需要相同形状的张量之间的数学运算时,使它们具有相同的形状。

下面是使用 expand() 方法的基本用法:

import torch# 创建一个原始张量
x = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用 expand 扩展张量的大小
expanded_x = x.expand(2, 3, 4)  # 扩展成维度为(2, 3, 4)的张量print(expanded_x)

在上面的例子中,我们首先创建了一个形状为 (2, 3) 的原始张量 x。然后,我们使用 expand() 方法将其扩展成一个维度为 (2, 3, 4) 的新张量 expanded_x,该张量的形状是在原始张量形状的基础上每个维度都扩展了一倍。

需要注意的是,expand() 方法只能用于增加张量的大小,不能减小。另外,扩展后的张量与原始张量共享底层数据,因此在原始张量上进行的任何修改都会反映在扩展后的张量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于从输入张量中按照指定索引提取元素。这个方法通常用于根据索引收集特定的元素,例如根据类别索引从分类得分张量中获取对应类别的得分。

下面是使用 gather() 方法的基本用法:

import torch# 创建一个输入张量
input_tensor = torch.tensor([[1, 2],[3, 4],[5, 6]])# 创建一个索引张量
indices = torch.tensor([[0, 0],[1, 0]])# 使用 gather 方法根据索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)print(output_tensor)

在上面的例子中,我们首先创建了一个形状为 (3, 2) 的输入张量 input_tensor,以及一个形状为 (2, 2) 的索引张量 indices。然后,我们使用 gather() 方法从输入张量 input_tensor 中按照索引张量 indices 收集元素。

gather() 方法中,参数 dim 指定了在哪个维度上进行收集操作,而 index 参数指定了收集元素所使用的索引张量。

需要注意的是,索引张量 indices 的形状必须与输出张量的形状一致,或者是可以广播成与输出张量形状一致的形状。

randint

torch.randint() 是 PyTorch 中用于生成随机整数张量的函数。它可以生成一个张量,其中的元素是在指定范围内随机抽样的整数。

下面是 torch.randint() 的基本用法示例:

import torch# 生成一个形状为 (3, 3) 的随机整数张量,范围是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))print(random_integers)

在上面的示例中,我们使用了 torch.randint() 函数来生成一个形状为 (3, 3) 的随机整数张量,其中的元素取值范围在闭区间 [low, high) 内,即从 0 到 9。

torch.randint() 函数的主要参数包括:

  • low:生成的随机整数的最小值(包含)。
  • high:生成的随机整数的最大值(不包含)。
  • size:生成的张量的形状。

你也可以不指定 low 参数,默认情况下它为 0。此外,还可以使用其他参数来控制生成的随机整数张量的设备类型、数据类型等。

相关文章:

pytorch花式索引提取topk的张量

文章目录 pytorch花式索引提取topk的张量问题设定代码实现索引方法gather方法验证 补充知识expand方法gather方法randint pytorch花式索引提取topk的张量 问题设定 或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(…...

Swagger2

Swagger2 引入依赖 <!-- springfox-swagger2 --><dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId><version>2.10.5</version></dependency>编写配置 @Configuration public …...

2024/2/13

数组练习 1、选择题 1.1、若有定义语句&#xff1a;int a[3][6]; &#xff0c;按在内存中的存放顺序&#xff0c;a 数组的第10个元素是 D A&#xff09;a[0][4] B) a[1][3] C)a[0][3] D)a[1][4] 1.2、有数组 int a[5] {10&#xff0c;20&#xff0c;30&#xff0c;40&…...

【工具】Android|Android Studio 长颈鹿版本安装下载使用详解

版本&#xff1a;2022.3.1.22&#xff0c; https://redirector.gvt1.com/edgedl/android/studio/install/2022.3.1.22/android-studio-2022.3.1.22-windows.exe 前言 笔者曾多次安装并卸载Android Studio&#xff0c;反复被安卓模拟器劝退。现在差不多是第三次安装&#xff0c…...

第三代互联网web3.0

Web3.0&#xff0c;通常被称为第三代互联网&#xff0c;代表了互联网技术的下一个演进阶段。它主要基于区块链、去中心化和用户赋权的理念构建&#xff0c;旨在创造一个更加智能、开放且安全的网络环境。以下是Web3.0的一些关键特点&#xff1a; 1. **去中心化**&#xff1a;We…...

FL Studio版本升级-FL Studio怎么升级-FL Studio升级方案

已经是新年2024年了&#xff0c;但是但是依然有很多朋友还在用FL Studio12又或者FL Studio20&#xff0c;今天这篇文章教大家如何升级FL Studio21 FL Studio 21是Image Line公司开发的音乐编曲软件&#xff0c;除了软件以外&#xff0c;我们还提供了FL Studio的升级服务&#…...

服务降级(Sentinel)

服务降级 采用 SentinelResource 注解方式实现&#xff0c; 必要的 依赖必须引入 以及 切面Bean 接口代码 RequestMapping("/degrade")SentinelResource(value DEGRADE_RESOURCE_NAME, blockHandler "blockHandlerForDegrade",entryType EntryType.IN…...

Rust入门问题: use of undeclared crate or module `rand`

按照官网学rust&#xff0c;程序地址在这里&#xff0c; 写个猜数字游戏 - Rust 程序设计语言 简体中文版 程序内容也很简单&#xff0c; use std::io; use rand::Rng;fn main() {println!("Guess the number!");let secret_number rand::thread_rng().gen_range…...

2024.2.6 模拟实现 RabbitMQ —— 数据库操作

目录 引言 选择数据库 环境配置 设计数据库表 实现流程 封装数据库操作 针对 DataBaseManager 单元测试 引言 硬盘保存分为两个部分 数据库&#xff1a;交换机&#xff08;Exchange&#xff09;、队列&#xff08;Queue&#xff09;、绑定&#xff08;Binding&#xff0…...

dolphinscheduler海豚调度(一)简介快速体验

1、简介 Apache DolphinScheduler 是一个分布式易扩展的可视化DAG工作流任务调度开源系统。适用于企业级场景&#xff0c;提供了一个可视化操作任务、工作流和全生命周期数据处理过程的解决方案。 Apache DolphinScheduler 旨在解决复杂的大数据任务依赖关系&#xff0c;并为应…...

VTK 三维场景的基本要素(相机) vtkCamera

观众的眼睛好比三维渲染场景中的相机&#xff0c;在VTK中用vtkCamera类来表示。vtkCamera负责把三维场景投影到二维平面&#xff0c;如屏幕&#xff0c;相机投影示意图如下图所示。 1.与相机投影相关的要素主要有如下几个&#xff1a; 1&#xff09;相机位置: 相机所处的位置…...

小游戏和GUI编程(5) | SVG图像格式简介

小游戏和GUI编程(5) | SVG图像格式简介 0. 问题 Q1: SVG 是什么的缩写&#xff1f;Q2: SVG 是一种图像格式吗&#xff1f;Q3: SVG 相对于其他图像格式的优点和缺点是什么&#xff1f;Q4: 哪些工具可以查看 SVG 图像&#xff1f;Q5: SVG 图像格式的规范是怎样的&#xff1f;Q6…...

多机多卡运行nccl-tests和channel获取

nccl-tests 环境1. 安装nccl2. 安装openmpi3. 单机测试4. 多机测试mpirun多机多进程多节点运行nccl-testschannel获取 环境 Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-91-generic x86_64)cuda 11.8 cudnn 8nccl 2.15.1NVIDIA GeForce RTX 4090 *2 1. 安装nccl #查看cuda版本 nv…...

SQL,HQL刷题,尚硅谷

相关表数据&#xff1a; 1、score_info 2、student_info 题目及思路解析&#xff1a; 分组结果的条件 1、查询平均成绩大于60分的学生的学号和平均成绩 代码&#xff1a; selectstu_id,avg(score) score_avg from score_info group by stu_id having score_avg>60; 思路…...

DevOps:CI、CD、CB、CT、CD

目录 一、软件开发流程演化快速回顾 &#xff08;一&#xff09;瀑布模型 &#xff08;二&#xff09;原型模型 &#xff08;三&#xff09;螺旋模型 &#xff08;四&#xff09;增量模型 &#xff08;五&#xff09;敏捷开发 &#xff08;六&#xff09;DevOps 二、走…...

[leetcode经典算法题]删除有序数组中的重复项(双指针)

删除有序数组中的重复项 给你一个 非严格递增排列 的数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使每个元素 只出现一次 &#xff0c;返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素…...

【国产MCU】-CH32V307-触摸按键检测(TKEY)

触摸按键检测(TKEY) 文章目录 触摸按键检测(TKEY)1、TKEY介绍2、TKEY使用实例触摸检测控制(TKEY)单元,借助ADC 模块的电压转换功能,通过将电容量转换为电压量进行采样,实现触摸按键检测功能。检测通道复用ADC 的16 个外部通道,通过ADC 模块的单次转换模式实现触摸按键…...

Hive的小文件问题

目录 一、小文件产生的原因 二、小文件的危害 三、小文件的解决方案 3.1 小文件的预防 3.1.1 减少Map数量 3.1.2 减少Reduce的数量 3.2 已存在的小文件合并 3.2.1 方式一&#xff1a;insert overwrite (推荐) 3.2.2 方式二&#xff1a;concatenate 3.2.3 方式三&#xff…...

攻防世界——re2-cpp-is-awesome

64位 我先用虚拟机跑了一下这个程序&#xff0c;结果输出一串字符串flag ——没用 IDA打开后 F5也没有什么可看的 那我们就F12查看字符串找可疑信息 这里一下就看见了 __int64 __fastcall main(int a1, char **a2, char **a3) {char *v3; // rbx__int64 v4; // rax__int64 v…...

问山海——天涯海角——桃花渊boss攻击顺序

文章目录 桃花渊代码代码解读代码执行结果攻击顺序示意图 桃花渊 规划击杀各个boss顺序。 副本持续时间为30分钟&#xff0c;每个地方的boss被打死后&#xff0c;需要一定时间才能重新刷新。 只考虑其中两种boss&#xff0c;龟将和龟龙。各有四个。 其中我从一个boss地点到…...

基于Unity的地牢游戏开发

1.数字字符串转数字System.Globalization.NumberStyles hexNum; // 专门的枚举成员&#xff0c;解析16进制字符串 hexNum System.Globalization.NumberStyles.HexNumber;int.Parse(tileNums[i], hexNum);2.注意&#xff1a;文件读取是从上到下&#xff0c;而 Unity y轴 …...

JiYuTrainer高效实用指南:3步解锁极域电子教室控制,恢复电脑操作自由

JiYuTrainer高效实用指南&#xff1a;3步解锁极域电子教室控制&#xff0c;恢复电脑操作自由 【免费下载链接】JiYuTrainer 极域电子教室防控制软件, StudenMain.exe 破解 项目地址: https://gitcode.com/gh_mirrors/ji/JiYuTrainer 还在为课堂上被老师全屏控制电脑而烦…...

2026生鲜零售收银软件推荐:四大主流方案深度对比

开一家生鲜店&#xff0c;最让人头疼的往往不是进货渠道或选址&#xff0c;而是每天高峰期那台“卡住”的收银机。想象一下&#xff0c;周末傍晚顾客排成长龙&#xff0c;称重员手忙脚乱地输入代码&#xff0c;屏幕转圈加载&#xff0c;后面的顾客开始不耐烦地催促&#xff0c;…...

GaussDB GDS 搭建完全指南:从安装到启动,一文搞定数据迁移服务

在进行 GaussDB 跨库数据迁移时&#xff0c;GDS&#xff08;Gauss Data Service&#xff09; 是实现外表迁移的核心组件。本文将手把手带你完成 GDS 的下载、安装、配置与启动&#xff0c;确保数据迁移通道畅通无阻。 &#x1f4ce; 关联阅读&#xff1a;GaussDB GDS 外表迁移实…...

ComfyUI Segment Anything:零门槛实现智能图像分割的完整指南

ComfyUI Segment Anything&#xff1a;零门槛实现智能图像分割的完整指南 【免费下载链接】comfyui_segment_anything Based on GroundingDino and SAM, use semantic strings to segment any element in an image. The comfyui version of sd-webui-segment-anything. 项目地…...

Ardupilot无人船新手必看:从遥控器开关到地面站,3档模式设置保姆级教程

Ardupilot无人船控制模式全解析&#xff1a;从基础配置到高阶应用实战 第一次接触Ardupilot无人船时&#xff0c;最让人困惑的莫过于各种控制模式的区别与适用场景。作为开源自动驾驶系统的标杆&#xff0c;Ardupilot为无人船提供了多达14种控制模式&#xff0c;每种模式都有其…...

别再手动改hosts了!用Docker Compose一键部署Authelia SSO,顺便搞定Traefik反向代理

一键部署Authelia SSO与Traefik反向代理的Docker Compose实战指南 在当今复杂的网络环境中&#xff0c;管理多个Web应用的认证流程往往成为开发者的痛点。手动配置hosts文件、逐个设置访问权限不仅耗时耗力&#xff0c;还容易出错。本文将介绍如何利用Docker Compose快速搭建Au…...

告别手动画图!用Perl脚本自动化统计MS动力学模拟中的氢键(附脚本下载)

用Perl脚本实现MS动力学模拟中氢键的自动化统计与分析 在分子动力学模拟研究中&#xff0c;氢键作为影响材料性能的关键因素之一&#xff0c;其动态变化规律往往需要从海量轨迹数据中提取。传统手动分析方法不仅效率低下&#xff0c;还容易引入人为误差。本文将介绍如何利用Per…...

2026年电钢琴避坑指南|高性价比品牌型号推荐,新手必看!

电钢琴选购核心要点&#xff08;快速避坑&#xff09; 在推荐具体机型前&#xff0c;先明确4个选购关键指标&#xff0c;确保不踩坑&#xff1a; 1.键盘&#xff1a;必须88键逐级配重重锤键盘&#xff0c;避免毁手型。 2.复音数:至少128复音&#xff08;避免弹奏复杂曲目时丢…...

别再硬刚滑块了!一个Python脚本自动搞定淘宝X5SEC验证码

Python自动化破解淘宝X5SEC滑块验证码实战指南 淘宝作为国内最大的电商平台之一&#xff0c;其反爬机制一直处于行业领先水平。其中X5SEC滑块验证码是淘宝用来识别自动化程序的主要手段之一。对于需要批量采集商品数据或进行价格监控的开发者来说&#xff0c;频繁的手动滑块验证…...