当前位置: 首页 > news >正文

Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

6、WindowAttention类

6.1 构造函数

class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)
  1. dim:输入特征维度
  2. window_size:窗口大小
  3. num_heads:多头注意力头数
  4. head_dim:每头注意力的头数
  5. scale :缩放因子
  6. relative_position_bias_table:相对位置偏置表,它对每个头存储不同窗口位置之间的偏置,以模拟位置信息
  7. coords_h 、coords_w、coords:窗口内每个位置的坐标
  8. coords_flatten :将坐标展平,为计算相对位置做准备
  9. 第1个relative_coords:计算窗口内每个位置相对于其他位置的坐标差
  10. 第2个relative_coords:重排坐标差的维度以符合预期的格式
  11. relative_coords[:, :, 0]、relative_coords[:, :, 1]、relative_coords[:, :, 0]:调整坐标差,使其能够映射到相对位置偏置表中的索引
  12. relative_position_index :计算每对位置之间的相对位置索引
  13. register_buffer:将相对位置索引注册为模型的缓冲区,这样它就不会在训练过程中被更新
  14. qkv :创建一个线性层,用于生成QKV
  15. attn_drop、proj、proj_drop:初始化注意力dropout、输出投影层及其dropout
  16. trunc_normal_:使用截断正态分布初始化相对位置偏置表
  17. softmax :初始化softmax层,用于计算注意力权重

6.2 前向传播

    def forward(self, x, mask=None):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x
  1. B_, N, C = x.shape原始输入: torch.Size([256, 49, 96]),B_, N, C即原始输入的维度
  2. qkv = self.qkv(x).reshape...qkv: torch.Size([3, 256, 3, 49, 32]),被重塑的一个五维张量,分别代表qkv三个维度、256个窗口、3个注意力头数但是不会一直是3越往后会越多、49是一个窗口有7*7=49元素、每个头的特征维度。在之前的Transformer以及Vision Transformer中,都是用x接上各自的全连接后分别生成QKV,这这里直接一起生成了。
  3. q: torch.Size([256, 3, 49, 32]),k: torch.Size([256, 3, 49, 32]),v: torch.Size([256, 3, 49, 32]),从qkv中分解出q、k、v,而且已经包含了多头注意力机制
  4. attn: torch.Size([256, 3, 49, 49]),attn是q和k的点积
  5. relative_position_bias: torch.Size([49, 49, 3]),从相对位置偏置表中索引出每对位置之间的偏置,并重塑以匹配注意力分数的形状
  6. relative_position_bias: torch.Size([3, 49, 49]),重新排列,位置编码在Transformer中一直当成偏置加进去的,而这个位置编码是对一个窗口的,所以每一个窗口的都对应了相同的位置编码
  7. attn: torch.Size([256, 3, 49, 49]),将位置编码加到注意力分数上,到这里就算完了全部的注意力机制了
  8. attn: torch.Size([256, 3, 49, 49]),掩码加到注意力分数上,使用softmax函数归一化注意力分数,得到注意力权重,应用注意力dropout
  9. x: torch.Size([256, 49, 96]),使用注意力权重对v向量进行重构,然后对结果进行转置和重塑
  10. x: torch.Size([256, 49, 96]),将加权的注意力输出通过一个线性投影层,应用输出dropout,这就是最后WindowAttention的输出,一共256个窗口,每个窗口有49个特征,每个特征对应96维的向量

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

相关文章:

Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 点我下载源码 SwinTransformer 算法原理 SwinTransformer 源码解读1(项目配置/SwinTr…...

Jenkins(本地Windows上搭建)上传 Pipeline构建前端项目并将生成dist文件夹上传至指定服务器

下载安装jdk https://www.oracle.com/cn/java/technologies/downloads/#jdk21-windows 下载jenkins window版 双击安装 https://www.jenkins.io/download/thank-you-downloading-windows-installer-stable/ 网页输入 http://localhost:8088/ 输入密码、设置账号、安装推…...

Elasticsearch 安装和配置脚本文档

Elasticsearch 安装和配置脚本文档 目录 **Elasticsearch 安装和配置脚本文档**0.**概述**1.**使用方法:**2.**脚本步骤:**3. **完整代码如下:** 0.概述 此Bash脚本用于自动化在CentOS 7系统上安装和配置Elasticsearch(ES&#x…...

【Android辟邪】之:gradle——在项目间共享依赖关系版本

翻译和简单修改自:https://docs.gradle.org/current/userguide/platforms.html#sec:sharing-catalogs 建议看原文(有能力的话) 现在 Gradle 脚本可以使用两种语法编写:Kotlin 和 Groovy 本文只使用kotlin脚本语法,更…...

Qt 项目树工程,拷贝子项目dll到子项目exe运行路径

1、项目树工程 2、项目树列表 ---- BuildAll -------- App (exe) -------- Database (dll) 注:使用 子项目–>添加库–>内部库 的方式 3、qmake 内置的变量 $$OUT_PWD 表示输出文件(如可执行文件…...

进程间通信方式

1>内核提供的原始通信方式有三种 1)无名管道 2)有名管道 3)信号 2>System V提供了三种通信方式 4)消息队列 5)共享内存 6)信号量(信号灯集) 3>套接字通信 7)socke…...

[linux]:匿名管道和命名管道(什么是管道,怎么创建管道(函数),匿名管道和命名管道的区别,代码例子)

目录 一、匿名管道 1.什么是管道?什么是匿名管道? 2.怎么创建匿名管道(函数) 3.匿名管道的4种情况 4.匿名管道有5种特性 5.怎么使用匿名管道?匿名管道有什么用?(例子) 二、命名…...

Python调用matlab程序

matlab官网:https://ww2.mathworks.cn/?s_tidgn_logo matlab外部语言和库接口,包括 Python、Java、C、C、.NET 和 Web 服务。 matlab和python的版本 安装依赖配置 安装matlab的engine 找到matlab的安装目录:“xxx\ extern\engines\python…...

FlinkSql 窗口函数

Windowing TVF 以前用的是Grouped Window Functions(分组窗口函数),但是分组窗口函数只支持窗口聚合 现在FlinkSql统一都是用的是Windowing TVFs(窗口表值函数),Windowing TVFs更符合 SQL 标准且更加强大…...

十分钟GIS——geoserver+postgis+udig从零开始发布地图服务

1数据库部署 1.1PostgreSql安装 下载到安装文件后(postgresql-9.2.19-1-windows-x64.exe),双击安装。 指定安装目录,如下图所示 指定数据库文件存放目录位置,如下图所示 指定数据库访问管理员密码,如下图所…...

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Span组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Span组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Span组件 鸿蒙(HarmonyOS)作为Text组件的子组件&#xff0…...

Leetcode—42. 接雨水【困难】

2024每日刷题&#xff08;112&#xff09; Leetcode—42. 接雨水 空间复杂度为O(n)的算法思想 实现代码 class Solution { public:int trap(vector<int>& height) {int ans 0;int n height.size();vector<int> l(n);vector<int> r(n);for(int i 0; …...

[Python] opencv - 什么是直方图?如何绘制图像的直方图?如何对直方图进行均匀化处理?

什么是直方图&#xff1f; 直方图是一种统计图&#xff0c;用于展示数据的分布情况。它将数据按照一定的区间或者组进行划分&#xff0c;然后计算在每个区间或组内的数据频数或频率&#xff08;即数据出现的次数或占比&#xff09;&#xff0c;然后用矩形或者柱形图的形式将这…...

ppi rust开发 python调用

创建python的一个测试工程 python -m venv venv .\venv\Scripts\activatepip install cffi创建一个rust的lib项目 cargo new --lib pyrustlib.rs #[no_mangle] pub extern "C" fn rust_add(x: i32, y: i32) -> i32 {x y }Cargo.toml [package] name "p…...

网站后端开发 thinkphp6 入门教程合集(更新中)

thinkphp6 入门&#xff08;1&#xff09;--安装、路由规则、多应用模式 thinkphp6 入门&#xff08;1&#xff09;--安装、路由规则、多应用模式_软件工程小施同学的博客-CSDN博客 thinkphp6 入门&#xff08;2&#xff09;--视图、渲染html页面、赋值 thinkphp6 入门&#x…...

Web前端框架-Vue(初识)

文章目录 web前端三大主流框架**1.Angular****2.React****3.Vue**什么是Vue.js 为什么要学习流行框架框架和库和插件的区别一.简介指令v-cloakv-textv-htmlv-pre**v-once**v-onv-on事件函数中传入参数事件修饰符双向数据绑定v-model 按键修饰符自定义按键修饰符别名v-bind(属性…...

配置dns服务的正反向解析

服务端IP客户端IP网址192.168.153.137192.168.153.www.openlab.com 1&#xff1a;正向解析 1.1关闭客户端和服务端的安全软件&#xff0c;安装bind软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# yum install bind -y [rootnod…...

小白水平理解面试经典题目LeetCode 71. Simplify Path【Stack类】

71. 简化路径 小白渣翻译 给定一个字符串 path &#xff0c;它是 Unix 风格文件系统中文件或目录的绝对路径&#xff08;以斜杠 ‘/’ 开头&#xff09;&#xff0c;将其转换为简化的规范路径。 在 Unix 风格的文件系统中&#xff0c;句点 ‘.’ 指的是当前目录&#xff0c;…...

电力负荷预测 | 电力系统负荷预测模型(Python线性回归、随机森林、支持向量机、BP神经网络、GRU、LSTM)

文章目录 效果一览文章概述源码设计参考资料效果一览 文章概述 电力系统负荷预测模型(Python线性回归、随机森林、支持向量机、BP神经网络、GRU、LSTM) 所谓预测,就是指通过对事物进行分析及研究,并运用合理的方法探索事物的发展变化规律,对其未来发展做出预先估计和判断。…...

YY调音台:音频后期处理

我从事影视后期处理的工作&#xff0c;主要负责音频、音效合成这块工作内容。在影视作品中&#xff0c;声音不仅仅是背景元素&#xff0c;它在叙事和创造情感氛围上发挥着至关重要的作用。我们的工作不仅要让听众听到声音&#xff0c;更要让他们通过声音感受到情感的波动和故事…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

(一)单例模式

一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...

关于easyexcel动态下拉选问题处理

前些日子突然碰到一个问题&#xff0c;说是客户的导入文件模版想支持部分导入内容的下拉选&#xff0c;于是我就找了easyexcel官网寻找解决方案&#xff0c;并没有找到合适的方案&#xff0c;没办法只能自己动手并分享出来&#xff0c;针对Java生成Excel下拉菜单时因选项过多导…...

【Veristand】Veristand环境安装教程-Linux RT / Windows

首先声明&#xff0c;此教程是针对Simulink编译模型并导入Veristand中编写的&#xff0c;同时需要注意的是老用户编译可能用的是Veristand Model Framework&#xff0c;那个是历史版本&#xff0c;且NI不会再维护&#xff0c;新版本编译支持为VeriStand Model Generation Suppo…...

如何配置一个sql server使得其它用户可以通过excel odbc获取数据

要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据&#xff0c;你需要完成以下配置步骤&#xff1a; ✅ 一、在 SQL Server 端配置&#xff08;服务器设置&#xff09; 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到&#xff1a;SQL Server 网络配…...

CppCon 2015 学习:Time Programming Fundamentals

Civil Time 公历时间 特点&#xff1a; 共 6 个字段&#xff1a; Year&#xff08;年&#xff09;Month&#xff08;月&#xff09;Day&#xff08;日&#xff09;Hour&#xff08;小时&#xff09;Minute&#xff08;分钟&#xff09;Second&#xff08;秒&#xff09; 表示…...

《Offer来了:Java面试核心知识点精讲》大纲

文章目录 一、《Offer来了:Java面试核心知识点精讲》的典型大纲框架Java基础并发编程JVM原理数据库与缓存分布式架构系统设计二、《Offer来了:Java面试核心知识点精讲(原理篇)》技术文章大纲核心主题:Java基础原理与面试高频考点Java虚拟机(JVM)原理Java并发编程原理Jav…...

React父子组件通信:Props怎么用?如何从父组件向子组件传递数据?

系列回顾&#xff1a; 在上一篇《React核心概念&#xff1a;State是什么&#xff1f;》中&#xff0c;我们学习了如何使用useState让一个组件拥有自己的内部数据&#xff08;State&#xff09;&#xff0c;并通过一个计数器案例&#xff0c;实现了组件的自我更新。这很棒&#…...

【QT控件】显示类控件

目录 一、Label 二、LCD Number 三、ProgressBar 四、Calendar Widget QT专栏&#xff1a;QT_uyeonashi的博客-CSDN博客 一、Label QLabel 可以用来显示文本和图片. 核心属性如下 代码示例: 显示不同格式的文本 1) 在界面上创建三个 QLabel 尺寸放大一些. objectName 分别…...