【Block总结】掩码窗口自注意力 (M-WSA)

摘要
论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。
设计原理
-
窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。
-
掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。
-
扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。

优势
-
高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。
-
信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。
-
灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。
实验结果
在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。
代码
Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1) # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1 # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模块oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向传播output = oca_attention(x)print("input张量形状:", x.shape)print("output张量形状:", output.shape)
DilatedOCA模块详解
代码结构
import torch
import torch.nn as nn
from einops import rearrange
- 导入库:首先导入 PyTorch 和 einops 库。
einops用于简化张量的重排操作。
模块定义
class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
-
初始化方法:
__init__方法定义了模块的结构。-
dim:输入特征的通道数。 -
window_size:窗口的大小,用于空间注意力计算。 -
overlap_ratio:重叠窗口的比例,决定了窗口之间的重叠程度。 -
num_heads:空间注意力的头数。 -
dim_head:每个头的维度。
-
-
层的定义:
-
self.unfold:用于将输入张量展开为重叠窗口的操作。 -
self.qkv:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。 -
self.project_out:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。 -
self.rel_pos_emb和self.fixed_pos_emb:用于位置编码的模块,增强模型对空间位置的感知。
-
前向传播
def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
-
输入形状:
x的形状为(batch_size, channels, height, width),其中b是批量大小,c是通道数,h和w是图像的高度和宽度。 -
特征提取:
-
qkv = self.qkv(x):通过qkv层生成 Q、K、V 特征图。 -
qs, ks, vs = qkv.chunk(3, dim=1):将 Q、K、V 特征图沿通道维度分离。
-
-
空间注意力计算:
-
qs被重排为适合空间注意力计算的格式。 -
ks和vs通过unfold操作展开为重叠窗口。
-
-
分头处理:
- 使用
einops.rearrange将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
- 使用
-
计算注意力:
-
qs = qs * self.scale:对 Q 进行缩放以提高稳定性。 -
spatial_attn = (qs @ ks.transpose(-2, -1)):计算注意力分数。 -
spatial_attn += self.rel_pos_emb(qs)和spatial_attn += self.fixed_pos_emb():添加位置编码以增强空间感知。 -
spatial_attn = spatial_attn.softmax(dim=-1):对注意力分数进行 softmax 归一化。
-
-
输出计算:
out = (spatial_attn @ vs):使用注意力权重对 V 进行加权求和,得到最终输出。
-
重排输出:
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...):将输出重排回原始形状。
-
最终投影:
out = self.project_out(out):通过投影层将输出映射回原始通道数。
总结
DilatedOCA 模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。
相关文章:
【Block总结】掩码窗口自注意力 (M-WSA)
摘要 论文链接:https://arxiv.org/pdf/2404.07846 论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在…...
用 HTML5 Canvas 和 JavaScript 实现雪花飘落特效
这篇文章将带您深入解析使用 HTML5 Canvas 和 JavaScript 实现动态雪花特效的代码原理。 1,效果展示 该效果模拟了雪花从天而降的动态场景,具有以下特点: 雪花数量、大小、透明度和下落速度随机。雪花会在屏幕底部重置到顶部,形成循环效果。随窗口大小动态调整,始终覆盖…...
【cocos creator】【ts】事件派发系统
触发使用: EventTool.emit(“onClick”) 需要监听的地方,onload调用: EventTool.on(“onClick”, this.onClickEvent, this) /**事件派发*/class EventTool {protected static _instance: EventTool null;public static get Instance(): Eve…...
《探索鸿蒙Next上开发人工智能游戏应用的技术难点》
在科技飞速发展的当下,鸿蒙Next系统为应用开发带来了新的机遇与挑战,开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点: 鸿蒙Next系统适配性 多设备协同:鸿蒙Next的一大特色…...
CSS | CSS实现两栏布局(左边定宽 右边自适应,左右成比自适应)
目录 一、左边定宽 右边自适应 1.浮动 2.利用浮动margin 3.定位margin 4.flex布局 5.table 布局 二、左右成比自适应 1:1 1flex布局 table布局 1:2 flex布局 <div class"father"><div class"left">左边自适应</div><div class"r…...
acwing_3195_有趣的数
acwing_3195_有趣的数 // // Created by HUAWEI on 2024/11/17. // #include<iostream> #include<cstring> #include<algorithm>#define int long longusing namespace std;const int N 1000 50; const int MOD 1e9 7; int C[N][N]; //组合数signed mai…...
Liunx-搭建安装VSOMEIP环境教程 执行 运行VSOMEIP示例demo
本文安装环境为Liunx,搭建安装VSOMEIP环境并运行基础例子。 1. 安装基础环境 使用apt-get来安装基础环境,受网络影响可以分开多次安装。环境好的也可以一次性执行。 sudo apt-get install gcc g sudo apt-get install cmake sudo apt-get install lib…...
Git | git revert命令详解
关注:CodingTechWork 引言 Git 是一个强大的版本控制工具,广泛应用于现代软件开发中。它为开发人员提供了多种功能来管理代码、协作开发和版本控制。在 Git 中,有时我们需要撤销或回退某些提交,而git revert 是一个非常有用的命令…...
ASP.NET Core 中,Cookie 认证在集群环境下的应用
在 ASP.NET Core 中,Cookie 认证在集群环境下的应用通常会遇到一些挑战。主要的问题是 Cookie 存储在客户端的浏览器中,而认证信息(比如 Session 或身份令牌)通常是保存在 Cookie 中,多个应用实例需要共享这些 Cookie …...
Flyte工作流平台调研(五)——扩展集成
系列文章: Flyte工作流平台调研(一)——整体架构 Flyte工作流平台调研(二)——核心概念说明 Flyte工作流平台调研(三)——核心组件原理 Flyte工作流平台调研(四)——…...
【AUTOSAR 基础软件】软件组件的建立与使用(“代理”SWC)
基础软件往往需要建立一些“代理”SWC来完成一些驱动的抽象工作(Complex_Device_Driver_Sw或者Ecu_Abstraction_Sw等),或建立Application Sw Component来补齐基础软件需要提供的功能实现。当面对具体的项目时,基础软件开发人员还可…...
java通过ocr实现识别pdf中的文字
需求:识别pdf文件中的中文 根据github项目mymonstercat 改造,先将pdf文件转为png文件存于临时文件夹,然后通过RapidOcr转为文字,最后删除临时文件夹 1、引入依赖 <dependency><groupId>org.apache.pdfbox</groupId><artifactId&g…...
Git 命令代码管理详解
一、Git 初相识:版本控制的神器 在当今的软件开发领域,版本控制如同基石般重要,而 Git 无疑是其中最耀眼的明珠。它由 Linus Torvalds 在 2005 年创造,最初是为了更好地管理 Linux 内核源代码。随着时间的推移,Git 凭借…...
Docker的安装和使用
容器技术 容器与虚拟机的区别 虚拟机 (VM) VM包含完整的操作系统,并在虚拟化层之上运行多个操作系统实例。 VM需要更多的系统资源(CPU、内存、存储)来管理这些操作系统实例。 容器 (Container) 容器共享主机操作系统的内核,具…...
Flink系统知识讲解之:Flink内存管理详解
Flink系统知识讲解之:Flink内存管理详解 在现阶段,大部分开源的大数据计算引擎都是用Java或者是基于JVM的编程语言实现的,如Apache Hadoop、Apache Spark、Apache Drill、Apache Flink等。Java语言的好处是不用考虑底层,降低了程…...
使用JMeter模拟多IP发送请求!
你是否曾遇到过这样的场景:使用 JMeter 进行压力测试时,单一 IP 被服务器限流或者屏蔽?这时,如何让 JMeter 模拟多个 IP 发送请求,成功突破测试限制,成为测试工程师必须攻克的难题。今天,我们就…...
【Ubuntu与Linux操作系统:六、软件包管理】
第6章 软件包管理 6.1 Linux软件安装基础 Linux的软件包是以二进制或源码形式发布的程序集合,包含程序文件和元数据。软件包管理器是Linux系统的重要工具,用于安装、更新和卸载软件。 1. 常见的软件包管理器: DEB 系统(如Ubunt…...
【数据结构-堆】力扣1834. 单线程 CPU
给你一个二维数组 tasks ,用于表示 n 项从 0 到 n - 1 编号的任务。其中 tasks[i] [enqueueTimei, processingTimei] 意味着第 i 项任务将会于 enqueueTimei 时进入任务队列,需要 processingTimei 的时长完成执行。 现…...
【前端动效】原生js实现拖拽排课效果
目录 1. 效果展示 2. 效果分析 2.1 关键点 2.2 实现方法 3. 代码实现 3.1 html部分 3.2 css部分 3.3 js部分 3.4 完整代码 4. 总结 1. 效果展示 如图所示,页面左侧有一个包含不同课程(如语文、数学等)的列表,页面右侧…...
C#使用OpenTK绘制3D可拖动旋转图形三棱锥
接上篇,绘制着色矩形 C#使用OpenTK绘制一个着色矩形-CSDN博客 上一篇安装OpenTK.GLControl后,这里可以直接拖动控件GLControl 我们会发现GLControl继承于UserControl //// 摘要:// OpenGL-aware WinForms control. The WinForms designer will always call the default//…...
后进先出(LIFO)详解
LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子(…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】
微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个生活电费的缴纳和查询小程序
一、项目初始化与配置 1. 创建项目 ohpm init harmony/utility-payment-app 2. 配置权限 // module.json5 {"requestPermissions": [{"name": "ohos.permission.INTERNET"},{"name": "ohos.permission.GET_NETWORK_INFO"…...
今日科技热点速览
🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
AI书签管理工具开发全记录(十九):嵌入资源处理
1.前言 📝 在上一篇文章中,我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源,方便后续将资源打包到一个可执行文件中。 2.embed介绍 🎯 Go 1.16 引入了革命性的 embed 包,彻底改变了静态资源管理的…...
USB Over IP专用硬件的5个特点
USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中,从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备(如专用硬件设备),从而消除了直接物理连接的需要。USB over IP的…...
Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成
一个面向 Java 开发者的 Sring-Ai 示例工程项目,该项目是一个 Spring AI 快速入门的样例工程项目,旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计,每个模块都专注于特定的功能领域,便于学习和…...
Leetcode33( 搜索旋转排序数组)
题目表述 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...
