当前位置: 首页 > article >正文

PyTorch中的Flatten

在 PyTorch 中,Flatten 操作是将多维张量转换为一维向量的重要操作,常用于卷积神经网络(CNN)的全连接层之前。以下是 PyTorch 中实现 Flatten 的各种方法及其应用场景。

一、基本 Flatten 方法

1. 使用 torch.flatten() 函数

import torch# 创建一个4D张量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28)  # 32张28x28的RGB图像# 展平整个张量
flattened = torch.flatten(x)  # 输出形状: [75264] (32*3*28*28)# 从指定维度开始展平
flattened = torch.flatten(x, start_dim=1)  # 输出形状: [32, 2352] (保持batch维度)

2. 使用 nn.Flatten 层

import torch.nn as nnflatten = nn.Flatten()  # 默认从第1维开始展平(保持batch维度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 2352]

 可以指定开始和结束维度:

flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 84, 28] (合并了第1和2维)

二、不同场景下的 Flatten 应用

1. CNN 中的典型用法

class CNN(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 16, 3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, 3),nn.ReLU(),nn.MaxPool2d(2))self.flatten = nn.Flatten()self.fc = nn.Linear(32 * 5 * 5, 10)  # 计算展平后的尺寸def forward(self, x):x = self.conv_layers(x)x = self.flatten(x)  # 形状从 [B, 32, 5, 5] 变为 [B, 800]x = self.fc(x)return x

 2. 手动计算展平后的尺寸

# 计算卷积层输出尺寸的辅助函数
def conv_output_size(input_size, kernel_size, stride=1, padding=0):return (input_size - kernel_size + 2 * padding) // stride + 1# 计算经过多层卷积和池化后的尺寸
h, w = 28, 28  # 输入尺寸
h = conv_output_size(h, 3)  # conv1: 26
w = conv_output_size(w, 3)  # conv1: 26
h = conv_output_size(h, 2, 2)  # pool1: 13
w = conv_output_size(w, 2, 2)  # pool1: 13
h = conv_output_size(h, 3)  # conv2: 11
w = conv_output_size(w, 3)  # conv2: 11
h = conv_output_size(h, 2, 2)  # pool2: 5
w = conv_output_size(w, 2, 2)  # pool2: 5
print(f"展平后的特征数: {32 * h * w}")  # 32 * 5 * 5 = 800

三、高级用法

1. 部分展平

# 只展平图像空间维度,保留通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2)  # 形状: [32, 3, 784]

 2. 自定义 Flatten 层

class ChannelLastFlatten(nn.Module):"""将通道维度移到最后的展平层"""def forward(self, x):# 输入形状: [B, C, H, W]x = x.permute(0, 2, 3, 1)  # [B, H, W, C]return x.reshape(x.size(0), -1)  # [B, H*W*C]

3. 展平特定维度

# 展平批量维度和通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1)  # 形状: [96, 28, 28] (32*3=96)

四、注意事项

  1. 维度计算:确保展平后的尺寸与全连接层的输入尺寸匹配

  2. 批量维度:通常保留第0维(batch维度)不被展平

  3. 内存连续性view()需要连续内存,必要时先调用contiguous()

  4. 替代方法x.view(x.size(0), -1)flatten(start_dim=1)的常见替代写法

五、性能比较

方法优点缺点
torch.flatten()官方推荐,可读性好
nn.Flatten()可作为网络层使用需要实例化对象
x.view()最简洁需要手动计算尺寸
x.reshape()自动处理内存连续性性能略低于view

六、示例代码

import torch
import torch.nn as nn# 定义一个包含Flatten的完整模型
class ImageClassifier(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.flatten = nn.Flatten()self.classifier = nn.Sequential(nn.Linear(256 * 4 * 4, 1024),  # 假设输入图像是32x32nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.features(x)x = self.flatten(x)x = self.classifier(x)return x# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32)  # batch=16, 3通道, 32x32图像
output = model(input_tensor)
print(output.shape)  # 输出形状: [16, 10]

相关文章:

PyTorch中的Flatten

在 PyTorch 中,Flatten 操作是将多维张量转换为一维向量的重要操作,常用于卷积神经网络(CNN)的全连接层之前。以下是 PyTorch 中实现 Flatten 的各种方法及其应用场景。 一、基本 Flatten 方法 1. 使用 torch.flatten() 函数 import torch# 创建一个4…...

深入浅出动态规划:从基础到蓝桥杯实战(Java版)

引言:为什么你需要掌握动态规划? 动态规划(DP)是算法竞赛和面试中的常客,不仅能大幅提升解题效率(时间复杂度通常为O(n)或O(n))[4],更是解决复杂优化问题的利器。统计显示&#xff…...

VS Code-i18n Ally国际化插件

前言 本文借鉴:i18n Ally 插件帮你轻松搞定国际化需求-按模块划分i18n Ally 是一款 VS Code 插件,它能通过可视 - 掘金本来是没有准备将I18n Ally插件单独写一个博客的,但是了解过后,功能强大,使用方便,解决…...

YOLO中mode.predict()参数详解

Inference arguments: ArgumentTypeDefaultDescriptionsourcestr‘ultralytics/assets’指定推理的数据源。可以是图像路径、视频文件、目录、URL 或实时源的设备 ID。支持多种格式和数据源,可在不同类型的输入中灵活应用。conffloat0.25设置检测的最小置信度阈值。…...

收敛算法有多少?

收敛算法是指在迭代计算过程中,能够使序列或函数逐渐逼近某个极限值或最优解的算法。常见的收敛算法有以下几种: 梯度下降法(Gradient Descent) 原理:通过沿着目标函数的负梯度方向更新参数,使得目标函数…...

在亚马逊云科技上使用n8n快速构建个人AI NEWS助理

前言: N8n 是一个强大的工作流自动化工具,它允许您连接不同的应用程序、服务和系统,以创建自动化工作流程,并且采用了开源MIT协议,可以放心使用,他的官方网站也提供了很多的工作流,大家有兴趣的…...

STM32单片机入门学习——第27节: [9-3] USART串口发送串口发送+接收

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难,但我还是想去做! 本文写于:2025.04.08 STM32开发板学习——第27节: [9-3] USART串口发送&串口发送接收 前言开发板说…...

python 3.9 随机生成 以UTF-8 编码 的随机中文

理论实践 因为python3的默认编码为UTF-8,我们将‘浪’的utf8\u6d6a进行打印测试 print(\u6d6a) >>浪 中文匹配范围有两种 [\u4e00-\u9fa5]和[\u2E80-\u9FFF],后者包括了日韩地区的汉字 由于utf采用16进制,则需要进行一个进制的变换&a…...

数字电子技术基础(四十)——使用Digital软件和Multisim软件模拟显示译码器

目录 1 使用Digital软件模拟显示译码器 1.1 原理介绍 1.2 器件选择 1.3 电路运行 1.4 结果分析 2 使用Multisim软件模拟显示译码器 2.1 器件选择 2.2 电路运行 1 使用Digital软件模拟显示译码器 1.1 原理介绍 7448常用于驱动7段显示译码器。如下所示为7448驱动BS201A…...

第十四届蓝桥杯大赛软件赛国赛C/C++研究生组

研究生C国赛软件大赛 题一:混乘数字题二:钉板上的正方形题三:整数变换题四:躲炮弹题五:最大区间 题一:混乘数字 有一点像哈希表: 首先定义两个数组,拆分ab和n 然后令n a*b 查看两个…...

innodb如何实现mvcc的

InnoDB 实现 MVCC(多版本并发控制)的机制主要依赖于 Undo Log(回滚日志)、Read View(读视图) 和 隐藏的事务字段。以下是具体实现步骤和原理: 1. 核心数据结构 InnoDB 的每一行数据&#xff08…...

多模态大语言模型arxiv论文略读(四)

A Survey on Multimodal Large Language Models ➡️ 论文标题:A Survey on Multimodal Large Language Models ➡️ 论文作者:Shukang Yin, Chaoyou Fu, Sirui Zhao, Ke Li, Xing Sun, Tong Xu, Enhong Chen ➡️ 研究机构: 中国科学技术大学、腾讯优图…...

空对象模式(Null Object Pattern)在C#中的实现详解

一 、什么是空对象模式 空对象模模是靠”空对孔象式是书丯一种引施丼文行为,行凌,凌万成,个默疤"空象象象象来飞䛿引用用用用电从延盈盈甘仙丿引用用用职从延务在仅代砷易行行 」这种燕式亲如要目的片片 也说媚平父如如 核心思烟 定义一个人 派一个 � 创建…...

在kotlin的安卓项目中使用dagger

在 Kotlin 的 Android 项目中使用 ​​Dagger​​(特别是 ​​Dagger Hilt​​,官方推荐的简化版)进行依赖注入(DI)可以大幅提升代码的可测试性和模块化程度。 1. 配置 Dagger Hilt​​ ​​1.1 添加依赖​​ 在 bu…...

(三)链式工作流构建——打造智能对话的强大引擎

上一篇:(二)输入输出处理——打造智能对话的灵魂 在前两个阶段,我们已经搭建了一个基础的智能对话,并深入探讨了输入输出处理的细节。今天,我们将进入智能对话的高级阶段——链式工作流构建。这一阶段的目…...

python三大库之---pandas(二)

python三大库之—pandas(二) 文章目录 python三大库之---pandas(二)六,函数6.1、常用的统计学函数6.2重置索引6.3 遍历6.3.1DataFrame 遍历6.3.2 itertuples()6.3.3 使用属性遍历 6.4 排序6.4.1 sort_index6.4.2 sort_…...

php7.4.3连接MSsql server方法

需要下载安装Microsoft Drivers for PHP for SQL Server驱动, https://download.csdn.net/download/tjsoft/90568178 实操Win2008IISphp7.4.3连接SqlServer2008数据库所有安装包资源-CSDN文库 适用于 SQL Server 的 PHP 的 Microsoft 驱动程序支持与 SQL Server …...

Flask返回文件方法详解

在 Flask 中返回文件可以通过 send_file 或 send_from_directory 方法实现。以下是详细方法和示例: 1. 使用 send_file 返回文件 这是最直接的方法,适用于返回任意路径的文件。 from flask import Flask, send_fileapp = Flask(__name__)@app.route("/download")…...

JS中的Promise对象

基本概念 Promise 是 JavaScript 中用于处理异步操作的对象。它代表一个异步操作的最终完成及其结果值。Promise 提供了一种更优雅的方式来处理异步代码,避免了传统的回调地狱。 Promise 有三种状态 Pending(等待中):初始状态&…...

macOS设置定时播放眼保健操

文章目录 1. ✅方法一:直接基于日历2. 方法二:基于脚本2.1 音乐文件获取(ncm转mp3)2.2 创建播放音乐任务2.3 脚本实现定时播放 1. ✅方法一:直接基于日历 左侧新建一个日历,不然会和其他日历混淆,看起来会有点乱 然后…...

Python 小练习系列 | Vol.14:掌握偏函数 partial,用函数更丝滑!

🧩 Python 小练习系列 | Vol.14:掌握偏函数 partial,用函数更丝滑! 本节的 Python 小练习系列我们将聚焦一个 冷门但高能 的工具 —— functools.partial。它的作用类似于“函数的预设模板”,能帮你写出更加灵活、优雅…...

记录学习的第二十三天

老样子,每日一题开胃。 我一开始还想着暴力解一下试试呢,结果不太行😂 接着两道动态规划。 这道题我本来是想用最长递增子序列来做的,不过实在是太麻烦了,实在做不下去了。 然后看了题解,发现可以倒着数。 …...

Web品质 - 重要的HTML元素

Web品质 - 重要的HTML元素 在构建一个优秀的Web页面时,HTML元素的选择和运用至关重要。这些元素不仅影响页面的结构,还直接关系到页面的可用性、可访问性和SEO表现。本文将深入探讨一些关键的HTML元素,并解释它们在提升Web品质方面的重要性。 1. <html> 根元素 HTM…...

SpringBoot整合sa-token,Redis:解决重启项目丢失登录态问题

SpringBoot整合sa-token&#xff0c;Redis&#xff1a;解决重启项目丢失登录态问题 &#x1f525;1. 痛点直击&#xff1a;为什么登录状态会消失&#xff1f;2.实现方案2.1.导入依赖2.2.新增yml配置文件 3.效果图4.结语 &#x1f600;大家好&#xff01;我是向阳&#x1f31e;&…...

Python 字典和集合(子类化UserDict)

本章内容的大纲如下&#xff1a; 常见的字典方法 如何处理查找不到的键 标准库中 dict 类型的变种set 和 frozenset 类型 散列表的工作原理 散列表带来的潜在影响&#xff08;什么样的数据类型可作为键、不可预知的 顺序&#xff0c;等等&#xff09; 子类化UserDict 就创造自…...

npm fund 命令的作用

运行别人的项目遇到这个问题&#xff1a; npm fund 命令的作用 npm fund 是 npm 提供的命令&#xff0c;用于显示项目依赖中哪些包需要资金支持。这些信息来自包的 package.json 中定义的 funding 字段&#xff0c;目的是帮助开发者了解如何支持开源维护者。 典型场景示例 假…...

ES:账号、索引、ILM

目录 笔记1&#xff1a;账号权限查看、查看账号、创建账号等查看所有用户查看特定用户验证权限修改用户权限删除用户 笔记2&#xff1a;索引状态和内容的查看等查看所有索引查看特定索引内容查看索引映射查看索引设置查看索引统计信息查看ILM策略 笔记1&#xff1a;账号权限查看…...

哈希表(开散列)的实现

目录 引入 开散列的底层实现 哈希表的定义 哈希表的扩容 哈希表的插入 哈希表查找 哈希表的删除 引入 接上一篇&#xff0c;我们使用了闭散列的方法解决了哈希冲突&#xff0c;此篇文章将会使用开散列的方式解决哈希冲突&#xff0c;后面对unordered_set和unordered_map的…...

#在docker中启动mysql之类的容器时,没有挂载的数据...在后期怎么把数据导出外部

如果要导出 Docker 容器内的 整个目录&#xff08;包含所有文件及子目录&#xff09;&#xff0c;可以使用以下几种方法&#xff1a; 方法 1&#xff1a;使用 docker cp 直接复制目录到宿主机 适用场景&#xff1a;容器正在运行或已停止&#xff08;但未删除&#xff09;。 命…...

[蓝桥杯] 挖矿(CC++双语版)

题目链接 P10904 [蓝桥杯 2024 省 C] 挖矿 - 洛谷 题目理解 我们可以将这道题中矿洞的位置理解成为一个坐标轴&#xff0c;以题目样例绘出坐标轴&#xff1a; 样例&#xff1a; 输入的5为矿洞数量&#xff0c;4为可走的步数。第二行输入是5个矿洞的坐标。输出结果为在要求步数…...