当前位置: 首页 > news >正文

【代码】Swan-Transformer 代码详解(待完成)

1. 局部注意力  Window Attention (W-MSA Module)

class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # [Mh, Mw]print(self.window_size)self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):"""Args:x: input features with shape of (num_windows*B, Mh*Mw, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# [batch_size*num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]q = q * self.scaleattn = (q @ k.transpose(-2, -1))# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [nW, Mh*Mw, Mh*Mw]nW = mask.shape[0]  # num_windows# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x

相关文章:

【代码】Swan-Transformer 代码详解(待完成)

1. 局部注意力 Window Attention (W-MSA Module) class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number…...

iframe.contentDocument 和document.documentElement的区别

iframe.contentDocument 和 document.documentElement 是用于访问不同内容的两个不同的对象或属性。 1. iframe.contentDocument 内容: iframe.contentDocument 代表的是 <iframe> 元素所嵌入的文档的 Document 对象。它允许你访问和操作嵌入的文档&#xff08;即 ifram…...

计算机操作员试题(中篇)

计算机操作员试题(中篇) 335.在 Excel中,把鼠标指向被选中单元格边框,当指变成箭头时,拖动鼠标到目标单 元格时,将完成( )操作。 (A)删除 (B)移动 ©自动填充 (D)复制 336.在 Excel 工作表的单元格中,如想输入数字字符串 070615 (例如学号),则应输 入()。 (A) 0007…...

车规级MCU「换道」竞赛

汽车芯片&#xff0c;尤其是MCU市场正在进入拐点期。 本周&#xff0c;总部位于荷兰的汽车芯片制造商—恩智浦&#xff08;NXP&#xff09;半导体总裁兼首席执行官Kurt Sievers在公司第二季度财报电话会议上告诉投资者&#xff0c;由于汽车需求停滞不前&#xff0c;该公司正在努…...

数学生物学-2-离散时间模型(Discrete Time Models)

上一篇介绍了一个指数增长模型。然而&#xff0c;我们也看到&#xff0c;在现实情况下&#xff0c;细菌培养的增长是在离散的时间&#xff08;在这种情况下是小时&#xff09;进行测量的&#xff0c;种群并没有无限增长&#xff0c;而是趋于以S形曲线趋于平稳&#xff0c;称为“…...

免费开源!AI视频自动剪辑已成现实!效率提升80%,打工人福音!(附详细教程)

大家好&#xff0c;我是程序员X小鹿&#xff0c;前互联网大厂程序员&#xff0c;自由职业2年&#xff0c;也一名 AIGC 爱好者&#xff0c;持续分享更多前沿的「AI 工具」和「AI副业玩法」&#xff0c;欢迎一起交流~ 想象一下&#xff0c;假设老板给你布置了一项任务&#xff1a…...

NtripShare全站仪自动化监测之气象改正

最近有幸和自动化监测领域权威专家进行交流&#xff0c;讨论到全站仪气象改正的问题&#xff0c;因为有些观点与专家不太一致&#xff0c;所以再次温习了一下全站仪气象改正的技术细节。 气象改正的概念 全站仪一般利用光波进行测距&#xff0c;首先仪器会处理测距光波的相位漂…...

【人工智能】项目案例分析:使用自动编码器进行信用卡欺诈检测

一、项目背景 信用卡欺诈是金融行业面临的一个重要问题&#xff0c;快速且准确的欺诈检测对于保护消费者和金融机构的利益至关重要。本项目旨在通过利用自动编码器&#xff08;Autoencoder&#xff09;这一无监督学习算法&#xff0c;来检测信用卡交易中的欺诈行为&#xff0c…...

【工控】线扫相机小结

背景简介 我目前接触到的线扫相机有两种形式: 无采集卡,数据通过网线传输。 配备采集卡,使用PCIe接口。 第一种形式的数据通过网线传输,速度较慢,因此扫描和生成图像的速度都较慢,参数设置主要集中在相机本身。第二种形式的相机配备采集卡,通常速度更快,但由于相机和…...

将Web应用部署到Tomcat根目录的三种方法

将应用部署到Tomcat根目录的三种方法 将应用部署到Tomcat根目录的目的是可以通过"http://[ip]:[port]"直接访问应用&#xff0c;而不是使用"http://[ip]:[port]/[appName]"上下文路径进行访问。 方法一&#xff1a;&#xff08;最简单直接的方法&#xff0…...

工业和信息化部教育与考试中心计算机相关专业介绍

国家工信部的认证证书在行业内享有较高声誉。 此外&#xff0c;还设有专门的工业和信息化技术技能人才数据库查询服务&#xff0c;进一步方便了个人和企业对相关职业能力证书的查询需求。 序号 专业工种 级别 备注 1 JAVA程序员 初级 职业技术 2 电子…...

第二证券:生物天然气线上交易达成 创新探索互联互通、气证合一

8月20日&#xff0c;上海石油天然气生意中心在国内立异推出生物天然气线上生意。当日&#xff0c;绿气新动力&#xff08;北京&#xff09;有限公司&#xff08;简称“绿气新动力”&#xff09;挂单的1500万立方米生物天然气被百事食物&#xff08;我国&#xff09;有限公司&am…...

重磅!RISC-V+OpenHarmony平板电脑发布

仟江水商业电讯&#xff08;8月18日 北京 委托发布&#xff09;RISC-V作为历史上全球发展速度最快、创新最为活跃的开放指令架构&#xff0c;正在不断拓展高性能计算领域的边界。OpenHarmony是由开放原子开源基金会孵化并运营的开源项目&#xff0c;已成为发展速度最快的智能终…...

[DL]深度学习_扩散模型

扩散模型原理 深入浅出扩散模型 一、概念简介 1、Denoising Diffusion Probalistic Models&#xff0c;DDPM 1.1 扩散模型运行原理 首先sample一个都是噪声的图片向量&#xff0c;这个向量的shape和要生成的图像大小相同。通过Denoise过程来一步一步有规律的滤去噪声。Den…...

AI学习记录 - 如何快速构造一个简单的token词汇表

创作不易&#xff0c;有用的话点个赞 先直接贴代码&#xff0c;我们再慢慢分析&#xff0c;代码来自openai的图像分类模型的一小段 def bytes_to_unicode():"""Returns list of utf-8 byte and a corresponding list of unicode strings.The reversible bpe c…...

JAVA中的数组流ByteArrayOutputStream

Java 中的 ByteArrayOutputStream 是一个字节数组输出流&#xff0c;它允许应用程序以字节的形式写入数据到一个字节数组缓冲区中。以下是对 ByteArrayOutputStream 的详细介绍&#xff0c;包括其构造方法、方法、使用示例以及运行结果。 一、ByteArrayOutputStream 概述 Byt…...

S3C2440中断处理

一、中断处理机制概述 中断是CPU在执行程序过程中&#xff0c;遇到急需处理的事件时&#xff0c;暂时停止当前程序的执行&#xff0c;转而执行处理该事件的中断服务程序&#xff0c;并在处理完毕后返回原程序继续执行的过程。S3C2440提供了丰富的中断源&#xff0c;包括内部中…...

《数据分析与知识发现》

《数据分析与知识发现》介绍 1 期刊定位 《数据分析与知识发现》&#xff08;Data Analysis and Knowledge Discovery&#xff09;是由中国科学院主管、中国科学院文献情报中心主办的学术性专业期刊。期刊创刊于2017年&#xff0c;由《现代图书情报技术》&#xff08;1985-20…...

IaaS,PaaS,aPaaS,SaaS,FaaS,如何区分?

​IaaS, PaaS&#xff0c;SaaS&#xff0c;aPaaS 还有一种 FaaS &#xff0c;这几个都是云服务中常见的 5 大类型&#xff1a; IaaS&#xff1a;基础架构即服务&#xff0c;Infrastructure as a Service PaaS&#xff1a;平台即服务&#xff0c;Platform as a Service aPaaS&…...

软件测试工具分享

要想在测试中旗开得胜&#xff0c;趁手的“武器”那是相当重要&#xff08;说人话&#xff0c;要保证测试质量和效率&#xff0c;测试工具也很重要&#xff09;。现在&#xff0c;小酋打算亮一亮自己的武器库&#xff0c;希望不要闪瞎你的眼&#xff08;天上在打雷&#xff0c;…...

内存分配函数malloc kmalloc vmalloc

内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架&#xff0c;相比 MapReduce 具有以下核心优势&#xff1a; 内存计算&#xff1a;数据可常驻内存&#xff0c;迭代计算性能提升 10-100 倍&#xff08;文档段落&#xff1a;3-79…...

.Net Framework 4/C# 关键字(非常用,持续更新...)

一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中&#xff0c;我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道&#xff0c;它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机

这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机&#xff0c;因为在使用过程中发现 Airsim 对外部监控相机的描述模糊&#xff0c;而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置&#xff0c;最后在源码示例中找到了&#xff0c;所以感…...

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join

纯 Java 项目&#xff08;非 SpringBoot&#xff09;集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...