当前位置: 首页 > article >正文

每日Attention学习22——Inverted Residual RWKV

模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import loadT_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])def get_norm(norm_layer='in_1d'):eps = 1e-6norm_dict = {'none': nn.Identity,'in_1d': partial(nn.InstanceNorm1d, eps=eps),'in_2d': partial(nn.InstanceNorm2d, eps=eps),'in_3d': partial(nn.InstanceNorm3d, eps=eps),'bn_1d': partial(nn.BatchNorm1d, eps=eps),'bn_2d': partial(nn.BatchNorm2d, eps=eps),# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),'bn_3d': partial(nn.BatchNorm3d, eps=eps),'gn': partial(nn.GroupNorm, eps=eps),'ln_1d': partial(nn.LayerNorm, eps=eps),# 'ln_2d': partial(LayerNorm2d, eps=eps),}return norm_dict[norm_layer]def get_act(act_layer='relu'):act_dict = {'none': nn.Identity,'sigmoid': Sigmoid,'swish': Swish,'mish': Mish,'hsigmoid': HardSigmoid,'hswish': HardSwish,'hmish': HardMish,'tanh': Tanh,'relu': nn.ReLU,'relu6': nn.ReLU6,'prelu': PReLU,'gelu': GELU,'silu': nn.SiLU}return act_dict[act_layer]class ConvNormAct(nn.Module):def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):super(ConvNormAct, self).__init__()self.has_skip = skip and dim_in == dim_outpadding = math.ceil((kernel_size - stride) / 2)self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)self.norm = get_norm(norm_layer)(dim_out)self.act = get_act(act_layer)(inplace=inplace)self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()def forward(self, x):shortcut = xx = self.conv(x)x = self.norm(x)x = self.act(x)if self.has_skip:x = self.drop_path(x) + shortcutreturn xclass SE(nn.Module):def __init__(self,in_chs: int,rd_ratio: float = 0.25,rd_channels: Optional[int] = None,act_layer: LayerType = nn.ReLU,gate_layer: LayerType = nn.Sigmoid,force_act_layer: Optional[LayerType] = None,rd_round_fn: Optional[Callable] = None,):super(SE, self).__init__()if rd_channels is None:rd_round_fn = rd_round_fn or roundrd_channels = rd_round_fn(in_chs * rd_ratio)act_layer = force_act_layer or act_layerself.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)self.act1 = create_act_layer(act_layer, inplace=True)self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)self.gate = create_act_layer(gate_layer)def forward(self, x):x_se = x.mean((2, 3), keepdim=True)x_se = self.conv_reduce(x_se)x_se = self.act1(x_se)x_se = self.conv_expand(x_se)return x * self.gate(x_se)def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):assert gamma <= 1/4B, N, C = input.shapeinput = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])B, C, H, W = input.shapeoutput = torch.zeros_like(input)output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]return output.flatten(2).transpose(1, 2)def RUN_CUDA(B, T, C, w, u, k, v):return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())class WKV(torch.autograd.Function):@staticmethoddef forward(ctx, B, T, C, w, u, k, v):ctx.B = Bctx.T = Tctx.C = Cassert T <= T_MAXassert B * C % min(C, 1024) == 0half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)ctx.save_for_backward(w, u, k, v)w = w.float().contiguous()u = u.float().contiguous()k = k.float().contiguous()v = v.float().contiguous()y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)wkv_cuda.forward(B, T, C, w, u, k, v, y)if half_mode:y = y.half()elif bf_mode:y = y.bfloat16()return y@staticmethoddef backward(ctx, gy):B = ctx.BT = ctx.TC = ctx.Cassert T <= T_MAXassert B * C % min(C, 1024) == 0w, u, k, v = ctx.saved_tensorsgw = torch.zeros((B, C), device='cuda').contiguous()gu = torch.zeros((B, C), device='cuda').contiguous()gk = torch.zeros((B, T, C), device='cuda').contiguous()gv = torch.zeros((B, T, C), device='cuda').contiguous()half_mode = (w.dtype == torch.half)bf_mode = (w.dtype == torch.bfloat16)wkv_cuda.backward(B, T, C,w.float().contiguous(),u.float().contiguous(),k.float().contiguous(),v.float().contiguous(),gy.float().contiguous(),gw, gu, gk, gv)if half_mode:gw = torch.sum(gw.half(), dim=0)gu = torch.sum(gu.half(), dim=0)return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())elif bf_mode:gw = torch.sum(gw.bfloat16(), dim=0)gu = torch.sum(gu.bfloat16(), dim=0)return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())else:gw = torch.sum(gw, dim=0)gu = torch.sum(gu, dim=0)return (None, None, None, gw, gu, gk, gv)class VRWKV_SpatialMix(nn.Module):def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):super().__init__()self.n_embd = n_embdattn_sz = n_embdself._init_weights()self.shift_pixel = shift_pixelif shift_pixel > 0:self.channel_gamma = channel_gammaelse:self.spatial_mix_k = Noneself.spatial_mix_v = Noneself.spatial_mix_r = Noneself.key = nn.Linear(n_embd, attn_sz, bias=False)self.value = nn.Linear(n_embd, attn_sz, bias=False)self.receptance = nn.Linear(n_embd, attn_sz, bias=False)self.key_norm = nn.LayerNorm(n_embd)self.output = nn.Linear(attn_sz, n_embd, bias=False)self.key.scale_init = 0self.receptance.scale_init = 0self.output.scale_init = 0def _init_weights(self):self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)def jit_func(self, x, patch_resolution):# Mix x with the previous timestep to produce xk, xv, xrB, T, C = x.size()# Use xk, xv, xr to produce k, v, rif self.shift_pixel > 0:xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)else:xk = xxv = xxr = xk = self.key(xk)v = self.value(xv)r = self.receptance(xr)sr = torch.sigmoid(r)return sr, k, vdef forward(self, x, patch_resolution=None):B, T, C = x.size()sr, k, v = self.jit_func(x, patch_resolution)x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)x = self.key_norm(x)x = sr * xx = self.output(x)return xclass iR_RWKV(nn.Module):def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):super().__init__()self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()dim_mid = int(dim_in * exp_ratio)self.ln1 = nn.LayerNorm(dim_mid)self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)self.has_skip = (dim_in == dim_out and stride == 1) and has_skipif attn_s==True:self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()self.proj_drop = nn.Dropout(drop)self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()self.attn_s=attn_sself.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)def forward(self, x):shortcut = xx = self.norm(x)x = self.conv(x)if self.attn_s:B, hidden, H, W = x.size()patch_resolution = (H,  W)x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)x = x.permute(0, 2, 1)x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hiddeh, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = x.permute(0, 2, 1)x = x.contiguous().view(B, hidden, h, w)x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))x = self.proj_drop(x)x = self.proj(x)x = (shortcut + self.drop_path(x)) if self.has_skip else xreturn xif __name__ == '__main__':x = torch.randn([1, 64, 11, 11]).cuda()ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()out = ir_rwkv(x)print(out.shape)  # [1, 64, 11, 11]

相关文章:

每日Attention学习22——Inverted Residual RWKV

模块出处 [arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation 模块名称 Inverted Residual RWKV (IR-RWKV) 模块作用 用于vision的RWKV结构 模块结构 模块代码 注&#xff1a;cpp扩展请参考作者原…...

机器学习之数学基础:线性代数、微积分、概率论 | PyTorch 深度学习实战

前一篇文章&#xff0c;使用线性回归模型逼近目标模型 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于 强化学习必修课&#xff1a;引领人工智能新时代【梗直哥瞿炜】 线性代数、微积分、概率论 …...

UNI-MOL: A UNIVERSAL 3D MOLECULAR REPRESENTATION LEARNING FRAMEWORK

UNI-MOL: A UNIVERSAL 3D MOLECULAR REPRESENTATION LEARNING FRAMEWORK Neurips23 推荐指数&#xff1a;#paper/⭐⭐⭐#​&#xff08;工作量不小) 动机 在大多数分子表征学习方法中&#xff0c;分子被视为 1D 顺序标记或2D 拓扑图&#xff0c;这限制了它们为下游任务整合…...

SQL Server查询计划操作符(7.3)——查询计划相关操作符(6)

7.3. 查询计划相关操作符 48)Key Lookup:该操作符对一个有簇索引的表进行书签查找。参数列包含簇索引的名字和用于查找簇索引中数据行的簇键。该操作符总是伴随一个Nested Loops操作符。如果其参数列中出现WITH PREFETCH子句,则查询处理器已决定使用异步预取(预读,read-ah…...

C语言【基础篇】之数组——解锁多维与动态数组的编程奥秘

数组 &#x1f680;前言&#x1f99c;数组的由来与用途&#x1f31f;一维数组详解&#x1f58a;️二维数组进阶&#x1f4af;动态数组原理&#x1f914;常见误区扫盲&#x1f4bb;学习路径建议✍️总结 &#x1f680;前言 大家好&#xff01;我是 EnigmaCoder。本文收录于我的专…...

C++ 字符串编码转换

UTF8 的string 转 UTF16 的 wstring std::wstring Utf8ToUtf16(const std::string& utf8Str) { // 获取 UTF-16 字符串所需的缓冲区大小 int wstrLength = MultiByteToWideChar(CP_UTF8, 0, utf8Str.c_str(), -1, NULL, 0); if (wstrLength == 0) { // …...

通讯录管理小程序

通讯录管理系统 是对c基础语法的巩固&#xff0c;比较简单的一个小程序&#xff0c;重点在于养成良好的c代码习惯。 通讯录是一个可以记录亲人、好友信息的工具。 本教程主要利用C来实现一个通讯录管理系统 下面是一些需要实现的功能&#xff1a; 1. 添加联系人 功能描述&…...

掌握API和控制点(从Java到JNI接口)_38 JNI从C调用Java函数 01

1. Why? 将控制点下移到下C/C层 对古典视角的反思 App接近User&#xff0c;所以App在整体架构里&#xff0c;是主导者&#xff0c;拥有控制权。所以&#xff0c; App是架构的控制点所在。Java函数调用C/C层函数&#xff0c;是合理的。 但是EIT造形告诉我们&#xff1a; App…...

理解UML中的四种关系:依赖、关联、泛化和实现

在软件工程中&#xff0c;统一建模语言&#xff08;UML&#xff09;是一种广泛使用的工具&#xff0c;用于可视化、设计、构造和文档化软件系统。UML提供了多种图表类型&#xff0c;如类图、用例图、序列图等&#xff0c;帮助开发者和设计师更好地理解系统的结构和行为。在UML中…...

windows蓝牙驱动开发-蓝牙 LE 邻近感应配置文件

邻近感应检测是蓝牙低功耗 (LE) 的常见用途。 本部分提供了创建可用于开发 UWP 设备应用的邻近感应配置文件的设备实现的指南。 在开发此应用之前&#xff0c;应熟悉蓝牙 LE 函数和蓝牙 LE 邻近感应配置文件规范。 示例服务声明 蓝牙低功耗引入了一个新的物理层&#xff0c;…...

【截图】selenium自动通过浏览器截取指定元素div的图片

【截图】selenium自动通过浏览器截取指定元素div的图片 思路 截取完整网页截图 通过元素的坐标 截图到指定位置的图片 前提是已经获取到 driver 了 # 定位目标divtarget_div driver.find_element(By.CLASS_NAME, headlines-right)# 获取div的位置和大小location target_div…...

【算法】动态规划专题⑨ —— 二维费用背包问题 python

目录 前置知识进入正题实战演练 前置知识 【算法】动态规划专题⑤ —— 0-1背包问题 滚动数组优化 python 进入正题 二维费用背包问题 方法思路 二维费用背包问题在传统背包问题的基础上增加了第二个维度的限制&#xff08;如重量&#xff09;。 每个物品具有两种费用&#x…...

免费windows pdf编辑工具Epdf

Epdf&#xff08;完全免费&#xff09; 作者&#xff1a;不染心 时间&#xff1a;2025/2/6 Github: https://github.com/dog-tired/Epdf Epdf Epdf 是一款使用 Rust 编写的 PDF 编辑器&#xff0c;目前仍在开发中。它提供了一系列实用的命令行选项&#xff0c;方便用户对 PDF …...

MVCC机制深度解析

在数据库管理系统中&#xff0c;多版本并发控制&#xff08;MVCC&#xff0c;Multi-Version Concurrency Control&#xff09;是一种用于提高数据库并发性能的技术。它通过在同一数据项上存储多个版本&#xff0c;允许事务在读取数据时不必等待其他事务的完成&#xff0c;从而提…...

C++:类和对象初识

C&#xff1a;类和对象初识 前言类的引入与定义引入定义类的两种定义方法1. 声明和定义全部放在类体中2. 声明和定义分离式 类的成员变量命名规则 类的访问限定符及封装访问限定符封装 类的作用域与实例化类的作用域类实例化实例化方式&#xff1a; 类对象模型类对象的大小存储…...

伪分布式Spark3.4.4安装

参考&#xff1a;Spark2.1.0入门&#xff1a;Spark的安装和使用_厦大数据库实验室博客 我的版本&#xff1a; hadoop 3.1.3 hbase 2.2.2 java openjdk version "1.8.0_432" 问了chatgpt,建议下载Spark3.4.4&#xff0c;不适合下载Spark 2.1.0: step1 Spark下载…...

kafka服务端之控制器

文章目录 概述控制器的选举与故障恢复控制器的选举故障恢复 优雅关闭分区leader的选举 概述 在Kafka集群中会有一个或多个broker&#xff0c;其中有一个broker会被选举为控制器&#xff08;Kafka Controler&#xff09;&#xff0c;它负责管理整个集群中所有分区和副本的状态。…...

element-plus el-tree-select 修改 value 字段

element-plus el-tree-select 修改 value 字段 &#xff0c;不显示label 需要注意两个地方&#xff1a; <el-tree-select v-model"value" :data"data" multiple :render-after-expand"false" show-checkbox style"width: 240px" …...

SQL最佳实践(笔记)

写在前面&#xff1a; 之前baeldung的Java Weekly &#xfeff;Reviews里面推荐了一篇关于SQL优化的文章&#xff0c;正好最近在学习数据库相关知识&#xff0c;记一些学习笔记 原文地址&#xff1a;SQL Best Practices Every Java Engineer Must Know 1. 使用索引 使用索引…...

在 Java 中执行一个复杂的 SQL 查询(包含多表连接、子查询和聚合函数),如何确保查询的性能?请列举至少三条措施。请简要描述其工作原理?

在Java中执行复杂的SQL查询时&#xff0c;确保查询性能是非常重要的。 以下是三条关键措施&#xff0c;以及它们的详细解释、代码示例和实际开发中的注意事项。 1. 使用索引 索引是提高数据库查询性能的最基本手段之一。通过在查询条件中使用的列上创建索引&#xff0c;可以…...

java将list转成树结构

首先是实体类 public class DwdCusPtlSelectDto {//idprivate String key;//值private String value;//中文名private String title;private List<DwdCusPtlSelectDto> children;private String parentId;public void addChild(DwdCusPtlSelectDto child) {if(this.chil…...

【R语言】数据分析

一、描述性统计量 借助R语言内置的airquality数据集进行简单地演示&#xff1a; 1、集中趋势&#xff1a;均值和中位数 head(airquality) # 求集中趋势 mean(airquality$Ozone, na.rmT) # 求均值 median(airquality$Ozone, na.rmT) # 求中位数 2、众数 众数&#xff08;mod…...

传输层协议 UDP 与 TCP

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 一&#xff1a;&#x1f525; 前置复盘&#x1f98b; 传输层&#x1f98b; 再谈端口号&#x1f98b; 端口号范围划分&#x1f98b; 认识知名端口号 (Well-Know Port Number) 二&#xf…...

Linux 调用可执行程序

Linux 调用可执行程序 1. system() 函数1.1 system() 函数的声明1.2 system() 函数的不同场景返回值1.3 system() 函数的代码示例 2. exec() 函数族2.1 exec() 函数族的声明2.2 exec() 函数族执行失败的情况2.3 exec() 函数族的代码示例 3. exec() 与 system() 的区别以及使用注…...

Java/Kotlin双语革命性ORM框架Jimmer(一)——介绍与简单使用

概览 Jimmer是一个Java/Kotlin双语框架 包含一个革命性的ORM 以此ORM为基础打造了一套综合性方案解决方案&#xff0c;包括 DTO语言 更全面更强大的缓存机制&#xff0c;以及高度自动化的缓存一致性 更强大客户端文档和代码生成能力&#xff0c;包括Jimmer独创的远程异常 …...

剪辑学习整理

文章目录 1. 剪辑介绍 1. 剪辑介绍 剪辑可以干什么&#xff1f;剪辑分为哪些种类&#xff1f; https://www.bilibili.com/video/BV15r421p7aF/?spm_id_from333.337.search-card.all.click&vd_source5534adbd427e3b01c725714cd93961af 学完剪辑之后如何找工作or兼职&#…...

IDEA查看项目依赖包及其版本

一.IDEA将现有项目转换为Maven项目 在IntelliJ IDEA中,将现有项目转换为Maven项目是一个常见的需求,可以通过几种不同的方法来实现。Maven是一个强大的构建工具,它可以帮助自动化项目的构建过程,管理依赖关系,以及其他许多方面。 添加Maven支持 如果你的项目还没有pom.xm…...

centos虚拟机迁移没有ip的问题

故事背景&#xff0c;我们的centos虚拟机本来是好好的&#xff0c;但是拷贝到其他电脑上就不能分配ip&#xff0c;我个人觉得这个vmware他们软件应该搞定这个啊&#xff0c;因为这个问题是每次都会出现的。 网络选桥接 网络启动失败 service network restart Restarting netw…...

Java 大视界 -- Java 大数据在智能供应链中的应用与优化(76)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…...

Java中的继承及相关概念

在 Java 中&#xff0c;继承是一种允许一个类继承另一个类的特性。通过继承&#xff0c;子类可以获取父类的属性和方法&#xff0c;这有助于减少代码冗余并提高代码的可维护性。以下是关于文件内容的相关分析和知识点总结&#xff1a; 一、继承的核心概念 1.继承的语法 Java …...