每日Attention学习23——KAN-Block
模块出处
[SPL 25] [link] [code] KAN See In the Dark
模块名称
Kolmogorov-Arnold Network Block (KAN-Block)
模块作用
用于vision的KAN结构
模块结构

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass Swish(nn.Module):def forward(self, x):return x * torch.sigmoid(x)class KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,spline_order=3,scale_noise=0.1,scale_base=1.0,scale_spline=1.0,enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,grid_eps=0.02,grid_range=[-1, 1],):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_sizeself.spline_order = spline_orderself.weight = nn.Parameter(torch.Tensor(out_features, in_features))self.bias = nn.Parameter(torch.Tensor(out_features))h = (grid_range[1] - grid_range[0]) / grid_sizegrid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0]).expand(in_features, -1).contiguous())self.register_buffer("grid", grid)self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.scale_noise = scale_noiseself.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)with torch.no_grad():noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features)- 1 / 2)* self.scale_noise/ self.grid_size)self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0)* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order],noise,))if self.enable_standalone_scale_spline:# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)def b_splines(self, x: torch.Tensor):"""Compute the B-spline bases for the given input tensor.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).Returns:torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresgrid: torch.Tensor = (self.grid) # (in_features, grid_size + 2 * spline_order + 1)x = x.unsqueeze(-1)bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)for k in range(1, self.spline_order + 1):bases = ((x - grid[:, : -(k + 1)])/ (grid[:, k:-1] - grid[:, : -(k + 1)])* bases[:, :, :-1]) + ((grid[:, k + 1 :] - x)/ (grid[:, k + 1 :] - grid[:, 1:(-k)])* bases[:, :, 1:])assert bases.size() == (x.size(0),self.in_features,self.grid_size + self.spline_order,)return bases.contiguous()def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):"""Compute the coefficients of the curve that interpolates the given points.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).Returns:torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresassert y.size() == (x.size(0), self.in_features, self.out_features)A = self.b_splines(x).transpose(0, 1) # (in_features, batch_size, grid_size + spline_order)B = y.transpose(0, 1) # (in_features, batch_size, out_features)solution = torch.linalg.lstsq(A, B).solution # (in_features, grid_size + spline_order, out_features)result = solution.permute(2, 0, 1) # (out_features, in_features, grid_size + spline_order)assert result.size() == (self.out_features,self.in_features,self.grid_size + self.spline_order,)return result.contiguous()@propertydef scaled_spline_weight(self):return self.spline_weight * (self.spline_scaler.unsqueeze(-1)if self.enable_standalone_scale_splineelse 1.0)def forward(self, x: torch.Tensor):assert x.dim() == 2 and x.size(1) == self.in_featuresbase_output = F.linear(self.base_activation(x), self.base_weight)spline_output = F.linear(self.b_splines(x).view(x.size(0), -1),self.scaled_spline_weight.view(self.out_features, -1),)return base_output + spline_output@torch.no_grad()def update_grid(self, x: torch.Tensor, margin=0.01):assert x.dim() == 2 and x.size(1) == self.in_featuresbatch = x.size(0)splines = self.b_splines(x) # (batch, in, coeff)splines = splines.permute(1, 0, 2) # (in, batch, coeff)orig_coeff = self.scaled_spline_weight # (out, in, coeff)orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2) # (batch, in, out)# sort each channel individually to collect data distributionx_sorted = torch.sort(x, dim=0)[0]grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_sizegrid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)* uniform_step+ x_sorted[0]- margin)grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptivegrid = torch.concatenate([grid[:1]- uniform_step* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),grid,grid[-1:]+ uniform_step* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),],dim=0,)self.grid.copy_(grid.T)self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):"""Compute the regularization loss.This is a dumb simulation of the original L1 regularization as stated in thepaper, since the original one requires computing absolutes and entropy from theexpanded (batch, in_features, out_features) intermediate tensor, which is hiddenbehind the F.linear function if we want an memory efficient implementation.The L1 regularization is now computed as mean absolute value of the splineweights. The authors implementation also includes this term in addition to thesample-based regularization."""l1_fake = self.spline_weight.abs().mean(-1)regularization_loss_activation = l1_fake.sum()p = l1_fake / regularization_loss_activationregularization_loss_entropy = -torch.sum(p * p.log())return (regularize_activation * regularization_loss_activation+ regularize_entropy * regularization_loss_entropy)class DW_bn_relu(nn.Module):def __init__(self, dim=768):super(DW_bn_relu, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)self.bn = nn.BatchNorm2d(dim)self.relu = nn.ReLU()def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = self.bn(x)x = self.relu(x)x = x.flatten(2).transpose(1, 2)return xclass KANBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dim = in_featuresgrid_size=5spline_order=3scale_noise=0.1scale_base=1.0scale_spline=1.0base_activation=torch.nn.SiLUgrid_eps=0.02grid_range=[-1, 1]self.fc1 = KANLinear(in_features,hidden_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc2 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc3 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,) self.dwconv_1 = DW_bn_relu(hidden_features)self.dwconv_2 = DW_bn_relu(hidden_features)self.dwconv_3 = DW_bn_relu(hidden_features)self.drop = nn.Dropout(drop)self.shift_size = shift_sizeself.pad = shift_size // 2def forward(self, x, H, W):B, N, C = x.shapex = self.fc1(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_1(x, H, W)x = self.fc2(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_2(x, H, W)x = self.fc3(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_3(x, H, W)return xif __name__ == '__main__':x = torch.randn([1, 22*22, 128])kan = KANBlock(in_features=128)out = kan(x, H=22, W=22)print(out.shape) # [1, 22*22, 128]
相关文章:
每日Attention学习23——KAN-Block
模块出处 [SPL 25] [link] [code] KAN See In the Dark 模块名称 Kolmogorov-Arnold Network Block (KAN-Block) 模块作用 用于vision的KAN结构 模块结构 模块代码 import torch import torch.nn as nn import torch.nn.functional as F import mathclass Swish(nn.Module)…...
基于Python的Optimal Interpolation (OI) 方法实现
前言 Optimal Interpolation (OI) 方法概述与实现 Optimal Interpolation (OI) 是一种广泛应用于气象学、海洋学等领域的空间数据插值方法。该方法通过结合观测数据与模型预测数据,最小化误差方差,从而实现对空间数据的最优插值。以下是OI方法的一般步骤…...
学习数据结构(10)栈和队列下+二叉树(堆)上
1.关于栈和队列的算法题 (1)用队列实现栈 解法一:(参考代码) 题目要求实现六个函数,分别是栈初始化,入栈,移除并返回栈顶元素,返回栈顶元素,判空࿰…...
.NET版Word处理控件Aspose.Words教程:使用 C# 删除 Word 中的空白页
Word 文档中的空白页会使其看起来不专业并扰乱流程。用户会遇到需要删除 Word 中的空白页的情况,但手动删除它们需要时间和精力。在这篇博文中,我们将探讨如何使用 C# 删除 Word 中的空白页。 本文涵盖以下主题: C# 库用于删除 Word 中的空…...
《代码随想录》刷题笔记——回溯篇【java实现】
文章目录 组合组合总和 III电话号码的字母组合组合总和组合总和II思路代码实现 分割回文串※思路字符串分割回文串判断效率优化※ 复原 IP 地址优化版本 子集子集 II使用usedArr辅助去重不使用usedArr辅助去重 递增子序列※全排列全排列 II重新安排行程题意代码 N 皇后解数独直…...
【JavaEE进阶】验证码案例
目 🌲实现说明 🎄Hutool介绍 🌳准备工作 🌴约定前后端交互接口 🚩接口定义 🚩实现服务器后端代码 🚩前端代码 🚩整体测试 🌲实现说明 随着安全性的要求越来越⾼…...
TCP/UDP 简介,三次握手与四次挥手
一、TCP 三次握手 目的:为了解决在不可靠的信道上建立可靠的网络连接 三次握手是连接请求的过程: A 发送连接请求的数据给 B(发送 SYN 包) B 同意连接,返回数据给 A(返回 SYNACK 包) A 收到后回…...
C++之线程池(Thread Pool)
1.介绍 线程池是一种并发编程的设计模式,用于管理和复用多个线程。以避免频繁创建和销毁线程的开销。线程池的核心思想是预先创建一组线程,并将任务分配给这些线程执行,从而提高程序的性能和资源利用率。 2.线程池的核心组件 一个经典的线程…...
Django中实现简单易用的分页工具
如何在Django中实现简单易用的分页工具?📚 嗨,小伙伴们!今天我们来看看如何在 Django 中实现一个超简单的分页工具。无论你是在处理博客文章、产品列表,还是用户评论,当数据量一大时,分页显得尤…...
【kafka系列】Exactly Once语义
目录 1. Exactly-Once语义的定义 2. Kafka实现Exactly-Once的机制 3. 端到端Exactly-Once示例 场景描述 3.1 生产者配置与代码 3.2 消费者配置与代码 4. 异常场景与Exactly-Once保障 场景1:生产者发送消息后宕机 场景2:消费者处理消息后宕机 场…...
export default与export区别
1.定义: export default:用于导出模块中的默认成员。一个模块中只能有一个export default,通常用于导出模块的主要功能或对象。导入时可以使用任意名称,因为它没有具体的名称 export:用于导出模块中的多个成…...
Qt Creator 5.0.2 (Community)用久了突然变得很卡
目录 1.现象 2.解决方案 1.现象 很久没有用Qt Creator开发项目了,刚刚结束的项目又是用VS2019开发的;这两天刚好有时间去学习一下Qt,刚好要用Qt Creator,结果一打开就没反应,主界面显示出来要好几分钟,最…...
Windows搭建CUDA大模型Docker环境
Windows搭建CUDA大模型Docker环境 一、安装Docker二、拉取镜像三、启动容器四、安装依赖环境五、安装Miniconda3六、设置pip源地址 一、安装Docker windows中docker安装教程 二、拉取镜像 系统:Ubuntu20.04CUDA版本:11.8.0 docker pull nvcr.io/nvid…...
阅读论文笔记《Efficient Estimation of Word Representations in Vector Space》
这篇文章写于2013年,对理解 word2vec 的发展历程挺有帮助。 本文仅适用于 Word2Vect 的复盘 引言 这篇论文致力于探索从海量数据中学习高质量单词向量的技术。当时已发现词向量能保留语义特征,例如 “国王 - 男人 女人≈女王”。论文打算借助该特性&am…...
初学PADS使用技巧笔记(也许会继续更新)
操作意图:网上找某个芯片封装又不想自己画,再加上没经验,怎么办? 就以AC-DC芯片PN8036为例,打开嘉立创的的DFM,打开立创商城,输入PN8036,点击数据手册,然后点击直接打开…...
C#学习之数据转换
目录 一、创作说明 二、数据类型之间的转换 1.数据类型之间的转换表格 2.代码示例 三、进制之间的转换 1.进制之间的转换表格 2.代码示例 四、ASCII 编码和字符之间的转换 1.ASCII 编码和字符之间的转换表格 2.代码示例 五、总结 一、创作说明 C#大多数时候都是和各…...
从无序到有序:上北智信通过深度数据分析改善会议室资源配置
当前企业普遍面临会议室资源管理难题,预约机制不完善和临时会议多导致资源调度不合理,既有空置又有过度拥挤现象。 针对上述问题,上北智信采用了专业数据分析手段,巧妙融合楼层平面图、环形图、折线图和柱形图等多种可视化工具&a…...
JavaScript 中toLocaleString()的基本用法
toLocaleString() 是 JavaScript 中多个内置对象(如 Number、Date、Array 等)都拥有的方法,其作用是将对象的值转换为符合特定语言环境的字符串表示形式。下面分别介绍不同对象使用该方法的具体用法。 1. Number.prototype.toLocaleString()…...
CAS单点登录(第7版)4.管理
如有疑问,请看视频:CAS单点登录(第7版) 管理 概述 Admin Console & 仪表板 CAS 提供了许多可用于管理 CAS 服务器部署的工具和控制板。此类选项通常不是互斥的,旨在协同工作并呈现 CAS 配置和构建的各个方面&am…...
Baklib一站式云平台:全场景赋能企业知识资产激活
内容概要 在数字化浪潮推动下,企业知识资产的高效管理与价值释放成为核心议题。Baklib作为一站式云平台,以全场景赋能为核心定位,通过构建知识中台架构,为企业提供从资源整合到应用落地的闭环解决方案。该平台不仅支持文本、图像…...
登录弹窗效果
1,要求 点击登录按钮,弹出登录窗口 提示1:登录窗口 display:none 隐藏状态; 提示2:登录按钮点击后,触发事件,修改 display:block 显示状态 提示3:登录窗口中点击关闭按钮࿰…...
文本表示方法
词向量 独热编码模型和分布式表征模型 独热编码分布式表征固定长度的稠密词向量优点一个单词一个维度,彼此之间构成标准正交向量组数字化后的数值可以表示语义上的关系缺点稀疏,词向量维度大导致计算效率低 独热编码会根据语料库中的单词个数,来确定词…...
小小小病毒(3)(~_~|)
一分耕耘一分收获 声明: 仅供损害电脑,不得用于非法。损坏电脑,作者一律不负责。此作为作者原创,转载请经过同意。 欢迎来到小小小病毒(3) 感谢大家的支持 还是那句话:上代码! …...
微软AutoGen高级功能——Memory
介绍 大家好,博主又来给大家分享知识了。这次又要给大家分享什么呢?哈哈。这次要给大家分享的是微软AutoGen框架的高级且重要的功能:Memory。在微软AutoGen中,Memory(记忆)是一个重要概念,它主要用于存储和管理智能体…...
Debezium系列之:时区转换器,时间戳字段转换到指定时区
Debezium系列之:时区转换器,时间戳字段转换到指定时区 示例:基本配置应用TimezoneConverter SMT的效果示例:高级配置配置选项当Debezium发出事件记录时,记录中的时间戳字段的时区值可能会有所不同,这取决于数据源的类型和配置。为了在数据处理管道和应用程序中保持数据一…...
【Java 面试 八股文】Spring Cloud 篇
Spring Cloud 篇 1. Spring Cloud 5大组件有哪些?2. 服务注册和发现是什么意思?Spring Cloud 如何实现服务注册发现?3. 我看你之前也用过nacos,你能说下nacos与eureka的区别?4. 你们项目负载均衡如何实现的?…...
Esxi8.0设置nvidia显卡直通安装最新驱动
ESXI8.0设置显卡直通 在某些情况下,我们需要多次切换操作系统,以测试软件是否适用于特定系统和环境,减少多次重装系统的麻烦 ESXI8.0安装包 通过网盘分享的文件:ESXi-8.0U2-22380479-USB-NVME-集成网卡镜像.iso 链接: https://…...
LabVIEW袜品压力测试系统
开发了一种基于LabVIEW开发的袜品压力测试系统。该系统利用LabVIEW并结合灵敏的传感器和高精度的处理模块,实现了对袜品压力的精确测量和分析。系统不同于传统的服装压力测试方法,为研究和评价袜子的舒适性提供了新的测试手段。 项目背景 该系统的…...
TestHubo基础教程-创建项目
TestHubo是一款国产开源一站式测试工具,涵盖功能测试、接口测试、性能测试,以及 Web 和 App 测试,可以满足不同类型项目的测试需求。本文将介绍如何快速创建第一个项目,以快速入门上手。 1、创建项目 在 TestHubo 中,…...
3.3 企业级AI Agent工程实践:从API设计到高可用架构的全栈开发指南
企业级AI Agent工程实践:从API设计到高可用架构的全栈开发指南 引言:AI Agent开发中的工程化挑战 据2024年DevOps状态报告,AI Agent项目的失败案例中**61%**源于工程实现缺陷。本文将基于GitHub Sentinel的实战案例,揭示如何构建支持百万级请求的工业级Agent系统,涵盖AP…...
