【AI模型学习】上/下采样
文章目录
- 分割中的上/下采样
- 下采样
- SegFormer和PVT(使用卷积)
- Swin-Unet(使用 Patch Merging)
- 上采样
- SegFormer(interpolate)
- Swin-Unet(Patch Expanding)
- 逐级interpolate的方式
- 反卷的方式
在基于Transformer架构的图像分割模型(如 SegFormer、Swin-Unet)中,上采样和下采样结构几乎是标准配置。
分割中的上/下采样
为什么需要下采样?
-
提取高层语义特征:
Transformer擅长全局建模,结合下采样可以:降低分辨率;聚焦于更宽范围的上下文。 -
减少计算成本:
原始输入图像太大,直接送入多层Transformer(特别是多头注意力)会导致计算量和显存爆炸。
为什么要上采样?
-
恢复空间分辨率:
Segmentation任务最终要输出与输入图像同样大小的分割mask; -
细粒度定位:
但如果没有上采样、跳跃连接或融合,容易失去细节;所以上采样常结合UNet-like结构来补偿细节损失。
模型 | 下采样方式 | 上采样方式 |
---|---|---|
SETR | ViT backbone patchify | 多层反卷积上采样 |
SegFormer | MLP Mixer + 4阶段卷积下采样 | 多层插值 + FFN |
Swin-Unet | Swin Transformer 下采样 | Patch expanding + skip连接 |
下采样
SegFormer和PVT(使用卷积)
# 输入:img.shape = [B, 3, 512, 512]# Stage 1
x1 = Conv2d(3, 32, kernel_size=7, stride=4, padding=3)(img) # → [B, 32, 128, 128]
x1 = x1.flatten(2).transpose(1, 2) # → [B, 16384, 32]
x1 = TransformerBlock(x1)# Stage 2
x2 = Conv2d(32, 64, kernel_size=3, stride=2, padding=1)(x1_reshaped) # → [B, 64, 64, 64]
x2 = x2.flatten(2).transpose(1, 2) # → [B, 4096, 64]
x2 = TransformerBlock(x2)# 后面还有 Stage3、Stage4 类似
Shape 演化:
Stage1: [B, 128×128=16384, 32]
Stage2: [B, 64×64=4096, 64]
Stage3: [B, 32×32=1024, 160]
Stage4: [B, 16×16=256, 256]
Swin-Unet(使用 Patch Merging)
# 初始 patch embedding(patch_size=4)
x = Conv2d(3, 96, kernel_size=4, stride=4)(img) # [B, 96, 128, 128]
x = x.flatten(2).transpose(1, 2) # → [B, 16384, 96]# Stage 1
x = SwinBlock(x) # [B, 16384, 96]
x = PatchMerging(x) # → [B, 4096, 192]# Stage 2
x = SwinBlock(x) # [B, 4096, 192]
x = PatchMerging(x) # → [B, 1024, 384]# Stage 3
x = SwinBlock(x)
x = PatchMerging(x) # → [B, 256, 768]
Shape 演化:
Stage0: [B, 128×128=16384, 96]
Stage1: [B, 64×64=4096, 192]
Stage2: [B, 32×32=1024, 384]
Stage3: [B, 16×16=256, 768]
过程不难,只是不好描述,可以看相关教程,这里就把代码贴出来
class PatchMerging(nn.Module):def __init__(self, in_dim):super().__init__()self.reduction = nn.Linear(in_dim * 4, in_dim * 2)def forward(self, x, H, W):# x: [B, H*W, C] → [B, H, W, C]x = x.view(B, H, W, C)# 拆分四个方向的 tokenx0 = x[:, 0::2, 0::2, :] # top-leftx1 = x[:, 1::2, 0::2, :] # bottom-leftx2 = x[:, 0::2, 1::2, :] # top-rightx3 = x[:, 1::2, 1::2, :] # bottom-rightx = torch.cat([x0, x1, x2, x3], dim=-1) # → [B, H/2, W/2, 4C]x = x.view(B, -1, 4 * C) # → [B, H/2*W/2, 4C]x = self.reduction(x) # → [B, H/2*W/2, 2C]return x
上采样
SegFormer(interpolate)
def forward(self, x1, x2, x3, x4): # 输入来自4个Stage:# x1: [B, 128*128, 32]# x2: [B, 64*64, 64]# x3: [B, 32*32, 160]# x4: [B, 16*16, 256]B = x1.shape[0]# === 1. Linear Projection:通道都投影为 256 ===_x1 = self.linear1(x1).permute(0, 2, 1).reshape(B, 256, 128, 128) # [B, 256, 128, 128]_x2 = self.linear2(x2).permute(0, 2, 1).reshape(B, 256, 64, 64) # [B, 256, 64, 64]_x3 = self.linear3(x3).permute(0, 2, 1).reshape(B, 256, 32, 32) # [B, 256, 32, 32]_x4 = self.linear4(x4).permute(0, 2, 1).reshape(B, 256, 16, 16) # [B, 256, 16, 16]# === 2. 上采样到统一大小 ===_x2 = F.interpolate(_x2, size=(128, 128), mode='bilinear', align_corners=False) # [B, 256, 128, 128]_x3 = F.interpolate(_x3, size=(128, 128), mode='bilinear', align_corners=False) # [B, 256, 128, 128]_x4 = F.interpolate(_x4, size=(128, 128), mode='bilinear', align_corners=False) # [B, 256, 128, 128]# === 3. 拼接所有层 ===fused = torch.cat([_x1, _x2, _x3, _x4], dim=1) # [B, 4*256=1024, 128, 128]# === 4. 1x1卷积融合通道数 ===out = self.fuse_conv(fused) # [B, 256, 128, 128]return out
Swin-Unet(Patch Expanding)
看图也能看出来,十分经典的U-Net结构。
在上采样阶段
输入:
一个高语义 token,维度为 [4C]
,是上一步 Patch Merging 得到的。
-
Linear 映射
将[4C]
投影为[C] × 4
,也就是还原为 2×2 patch 每格的C
维向量。 -
reshape → [H, W, 2, 2, C] → [2H, 2W, C]
把这 4 个 token 安排到一个新的空间位置(上采样 ×2)。 -
最终输出为:
Token 数量 × 4 , 通道数 ÷ 2 \text{Token 数量} \times 4,\quad \text{通道数} \div 2 Token 数量×4,通道数÷2
class PatchExpanding(nn.Module):def __init__(self, in_dim, expand_ratio=2):super().__init__()# Linear: [B, H*W, in_dim] → [B, H*W, out_dim = (expand_ratio^2) * out_channels]# 例如:in_dim = 512,expand_ratio = 2 → 输出 4×C = 1024self.linear = nn.Linear(in_dim, in_dim // 2 * expand_ratio**2)self.expand_ratio = expand_ratiodef forward(self, x, H, W):# x: [B, H*W, C]B, N, C = x.shapeR = self.expand_ratio # 通常为 2# 线性投影:C → 4 * (C/2),也就是 [B, H*W, 4*C'],每个 token 展开为 2×2 的 patchx = self.linear(x) # [B, H*W, 4*C'] = [B, H*W, R*R*(C//2)]# reshape 成图像形式,带有 2×2 子结构 → [B, H, W, R, R, C']x = x.view(B, H, W, R, R, C // 2) # [B, H, W, 2, 2, C//2]# 调整顺序,将 2×2 子结构移入空间维度 → [B, H*2, W*2, C//2]x = x.permute(0, 1, 3, 2, 4, 5) # [B, H, 2, W, 2, C//2]x = x.reshape(B, H * R, W * R, C // 2) # [B, 2H, 2W, C//2]# flatten 成 token 序列形式(可再送入 Transformer)→ [B, 4*H*W, C//2]x = x.view(B, -1, C // 2) # [B, 4*H*W, C//2]return x
逐级interpolate的方式
-
输入来自编码器 4 个 stage:
x4
:[16×16, 512] ← 最深层x3
:[32×32, 320]x2
:[64×64, 128]x1
:[128×128, 64] ← 最浅层
-
通道统一:
每个特征图先通过 1×1 卷积或 Linear 映射,统一成相同维度(如全部 → 256 或 512) -
上采样与融合(逐级):
f4 = Conv(x4) # [16×16] f3 = F.interpolate(f4, scale=2) + Conv(x3) # → [32×32] f2 = F.interpolate(f3, scale=2) + Conv(x2) # → [64×64] f1 = F.interpolate(f2, scale=2) + Conv(x1) # → [128×128]
反卷的方式
很经典的设计,不必过多介绍。
相关文章:

【AI模型学习】上/下采样
文章目录 分割中的上/下采样下采样SegFormer和PVT(使用卷积)Swin-Unet(使用 Patch Merging) 上采样SegFormer(interpolate)Swin-Unet(Patch Expanding)逐级interpolate的方式反卷的方…...

Unity Shader入门(更新中)
参考书籍:UnityShader入门精要(冯乐乐著) 参考视频:Bilibili《Unity Shader 入门精要》 写在前面:前置知识需要一些计算机组成原理、线性代数、Unity的基础 这篇记录一些学历过程中的理解和笔记(更新中&…...

嵌入式学习的第二十六天-系统编程-文件IO+目录
一、文件IO相关函数 1.read/write cp #include <fcntl.h> #include <sys/stat.h> #include <sys/types.h> #include <stdio.h> #include<unistd.h> #include<string.h>int main(int argc, char **argv) {if(argc<3){fprintf(stderr, …...

珠宝课程小程序源码介绍
这款珠宝课程小程序源码,基于ThinkPHPFastAdminUniApp开发,功能丰富且实用。ThinkPHP提供稳定高效的后台服务,保障数据安全与处理速度;FastAdmin助力快速搭建管理后台,提升开发效率;UniApp则让小程序能多端…...

KNN模型思想与实现
KNN算法简介 核心思想:通过样本在特征空间中k个最相似样本的多数类别来决定其类别归属。"附近的邻居确定你的属性"是核心逻辑 决策依据:采用"多数表决"原则,即统计k个最近邻样本中出现次数最多的类别 样本相似性度量 …...
【信息系统项目管理师】第15章:项目风险管理 - 55个经典题目及详解
更多内容请见: 备考信息系统项目管理师-专栏介绍和目录 文章目录 【第1题】【第2题】【第3题】【第4题】【第5题】【第6题】【第7题】【第8题】【第9题】【第10题】【第11题】【第12题】【第13题】【第14题】【第15题】【第16题】【第17题】【第18题】【第19题】【第20题】【第…...

fscan教程1-存活主机探测与端口扫描
实验目的 本实验主要介绍fscan工具信息收集功能,对同一网段的主机进行存活探测以及常见服务扫描。 技能增长 通过本次实验的学习,了解信息收集的过程,掌握fscan工具主机探测和端口扫描功能。 预备知识 fscan工具有哪些作用? …...
蓝桥杯1447 砝码称重
问题描述 你有一架天平和 N 个砝码,这 N 个砝码重量依次是 W1,W2,⋅⋅⋅,WN。 请你计算一共可以称出多少种不同的重量? 注意砝码可以放在天平两边。 输入格式 输入的第一行包含一个整数 N。 第二行包含 N 个整数:W1,W2,W3,⋅⋅⋅,WN…...

腾讯2025年校招笔试真题手撕(三)
一、题目 今天正在进行赛车车队选拔,每一辆赛车都有一个不可以改变的速度。现在需要选取速度差距在10以内的车队(车队中速度的最大值减去最小值不大于10),用于迎宾。车队的选拔按照的是人越多越好的原则,给出n辆车的速…...

怎样通过神经网络估计股票走向
本博文将教会你如何通过神经网络建立股票模型并对其进行未来趋势估计,尽管博主已通过此方法取得一定利润,但是建议大家不要过分相信AI。本博文仅用于代码学习,请大家谨慎投资。 一、通过爬虫爬取股票往年数据 在信息爆炸的当今时代…...

【RocketMQ 生产者和消费者】- 生产者启动源码-上报生产者和消费者心跳信息到 broker(3)
文章目录 1. 前言2. sendHeartbeatToAllBrokerWithLock 上报心跳信息3. prepareHeartbeatData 准备心跳数据4. sendHearbeat 发送心跳上报请求5. broker 处理心跳请求5.1 heartBeat 处理心跳包5.2 createTopicInSendMessageBackMethod 创建重传 topic5.3 registerConsumer 注册…...

Python----循环神经网络(Word2Vec的优化)
一、负采样 基本思想: 在训练过程中,对于每个正样本(中心词和真实上下文词组成的词对),随机采样少量(如5-20个)负样本(中心词与非上下文词组成的词对)。 模型通过区分正…...

Simon J.D. Prince《Understanding Deep Learning》
学习神经网络和深度学习推荐这本书,这本书站位非常高,且很多问题都深入剖析了,甩其他同类书籍几条街。 多数书,不深度分析、没有知识体系,知识点零散、章节之间孤立。还有一些人Tian所谓的权威,醒醒吧。 …...

开搞:第四个微信小程序:图上县志
原因:我换了一个微信号来搞,因为用同一个用户,备案只能一个个的来。这样不行。所以我换了一个。原来注册过小程序。现在修改即可。注意做好计划后,速度备案和审核,不然你时间浪费不起。30元花起。 结构: -…...
模型评估与调优(PyTorch)
文章目录 模型评估方法混淆矩阵混淆矩阵中的指标ROC曲线(受试者工作特征)AUCR平方残差均方误差(MSE)均方根误差(RMSE)平均绝对误差(MAE) 模型调优方法交叉验证(CV&#x…...
sockaddr结构体详解
在网络编程中,sockaddr 结构体用于表示套接字的地址信息。由于不同协议(如 IPv4、IPv6、Unix 域套接字)的地址格式不同,实际使用中通常通过以下三种变体结构来处理不同类型的地址: 1. 通用地址结构:struct …...

Seata源码—7.Seata TCC模式的事务处理一
大纲 1.Seata TCC分布式事务案例配置 2.Seata TCC案例服务提供者启动分析 3.TwoPhaseBusinessAction注解扫描源码 4.Seata TCC案例分布式事务入口分析 5.TCC核心注解扫描与代理创建入口源码 6.TCC动态代理拦截器TccActionInterceptor 7.Action拦截处理器ActionIntercept…...

【语法】C++的map/set
目录 平衡二叉搜索树 set insert() find() erase() swap() map insert() 迭代器 erase() operator[] multiset和multimap 在之前学习的STL中,string,vector,list,deque,array都是序列式容器,它们的…...
【FAQ】HarmonyOS SDK 闭源开放能力 —Live View Kit (3)
1.问题描述: 通过Push Kit创建实况窗之后,再更新实况窗失败,平台查询提示“实况窗端更新失败,通知未创建或已经过期”。 解决方案: 通过Push Kit更新实况窗内容的过程是自动更新的。客户端在创建本地实况窗后&#…...

vue vite textarea标签按下Shift+Enter 换行输入,只按Enter则提交的实现思路
注意input标签不能实现,需要用textarea标签 直接看代码 <template><textareav-model"message"keydown.enter"handleEnter"placeholder"ShiftEnter 换行,Enter 提交"></textarea> </template>&l…...
MySQL多线程备份工具mysqlpump详解!
MySQLPUMP备份工具详解 1. 概述 MySQLPump 是 MySQL 5.7 引入的一个客户端备份工具,用于替代传统的 mysqldump 工具。它提供了并行处理、进度状态显示、更好的压缩支持等新特性,能够更高效地执行 MySQL 数据库备份操作。 2. 主要特性 并行处理&#x…...
创建信任所有证书的HttpClient:Java 实现 HTTPS 接口调用,等效于curl -k
在 Java 生态中,HttpClient 和 Feign 都是调用第三方接口的常用工具,但它们的定位、设计理念和使用场景有显著差异。以下是详细对比: DIFF1. 定位与抽象层级 特性HttpClientFeign层级底层 HTTP 客户端库(处理原始请求/响应&#…...
Redisson分布式集合原理及应用
Redisson是一个用于Redis的Java客户端,它简化了复杂的数据结构和分布式服务的使用。 适用场景对比 数据结构适用场景优点RList消息队列、任务队列、历史记录分布式共享、阻塞操作、分页查询RMap缓存、配置中心、键值关联数据支持键值对、分布式事务、TTLRSet去重集…...

深入理解 PlaNet(Deep Planning Network):基于python从零实现
引言:基于模型的强化学习与潜在动态 基于模型的强化学习(Model-based Reinforcement Learning)旨在通过学习环境动态的模型来提高样本效率。这个模型可以用来进行规划,让智能体在不需要与真实环境进行每一次决策交互的情况下&…...
精益数据分析(75/126):用户反馈的科学解读与试验驱动迭代——Rally的双向验证方法论
精益数据分析(75/126):用户反馈的科学解读与试验驱动迭代——Rally的双向验证方法论 在创业的黏性阶段,用户反馈是优化产品的重要依据,但如何避免被表面反馈误导?如何将反馈转化为可落地的迭代策略&#x…...

仿腾讯会议——视频发送接收
1、 添加音频模块 2、刷新图片,触发重绘 3、 等比例缩放视频帧 4、 新建视频对象 5、在中介者内定义发送视频帧的函数 6、完成发送视频的函数 7、 完成开启/关闭视频 8、绑定视频的信号槽函数 9、 完成开启/关闭视频 10、 完成发送视频 11、 完成刷新图片显示 12、完…...

从3.7V/5V到7.4V,FP6291在应急供电智能门锁中的应用
在智能家居蓬勃发展的当下,智能门锁以其便捷、安全的特性,成为现代家庭安防的重要组成部分。在智能门锁电量耗尽的情况下,应急电源外接移动电源(USB5V输入) FP6291升压到7.4V供电可应急开锁。增强用户在锁具的安全性、…...
java后端-海外登录(谷歌/FaceBook)
前言 由于最近公司的项目要在海外运行,因此需要对接海外的登录,目前就是谷歌和facebook两种,后面支付也是需要的,后续再进行书写 谷歌登录 这个相对比较容易,而且只提供给安卓即可,废话就不多说了,直接贴解决方案 引入maven依赖 <dependency> <groupId>com.go…...

【人工智障生成日记1】从零开始训练本地小语言模型
🎯 从零开始训练本地小语言模型:MiniGPT TinyStories(4090Ti) 🧭 项目背景 本项目旨在以学习为目的,从头构建一个完整的本地语言模型训练管线。目标是: ✅ 不依赖外部云计算✅ 完全本地运行…...

Selenium-Java版(frame切换/窗口切换)
frame切换/窗口切换 前言 切换到frame 原因 解决 切换回原来的主html 切换到新的窗口 问题 解决 回到原窗口 法一 法二 示例 前言 参考教程:Python Selenium Web自动化 2024版 - 自动化测试 爬虫_哔哩哔哩_bilibili 上期文章:Sel…...