当前位置: 首页 > news >正文

每日Attention学习22——Inverted Residual RWKV

模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import loadT_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])def get_norm(norm_layer='in_1d'):eps = 1e-6norm_dict = {'none': nn.Identity,'in_1d': partial(nn.InstanceNorm1d, eps=eps),'in_2d': partial(nn.InstanceNorm2d, eps=eps),'in_3d': partial(nn.InstanceNorm3d, eps=eps),'bn_1d': partial(nn.BatchNorm1d, eps=eps),'bn_2d': partial(nn.BatchNorm2d, eps=eps),# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),'bn_3d': partial(nn.BatchNorm3d, eps=eps),'gn': partial(nn.GroupNorm, eps=eps),'ln_1d': partial(nn.LayerNorm, eps=eps),# 'ln_2d': partial(LayerNorm2d, eps=eps),}return norm_dict[norm_layer]def get_act(act_layer='relu'):act_dict = {'none': nn.Identity,'sigmoid': Sigmoid,'swish': Swish,'mish': Mish,'hsigmoid': HardSigmoid,'hswish': HardSwish,'hmish': HardMish,'tanh': Tanh,'relu': nn.ReLU,'relu6': nn.ReLU6,'prelu': PReLU,'gelu': GELU,'silu': nn.SiLU}return act_dict[act_layer]class ConvNormAct(nn.Module):def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):super(ConvNormAct, self).__init__()self.has_skip = skip and dim_in == dim_outpadding = math.ceil((kernel_size - stride) / 2)self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)self.norm = get_norm(norm_layer)(dim_out)self.act = get_act(act_layer)(inplace=inplace)self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()def forward(self, x):shortcut = xx = self.conv(x)x = self.norm(x)x = self.act(x)if self.has_skip:x = self.drop_path(x) + shortcutreturn xclass SE(nn.Module):def __init__(self,in_chs: int,rd_ratio: float = 0.25,rd_channels: Optional[int] = None,act_layer: LayerType = nn.ReLU,gate_layer: LayerType = nn.Sigmoid,force_act_layer: Optional[LayerType] = None,rd_round_fn: Optional[Callable] = None,):super(SE, self).__init__()if rd_channels is None:rd_round_fn = rd_round_fn or roundrd_channels = rd_round_fn(in_chs * rd_ratio)act_layer = force_act_layer or act_layerself.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)self.act1 = create_act_layer(act_layer, inplace=True)self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)self.gate = create_act_layer(gate_layer)def forward(self, x):x_se = x.mean((2, 3), keepdim=True)x_se = self.conv_reduce(x_se)x_se = self.act1(x_se)x_se = self.conv_expand(x_se)return x * self.gate(x_se)def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):assert gamma <= 1/4B, N, C = input.shapeinput = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])B, C, H, W = input.shapeoutput = torch.zeros_like(input)output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]return output.flatten(2).transpose(1, 2)def RUN_CUDA(B, T, C, w, u, k, v):return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())class WKV(torch.autograd.Function):@staticmethoddef forward(ctx, B, T, C, w, u, k, v):ctx.B = Bctx.T = Tctx.C = Cassert T <= T_MAXassert B * C % min(C, 1024) == 0half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)ctx.save_for_backward(w, u, k, v)w = w.float().contiguous()u = u.float().contiguous()k = k.float().contiguous()v = v.float().contiguous()y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)wkv_cuda.forward(B, T, C, w, u, k, v, y)if half_mode:y = y.half()elif bf_mode:y = y.bfloat16()return y@staticmethoddef backward(ctx, gy):B = ctx.BT = ctx.TC = ctx.Cassert T <= T_MAXassert B * C % min(C, 1024) == 0w, u, k, v = ctx.saved_tensorsgw = torch.zeros((B, C), device='cuda').contiguous()gu = torch.zeros((B, C), device='cuda').contiguous()gk = torch.zeros((B, T, C), device='cuda').contiguous()gv = torch.zeros((B, T, C), device='cuda').contiguous()half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)wkv_cuda.backward(B, T, C,w.float().contiguous(),u.float().contiguous(),k.float().contiguous(),v.float().contiguous(),gy.float().contiguous(),gw, gu, gk, gv)if half_mode:gw = torch.sum(gw.half(), dim=0)gu = torch.sum(gu.half(), dim=0)return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())elif bf_mode:gw = torch.sum(gw.bfloat16(), dim=0)gu = torch.sum(gu.bfloat16(), dim=0)return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())else:gw = torch.sum(gw, dim=0)gu = torch.sum(gu, dim=0)return (None, None, None, gw, gu, gk, gv)class VRWKV_SpatialMix(nn.Module):def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):super().__init__()self.n_embd = n_embdattn_sz = n_embdself._init_weights()self.shift_pixel = shift_pixelif shift_pixel > 0:self.channel_gamma = channel_gammaelse:self.spatial_mix_k = Noneself.spatial_mix_v = Noneself.spatial_mix_r = Noneself.key = nn.Linear(n_embd, attn_sz, bias=False)self.value = nn.Linear(n_embd, attn_sz, bias=False)self.receptance = nn.Linear(n_embd, attn_sz, bias=False)self.key_norm = nn.LayerNorm(n_embd)self.output = nn.Linear(attn_sz, n_embd, bias=False)self.key.scale_init = 0self.receptance.scale_init = 0self.output.scale_init = 0def _init_weights(self):self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)def jit_func(self, x, patch_resolution):# Mix x with the previous timestep to produce xk, xv, xrB, T, C = x.size()# Use xk, xv, xr to produce k, v, rif self.shift_pixel > 0:xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)else:xk = xxv = xxr = xk = self.key(xk)v = self.value(xv)r = self.receptance(xr)sr = torch.sigmoid(r)return sr, k, vdef forward(self, x, patch_resolution=None):B, T, C = x.size()sr, k, v = self.jit_func(x, patch_resolution)x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)x = self.key_norm(x)x = sr * xx = self.output(x)return xclass iR_RWKV(nn.Module):def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):super().__init__()self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()dim_mid = int(dim_in * exp_ratio)self.ln1 = nn.LayerNorm(dim_mid)self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)self.has_skip = (dim_in == dim_out and stride == 1) and has_skipif attn_s==True:self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()self.proj_drop = nn.Dropout(drop)self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()self.attn_s=attn_sself.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)def forward(self, x):shortcut = xx = self.norm(x)x = self.conv(x)if self.attn_s:B, hidden, H, W = x.size()patch_resolution = (H,  W)x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)x = x.permute(0, 2, 1)x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hiddeh, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = x.permute(0, 2, 1)x = x.contiguous().view(B, hidden, h, w)x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))x = self.proj_drop(x)x = self.proj(x)x = (shortcut + self.drop_path(x)) if self.has_skip else xreturn xif __name__ == '__main__':x = torch.randn([1, 64, 11, 11]).cuda()ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()out = ir_rwkv(x)print(out.shape)  # [1, 64, 11, 11]

相关文章:

每日Attention学习22——Inverted Residual RWKV

模块出处 [arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation 模块名称 Inverted Residual RWKV (IR-RWKV) 模块作用 用于vision的RWKV结构 模块结构 模块代码 注&#xff1a;cpp扩展请参考作者原…...

使用jmeter进行压力测试

使用jmeter进行压力测试 jmeter安装 官网安装包下载&#xff0c;选择二进制文件&#xff0c;解压。 tar -xzvf apache-jmeter-x.tgz依赖jdk安装。 yum install java-1.8.0-openjdk环境变量配置&#xff0c;修改/etc/profile文件&#xff0c;添加以下内容。 export JMETER/…...

LQB(0)-python-基础知识

一、Python开发环境与基础知识 python解释器&#xff1a;用于解释python代码 方式&#xff1a; 1.直接安装python解释器 2.安装Anaconda管理python环境 python开发环境&#xff1a;用于编写python代码 1.vscode 2.pycharm # 3.安装Anaconda后可以使用网页版的jupyter n…...

每日Attention学习18——Grouped Attention Gate

模块出处 [ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation 模块名称 Grouped Attention Gate (GAG) 模块作用 轻量特征融合 模块结构 模块特点 特征融合前使用Group…...

QT 窗口A覆盖窗口B时,窗口B接受不到鼠标事件

一、问题 在项目的需求中&#xff0c;地图A上面需要叠放一个任务窗口B&#xff0c;B覆盖了A&#xff0c;导致A接受不到鼠标及滚轮事件。 二、解决方案 1、Qt::WA_TransparentForMouseEvents 是 Qt 框架中的一个属性&#xff0c;用于使指定的控件及其子控件不响应鼠标事件。当启…...

Unity安装教学与相关问题

文章目录 1. 前言2.Unity Hub2.1 下载Unity Hub2.2 安装Unity Hub2.3 注册Unity账号2.4 在Hub上登录账号2.5 在Hub上获取许可证 3. 下载并安装Unity3.1 从Unity Hub下载&#xff08;推荐&#xff09;3.1.1 选择下载版本3.1.2 选择下载组件3.1.3 安装Visual Studio Community 20…...

[Python人工智能] 四十九.PyTorch入门 (4)利用基础模块构建神经网络并实现分类预测

从本专栏开始,作者正式研究Python深度学习、神经网络及人工智能相关知识。前文讲解PyTorch构建回归神经网络。这篇文章将介绍如何利用PyTorch构建神经网络实现分类预测,其是使用基础模块构建。前面我们的Python人工智能主要以TensorFlow和Keras为主,而现在最主流的深度学习框…...

实现一个 LRU 风格的缓存类

实现一个缓存类 需求描述豆包解决思路&#xff1a;实现代码&#xff1a;优化11. std::list::remove 的时间复杂度问题2. 代码复用优化后的代码优化说明 优化21. 边界条件检查2. 异常处理3. 代码封装性4. 线程安全优化后的代码示例优化说明 DeepSeek&#xff08;深度思考R1&…...

【蓝桥杯嵌入式】4_key:单击+长按+双击

1、电路图 将4个按键的引脚设置为input&#xff0c;并将初始状态设置为Pull-up&#xff08;上拉输入&#xff09; 为解决按键抖动的问题&#xff0c;我们使用定时器中断进行消抖 打开TIM3时钟并设置参数&#xff0c;中断间隔10ms&#xff0c;当计数达到10000时溢出。80M/80/10…...

深入理解 C# 与.NET 框架

.NET学习资料 .NET学习资料 .NET学习资料 一、引言 在现代软件开发领域&#xff0c;C# 与.NET 框架是构建 Windows、Web、移动及云应用的强大工具。C# 作为一种面向对象的编程语言&#xff0c;而.NET 框架则是一个综合性的开发平台&#xff0c;它们紧密结合&#xff0c;为开…...

10. 神经网络(二.多层神经网络模型)

多层神经网络&#xff08;Multi-Layer Neural Network&#xff09;&#xff0c;也称为深度神经网络&#xff08;Deep Neural Network, DNN&#xff09;&#xff0c;是机器学习中一种重要的模型&#xff0c;能够通过多层次的非线性变换解决复杂的分类、回归和模式识别问题。以下…...

spark 性能调优 (一):执行计划

在 Spark 中&#xff0c;explain 函数用于提供数据框&#xff08;DataFrame&#xff09;或 SQL 查询的逻辑计划和物理执行计划的详细解释。它可以帮助开发者理解 Spark 是如何执行查询的&#xff0c;包括优化过程、转换步骤以及它将采用的物理执行策略。 1. 逻辑计划 (Logical…...

“卫星-无人机-地面”遥感数据快速使用及地物含量计算的实现方法

在与上千学员交流过程中&#xff0c;发现科研、生产和应用多源遥感数据时&#xff0c;能快速上手&#xff0c;发挥数据的时效性&#xff0c;尽快出创新性成果&#xff0c;是目前的学员最迫切的需求。特别是按照“遥感数据获取-处理-分析-计算-制图”全流程的答疑解惑&#xff0…...

杨氏数组中查找某一数值是否存在

判断数据是否存在于杨氏矩阵中 &#xff08;小米真题&#xff09; 题目&#xff1a;有一个数字矩阵&#xff0c;矩阵的每行从左到右是递增的&#xff0c;矩阵从上到下是递增的&#xff0c;请编写程序在这样的矩阵中查找某个数字是否存在。 要求&#xff1a;时间复杂度小于O(N) …...

c语言对应汇编写法(以中微单片机举例)

芯片手册资料 1. 赋值语句 C语言&#xff1a; a 5; b a; 汇编&#xff1a; ; 立即数赋值 LDIA 05H ; ACC 5 LD R01,A ; R01 ACC&#xff08;a5&#xff09;; 寄存器间赋值 LD A,R01 ; ACC R01&#xff08;读取a的值&#xff09; LD R02,A ; R02 ACC&…...

详解CSS `clear` 属性及其各个选项

详解CSS clear 属性及其各个选项 1. clear: left;示例代码 2. clear: right;示例代码 3. clear: both;示例代码 4. clear: none;示例代码 总结 在CSS布局中&#xff0c;clear 属性是一个非常重要的工具&#xff0c;特别是在处理浮动元素时。本文将详细解释 clear 属性及其各个选…...

算法设计与分析三级项目--管道铺设系统

摘 要 该项目使用c算法逻辑&#xff0c;开发环境为VS2022&#xff0c;旨在通过Prim算法优化建筑物间的连接路径&#xff0c;以支持管线铺设规划。可以读取文本文件中的建筑物名称和距离的信息&#xff0c;并计算出建筑物之间的最短连接路径和总路径长度&#xff0c;同时以利用…...

Page Assist - 本地Deepseek模型 Web UI 的安装和使用

Page Assist Page Assist是一个开源的Chrome扩展程序&#xff0c;为本地AI模型提供一个直观的交互界面。通过它可以在任何网页上打开侧边栏或Web UI&#xff0c;与自己的AI模型进行对话&#xff0c;获取智能辅助。这种设计不仅方便了用户随时调用AI的能力&#xff0c;还保护了…...

VMware Win10下载安装教程(超详细)

《网络安全自学教程》 从MSDN下载系统镜像&#xff0c;使用 VMware Workstation 17 Pro 安装 Windows 10 consumer家庭版 和 VMware Tools。 Win10下载安装 1、下载镜像2、创建虚拟机3、安装操作系统4、配置系统5、安装VMware Tools 1、下载镜像 到MSDN https://msdn.itellyou…...

DS目前曲线代替的网站汇总

DS目前还不稳定&#xff0c;好在国内外大厂平台都上线了&#xff0c;汇总如下&#xff1a; 秘塔搜索&#xff1a; https://metaso.cn 360纳米AI搜索&#xff1a; https://www.n.cn/ 硅基流动&#xff1a; https://cloud.siliconflow.cn/i/snHnLED8 字节跳动火山引擎&#xf…...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

浅谈不同二分算法的查找情况

二分算法原理比较简单&#xff0c;但是实际的算法模板却有很多&#xff0c;这一切都源于二分查找问题中的复杂情况和二分算法的边界处理&#xff0c;以下是博主对一些二分算法查找的情况分析。 需要说明的是&#xff0c;以下二分算法都是基于有序序列为升序有序的情况&#xf…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...