当前位置: 首页 > article >正文

PyTorch核心函数详解:gather与where的实战指南

PyTorch中的torch.gathertorch.where是处理张量数据的关键工具,前者实现基于索引的灵活数据提取,后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示,深入解析两者的工作原理及使用技巧,帮助开发者提升数据处理的灵活性与效率。

在深度学习中,我们经常需要根据特定规则提取或生成数据。例如:

  • 从预测概率中提取Top-K类别索引
  • 根据掩码筛选有效数据点
  • 动态生成条件化张量

torch.gathertorch.where正是解决这类问题的核心函数。下文将结合图像处理、数据筛选等场景,详解它们的用法与差异。
在这里插入图片描述

一、torch.gather:基于索引的精准提取

功能描述

torch.gather(input, dim, index) 沿指定维度dim,根据index张量中的索引值,从input中提取对应元素,输出形状与index一致。

参数说明
  • input:源张量
  • dim:指定操作的维度
  • index:索引张量,其值必须为整数类型

核心规则

  • 索引穿透性:索引值直接映射源张量的位置,不改变维度
  • 广播机制:当index维度小于input时,会自动广播到匹配形状
  • 多维索引:支持通过多维索引张量提取复杂结构的数据

应用场景与示例

场景1:图像数据批量提取

假设需要从批量图像中提取特定位置的像素值:

# 假设images是形状为(2,3,3)的图像批次 (批次大小2,通道3,分辨率3x3)
images = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]],  # 第一张图像[[10,11,12],[13,14,15],[16,17,18]]  # 第二张图像
])# 提取所有图像的第0行第1列像素 (shape: (2,))
pixels = torch.gather(images, dim=2, index=torch.tensor([[[0,1,0],[0,1,0]], [[0,1,0],[0,1,0]]]))
print(pixels)
# 输出: tensor([[1, 2, 1],
#                [10, 11, 10]])
场景2:从概率分布提取Top-K结果

在NLP任务中提取预测词ID:

logits = torch.tensor([[0.1, 0.4, 0.5], [0.3, 0.6, 0.1]])  # 2个样本的3个类别的概率
topk_indices = logits.topk(k=2, dim=1).indices  # 获取Top-2索引# 使用gather提取Top-2概率值
topk_probs = torch.gather(logits, dim=1, index=topk_indices)
print(topk_probs)
# 输出:
# tensor([[0.5, 0.4],
#         [0.6, 0.3]])

二、torch.where:条件驱动的动态生成

功能描述

torch.where(condition, x, y) 根据布尔条件condition,从张量xy中选择元素,生成与输入同形状的新张量。

参数说明
  • condition:布尔型张量,决定元素来源
  • x:满足条件时选择的元素来源
  • y:不满足条件时选择的元素来源

核心特性

  • 自动广播:支持不同形状的条件与输入张量
  • 元素级操作:逐元素比较生成动态结果
  • 类型转换:输出类型由xy决定

应用场景与示例

场景1:数据清洗与过滤

筛选出温度超过30℃且湿度低于60%的记录:

temperature = torch.tensor([25.0, 32.5, 28.0, 35.0])
humidity = torch.tensor([55.0, 58.0, 70.0, 50.0])# 生成布尔掩码
mask = (temperature > 30) & (humidity < 60)# 根据条件生成标签
labels = torch.where(mask, torch.tensor("High Risk"), torch.tensor("Normal"))
print(labels)
# 输出: tensor(['Normal', 'High Risk', 'Normal', 'Normal'], dtype=string)
场景2:图像二值化处理

将灰度图像转换为二值掩码:

gray_image = torch.tensor([[0.1, 0.8], [0.6, 0.3]], dtype=torch.float32)
threshold = 0.5# 生成二值掩码
binary_mask = torch.where(gray_image > threshold, torch.tensor(1.0), torch.tensor(0.0))
print(binary_mask)
# 输出:
# tensor([[0., 1.],
#         [1., 0.]])

三、函数对比与选择指南

特性torch.gathertorch.where
核心功能基于索引精确提取元素条件驱动动态生成元素
输入要求需显式提供索引张量需条件张量及候选值张量
维度匹配严格匹配索引与源张量维度自动广播兼容不同形状
典型应用多维数据查询、Top-K提取条件筛选、数据转换、掩码生成
性能消耗较高(涉及索引计算)较低(基于原生条件判断)

四、综合实战:图像语义分割后处理

任务需求

将模型输出的概率图转换为二值掩码,并提取连通区域标签。

解决方案

# 假设prob_map是模型输出的概率图 (H,W)
prob_map = torch.rand(256, 256) > 0.5  # 二值化处理# 使用where生成掩码
mask = torch.where(prob_map, torch.tensor(1), torch.tensor(0))# 使用gather提取连通区域标签(假设labels是预测的类别索引)
labels = torch.randint(0, 10, (256, 256))
selected_labels = torch.gather(labels, dim=0, index=mask.nonzero(as_tuple=True)[0])

五、注意事项与最佳实践

  1. 索引越界预防

    # 错误示例:索引超出范围会导致错误
    valid_indices = torch.clamp(indices, min=0, max=max_dim-1)
    
  2. 类型一致性

    # 确保index张量为整型
    index = index.long()  
    
  3. 内存优化

    # 优先使用in-place操作减少显存占用
    mask.masked_fill_(condition, value)
    

结语

torch.gathertorch.where作为PyTorch生态中的基石函数,在数据工程与模型开发中扮演着不可替代的角色。理解它们的底层逻辑与适用场景,能够帮助您:

  • 更高效地实现复杂数据操作
  • 优化模型推理与训练流程
  • 解决各类条件化数据处理难题

掌握这两把利器,您将在PyTorch开发中如鱼得水!

相关文章:

PyTorch核心函数详解:gather与where的实战指南

PyTorch中的torch.gather和torch.where是处理张量数据的关键工具&#xff0c;前者实现基于索引的灵活数据提取&#xff0c;后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示&#xff0c;深入解析两者的工作原理及使用技巧&#xff0c;帮助开发者提升数据处理的灵活…...

《Operating System Concepts》阅读笔记:p636-p666

《Operating System Concepts》学习第 58 天&#xff0c;p636-p666 总结&#xff0c;总计 31 页。 一、技术总结 1.system and network threats (1)attack network traffic (2)denial of service (3)port scanning 2.symmetric/asymmetric encryption algorithm (1)symm…...

Go:接口

接口既约定 Go 语言中接口是抽象类型 &#xff0c;与具体类型不同 &#xff0c;不暴露数据布局、内部结构及基本操作 &#xff0c;仅提供一些方法 &#xff0c;拿到接口类型的值 &#xff0c;只能知道它能做什么 &#xff0c;即提供了哪些方法 。 func Fprintf(w io.Writer, …...

ESP32+Arduino入门(三):连接WIFI获取当前时间

ESP32内置了WIFI模块连接WIFI非常简单方便。 代码如下&#xff1a; #include <WiFi.h>const char* ssid "WIFI名称"; const char* password "WIFI密码";void setup() {Serial.begin(115200);WiFi.begin(ssid,password);while(WiFi.status() ! WL…...

FastAPI用户认证系统开发指南:从零构建安全API

前言 在现代Web应用开发中&#xff0c;用户认证系统是必不可少的功能。本文将带你使用FastAPI框架构建一个完整的用户认证系统&#xff0c;包含注册、登录、信息更新和删除等功能。我们将采用JWT&#xff08;JSON Web Token&#xff09;进行身份验证&#xff0c;并使用SQLite作…...

CSS高度坍塌?如何解决?

一、什么是高度坍塌&#xff1f; 高度坍塌&#xff08;Collapsing Margins&#xff09;是指当父元素没有设置边框&#xff08;border&#xff09;、内边距&#xff08;padding&#xff09;、内容&#xff08;content&#xff09;或清除浮动时&#xff0c;其子元素的 margin 会…...

【数据结构】之散列

一、定义与基本术语 &#xff08;一&#xff09;、定义 散列&#xff08;Hash&#xff09;是一种将键&#xff08;key&#xff09;通过散列函数映射到一个固定大小的数组中的技术&#xff0c;因为键值对的映射关系&#xff0c;散列表可以实现快速的插入、删除和查找操作。在这…...

空地机器人在复杂动态环境下,如何高效自主导航?

随着空陆两栖机器人(AGR)在应急救援和城市巡检等领域的应用范围不断扩大&#xff0c;其在复杂动态环境中实现自主导航的挑战也日益凸显。对此香港大学王俊铭基于阿木实验室P600无人机平台自主搭建了一整套空地两栖机器人&#xff0c;使用Prometheus开源框架完成算法的仿真验证与…...

python小记(十二):Python 中 Lambda函数详解

Python 中 Lambda函数详解 Lambda函数详解&#xff1a;从入门到实战一、什么是Lambda函数&#xff1f;二、Lambda的核心语法与特点1. 基础语法2. 与普通函数对比 三、Lambda的六大应用场景&#xff08;附代码示例&#xff09;1. 基本数学运算2. 列表排序与自定义规则3. 数据映射…...

第二十一讲 XGBoost 回归建模 + SHAP 可解释性分析(利用R语言内置数据集)

下面我将使用 R 语言内置的 mtcars 数据集&#xff0c;模拟一个完整的 XGBoost 回归建模 SHAP 可解释性分析 实战流程。我们将以预测汽车的油耗&#xff08;mpg&#xff09;为目标变量&#xff0c;构建 XGBoost 模型&#xff0c;并用 SHAP 来解释模型输出。 &#x1f697; 示例…...

数据分析实战案例:使用 Pandas 和 Matplotlib 进行居民用水

原创 IT小本本 IT小本本 2025年04月15日 18:31 北京 本文将使用 Matplotlib 及 Seaborn 进行数据可视化。探索如何清理数据、计算月度用水量并生成有价值的统计图表&#xff0c;以便更好地理解居民的用水情况。 数据处理与清理 读取 Excel 文件 首先&#xff0c;我们使用 pan…...

Asp.NET Core WebApi 创建带鉴权机制的Api

构建一个包含 JWT&#xff08;JSON Web Token&#xff09;鉴权的 Web API 是一种常见的做法&#xff0c;用于保护 API 端点并验证用户身份。以下是一个基于 ASP.NET Core 的完整示例&#xff0c;展示如何实现 JWT 鉴权。 1. 创建 ASP.NET Core Web API 项目 使用 .NET CLI 或 …...

hash.

Redis 自身就是键值对结构 Redis 自身的键值对结构就是通过 哈希 的方式来组织的 哈希类型中的映射关系通常称为 field-value&#xff0c;用于区分 Redis 整体的键值对&#xff08;key-value&#xff09;&#xff0c; 注意这里的 value 是指 field 对应的值&#xff0c;不是键…...

记录鸿蒙应用上架应用未配置图标的前景图和后景图标准要求尺寸1024px*1024px和标准要求尺寸1024px*1024px

审核报错【①应用未配置图标的前景图和后景图,标准要求尺寸1024px*1024px且需下载HUAWEI DevEco Studio 5.0.5.315或以上版本进行图标再处理、②应用在展开状态下存在页面左边距过大的问题, 应用在展开状态下存在页面右边距过大的问题, 当前页面左边距: 504 px, 当前页面右边距…...

golang-常见的语法错误

https://juejin.cn/post/6923477800041054221 看这篇文章 Golang 基础面试高频题详细解析【第一版】来啦&#xff5e; 大叔说码 for-range的坑 func main() { slice : []int{0, 1, 2, 3} m : make(map[int]*int) for key, val : range slice {m[key] &val }for k, v : …...

Google最新《Prompt Engineering》白皮书全解析

近期有幸拿到了Google最新发布的《Prompt Engineering》白皮书&#xff0c;这是一份由Lee Boonstra主笔&#xff0c;Michael Sherman、Yuan Cao、Erick Armbrust、Antonio Gulli等多位专家共同贡献的权威性指南&#xff0c;发布于2025年2月。今天我想和大家分享这份68页的宝贵资…...

如何快速部署基于Docker 的 OBDIAG 开发环境

很多开发者对 OceanBase的 SIG社区小组很有兴趣&#xff0c;但如何将OceanBase的各类工具部署在开发环境&#xff0c;对于不少开发者而言都是比较蛮烦的事情。例如&#xff0c;像OBDIAG&#xff0c;其在WINDOWS系统上配置较繁琐&#xff0c;需要单独搭建C开发环境。此外&#x…...

[LeetCode 1306] 跳跃游戏3(Ⅲ)

题面&#xff1a; LeetCode 1306 思路&#xff1a; 只要能跳到其中一个0即可&#xff0c;和跳跃游戏1/2完全不同了&#xff0c;记忆化暴搜即可。 时间复杂度&#xff1a; O ( n ) O(n) O(n) 空间复杂度&#xff1a; O ( n ) O(n) O(n) 代码&#xff1a; dfs vector<…...

spring-ai-alibaba使用Agent实现智能机票助手

示例目标是使用 Spring AI Alibaba 框架开发一个智能机票助手&#xff0c;它可以帮助消费者完成机票预定、问题解答、机票改签、取消等动作&#xff0c;具体要求为&#xff1a; 基于 AI 大模型与用户对话&#xff0c;理解用户自然语言表达的需求支持多轮连续对话&#xff0c;能…...

STM32平衡车开发实战教程:从零基础到项目精通

STM32平衡车开发实战教程&#xff1a;从零基础到项目精通 一、项目概述与基本原理 1.1 平衡车工作原理 平衡车是一种基于倒立摆原理的两轮自平衡小车&#xff0c;其核心控制原理类似于人类保持平衡的过程。当人站立不稳时&#xff0c;会通过腿部肌肉的快速调整来维持平衡。平…...

使用DeepSeek AI高效降低论文重复率

一、论文查重原理与DeepSeek降重机制 1.1 主流查重系统工作原理 文本比对算法:连续字符匹配(通常13-15字符)语义识别技术:检测同义替换和结构调整参考文献识别:区分合理引用与不当抄袭跨语言检测:中英文互译内容识别1.2 DeepSeek降重核心技术 深度语义理解:分析句子核心…...

linux多线(进)程编程——(7)消息队列

前言 现在修真界大家的沟通手段已经越来越丰富了&#xff0c;有了匿名管道&#xff0c;命名管道&#xff0c;共享内存等多种方式。但是随着深入使用人们逐渐发现了这些传音术的局限性。 匿名管道&#xff1a;只能在有血缘关系的修真者&#xff08;进程&#xff09;间使用&…...

WinForm真入门(14)——ListView控件详解

一、ListView 控件核心概念与功能 ‌ListView‌ 是 WinForm 中用于展示结构化数据的多功能列表控件&#xff0c;支持多列、多视图模式及复杂交互&#xff0c;常用于文件资源管理器、数据报表等场景‌。 核心特点‌&#xff1a; 支持 ‌5种视图模式‌&#xff1a;Details&…...

Python + Playwright:规避常见的UI自动化测试反模式

Python + Playwright:规避常见的UI自动化测试反模式 前言反模式一:整体式页面对象(POM)反模式二:具有逻辑的页面对象 - POM 的“越界”行为反模式三:基于 UI 的测试设置 - 缓慢且脆弱的“舞台搭建”反模式四:功能测试过载 - “试图覆盖一切”的测试反模式之间的关联与核…...

从服务器多线程批量下载文件到本地

1、客户端安装 aria2 下载地址&#xff1a;aria2 解压文件&#xff0c;然后将文件目录添加到系统环境变量Path中&#xff0c;然后打开cmd&#xff0c;输入&#xff1a;aria2c 文件地址&#xff0c;就可以下载文件了 2、服务端配置nginx文件服务器 server {listen 8080…...

循环神经网络 - 深层循环神经网络

如果将深度定义为网络中信息传递路径长度的话&#xff0c;循环神经网络可以看作既“深”又“浅”的网络。 一方面来说&#xff0c;如果我们把循环网络按时间展开&#xff0c;长时间间隔的状态之间的路径很长&#xff0c;循环网络可以看作一个非常深的网络。 从另一方面来 说&…...

linux运维篇-Ubuntu(debian)系操作系统创建源仓库

适用范围 适用于Ubuntu&#xff08;Debian&#xff09;及其衍生版本的linux系统 例如&#xff0c;国产化操作系统kylin-desktop-v10 简介 先来看下我们需要创建出来的仓库目录结构 Deb_conf_test apt源的主目录 conf 配置文件存放目录 conf目录下存放两个配置文件&…...

深度学习之微积分

2.4.1 导数和微分 2.4.2 偏导数 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/17227e00adb14472902baba4da675aed.png 2.4.3 梯度 具体证明&#xff0c;矩阵-向量积...

20242817李臻《Linux⾼级编程实践》第7周

20242817李臻《Linux⾼级编程实践》第7周 一、AI对学习内容的总结 第八章&#xff1a;多线程编程 8.1 多线程概念 进程与线程的区别&#xff1a; 进程是资源分配单位&#xff0c;拥有独立的地址空间、全局变量、打开的文件等。线程是调度单位&#xff0c;在同一进程内的线程…...

浙江大学:DeepSeek如何引领智慧医疗的革新之路?|48页PPT下载方法

导 读INTRODUCTION 随着人工智能技术的飞速发展&#xff0c;DeepSeek等大模型正在引领医疗行业进入一个全新的智慧医疗时代。这些先进的技术不仅正在改变医疗服务的提供方式&#xff0c;还在提高医疗质量和效率方面展现出巨大潜力。 想象一下&#xff0c;当你走进医院&#xff…...