PyTorch核心函数详解:gather与where的实战指南
PyTorch中的
torch.gather和torch.where是处理张量数据的关键工具,前者实现基于索引的灵活数据提取,后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示,深入解析两者的工作原理及使用技巧,帮助开发者提升数据处理的灵活性与效率。
在深度学习中,我们经常需要根据特定规则提取或生成数据。例如:
- 从预测概率中提取Top-K类别索引
- 根据掩码筛选有效数据点
- 动态生成条件化张量
torch.gather和torch.where正是解决这类问题的核心函数。下文将结合图像处理、数据筛选等场景,详解它们的用法与差异。

一、torch.gather:基于索引的精准提取
功能描述
torch.gather(input, dim, index) 沿指定维度dim,根据index张量中的索引值,从input中提取对应元素,输出形状与index一致。
参数说明
input:源张量dim:指定操作的维度index:索引张量,其值必须为整数类型
核心规则
- 索引穿透性:索引值直接映射源张量的位置,不改变维度
- 广播机制:当
index维度小于input时,会自动广播到匹配形状 - 多维索引:支持通过多维索引张量提取复杂结构的数据
应用场景与示例
场景1:图像数据批量提取
假设需要从批量图像中提取特定位置的像素值:
# 假设images是形状为(2,3,3)的图像批次 (批次大小2,通道3,分辨率3x3)
images = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]], # 第一张图像[[10,11,12],[13,14,15],[16,17,18]] # 第二张图像
])# 提取所有图像的第0行第1列像素 (shape: (2,))
pixels = torch.gather(images, dim=2, index=torch.tensor([[[0,1,0],[0,1,0]], [[0,1,0],[0,1,0]]]))
print(pixels)
# 输出: tensor([[1, 2, 1],
# [10, 11, 10]])
场景2:从概率分布提取Top-K结果
在NLP任务中提取预测词ID:
logits = torch.tensor([[0.1, 0.4, 0.5], [0.3, 0.6, 0.1]]) # 2个样本的3个类别的概率
topk_indices = logits.topk(k=2, dim=1).indices # 获取Top-2索引# 使用gather提取Top-2概率值
topk_probs = torch.gather(logits, dim=1, index=topk_indices)
print(topk_probs)
# 输出:
# tensor([[0.5, 0.4],
# [0.6, 0.3]])
二、torch.where:条件驱动的动态生成
功能描述
torch.where(condition, x, y) 根据布尔条件condition,从张量x或y中选择元素,生成与输入同形状的新张量。
参数说明
condition:布尔型张量,决定元素来源x:满足条件时选择的元素来源y:不满足条件时选择的元素来源
核心特性
- 自动广播:支持不同形状的条件与输入张量
- 元素级操作:逐元素比较生成动态结果
- 类型转换:输出类型由
x和y决定
应用场景与示例
场景1:数据清洗与过滤
筛选出温度超过30℃且湿度低于60%的记录:
temperature = torch.tensor([25.0, 32.5, 28.0, 35.0])
humidity = torch.tensor([55.0, 58.0, 70.0, 50.0])# 生成布尔掩码
mask = (temperature > 30) & (humidity < 60)# 根据条件生成标签
labels = torch.where(mask, torch.tensor("High Risk"), torch.tensor("Normal"))
print(labels)
# 输出: tensor(['Normal', 'High Risk', 'Normal', 'Normal'], dtype=string)
场景2:图像二值化处理
将灰度图像转换为二值掩码:
gray_image = torch.tensor([[0.1, 0.8], [0.6, 0.3]], dtype=torch.float32)
threshold = 0.5# 生成二值掩码
binary_mask = torch.where(gray_image > threshold, torch.tensor(1.0), torch.tensor(0.0))
print(binary_mask)
# 输出:
# tensor([[0., 1.],
# [1., 0.]])
三、函数对比与选择指南
| 特性 | torch.gather | torch.where |
|---|---|---|
| 核心功能 | 基于索引精确提取元素 | 条件驱动动态生成元素 |
| 输入要求 | 需显式提供索引张量 | 需条件张量及候选值张量 |
| 维度匹配 | 严格匹配索引与源张量维度 | 自动广播兼容不同形状 |
| 典型应用 | 多维数据查询、Top-K提取 | 条件筛选、数据转换、掩码生成 |
| 性能消耗 | 较高(涉及索引计算) | 较低(基于原生条件判断) |
四、综合实战:图像语义分割后处理
任务需求
将模型输出的概率图转换为二值掩码,并提取连通区域标签。
解决方案
# 假设prob_map是模型输出的概率图 (H,W)
prob_map = torch.rand(256, 256) > 0.5 # 二值化处理# 使用where生成掩码
mask = torch.where(prob_map, torch.tensor(1), torch.tensor(0))# 使用gather提取连通区域标签(假设labels是预测的类别索引)
labels = torch.randint(0, 10, (256, 256))
selected_labels = torch.gather(labels, dim=0, index=mask.nonzero(as_tuple=True)[0])
五、注意事项与最佳实践
-
索引越界预防:
# 错误示例:索引超出范围会导致错误 valid_indices = torch.clamp(indices, min=0, max=max_dim-1) -
类型一致性:
# 确保index张量为整型 index = index.long() -
内存优化:
# 优先使用in-place操作减少显存占用 mask.masked_fill_(condition, value)
结语
torch.gather和torch.where作为PyTorch生态中的基石函数,在数据工程与模型开发中扮演着不可替代的角色。理解它们的底层逻辑与适用场景,能够帮助您:
- 更高效地实现复杂数据操作
- 优化模型推理与训练流程
- 解决各类条件化数据处理难题
掌握这两把利器,您将在PyTorch开发中如鱼得水!
相关文章:
PyTorch核心函数详解:gather与where的实战指南
PyTorch中的torch.gather和torch.where是处理张量数据的关键工具,前者实现基于索引的灵活数据提取,后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示,深入解析两者的工作原理及使用技巧,帮助开发者提升数据处理的灵活…...
《Operating System Concepts》阅读笔记:p636-p666
《Operating System Concepts》学习第 58 天,p636-p666 总结,总计 31 页。 一、技术总结 1.system and network threats (1)attack network traffic (2)denial of service (3)port scanning 2.symmetric/asymmetric encryption algorithm (1)symm…...
Go:接口
接口既约定 Go 语言中接口是抽象类型 ,与具体类型不同 ,不暴露数据布局、内部结构及基本操作 ,仅提供一些方法 ,拿到接口类型的值 ,只能知道它能做什么 ,即提供了哪些方法 。 func Fprintf(w io.Writer, …...
ESP32+Arduino入门(三):连接WIFI获取当前时间
ESP32内置了WIFI模块连接WIFI非常简单方便。 代码如下: #include <WiFi.h>const char* ssid "WIFI名称"; const char* password "WIFI密码";void setup() {Serial.begin(115200);WiFi.begin(ssid,password);while(WiFi.status() ! WL…...
FastAPI用户认证系统开发指南:从零构建安全API
前言 在现代Web应用开发中,用户认证系统是必不可少的功能。本文将带你使用FastAPI框架构建一个完整的用户认证系统,包含注册、登录、信息更新和删除等功能。我们将采用JWT(JSON Web Token)进行身份验证,并使用SQLite作…...
CSS高度坍塌?如何解决?
一、什么是高度坍塌? 高度坍塌(Collapsing Margins)是指当父元素没有设置边框(border)、内边距(padding)、内容(content)或清除浮动时,其子元素的 margin 会…...
【数据结构】之散列
一、定义与基本术语 (一)、定义 散列(Hash)是一种将键(key)通过散列函数映射到一个固定大小的数组中的技术,因为键值对的映射关系,散列表可以实现快速的插入、删除和查找操作。在这…...
空地机器人在复杂动态环境下,如何高效自主导航?
随着空陆两栖机器人(AGR)在应急救援和城市巡检等领域的应用范围不断扩大,其在复杂动态环境中实现自主导航的挑战也日益凸显。对此香港大学王俊铭基于阿木实验室P600无人机平台自主搭建了一整套空地两栖机器人,使用Prometheus开源框架完成算法的仿真验证与…...
python小记(十二):Python 中 Lambda函数详解
Python 中 Lambda函数详解 Lambda函数详解:从入门到实战一、什么是Lambda函数?二、Lambda的核心语法与特点1. 基础语法2. 与普通函数对比 三、Lambda的六大应用场景(附代码示例)1. 基本数学运算2. 列表排序与自定义规则3. 数据映射…...
第二十一讲 XGBoost 回归建模 + SHAP 可解释性分析(利用R语言内置数据集)
下面我将使用 R 语言内置的 mtcars 数据集,模拟一个完整的 XGBoost 回归建模 SHAP 可解释性分析 实战流程。我们将以预测汽车的油耗(mpg)为目标变量,构建 XGBoost 模型,并用 SHAP 来解释模型输出。 🚗 示例…...
数据分析实战案例:使用 Pandas 和 Matplotlib 进行居民用水
原创 IT小本本 IT小本本 2025年04月15日 18:31 北京 本文将使用 Matplotlib 及 Seaborn 进行数据可视化。探索如何清理数据、计算月度用水量并生成有价值的统计图表,以便更好地理解居民的用水情况。 数据处理与清理 读取 Excel 文件 首先,我们使用 pan…...
Asp.NET Core WebApi 创建带鉴权机制的Api
构建一个包含 JWT(JSON Web Token)鉴权的 Web API 是一种常见的做法,用于保护 API 端点并验证用户身份。以下是一个基于 ASP.NET Core 的完整示例,展示如何实现 JWT 鉴权。 1. 创建 ASP.NET Core Web API 项目 使用 .NET CLI 或 …...
hash.
Redis 自身就是键值对结构 Redis 自身的键值对结构就是通过 哈希 的方式来组织的 哈希类型中的映射关系通常称为 field-value,用于区分 Redis 整体的键值对(key-value), 注意这里的 value 是指 field 对应的值,不是键…...
记录鸿蒙应用上架应用未配置图标的前景图和后景图标准要求尺寸1024px*1024px和标准要求尺寸1024px*1024px
审核报错【①应用未配置图标的前景图和后景图,标准要求尺寸1024px*1024px且需下载HUAWEI DevEco Studio 5.0.5.315或以上版本进行图标再处理、②应用在展开状态下存在页面左边距过大的问题, 应用在展开状态下存在页面右边距过大的问题, 当前页面左边距: 504 px, 当前页面右边距…...
golang-常见的语法错误
https://juejin.cn/post/6923477800041054221 看这篇文章 Golang 基础面试高频题详细解析【第一版】来啦~ 大叔说码 for-range的坑 func main() { slice : []int{0, 1, 2, 3} m : make(map[int]*int) for key, val : range slice {m[key] &val }for k, v : …...
Google最新《Prompt Engineering》白皮书全解析
近期有幸拿到了Google最新发布的《Prompt Engineering》白皮书,这是一份由Lee Boonstra主笔,Michael Sherman、Yuan Cao、Erick Armbrust、Antonio Gulli等多位专家共同贡献的权威性指南,发布于2025年2月。今天我想和大家分享这份68页的宝贵资…...
如何快速部署基于Docker 的 OBDIAG 开发环境
很多开发者对 OceanBase的 SIG社区小组很有兴趣,但如何将OceanBase的各类工具部署在开发环境,对于不少开发者而言都是比较蛮烦的事情。例如,像OBDIAG,其在WINDOWS系统上配置较繁琐,需要单独搭建C开发环境。此外&#x…...
[LeetCode 1306] 跳跃游戏3(Ⅲ)
题面: LeetCode 1306 思路: 只要能跳到其中一个0即可,和跳跃游戏1/2完全不同了,记忆化暴搜即可。 时间复杂度: O ( n ) O(n) O(n) 空间复杂度: O ( n ) O(n) O(n) 代码: dfs vector<…...
spring-ai-alibaba使用Agent实现智能机票助手
示例目标是使用 Spring AI Alibaba 框架开发一个智能机票助手,它可以帮助消费者完成机票预定、问题解答、机票改签、取消等动作,具体要求为: 基于 AI 大模型与用户对话,理解用户自然语言表达的需求支持多轮连续对话,能…...
STM32平衡车开发实战教程:从零基础到项目精通
STM32平衡车开发实战教程:从零基础到项目精通 一、项目概述与基本原理 1.1 平衡车工作原理 平衡车是一种基于倒立摆原理的两轮自平衡小车,其核心控制原理类似于人类保持平衡的过程。当人站立不稳时,会通过腿部肌肉的快速调整来维持平衡。平…...
使用DeepSeek AI高效降低论文重复率
一、论文查重原理与DeepSeek降重机制 1.1 主流查重系统工作原理 文本比对算法:连续字符匹配(通常13-15字符)语义识别技术:检测同义替换和结构调整参考文献识别:区分合理引用与不当抄袭跨语言检测:中英文互译内容识别1.2 DeepSeek降重核心技术 深度语义理解:分析句子核心…...
linux多线(进)程编程——(7)消息队列
前言 现在修真界大家的沟通手段已经越来越丰富了,有了匿名管道,命名管道,共享内存等多种方式。但是随着深入使用人们逐渐发现了这些传音术的局限性。 匿名管道:只能在有血缘关系的修真者(进程)间使用&…...
WinForm真入门(14)——ListView控件详解
一、ListView 控件核心概念与功能 ListView 是 WinForm 中用于展示结构化数据的多功能列表控件,支持多列、多视图模式及复杂交互,常用于文件资源管理器、数据报表等场景。 核心特点: 支持 5种视图模式:Details&…...
Python + Playwright:规避常见的UI自动化测试反模式
Python + Playwright:规避常见的UI自动化测试反模式 前言反模式一:整体式页面对象(POM)反模式二:具有逻辑的页面对象 - POM 的“越界”行为反模式三:基于 UI 的测试设置 - 缓慢且脆弱的“舞台搭建”反模式四:功能测试过载 - “试图覆盖一切”的测试反模式之间的关联与核…...
从服务器多线程批量下载文件到本地
1、客户端安装 aria2 下载地址:aria2 解压文件,然后将文件目录添加到系统环境变量Path中,然后打开cmd,输入:aria2c 文件地址,就可以下载文件了 2、服务端配置nginx文件服务器 server {listen 8080…...
循环神经网络 - 深层循环神经网络
如果将深度定义为网络中信息传递路径长度的话,循环神经网络可以看作既“深”又“浅”的网络。 一方面来说,如果我们把循环网络按时间展开,长时间间隔的状态之间的路径很长,循环网络可以看作一个非常深的网络。 从另一方面来 说&…...
linux运维篇-Ubuntu(debian)系操作系统创建源仓库
适用范围 适用于Ubuntu(Debian)及其衍生版本的linux系统 例如,国产化操作系统kylin-desktop-v10 简介 先来看下我们需要创建出来的仓库目录结构 Deb_conf_test apt源的主目录 conf 配置文件存放目录 conf目录下存放两个配置文件&…...
深度学习之微积分
2.4.1 导数和微分 2.4.2 偏导数 