当前位置: 首页 > news >正文

每日Attention学习23——KAN-Block

模块出处

[SPL 25] [link] [code] KAN See In the Dark


模块名称

Kolmogorov-Arnold Network Block (KAN-Block)


模块作用

用于vision的KAN结构


模块结构

在这里插入图片描述


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass Swish(nn.Module):def forward(self, x):return x * torch.sigmoid(x)class KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,spline_order=3,scale_noise=0.1,scale_base=1.0,scale_spline=1.0,enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,grid_eps=0.02,grid_range=[-1, 1],):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_sizeself.spline_order = spline_orderself.weight = nn.Parameter(torch.Tensor(out_features, in_features))self.bias = nn.Parameter(torch.Tensor(out_features))h = (grid_range[1] - grid_range[0]) / grid_sizegrid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0]).expand(in_features, -1).contiguous())self.register_buffer("grid", grid)self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.scale_noise = scale_noiseself.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)with torch.no_grad():noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features)- 1 / 2)* self.scale_noise/ self.grid_size)self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0)* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order],noise,))if self.enable_standalone_scale_spline:# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)def b_splines(self, x: torch.Tensor):"""Compute the B-spline bases for the given input tensor.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).Returns:torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresgrid: torch.Tensor = (self.grid)  # (in_features, grid_size + 2 * spline_order + 1)x = x.unsqueeze(-1)bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)for k in range(1, self.spline_order + 1):bases = ((x - grid[:, : -(k + 1)])/ (grid[:, k:-1] - grid[:, : -(k + 1)])* bases[:, :, :-1]) + ((grid[:, k + 1 :] - x)/ (grid[:, k + 1 :] - grid[:, 1:(-k)])* bases[:, :, 1:])assert bases.size() == (x.size(0),self.in_features,self.grid_size + self.spline_order,)return bases.contiguous()def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):"""Compute the coefficients of the curve that interpolates the given points.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).Returns:torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresassert y.size() == (x.size(0), self.in_features, self.out_features)A = self.b_splines(x).transpose(0, 1)  # (in_features, batch_size, grid_size + spline_order)B = y.transpose(0, 1)  # (in_features, batch_size, out_features)solution = torch.linalg.lstsq(A, B).solution  # (in_features, grid_size + spline_order, out_features)result = solution.permute(2, 0, 1)  # (out_features, in_features, grid_size + spline_order)assert result.size() == (self.out_features,self.in_features,self.grid_size + self.spline_order,)return result.contiguous()@propertydef scaled_spline_weight(self):return self.spline_weight * (self.spline_scaler.unsqueeze(-1)if self.enable_standalone_scale_splineelse 1.0)def forward(self, x: torch.Tensor):assert x.dim() == 2 and x.size(1) == self.in_featuresbase_output = F.linear(self.base_activation(x), self.base_weight)spline_output = F.linear(self.b_splines(x).view(x.size(0), -1),self.scaled_spline_weight.view(self.out_features, -1),)return base_output + spline_output@torch.no_grad()def update_grid(self, x: torch.Tensor, margin=0.01):assert x.dim() == 2 and x.size(1) == self.in_featuresbatch = x.size(0)splines = self.b_splines(x)  # (batch, in, coeff)splines = splines.permute(1, 0, 2)  # (in, batch, coeff)orig_coeff = self.scaled_spline_weight  # (out, in, coeff)orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2)  # (batch, in, out)# sort each channel individually to collect data distributionx_sorted = torch.sort(x, dim=0)[0]grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_sizegrid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)* uniform_step+ x_sorted[0]- margin)grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptivegrid = torch.concatenate([grid[:1]- uniform_step* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),grid,grid[-1:]+ uniform_step* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),],dim=0,)self.grid.copy_(grid.T)self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):"""Compute the regularization loss.This is a dumb simulation of the original L1 regularization as stated in thepaper, since the original one requires computing absolutes and entropy from theexpanded (batch, in_features, out_features) intermediate tensor, which is hiddenbehind the F.linear function if we want an memory efficient implementation.The L1 regularization is now computed as mean absolute value of the splineweights. The authors implementation also includes this term in addition to thesample-based regularization."""l1_fake = self.spline_weight.abs().mean(-1)regularization_loss_activation = l1_fake.sum()p = l1_fake / regularization_loss_activationregularization_loss_entropy = -torch.sum(p * p.log())return (regularize_activation * regularization_loss_activation+ regularize_entropy * regularization_loss_entropy)class DW_bn_relu(nn.Module):def __init__(self, dim=768):super(DW_bn_relu, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)self.bn = nn.BatchNorm2d(dim)self.relu = nn.ReLU()def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = self.bn(x)x = self.relu(x)x = x.flatten(2).transpose(1, 2)return xclass KANBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dim = in_featuresgrid_size=5spline_order=3scale_noise=0.1scale_base=1.0scale_spline=1.0base_activation=torch.nn.SiLUgrid_eps=0.02grid_range=[-1, 1]self.fc1 = KANLinear(in_features,hidden_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc2 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc3 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)   self.dwconv_1 = DW_bn_relu(hidden_features)self.dwconv_2 = DW_bn_relu(hidden_features)self.dwconv_3 = DW_bn_relu(hidden_features)self.drop = nn.Dropout(drop)self.shift_size = shift_sizeself.pad = shift_size // 2def forward(self, x, H, W):B, N, C = x.shapex = self.fc1(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_1(x, H, W)x = self.fc2(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_2(x, H, W)x = self.fc3(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_3(x, H, W)return xif __name__ == '__main__':x = torch.randn([1, 22*22, 128])kan = KANBlock(in_features=128)out = kan(x, H=22, W=22)print(out.shape)  # [1, 22*22, 128]

相关文章:

每日Attention学习23——KAN-Block

模块出处 [SPL 25] [link] [code] KAN See In the Dark 模块名称 Kolmogorov-Arnold Network Block (KAN-Block) 模块作用 用于vision的KAN结构 模块结构 模块代码 import torch import torch.nn as nn import torch.nn.functional as F import mathclass Swish(nn.Module)…...

今日写题04work

题目&#xff1a;移除链表元素 两种实现思路 思路一 使用双指针&#xff0c;prev&#xff0c;cur快慢指针解决。当cur不等于val&#xff0c;两个指针跳过。当等于val时&#xff0c;要考虑两种情况&#xff0c;一种是pos删&#xff0c;一种是头删除。 pos删除就是正常情况&am…...

Managed Lustre 和 WEKA:高性能文件系统的对比与应用

Managed Lustre 和 WEKA&#xff1a;高性能文件系统的对比与应用 1. 什么是 Managed Lustre&#xff1f;主要特点&#xff1a;适用场景&#xff1a; 2. 什么是 WEKA&#xff1f;主要特点&#xff1a;适用场景&#xff1a; 3. Managed Lustre 和 WEKA 的对比4. 如何选择 Managed…...

LeetCode541 反转字符串2

一、题目描述 给定一个字符串 s 和一个整数 k&#xff0c;从字符串开头算起&#xff0c;每计数至 2k 个字符&#xff0c;就反转这 2k 字符中的前 k 个字符。具体规则如下&#xff1a; 如果剩余字符少于 k 个&#xff0c;则将剩余字符全部反转。如果剩余字符小于 2k 但大于或等…...

MAC 系统关闭屏幕/睡眠 后被唤醒 Wake Requests

问题&#xff1b;查看wake 日志 pmset -g log | grep "Wake Requests" 为 Wake Requests [*processdasd requestSleepService...info"com.apple.alarm.user-invisible-com.apple.calaccessd...电源设置命令参考&#xff1a; pmset -g sched //查看定时…...

论文笔记:Multi-Head Mixture-of-Experts

2024 neurips 1 背景 稀疏混合专家&#xff08;SMoE&#xff09;可在不显著增加训练和推理成本的前提下提升模型的能力【比如Mixtral 8*7B&#xff0c;表现可以媲美LLaMA-2 70B】 但它也有两个问题 专家激活率低&#xff08;下图左&#xff09; 在优化时只有一小部分专家会被…...

vue和Django快速创建项目

一、VUE 1.创建 Vue 3 JavaScript 项目 npm create vitelatest 项目名称 -- --template vue创建 Vue 3 TypeScript 项目 npm create vitelatest 项目名称 -- --template vue-ts 2.然后 cd 项目名称 npm install npm install axios # 发送 API 请求 npm install pinia …...

Java LinkedList(单列集合)

LinkedList 是 Java 中实现了 List 接口的一个类&#xff0c;它属于 java.util 包。与 ArrayList 不同&#xff0c;LinkedList 是基于双向链表实现的&#xff0c;适合于频繁进行插入和删除操作的场景。 1. LinkedList 的基本特性 基于链表实现&#xff1a;LinkedList 使用双向…...

多线程基础面试题剖析

一、线程的创建方式有几种 创建线程的方式有两种&#xff0c;一种是继承Thread&#xff0c;一种是实现Runable 在这里推荐使用实现Runable接口&#xff0c;因为java是单继承的&#xff0c;一个类继承了Thread将无法继承其他的类&#xff0c;而java可以实现多个接口&#xff0…...

.NET SixLabors.ImageSharp v1.0 图像实用程序控制台示例

使用 C# 控制台应用程序示例在 Windows、Linux 和 MacOS 机器上处理图像&#xff0c;包括创建散点图和直方图&#xff0c;以及根据需要旋转图像以便正确显示。 这个小型实用程序库需要将 NuGet SixLabors.ImageSharp包&#xff08;版本 1.0.4&#xff09;添加到.NET Core 3.1/ …...

EasyExcel提取excel文档

目录 一、前言二、提取excel文档2.1、所有sheet----获取得到headerList和总行数2.2、所有sheet----获取合并单元格信息2.3、读取某个sheet的每行数据一、前言 EasyExcel 是阿里巴巴开源的一个高性能 Excel 读写库,相比于 Apache POI 和 JXL,它有明显的优势,特别是在处理大数…...

第十五届蓝桥杯嵌入式省赛真题(满分)

第十五届蓝桥杯嵌入式省赛真题 目录 第十五届蓝桥杯嵌入式省赛真题 一、题目 二、分析 1、配置 2、变量定义 3、LCD显示模块 4、按键模块 5、数据分析和处理模块 1、频率突变 2、频率超限 3、数据处理 三、评价结果 一、题目 二、分析 1、配置 首先是配置cubemx…...

ASP.NET Core Web应用(.NET9.0)读取数据库表记录并显示到页面

1.创建ASP.NET Core Web应用 选择.NET9.0框架 安装SqlClient依赖包 2.实现数据库记录读取: 引用数据库操作类命名空间 创建查询记录结构类 查询数据并返回数据集合 3.前端遍历数据并动态生成表格显示 生成结果:...

【Sceneform-EQR】实现3D场景背景颜色的定制化(背景融合的方式、Filament材质定制)

写在前面的话 Sceneform-EQR是基于&#xff08;filament&#xff09;扩展的一个用于安卓端的渲染引擎。故本文内容对Sceneform-EQR与Filament都适用。 需求场景 在使用Filament加载三维场景的过程中&#xff0c;一个3D场景对应加载一个背景纹理。而这样的话&#xff0c;即便…...

LeetCode1706

LeetCode1706 目录 LeetCode1706题目描述示例题目理解问题描述 示例分析思路分析问题核心 代码段代码逐行讲解1. 获取网格的列数2. 初始化结果数组3. 遍历每个球4. 逐行模拟下落过程5. 检查是否卡住6. 记录结果7. 返回结果数组 复杂度分析时间复杂度空间复杂度 总结的知识点1. …...

2517. 礼盒的最大甜蜜度(Maximum Tastiness of Candy Box)

2517. 礼盒的最大甜蜜度&#xff08;Maximum Tastiness of Candy Box&#xff09; 问题描述 给定一个正整数数组 price&#xff0c;其中 price[i] 表示第 i 类糖果的价格&#xff0c;另给定一个正整数 k。商店将 k 类不同糖果组合成礼盒出售。礼盒的甜蜜度是礼盒中任意两种糖…...

Golang 的字符编码与 regexp

前言 最近在使用 Golang 的 regexp 对网络流量做正则匹配时&#xff0c;发现有些情况无法正确进行匹配&#xff0c;找到资料发现 regexp 内部以 UTF-8 编码的方式来处理正则表达式&#xff0c;而网络流量是字节序列&#xff0c;由其中的非 UTF-8 字符造成的问题。 我们这里从 G…...

利用ollama 与deepseek r1大模型搭建本地知识库

1.安装运行ollama ollama下载 https://ollama.com/download/windows 验证ollama是否安装成功 ollama --version 访问ollama本地地址&#xff1a; http://localhost:11434/ 出现如下界面 ollama运行模型 ollama run llama3.2 ollama常用操作命令 启动 Ollama 服务&#xf…...

Java短信验证功能简单使用

注册登录阿里云官网&#xff1a;https://www.aliyun.com/ 搜索短信服务 自己一步步申请就可以了 开发文档&#xff1a; https://next.api.aliyun.com/api-tools/sdk/Dysmsapi?version2017-05-25&languagejava-tea&tabprimer-doc 1.引入依赖 <dependency>…...

CAS单点登录(第7版)21.可接受的使用政策

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 可接受的使用政策 概述 可接受的使用政策 CAS 也称为使用条款或 EULA&#xff0c;它允许用户在继续应用程序之前接受使用策略。此功能的生产级部署需要修改流&#xff0c;以便通过外部存…...

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性&#xff0c;不同版本的Docker对内核版本有不同要求。例如&#xff0c;Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本&#xff0c;Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

PL0语法,分析器实现!

简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观&#xff0c;可持续性好高效率高复用&#xff0c;可移植性好高内聚&#xff0c;低耦合没有冗余规范性&#xff0c;代码有规可循&#xff0c;可以看出自己当时的思考过程特殊排版&#xff0c;特殊语法&#xff0c;特殊指令&#xff0c;必须…...

tauri项目,如何在rust端读取电脑环境变量

如果想在前端通过调用来获取环境变量的值&#xff0c;可以通过标准的依赖&#xff1a; std::env::var(name).ok() 想在前端通过调用来获取&#xff0c;可以写一个command函数&#xff1a; #[tauri::command] pub fn get_env_var(name: String) -> Result<String, Stri…...