【Block总结】掩码窗口自注意力 (M-WSA)

摘要
论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。
设计原理
-
窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。
-
掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。
-
扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。

优势
-
高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。
-
信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。
-
灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。
实验结果
在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。
代码
Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1) # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1 # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模块oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向传播output = oca_attention(x)print("input张量形状:", x.shape)print("output张量形状:", output.shape)
DilatedOCA模块详解
代码结构
import torch
import torch.nn as nn
from einops import rearrange
- 导入库:首先导入 PyTorch 和 einops 库。
einops用于简化张量的重排操作。
模块定义
class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
-
初始化方法:
__init__方法定义了模块的结构。-
dim:输入特征的通道数。 -
window_size:窗口的大小,用于空间注意力计算。 -
overlap_ratio:重叠窗口的比例,决定了窗口之间的重叠程度。 -
num_heads:空间注意力的头数。 -
dim_head:每个头的维度。
-
-
层的定义:
-
self.unfold:用于将输入张量展开为重叠窗口的操作。 -
self.qkv:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。 -
self.project_out:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。 -
self.rel_pos_emb和self.fixed_pos_emb:用于位置编码的模块,增强模型对空间位置的感知。
-
前向传播
def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
-
输入形状:
x的形状为(batch_size, channels, height, width),其中b是批量大小,c是通道数,h和w是图像的高度和宽度。 -
特征提取:
-
qkv = self.qkv(x):通过qkv层生成 Q、K、V 特征图。 -
qs, ks, vs = qkv.chunk(3, dim=1):将 Q、K、V 特征图沿通道维度分离。
-
-
空间注意力计算:
-
qs被重排为适合空间注意力计算的格式。 -
ks和vs通过unfold操作展开为重叠窗口。
-
-
分头处理:
- 使用
einops.rearrange将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
- 使用
-
计算注意力:
-
qs = qs * self.scale:对 Q 进行缩放以提高稳定性。 -
spatial_attn = (qs @ ks.transpose(-2, -1)):计算注意力分数。 -
spatial_attn += self.rel_pos_emb(qs)和spatial_attn += self.fixed_pos_emb():添加位置编码以增强空间感知。 -
spatial_attn = spatial_attn.softmax(dim=-1):对注意力分数进行 softmax 归一化。
-
-
输出计算:
out = (spatial_attn @ vs):使用注意力权重对 V 进行加权求和,得到最终输出。
-
重排输出:
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...):将输出重排回原始形状。
-
最终投影:
out = self.project_out(out):通过投影层将输出映射回原始通道数。
总结
DilatedOCA 模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。
相关文章:
【Block总结】掩码窗口自注意力 (M-WSA)
摘要 论文链接:https://arxiv.org/pdf/2404.07846 论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在…...
用 HTML5 Canvas 和 JavaScript 实现雪花飘落特效
这篇文章将带您深入解析使用 HTML5 Canvas 和 JavaScript 实现动态雪花特效的代码原理。 1,效果展示 该效果模拟了雪花从天而降的动态场景,具有以下特点: 雪花数量、大小、透明度和下落速度随机。雪花会在屏幕底部重置到顶部,形成循环效果。随窗口大小动态调整,始终覆盖…...
【cocos creator】【ts】事件派发系统
触发使用: EventTool.emit(“onClick”) 需要监听的地方,onload调用: EventTool.on(“onClick”, this.onClickEvent, this) /**事件派发*/class EventTool {protected static _instance: EventTool null;public static get Instance(): Eve…...
《探索鸿蒙Next上开发人工智能游戏应用的技术难点》
在科技飞速发展的当下,鸿蒙Next系统为应用开发带来了新的机遇与挑战,开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点: 鸿蒙Next系统适配性 多设备协同:鸿蒙Next的一大特色…...
CSS | CSS实现两栏布局(左边定宽 右边自适应,左右成比自适应)
目录 一、左边定宽 右边自适应 1.浮动 2.利用浮动margin 3.定位margin 4.flex布局 5.table 布局 二、左右成比自适应 1:1 1flex布局 table布局 1:2 flex布局 <div class"father"><div class"left">左边自适应</div><div class"r…...
acwing_3195_有趣的数
acwing_3195_有趣的数 // // Created by HUAWEI on 2024/11/17. // #include<iostream> #include<cstring> #include<algorithm>#define int long longusing namespace std;const int N 1000 50; const int MOD 1e9 7; int C[N][N]; //组合数signed mai…...
Liunx-搭建安装VSOMEIP环境教程 执行 运行VSOMEIP示例demo
本文安装环境为Liunx,搭建安装VSOMEIP环境并运行基础例子。 1. 安装基础环境 使用apt-get来安装基础环境,受网络影响可以分开多次安装。环境好的也可以一次性执行。 sudo apt-get install gcc g sudo apt-get install cmake sudo apt-get install lib…...
Git | git revert命令详解
关注:CodingTechWork 引言 Git 是一个强大的版本控制工具,广泛应用于现代软件开发中。它为开发人员提供了多种功能来管理代码、协作开发和版本控制。在 Git 中,有时我们需要撤销或回退某些提交,而git revert 是一个非常有用的命令…...
ASP.NET Core 中,Cookie 认证在集群环境下的应用
在 ASP.NET Core 中,Cookie 认证在集群环境下的应用通常会遇到一些挑战。主要的问题是 Cookie 存储在客户端的浏览器中,而认证信息(比如 Session 或身份令牌)通常是保存在 Cookie 中,多个应用实例需要共享这些 Cookie …...
Flyte工作流平台调研(五)——扩展集成
系列文章: Flyte工作流平台调研(一)——整体架构 Flyte工作流平台调研(二)——核心概念说明 Flyte工作流平台调研(三)——核心组件原理 Flyte工作流平台调研(四)——…...
【AUTOSAR 基础软件】软件组件的建立与使用(“代理”SWC)
基础软件往往需要建立一些“代理”SWC来完成一些驱动的抽象工作(Complex_Device_Driver_Sw或者Ecu_Abstraction_Sw等),或建立Application Sw Component来补齐基础软件需要提供的功能实现。当面对具体的项目时,基础软件开发人员还可…...
java通过ocr实现识别pdf中的文字
需求:识别pdf文件中的中文 根据github项目mymonstercat 改造,先将pdf文件转为png文件存于临时文件夹,然后通过RapidOcr转为文字,最后删除临时文件夹 1、引入依赖 <dependency><groupId>org.apache.pdfbox</groupId><artifactId&g…...
Git 命令代码管理详解
一、Git 初相识:版本控制的神器 在当今的软件开发领域,版本控制如同基石般重要,而 Git 无疑是其中最耀眼的明珠。它由 Linus Torvalds 在 2005 年创造,最初是为了更好地管理 Linux 内核源代码。随着时间的推移,Git 凭借…...
Docker的安装和使用
容器技术 容器与虚拟机的区别 虚拟机 (VM) VM包含完整的操作系统,并在虚拟化层之上运行多个操作系统实例。 VM需要更多的系统资源(CPU、内存、存储)来管理这些操作系统实例。 容器 (Container) 容器共享主机操作系统的内核,具…...
Flink系统知识讲解之:Flink内存管理详解
Flink系统知识讲解之:Flink内存管理详解 在现阶段,大部分开源的大数据计算引擎都是用Java或者是基于JVM的编程语言实现的,如Apache Hadoop、Apache Spark、Apache Drill、Apache Flink等。Java语言的好处是不用考虑底层,降低了程…...
使用JMeter模拟多IP发送请求!
你是否曾遇到过这样的场景:使用 JMeter 进行压力测试时,单一 IP 被服务器限流或者屏蔽?这时,如何让 JMeter 模拟多个 IP 发送请求,成功突破测试限制,成为测试工程师必须攻克的难题。今天,我们就…...
【Ubuntu与Linux操作系统:六、软件包管理】
第6章 软件包管理 6.1 Linux软件安装基础 Linux的软件包是以二进制或源码形式发布的程序集合,包含程序文件和元数据。软件包管理器是Linux系统的重要工具,用于安装、更新和卸载软件。 1. 常见的软件包管理器: DEB 系统(如Ubunt…...
【数据结构-堆】力扣1834. 单线程 CPU
给你一个二维数组 tasks ,用于表示 n 项从 0 到 n - 1 编号的任务。其中 tasks[i] [enqueueTimei, processingTimei] 意味着第 i 项任务将会于 enqueueTimei 时进入任务队列,需要 processingTimei 的时长完成执行。 现…...
【前端动效】原生js实现拖拽排课效果
目录 1. 效果展示 2. 效果分析 2.1 关键点 2.2 实现方法 3. 代码实现 3.1 html部分 3.2 css部分 3.3 js部分 3.4 完整代码 4. 总结 1. 效果展示 如图所示,页面左侧有一个包含不同课程(如语文、数学等)的列表,页面右侧…...
C#使用OpenTK绘制3D可拖动旋转图形三棱锥
接上篇,绘制着色矩形 C#使用OpenTK绘制一个着色矩形-CSDN博客 上一篇安装OpenTK.GLControl后,这里可以直接拖动控件GLControl 我们会发现GLControl继承于UserControl //// 摘要:// OpenGL-aware WinForms control. The WinForms designer will always call the default//…...
C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...
Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动
一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...
黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门  float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...
Psychopy音频的使用
Psychopy音频的使用 本文主要解决以下问题: 指定音频引擎与设备;播放音频文件 本文所使用的环境: Python3.10 numpy2.2.6 psychopy2025.1.1 psychtoolbox3.0.19.14 一、音频配置 Psychopy文档链接为Sound - for audio playback — Psy…...
python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?
Redis 的发布订阅(Pub/Sub)模式与专业的 MQ(Message Queue)如 Kafka、RabbitMQ 进行比较,核心的权衡点在于:简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...
STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...
