【Block总结】掩码窗口自注意力 (M-WSA)
摘要
论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。
设计原理
-
窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。
-
掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。
-
扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。
优势
-
高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。
-
信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。
-
灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。
实验结果
在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。
代码
Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1) # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1 # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模块oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向传播output = oca_attention(x)print("input张量形状:", x.shape)print("output张量形状:", output.shape)
DilatedOCA模块详解
代码结构
import torch
import torch.nn as nn
from einops import rearrange
- 导入库:首先导入 PyTorch 和 einops 库。
einops
用于简化张量的重排操作。
模块定义
class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
-
初始化方法:
__init__
方法定义了模块的结构。-
dim
:输入特征的通道数。 -
window_size
:窗口的大小,用于空间注意力计算。 -
overlap_ratio
:重叠窗口的比例,决定了窗口之间的重叠程度。 -
num_heads
:空间注意力的头数。 -
dim_head
:每个头的维度。
-
-
层的定义:
-
self.unfold
:用于将输入张量展开为重叠窗口的操作。 -
self.qkv
:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。 -
self.project_out
:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。 -
self.rel_pos_emb
和self.fixed_pos_emb
:用于位置编码的模块,增强模型对空间位置的感知。
-
前向传播
def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
-
输入形状:
x
的形状为(batch_size, channels, height, width)
,其中b
是批量大小,c
是通道数,h
和w
是图像的高度和宽度。 -
特征提取:
-
qkv = self.qkv(x)
:通过qkv
层生成 Q、K、V 特征图。 -
qs, ks, vs = qkv.chunk(3, dim=1)
:将 Q、K、V 特征图沿通道维度分离。
-
-
空间注意力计算:
-
qs
被重排为适合空间注意力计算的格式。 -
ks
和vs
通过unfold
操作展开为重叠窗口。
-
-
分头处理:
- 使用
einops.rearrange
将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
- 使用
-
计算注意力:
-
qs = qs * self.scale
:对 Q 进行缩放以提高稳定性。 -
spatial_attn = (qs @ ks.transpose(-2, -1))
:计算注意力分数。 -
spatial_attn += self.rel_pos_emb(qs)
和spatial_attn += self.fixed_pos_emb()
:添加位置编码以增强空间感知。 -
spatial_attn = spatial_attn.softmax(dim=-1)
:对注意力分数进行 softmax 归一化。
-
-
输出计算:
out = (spatial_attn @ vs)
:使用注意力权重对 V 进行加权求和,得到最终输出。
-
重排输出:
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...)
:将输出重排回原始形状。
-
最终投影:
out = self.project_out(out)
:通过投影层将输出映射回原始通道数。
总结
DilatedOCA
模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。
相关文章:

【Block总结】掩码窗口自注意力 (M-WSA)
摘要 论文链接:https://arxiv.org/pdf/2404.07846 论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在…...
用 HTML5 Canvas 和 JavaScript 实现雪花飘落特效
这篇文章将带您深入解析使用 HTML5 Canvas 和 JavaScript 实现动态雪花特效的代码原理。 1,效果展示 该效果模拟了雪花从天而降的动态场景,具有以下特点: 雪花数量、大小、透明度和下落速度随机。雪花会在屏幕底部重置到顶部,形成循环效果。随窗口大小动态调整,始终覆盖…...
【cocos creator】【ts】事件派发系统
触发使用: EventTool.emit(“onClick”) 需要监听的地方,onload调用: EventTool.on(“onClick”, this.onClickEvent, this) /**事件派发*/class EventTool {protected static _instance: EventTool null;public static get Instance(): Eve…...
《探索鸿蒙Next上开发人工智能游戏应用的技术难点》
在科技飞速发展的当下,鸿蒙Next系统为应用开发带来了新的机遇与挑战,开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点: 鸿蒙Next系统适配性 多设备协同:鸿蒙Next的一大特色…...

CSS | CSS实现两栏布局(左边定宽 右边自适应,左右成比自适应)
目录 一、左边定宽 右边自适应 1.浮动 2.利用浮动margin 3.定位margin 4.flex布局 5.table 布局 二、左右成比自适应 1:1 1flex布局 table布局 1:2 flex布局 <div class"father"><div class"left">左边自适应</div><div class"r…...
acwing_3195_有趣的数
acwing_3195_有趣的数 // // Created by HUAWEI on 2024/11/17. // #include<iostream> #include<cstring> #include<algorithm>#define int long longusing namespace std;const int N 1000 50; const int MOD 1e9 7; int C[N][N]; //组合数signed mai…...
Liunx-搭建安装VSOMEIP环境教程 执行 运行VSOMEIP示例demo
本文安装环境为Liunx,搭建安装VSOMEIP环境并运行基础例子。 1. 安装基础环境 使用apt-get来安装基础环境,受网络影响可以分开多次安装。环境好的也可以一次性执行。 sudo apt-get install gcc g sudo apt-get install cmake sudo apt-get install lib…...
Git | git revert命令详解
关注:CodingTechWork 引言 Git 是一个强大的版本控制工具,广泛应用于现代软件开发中。它为开发人员提供了多种功能来管理代码、协作开发和版本控制。在 Git 中,有时我们需要撤销或回退某些提交,而git revert 是一个非常有用的命令…...
ASP.NET Core 中,Cookie 认证在集群环境下的应用
在 ASP.NET Core 中,Cookie 认证在集群环境下的应用通常会遇到一些挑战。主要的问题是 Cookie 存储在客户端的浏览器中,而认证信息(比如 Session 或身份令牌)通常是保存在 Cookie 中,多个应用实例需要共享这些 Cookie …...

Flyte工作流平台调研(五)——扩展集成
系列文章: Flyte工作流平台调研(一)——整体架构 Flyte工作流平台调研(二)——核心概念说明 Flyte工作流平台调研(三)——核心组件原理 Flyte工作流平台调研(四)——…...

【AUTOSAR 基础软件】软件组件的建立与使用(“代理”SWC)
基础软件往往需要建立一些“代理”SWC来完成一些驱动的抽象工作(Complex_Device_Driver_Sw或者Ecu_Abstraction_Sw等),或建立Application Sw Component来补齐基础软件需要提供的功能实现。当面对具体的项目时,基础软件开发人员还可…...
java通过ocr实现识别pdf中的文字
需求:识别pdf文件中的中文 根据github项目mymonstercat 改造,先将pdf文件转为png文件存于临时文件夹,然后通过RapidOcr转为文字,最后删除临时文件夹 1、引入依赖 <dependency><groupId>org.apache.pdfbox</groupId><artifactId&g…...

Git 命令代码管理详解
一、Git 初相识:版本控制的神器 在当今的软件开发领域,版本控制如同基石般重要,而 Git 无疑是其中最耀眼的明珠。它由 Linus Torvalds 在 2005 年创造,最初是为了更好地管理 Linux 内核源代码。随着时间的推移,Git 凭借…...
Docker的安装和使用
容器技术 容器与虚拟机的区别 虚拟机 (VM) VM包含完整的操作系统,并在虚拟化层之上运行多个操作系统实例。 VM需要更多的系统资源(CPU、内存、存储)来管理这些操作系统实例。 容器 (Container) 容器共享主机操作系统的内核,具…...

Flink系统知识讲解之:Flink内存管理详解
Flink系统知识讲解之:Flink内存管理详解 在现阶段,大部分开源的大数据计算引擎都是用Java或者是基于JVM的编程语言实现的,如Apache Hadoop、Apache Spark、Apache Drill、Apache Flink等。Java语言的好处是不用考虑底层,降低了程…...

使用JMeter模拟多IP发送请求!
你是否曾遇到过这样的场景:使用 JMeter 进行压力测试时,单一 IP 被服务器限流或者屏蔽?这时,如何让 JMeter 模拟多个 IP 发送请求,成功突破测试限制,成为测试工程师必须攻克的难题。今天,我们就…...
【Ubuntu与Linux操作系统:六、软件包管理】
第6章 软件包管理 6.1 Linux软件安装基础 Linux的软件包是以二进制或源码形式发布的程序集合,包含程序文件和元数据。软件包管理器是Linux系统的重要工具,用于安装、更新和卸载软件。 1. 常见的软件包管理器: DEB 系统(如Ubunt…...

【数据结构-堆】力扣1834. 单线程 CPU
给你一个二维数组 tasks ,用于表示 n 项从 0 到 n - 1 编号的任务。其中 tasks[i] [enqueueTimei, processingTimei] 意味着第 i 项任务将会于 enqueueTimei 时进入任务队列,需要 processingTimei 的时长完成执行。 现…...

【前端动效】原生js实现拖拽排课效果
目录 1. 效果展示 2. 效果分析 2.1 关键点 2.2 实现方法 3. 代码实现 3.1 html部分 3.2 css部分 3.3 js部分 3.4 完整代码 4. 总结 1. 效果展示 如图所示,页面左侧有一个包含不同课程(如语文、数学等)的列表,页面右侧…...

C#使用OpenTK绘制3D可拖动旋转图形三棱锥
接上篇,绘制着色矩形 C#使用OpenTK绘制一个着色矩形-CSDN博客 上一篇安装OpenTK.GLControl后,这里可以直接拖动控件GLControl 我们会发现GLControl继承于UserControl //// 摘要:// OpenGL-aware WinForms control. The WinForms designer will always call the default//…...

利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
React Native 导航系统实战(React Navigation)
导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
VTK如何让部分单位不可见
最近遇到一个需求,需要让一个vtkDataSet中的部分单元不可见,查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行,是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示,主要是最后一个参数,透明度…...
Robots.txt 文件
什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
C# SqlSugar:依赖注入与仓储模式实践
C# SqlSugar:依赖注入与仓储模式实践 在 C# 的应用开发中,数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护,许多开发者会选择成熟的 ORM(对象关系映射)框架,SqlSugar 就是其中备受…...

C++:多态机制详解
目录 一. 多态的概念 1.静态多态(编译时多态) 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1).协变 2).析构函数的重写 5.override 和 final关键字 1&#…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...