6、关于Medical-Transformer
6、关于Medical-Transformer
Axial-Attention原文链接:Axial-attention
Medical-Transformer原文链接:Medical-Transformer
Medical-Transformer实际上是Axial-Attention在医学领域的运行,只是在这基础上增加了门机制,实际上也就是在原来Axial-attention基础之上增加权重机制,虚弱位置信息对于数据的影响,发现虚弱之后的效果比Axial-Attention机制效果更好
Axial-Attention
Axial-Attention与传统Transformer的self-attention相比较,将2D计算转成1D计算,Axial-attention机制,对于qkv的计算,做出了简化,仅仅某个点的横竖两个方向上的特殊,同时在qkv的基础上加上了各自位置特征,这些特征都是更新学习的。
Axial-attention模型架构图
左图为传统的self-attention机制,右图为Axial-attention机制,对于qkv都加上rq,rk,rv这样的位置参数,这些参数都是可以更新的,也就是说,每个的q在和自己对应的横竖轴反向进行计算的时候,q会和自己rq先进行权重计算,同样的k和v也会进行同样的计算,随后进行q和k进行计算得到权重,计算过程和原来的self-attention机制是一样的。
class AxialAttention(nn.Module):def forward(self, x):# 前向传播函数# 如果设置了 width 参数,调整张量维度顺序if self.width:x = x.permute(0, 2, 1, 3) # 调整维度顺序else:x = x.permute(0, 3, 1, 2) # N, W, C, H 调整为 N, C, H, WN, W, C, H = x.shape # 获取张量形状x = x.contiguous().view(N * W, C, H) # 重新调整形状,合并 N 和 W 维度# 通过x获得对应的qkv 批归一化后计算 qkvqkv = self.bn_qkv(self.qkv_transform(x)) q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),[self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) # 将 qkv 拆分为 q, k, v# 计算位置嵌入all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) # 拆分嵌入# 计算 QR, KR, QK 相似性,分别计算得出rq,rkqr = torch.einsum('bgci,cij->bgij', q, q_embedding) # QR: q 和 q_embedding 的爱因斯坦求和kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) # KR: k 和 k_embedding 的爱因斯坦求和,并转置# q和k进行计算,得到最后的权重qk = torch.einsum('bgci, bgcj->bgij', q, k) # QK: q 和 k 之间的点积# 将 QR, KR, QK 相似性进行堆叠,连在一起进行计算stacked_similarity = torch.cat([qk, qr, kr], dim=1) # 将 qk, qr, kr 连接起来stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1) # 批归一化并调整形状# similarity为q和k计算得出权重关系similarity = F.softmax(stacked_similarity, dim=3) # 在第 3 维度上计算 softmax# 将q和v计算出来权重和v加权求和sv = torch.einsum('bgij,bgcj->bgci', similarity, v) # 将相似度与 v 进行求和# v与位置信息结合sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) # 将similarity与 v_embedding 进行求和# 将位置加权后的v和q和k计算结果与v加权的合并,并调整形状输出stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H) # 合并 sv 和 sve,并调整形状output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2) # 批归一化并调整形状# 恢复维度顺序if self.width:output = output.permute(0, 2, 1, 3) # 调整维度顺序else:output = output.permute(0, 2, 3, 1) # 调整维度顺序# 如果步长大于 1,应用池化操作if self.stride > 1:output = self.pooling(output) # 池化return output # 返回输出
横竖轴计算过程
先通过卷积把特征图缩小,然后横竖轴计算时,是将横轴一起进行计算,然后再进行纵轴计算的,完成计算后,通过1x1卷积将特征图还原为原来的大小,在传入下一层进行计算。
Medical-Transformer
Medical-Transformer架构图
Medical-Transformer实际就是Axial-attention在医学图像分割领域的应用,medical-tranformer大模型架构采用整个图像进行Axial-attention特征提取,同时也将图像分成多个窗口,对每个窗口进行axial-attention特征提取,窗口由于计算量小,可以多进行几层Axial-attention,最终将整个图像特征和窗口特征融合,完成整个的特征提取,值得一提的是在进行窗口Axial-attention时,qkv都没有加上位置编码(也就是下面部分的图像)。
主体架构
class medt_net(nn.Module):def _forward_impl(self, x):xin = x.clone() # 保存输入数据的副本x = self.conv1(x) # 第一个卷积层x = self.bn1(x) # 第一个批归一化层x = self.relu(x) # ReLU 激活函数x = self.conv2(x) # 第二个卷积层x = self.bn2(x) # 第二个批归一化层x = self.relu(x) # ReLU 激活函数x = self.conv3(x) # 第三个卷积层x = self.bn3(x) # 第三个批归一化层x = self.relu(x) # ReLU 激活函数x1 = self.layer1(x) # 第一个残差层 实际上就是 Gated Axial Attention Layerx2 = self.layer2(x1) # 第二个残差层 同样是 Gated Axial Attention Layer# 对输入进行插值放大,并通过解码器处理x = F.relu(F.interpolate(self.decoder4(x2), scale_factor=(2, 2), mode='bilinear'))x = torch.add(x, x1) # 将放大的特征图与 x1 相加x = F.relu(F.interpolate(self.decoder5(x), scale_factor=(2, 2), mode='bilinear'))# 以上完成就是图上方整个图像的卷积过程# -------------------------------------------------------------------------------------------x_loc = x.clone() # 生成一个本地副本# 下面对图像进行切分,分别对每个窗口进行局部处理,实际上是16个窗口for i in range(0, 4):for j in range(0, 4):x_p = xin[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)] # 提取32x32的局部patch# 逐层卷积处理patchx_p = self.conv1_p(x_p)x_p = self.bn1_p(x_p)x_p = self.relu(x_p)x_p = self.conv2_p(x_p)x_p = self.bn2_p(x_p)x_p = self.relu(x_p)x_p = self.conv3_p(x_p)x_p = self.bn3_p(x_p)x_p = self.relu(x_p)# 进行四个x1_p = self.layer1_p(x_p) # 第一个残差层(patch-wise) 这里进行的axial-attention在进行qkv计算时,qkv都没有加入位置信息计算x2_p = self.layer2_p(x1_p) # 第二个残差层(patch-wise)x3_p = self.layer3_p(x2_p) # 第三个残差层(patch-wise)x4_p = self.layer4_p(x3_p) # 第四个残差层(patch-wise)# 对patch进行插值放大并通过解码器处理x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x4_p) # 将放大的特征图与 x4_p 相加x_p = F.relu(F.interpolate(self.decoder2_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x3_p) # 将放大的特征图与 x3_p 相加x_p = F.relu(F.interpolate(self.decoder3_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x2_p) # 将放大的特征图与 x2_p 相加x_p = F.relu(F.interpolate(self.decoder4_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x1_p) # 将放大的特征图与 x1_p 相加x_p = F.relu(F.interpolate(self.decoder5_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_loc[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)] = x_p # 将局部处理后的结果放回原始位置# 将整个图片的axial-attention,和每个窗口得出的结果进行结合x = torch.add(x, x_loc) # 将全局和局部特征进行融合x = F.relu(self.decoderf(x)) # 通过最终的解码器层x = self.adjust(F.relu(x)) # 调整输出return x # 返回最终输出
Gated Axial Attention Layer
从架构图中可以看出,就是在Axial-attention的基础上,加上了门机制,说白了,也就是在qkv和各自的rq,rk,rv计算完成后,再进行下一步计算之前,进行了一个加权计算,虚弱了位置变量对特征提取结果的影响。
横向或纵向Gated Axial-attention过程
注意里面qr,kr实际上就是图片中的rq,rk,而
class AxialAttention_dynamic(nn.Module):def forward(self, x):# 判断是否需要对宽度维度进行变换if self.width:x = x.permute(0, 2, 1, 3) # 交换维度顺序,形状变为 [N, C, W, H]else:x = x.permute(0, 3, 1, 2) # 交换维度顺序,形状变为 [N, W, C, H]N, W, C, H = x.shape # 获取输入张量的形状x = x.contiguous().view(N * W, C, H) # 将张量变形为 [N * W, C, H]print(x.shape) # 输出形状: [64, 16, 64]# 变换操作qkv = self.bn_qkv(self.qkv_transform(x)) # 对qkv进行批归一化print(qkv.shape) # 输出形状: [64, 32, 64]# 将qkv张量拆分为q、k、v,分别表示查询、键和值q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)print(q.shape) # 输出q的形状: [64, 8, 1, 64]print(k.shape) # 输出k的形状: [64, 8, 1, 64]print(v.shape) # 输出v的形状: [64, 8, 2, 64],v有两份# 计算位置嵌入all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)print(all_embeddings.shape) # 输出嵌入的形状: [4, 64, 64],共有4份q_embedding, k_embedding, v_embedding =torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)print(q_embedding.shape) # 输出q的位置嵌入形状: [1, 64, 64]print(k_embedding.shape) # 输出k的位置嵌入形状: [1, 64, 64]print(v_embedding.shape) # 输出v的位置嵌入形状: [2, 64, 64],v有两份位置编码# 计算q与位置嵌入的乘积qr = torch.einsum('bgci,cij->bgij', q, q_embedding)print(qr.shape) # 输出qr的形状: [64, 8, 64, 64]# 计算k与位置嵌入的乘积,并进行转置kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)print(kr.shape) # 输出kr的形状: [64, 8, 64, 64]# 计算q和k的点积qk = torch.einsum('bgci, bgcj->bgij', q, k)print(qk.shape) # 输出qk的形状: [64, 8, 64, 64]# 对qr和kr进行初始化,使用self.f_qr和self.f_kr作为初始化的权重qr = torch.mul(qr, self.f_qr)print(qr.shape) # 输出qr的形状: [64, 8, 64, 64]kr = torch.mul(kr, self.f_kr)print(kr.shape) # 输出kr的形状: [64, 8, 64, 64]# 将qk、qr和kr拼接起来stacked_similarity = torch.cat([qk, qr, kr], dim=1)print(stacked_similarity.shape) # 输出拼接后的形状: [64, 24, 64, 64]# 进行批归一化,重新变形为[N * W, 3, groups, H, H],并对维度1求和stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)print(stacked_similarity.shape) # 输出归一化后的形状: [64, 8, 64, 64]# 计算相似度similarity = F.softmax(stacked_similarity, dim=3)print(similarity.shape) # 输出相似度的形状: [64, 8, 64, 64]# 使用相似度与v相乘,获得加权后的值sv = torch.einsum('bgij,bgcj->bgci', similarity, v)print(sv.shape) # 输出加权后的形状: [64, 8, 2, 64]# 使用相似度与v的位置嵌入相乘sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)print(sve.shape) # 输出位置嵌入加权后的形状: [64, 8, 2, 64]# 对sv和sve进行初始化sv = torch.mul(sv, self.f_sv)print(sv.shape) # 输出sv的形状: [64, 8, 2, 64]sve = torch.mul(sve, self.f_sve)print(sve.shape) # 输出sve的形状: [64, 8, 2, 64]# 将sv和sve拼接在一起,并重新变形为[N * W, out_planes * 2, H]stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)print(stacked_output.shape) # 输出拼接后的形状: [64, 32, 64]# 进行批归一化,并变形为[N, W, out_planes, 2, H],对维度-2求和output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)print(output.shape) # 输出归一化后的形状: [1, 64, 16, 64]# 根据宽度调整维度顺序if self.width:output = output.permute(0, 2, 1, 3)else:output = output.permute(0, 2, 3, 1)print(output.shape) # 输出最终的形状: [1, 16, 64, 64]# 如果步幅大于1,进行池化操作if self.stride > 1:output = self.pooling(output)return output
相关文章:

6、关于Medical-Transformer
6、关于Medical-Transformer Axial-Attention原文链接:Axial-attention Medical-Transformer原文链接:Medical-Transformer Medical-Transformer实际上是Axial-Attention在医学领域的运行,只是在这基础上增加了门机制,实际上也就…...

19_单片机开发常用工具的使用
工欲善其事必先利其器,我们做单片机开发的时候,不管是调试电路还是调试程序,都需要借助一些辅助工具来帮助查找和定位问题,从而帮助我们顺利解决问题。没有任何辅助工具的单片机项目开发很可能就是无法完成的任务,不过…...

最新版微服务项目搭建
一,项目总体介绍 在本项目中,我将使用alibabba的 nacos 作为项目的注册中心,使用 spring cloud gateway 做为项目的网关,用 openfeign 作为服务间的调用组件。 项目总体架构图如下: 注意:我的Java环境是17…...

spring揭秘19-spring事务01-事务抽象
文章目录 【README】【1】事务基本元素【1.1】事务分类 【2】java事务管理【2.1】基于java的局部事务管理【2.2】基于java的分布式事务管理【2.2.1】基于JTA的分布式事务管理【2.2.2】基于JCA的分布式事务管理 【2.3】java事务管理的问题 【3】spring事务抽象概述【3.1】spring…...

基于Matlab的图像去雾系统(四种方法)关于图像去雾的基本算法代码的集合,方法包括局部直方图均衡法、全部直方图均衡法、暗通道先验法、Retinex增强。
基于Matlab的图像去雾系统(四种方法) 关于图像去雾的基本算法代码的集合,方法包括局部直方图均衡法、全部直方图均衡法、暗通道先验法、Retinex增强。 所有代码整合到App designer编写的GUI界面中,包括导入图片,保存处…...

油猴插件录制请求,封装接口自动化参数
参考:如何使用油猴插件提高测试工作效率 一、背景 在酷家乐设计工具测试中,总会有许多高频且较繁琐的工作,比如: 查询插件版本:需要打开Chrome控制台,输入好几个命令然后过滤出版本信息。 查询模型商品&…...

循环购模式!结合引流和复购于一体的商业模型!
欢迎各位朋友,我是你们的电商策略顾问吴军。今天,我将向大家介绍一种新颖的商业模式——循环购模式,它将如何改变我们的消费和收益方式。你是否好奇,为何商家会提供如此慷慨的优惠?消费一千元,不仅能够得到…...

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧
Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用&…...

c中 int 和 unsigned int
c语言中,char、short、int、int64以及unsigned char、unsigned short、unsigned int、unsigned int64等等类型都可以表示整数。但是他们表示整数的位数不同,比如:char/unisigned char表示8位整数; short/unsigned short表示16位整…...

sheng的学习笔记-AI-话题模型(topic model),LDA模型,Unigram Model,pLSA Model
AI目录:sheng的学习笔记-AI目录-CSDN博客 基础知识 什么是话题模型(topic model) 话题模型(topic model)是一族生成式有向图模型,主要用于处理离散型的数据(如文本集合),在信息检索、自然语言处理等领域有广泛应用…...

html 页面引入 vue 组件之 http-vue-loader.js
一、http-vue-loader.js http-vue-loader.js 是一个 Vue 单文件组件加载器,可以让我们在传统的 HTML 页面中使用 Vue 单文件组件,而不必依赖 Node.js 等其他构建工具。它内置了 Vue.js 和样式加载器,并能自动解析 Vue 单文件组件中的所有内容…...

html+css网页设计 旅行 蜘蛛旅行社3个页面
htmlcss网页设计 旅行 蜘蛛旅行社3个页面 网页作品代码简单,可使用任意HTML辑软件(如:Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作)。 获取源码 1&#…...

考拉悠然产品发布会丨以悠然远智全模态AI应用平台探索AI行业应用
9月6日,成都市大模型新技术新成果发布暨供需对接系列活动——考拉悠然专场,在成都市高新区菁蓉汇盛大举行。考拉悠然重磅发布了悠然远智丨全模态AI应用平台,并精彩展示了交通大模型应用——智析快处等最新的AI产品和技术成果。 在四川省科学…...

LLM大模型学习:揭秘LLM应用构建:探究文本加载器的必要性及在LangChain中的运用
构建 LLM 应用为什么需要文本加载器,langchain 中如何使用文本加载器? 在不同的应用场景中需要使用不同的文本内容作为内容的载体,针对不同的类型的文本,langchain 提供了多种文本加载器来帮助我们快速的将文本切片,从…...

Flutter函数
在Dart中,函数为 一等公民,可以作为参数对象传递,也可以作为返回值返回。 函数定义 // 返回值 (可以不写返回值,但建议写)、函数名、参数列表 showMessage(String message) {//函数体print(message); }void showMessage(String m…...

P3565 [POI2014] HOT-Hotels
~~~~~ P3565 [POI2014] HOT-Hotels ~~~~~ 总题单链接 思路 ~~~~~ 设 g [ u ] [ i ] g[u][i] g[u][i] 表示在 u u u 的子树内,距离 u u u 为 i i i 的点的个数。 ~~~~~ 设 d p [ u ] [ i ] dp[u][i] dp[u][i] 表示: u u u 的子树内存在两个点 x , …...

设计模式 | 单例模式
定义 单例设计模式(Singleton Pattern)是一种创建型设计模式,它确保一个类只有一个实例,并提供一个全局访问点来获取该实例。这种模式常用于需要控制对某些资源的访问的场景,例如数据库连接、日志记录等。 单例模式涉…...

Web安全之CSRF攻击详解与防护
在互联网应用中,安全性问题是开发者必须时刻关注的核心内容之一。跨站请求伪造(Cross-Site Request Forgery, CSRF),是一种常见的Web安全漏洞。通过CSRF攻击,黑客可以冒用受害者的身份,发送恶意请求&#x…...

IDEA运行Java程序提示“java: 警告: 源发行版 11 需要目标发行版 11”
遇到这个提示一般是在pom.xml中已经指定了构建的Java版本环境是11例如(此时添加了build插件的情况下虽然不能直接运行代码但是maven是可以正常打包构建): <build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><…...

车载测试| 汽车的五域架构 (含线控技术知识)
汽车的五域架构是一种将汽车电子控制系统按照功能进行划分的架构模式,主要包括动力域、底盘域、座舱域、自动驾驶域和车身域。(汽车三域架构通常是指将汽车电子系统划分为三个主要领域:动力域、底盘域和智能座舱域(或车身舒适域&a…...

【Linux】gcc/g++ 、make/Makefile、git、gdb 的使用
目录 1. Linux编译器-gcc/g1.1 编译器gcc/g的工作步骤1.2 函数库1.2.1 函数库的作用及分类1.2.2 动态链接和静态链接1.2.3 动态库和静态库的优缺点 1.3 gcc选项 2. Linux项目自动化构建工具-make/Makefile2.1 .PHONY2.2 尝试编写进度条程序 3. git3.1 安装 git3.2 下载项目到本…...

Elastic Stack--ES的DSL语句查询
前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 学习B站博主教程笔记: 最新版适合自学的ElasticStack全套视频(Elk零基础入门到精通教程)Linux运维必备—Elastic…...

ARM基础知识---CPU---处理器
目录 一、ARM架构 1.1.RAM---随机存储器 1.2.ROM---只读存储器 1.3.flash---闪存存储器 1.4.时钟(振晶) 1.5.复位 二、CPU---ARM920T 2.1.R0~R12---通用寄存器 2.2.PC程序计数器 2.3.LR连接寄存器 2.4.SP栈指针寄存器 2.5.CPSR当前程序状态寄存…...

将星 x17 安装ubuntu 20.04 双系统
准备工作,包含关闭快速启动,关闭Secret Boot 1.进入控制面板选择小图标,找到电源选项 2.点击更改当前不可用的设置,关闭快速启动 3.开机启动时快速按F2,进入BIOS 4.选择Setup Utiltity,选择Security&#…...

E31.【C语言】练习:指针运算习题集(上)
Exercise 1 求下列代码的运行结果 #include <stdio.h> int main() {int a[5] { 1, 2, 3, 4, 5 };int* ptr (int*)(&a 1);printf("%d",*(ptr - 1));return 0; } 答案速查: 分析: Exercise 2 求下列代码的运行结果 //在x86环境下 //假设结…...

git分支的管理
分支管理是 Git 版本控制系统中的一个核心功能,它涉及如何创建、管理、合并和删除分支,以便在团队协作和开发过程中更有效地组织代码。以下是分支管理中的一些关键概念和实践: 1. 分支的创建 创建新分支:在开发新功能、修复 bug…...

对于消息队列的一些思考
如何保证消息不被重复消费 唯一ID:你提到的通过唯一ID解决重复消费问题非常重要。这通常通过业务系统引入唯一消息ID(如UUID)来实现。在消费端,先检查消息ID是否已经被处理,未处理过的才进行处理,确保幂等…...

IM即时通讯软件-WorkPlus私有化部署的局域网即时通讯工具
随着企业对通讯安全和数据掌控的需求不断增加,许多企业开始选择私有化部署的即时通讯工具,以在内部局域网环境中实现安全、高效的沟通与协作。IM-WorkPlus作为一款受欢迎的即时通讯软件,提供了私有化部署的选项,使企业能够在自己的…...

AI大模型的饕餮盛宴,系统学习大模型技术,你想要的书都在这里了
AI大模型的饕餮盛宴,系统学习大模型技术,你想要的书都在这里了 要说现在最热门的技术,可谓非大模型莫属!不少小伙伴都想要学习大模型技术,转战AI领域,以适应未来的大趋势,寻求更有前景的发展~~…...

支付宝开放平台-开发者社区——AI 日报「9 月 9 日」
1 离开 OpenAl 后,llya 拿了10亿美金对抗 Al 作恶 极窖公园 丨阅读原文 lya Sutskever, OpenAl的前联合创始人,成立了SS1 (Safe Superintelligence),旨在构建安全的Al模型。SSl获得了10亿美元的融资,估值达到50亿美元ÿ…...