当前位置: 首页 > news >正文

每日Attention学习23——KAN-Block

模块出处

[SPL 25] [link] [code] KAN See In the Dark


模块名称

Kolmogorov-Arnold Network Block (KAN-Block)


模块作用

用于vision的KAN结构


模块结构

在这里插入图片描述


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass Swish(nn.Module):def forward(self, x):return x * torch.sigmoid(x)class KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,spline_order=3,scale_noise=0.1,scale_base=1.0,scale_spline=1.0,enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,grid_eps=0.02,grid_range=[-1, 1],):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_sizeself.spline_order = spline_orderself.weight = nn.Parameter(torch.Tensor(out_features, in_features))self.bias = nn.Parameter(torch.Tensor(out_features))h = (grid_range[1] - grid_range[0]) / grid_sizegrid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0]).expand(in_features, -1).contiguous())self.register_buffer("grid", grid)self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.scale_noise = scale_noiseself.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)with torch.no_grad():noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features)- 1 / 2)* self.scale_noise/ self.grid_size)self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0)* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order],noise,))if self.enable_standalone_scale_spline:# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)def b_splines(self, x: torch.Tensor):"""Compute the B-spline bases for the given input tensor.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).Returns:torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresgrid: torch.Tensor = (self.grid)  # (in_features, grid_size + 2 * spline_order + 1)x = x.unsqueeze(-1)bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)for k in range(1, self.spline_order + 1):bases = ((x - grid[:, : -(k + 1)])/ (grid[:, k:-1] - grid[:, : -(k + 1)])* bases[:, :, :-1]) + ((grid[:, k + 1 :] - x)/ (grid[:, k + 1 :] - grid[:, 1:(-k)])* bases[:, :, 1:])assert bases.size() == (x.size(0),self.in_features,self.grid_size + self.spline_order,)return bases.contiguous()def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):"""Compute the coefficients of the curve that interpolates the given points.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).Returns:torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresassert y.size() == (x.size(0), self.in_features, self.out_features)A = self.b_splines(x).transpose(0, 1)  # (in_features, batch_size, grid_size + spline_order)B = y.transpose(0, 1)  # (in_features, batch_size, out_features)solution = torch.linalg.lstsq(A, B).solution  # (in_features, grid_size + spline_order, out_features)result = solution.permute(2, 0, 1)  # (out_features, in_features, grid_size + spline_order)assert result.size() == (self.out_features,self.in_features,self.grid_size + self.spline_order,)return result.contiguous()@propertydef scaled_spline_weight(self):return self.spline_weight * (self.spline_scaler.unsqueeze(-1)if self.enable_standalone_scale_splineelse 1.0)def forward(self, x: torch.Tensor):assert x.dim() == 2 and x.size(1) == self.in_featuresbase_output = F.linear(self.base_activation(x), self.base_weight)spline_output = F.linear(self.b_splines(x).view(x.size(0), -1),self.scaled_spline_weight.view(self.out_features, -1),)return base_output + spline_output@torch.no_grad()def update_grid(self, x: torch.Tensor, margin=0.01):assert x.dim() == 2 and x.size(1) == self.in_featuresbatch = x.size(0)splines = self.b_splines(x)  # (batch, in, coeff)splines = splines.permute(1, 0, 2)  # (in, batch, coeff)orig_coeff = self.scaled_spline_weight  # (out, in, coeff)orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2)  # (batch, in, out)# sort each channel individually to collect data distributionx_sorted = torch.sort(x, dim=0)[0]grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_sizegrid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)* uniform_step+ x_sorted[0]- margin)grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptivegrid = torch.concatenate([grid[:1]- uniform_step* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),grid,grid[-1:]+ uniform_step* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),],dim=0,)self.grid.copy_(grid.T)self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):"""Compute the regularization loss.This is a dumb simulation of the original L1 regularization as stated in thepaper, since the original one requires computing absolutes and entropy from theexpanded (batch, in_features, out_features) intermediate tensor, which is hiddenbehind the F.linear function if we want an memory efficient implementation.The L1 regularization is now computed as mean absolute value of the splineweights. The authors implementation also includes this term in addition to thesample-based regularization."""l1_fake = self.spline_weight.abs().mean(-1)regularization_loss_activation = l1_fake.sum()p = l1_fake / regularization_loss_activationregularization_loss_entropy = -torch.sum(p * p.log())return (regularize_activation * regularization_loss_activation+ regularize_entropy * regularization_loss_entropy)class DW_bn_relu(nn.Module):def __init__(self, dim=768):super(DW_bn_relu, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)self.bn = nn.BatchNorm2d(dim)self.relu = nn.ReLU()def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = self.bn(x)x = self.relu(x)x = x.flatten(2).transpose(1, 2)return xclass KANBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dim = in_featuresgrid_size=5spline_order=3scale_noise=0.1scale_base=1.0scale_spline=1.0base_activation=torch.nn.SiLUgrid_eps=0.02grid_range=[-1, 1]self.fc1 = KANLinear(in_features,hidden_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc2 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc3 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)   self.dwconv_1 = DW_bn_relu(hidden_features)self.dwconv_2 = DW_bn_relu(hidden_features)self.dwconv_3 = DW_bn_relu(hidden_features)self.drop = nn.Dropout(drop)self.shift_size = shift_sizeself.pad = shift_size // 2def forward(self, x, H, W):B, N, C = x.shapex = self.fc1(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_1(x, H, W)x = self.fc2(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_2(x, H, W)x = self.fc3(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_3(x, H, W)return xif __name__ == '__main__':x = torch.randn([1, 22*22, 128])kan = KANBlock(in_features=128)out = kan(x, H=22, W=22)print(out.shape)  # [1, 22*22, 128]

相关文章:

每日Attention学习23——KAN-Block

模块出处 [SPL 25] [link] [code] KAN See In the Dark 模块名称 Kolmogorov-Arnold Network Block (KAN-Block) 模块作用 用于vision的KAN结构 模块结构 模块代码 import torch import torch.nn as nn import torch.nn.functional as F import mathclass Swish(nn.Module)…...

今日写题04work

题目&#xff1a;移除链表元素 两种实现思路 思路一 使用双指针&#xff0c;prev&#xff0c;cur快慢指针解决。当cur不等于val&#xff0c;两个指针跳过。当等于val时&#xff0c;要考虑两种情况&#xff0c;一种是pos删&#xff0c;一种是头删除。 pos删除就是正常情况&am…...

Managed Lustre 和 WEKA:高性能文件系统的对比与应用

Managed Lustre 和 WEKA&#xff1a;高性能文件系统的对比与应用 1. 什么是 Managed Lustre&#xff1f;主要特点&#xff1a;适用场景&#xff1a; 2. 什么是 WEKA&#xff1f;主要特点&#xff1a;适用场景&#xff1a; 3. Managed Lustre 和 WEKA 的对比4. 如何选择 Managed…...

LeetCode541 反转字符串2

一、题目描述 给定一个字符串 s 和一个整数 k&#xff0c;从字符串开头算起&#xff0c;每计数至 2k 个字符&#xff0c;就反转这 2k 字符中的前 k 个字符。具体规则如下&#xff1a; 如果剩余字符少于 k 个&#xff0c;则将剩余字符全部反转。如果剩余字符小于 2k 但大于或等…...

MAC 系统关闭屏幕/睡眠 后被唤醒 Wake Requests

问题&#xff1b;查看wake 日志 pmset -g log | grep "Wake Requests" 为 Wake Requests [*processdasd requestSleepService...info"com.apple.alarm.user-invisible-com.apple.calaccessd...电源设置命令参考&#xff1a; pmset -g sched //查看定时…...

论文笔记:Multi-Head Mixture-of-Experts

2024 neurips 1 背景 稀疏混合专家&#xff08;SMoE&#xff09;可在不显著增加训练和推理成本的前提下提升模型的能力【比如Mixtral 8*7B&#xff0c;表现可以媲美LLaMA-2 70B】 但它也有两个问题 专家激活率低&#xff08;下图左&#xff09; 在优化时只有一小部分专家会被…...

vue和Django快速创建项目

一、VUE 1.创建 Vue 3 JavaScript 项目 npm create vitelatest 项目名称 -- --template vue创建 Vue 3 TypeScript 项目 npm create vitelatest 项目名称 -- --template vue-ts 2.然后 cd 项目名称 npm install npm install axios # 发送 API 请求 npm install pinia …...

Java LinkedList(单列集合)

LinkedList 是 Java 中实现了 List 接口的一个类&#xff0c;它属于 java.util 包。与 ArrayList 不同&#xff0c;LinkedList 是基于双向链表实现的&#xff0c;适合于频繁进行插入和删除操作的场景。 1. LinkedList 的基本特性 基于链表实现&#xff1a;LinkedList 使用双向…...

多线程基础面试题剖析

一、线程的创建方式有几种 创建线程的方式有两种&#xff0c;一种是继承Thread&#xff0c;一种是实现Runable 在这里推荐使用实现Runable接口&#xff0c;因为java是单继承的&#xff0c;一个类继承了Thread将无法继承其他的类&#xff0c;而java可以实现多个接口&#xff0…...

.NET SixLabors.ImageSharp v1.0 图像实用程序控制台示例

使用 C# 控制台应用程序示例在 Windows、Linux 和 MacOS 机器上处理图像&#xff0c;包括创建散点图和直方图&#xff0c;以及根据需要旋转图像以便正确显示。 这个小型实用程序库需要将 NuGet SixLabors.ImageSharp包&#xff08;版本 1.0.4&#xff09;添加到.NET Core 3.1/ …...

EasyExcel提取excel文档

目录 一、前言二、提取excel文档2.1、所有sheet----获取得到headerList和总行数2.2、所有sheet----获取合并单元格信息2.3、读取某个sheet的每行数据一、前言 EasyExcel 是阿里巴巴开源的一个高性能 Excel 读写库,相比于 Apache POI 和 JXL,它有明显的优势,特别是在处理大数…...

第十五届蓝桥杯嵌入式省赛真题(满分)

第十五届蓝桥杯嵌入式省赛真题 目录 第十五届蓝桥杯嵌入式省赛真题 一、题目 二、分析 1、配置 2、变量定义 3、LCD显示模块 4、按键模块 5、数据分析和处理模块 1、频率突变 2、频率超限 3、数据处理 三、评价结果 一、题目 二、分析 1、配置 首先是配置cubemx…...

ASP.NET Core Web应用(.NET9.0)读取数据库表记录并显示到页面

1.创建ASP.NET Core Web应用 选择.NET9.0框架 安装SqlClient依赖包 2.实现数据库记录读取: 引用数据库操作类命名空间 创建查询记录结构类 查询数据并返回数据集合 3.前端遍历数据并动态生成表格显示 生成结果:...

【Sceneform-EQR】实现3D场景背景颜色的定制化(背景融合的方式、Filament材质定制)

写在前面的话 Sceneform-EQR是基于&#xff08;filament&#xff09;扩展的一个用于安卓端的渲染引擎。故本文内容对Sceneform-EQR与Filament都适用。 需求场景 在使用Filament加载三维场景的过程中&#xff0c;一个3D场景对应加载一个背景纹理。而这样的话&#xff0c;即便…...

LeetCode1706

LeetCode1706 目录 LeetCode1706题目描述示例题目理解问题描述 示例分析思路分析问题核心 代码段代码逐行讲解1. 获取网格的列数2. 初始化结果数组3. 遍历每个球4. 逐行模拟下落过程5. 检查是否卡住6. 记录结果7. 返回结果数组 复杂度分析时间复杂度空间复杂度 总结的知识点1. …...

2517. 礼盒的最大甜蜜度(Maximum Tastiness of Candy Box)

2517. 礼盒的最大甜蜜度&#xff08;Maximum Tastiness of Candy Box&#xff09; 问题描述 给定一个正整数数组 price&#xff0c;其中 price[i] 表示第 i 类糖果的价格&#xff0c;另给定一个正整数 k。商店将 k 类不同糖果组合成礼盒出售。礼盒的甜蜜度是礼盒中任意两种糖…...

Golang 的字符编码与 regexp

前言 最近在使用 Golang 的 regexp 对网络流量做正则匹配时&#xff0c;发现有些情况无法正确进行匹配&#xff0c;找到资料发现 regexp 内部以 UTF-8 编码的方式来处理正则表达式&#xff0c;而网络流量是字节序列&#xff0c;由其中的非 UTF-8 字符造成的问题。 我们这里从 G…...

利用ollama 与deepseek r1大模型搭建本地知识库

1.安装运行ollama ollama下载 https://ollama.com/download/windows 验证ollama是否安装成功 ollama --version 访问ollama本地地址&#xff1a; http://localhost:11434/ 出现如下界面 ollama运行模型 ollama run llama3.2 ollama常用操作命令 启动 Ollama 服务&#xf…...

Java短信验证功能简单使用

注册登录阿里云官网&#xff1a;https://www.aliyun.com/ 搜索短信服务 自己一步步申请就可以了 开发文档&#xff1a; https://next.api.aliyun.com/api-tools/sdk/Dysmsapi?version2017-05-25&languagejava-tea&tabprimer-doc 1.引入依赖 <dependency>…...

CAS单点登录(第7版)21.可接受的使用政策

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 可接受的使用政策 概述 可接受的使用政策 CAS 也称为使用条款或 EULA&#xff0c;它允许用户在继续应用程序之前接受使用策略。此功能的生产级部署需要修改流&#xff0c;以便通过外部存…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...

DingDing机器人群消息推送

文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人&#xff0c;点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置&#xff0c;详见说明文档 成功后&#xff0c;记录Webhook 2 API文档说明 点击设置说明 查看自…...

Modbus RTU与Modbus TCP详解指南

目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...

数据结构:递归的种类(Types of Recursion)

目录 尾递归&#xff08;Tail Recursion&#xff09; 什么是 Loop&#xff08;循环&#xff09;&#xff1f; 复杂度分析 头递归&#xff08;Head Recursion&#xff09; 树形递归&#xff08;Tree Recursion&#xff09; 线性递归&#xff08;Linear Recursion&#xff09;…...

【Post-process】【VBA】ETABS VBA FrameObj.GetNameList and write to EXCEL

ETABS API实战:导出框架元素数据到Excel 在结构工程师的日常工作中,经常需要从ETABS模型中提取框架元素信息进行后续分析。手动复制粘贴不仅耗时,还容易出错。今天我们来用简单的VBA代码实现自动化导出。 🎯 我们要实现什么? 一键点击,就能将ETABS中所有框架元素的基…...