当前位置: 首页 > news >正文

Efficient-KAN 源码详解

Efficient-KAN源码链接

Efficient-KAN (GitHub)

改进细节

1.内存效率提升

KAN网络的原始实现的性能问题主要在于它需要扩展所有中间变量以执行不同的激活函数。对于具有in_features个输入和out_features个输出的层,原始实现需要将输入扩展为shape为(batch_size, out_features, in_features)的tensor以执行激活函数。然而,所有激活函数都是一组固定基函数(3阶B样条)的线性组合。鉴于此,拟将计算重新表述为不同的基函数激活输入,然后将它们线性组合。这种重新表述可以显著减少内存消耗,并使计算变得更加简单的矩阵乘法,自然地适用于前向和后向传递。

2.正则化方法的改变

稀疏化被认为对KAN的可解释性至关重要。作者提出了一种定义在输入样本上的L1正则化,它需要对**(batch_size, out_features, in_features)** tensor进行非线性操作,因此与重新表述不兼容。拟改为对权重进行L1正则化,这在NN中更为常见,并且与重新表述兼容。

3.激活函数缩放选项

除了可学习的激活函数(B样条),原始实现还包括对每个激活函数的可学习缩放 ( w s ) (w_s) (ws)。拟提供一个名为enable_standalone_scale_spline的选项,默认情况下为True,以包含此功能。禁用它会使模型更高效,但可能会影响结果。这需要更多实验验证。

4.参数初始化的改变

为了解决在MNIST数据集上的性能问题,该代码修改了参数的初始化方式,采用Kaiming初始化

KAN_fast.py解析

基本参数和类定义

import torch
import torch.nn.functional as F
import mathclass KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,  # 网格大小,默认为 5spline_order=3, # 分段多项式的阶数,默认为 3scale_noise=0.1,  # 缩放噪声,默认为 0.1scale_base=1.0,   # 基础缩放,默认为 1.0scale_spline=1.0,    # 分段多项式的缩放,默认为 1.0enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,  # 基础激活函数,默认为 SiLU(Sigmoid Linear Unit)grid_eps=0.02,grid_range=[-1, 1],  # 网格范围,默认为 [-1, 1]):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_size # 设置网格大小和分段多项式的阶数self.spline_order = spline_orderh = (grid_range[1] - grid_range[0]) / grid_size   # 计算网格步长

生成网格

 grid = ( # 生成网格(torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0] ).expand(in_features, -1).contiguous())
self.register_buffer("grid", grid)  # 将网格作为缓冲区注册

1. torch.arange(-spline_order, grid_size + spline_order + 1)
  • **torch.arange(start, end)**:生成一个从 startend-1 的整数序列(左闭右开区间)。
  • **-spline_order**:从负的 spline_order 开始。
  • **grid_size + spline_order + 1**:终止于 grid_size + spline_order(不包括 +1)。

这个序列的长度是 grid_size + 2 * spline_order + 1,用于涵盖所有需要的网格点,包括两端的扩展区域。

2. * h
  • 这一步将生成的整数序列乘以步长 h,将索引序列转换为实际的网格位置。
3. + grid_range[0]
  • 这一步将整个网格位置进行平移,使得网格的起始点与 grid_range[0] 对齐。

如果 grid_range[0] = -1,则每个位置都会减去 1

4. .expand(in_features, -1)
  • **.expand()**:将这个网格复制 in_features 次,以适应输入特征的维度。具体来说,它将原本的一维网格向量扩展成一个 in_features × (grid_size + 2 * spline_order + 1) 的二维张量。 其中每一行都是相同的网格向量。
5. .contiguous()
  • **.contiguous()**:确保扩展后的张量在内存中是连续存储的,方便后续的计算和操作。虽然在大多数情况下这个操作是可选的,但它可以提高计算效率并避免潜在的问题。

最终效果:

这段代码生成了一个二维张量 grid,它的形状为 [in_features, grid_size + 2 * spline_order + 1],其中每一行都是相同的、覆盖整个 grid_range 并适当扩展的网格点序列。这个网格用于模型中的 B 样条或其他基函数计算,使得模型可以在输入数据范围内执行灵活的插值和拟合操作。

初始化可训练参数

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) # 初始化基础权重和分段多项式权重self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:  # 如果启用独立的分段多项式缩放,则初始化分段多项式缩放参数self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))

1. self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w b ) (w_b) (wb)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的未初始化张量,用于存储基础线性层的权重。这个张量的元素初始时没有具体的数值,通常在后续的 reset_parameters() 方法中进行初始化。
  • **torch.nn.Parameter**:将这个张量封装成 torch.nn.Parameter 对象。这意味着这个张量会被视为模型的可训练参数,PyTorch 会自动将其包含在模型的参数列表中,并在反向传播时更新其值。
  • **self.base_weight**:这个属性存储的是基础线性变换的权重矩阵。这个矩阵将在前向传播过程中被用来对输入特征进行线性变换。

2. self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order)) ( c i ) (c_i) (ci)

  • **torch.Tensor(out_features, in_features, grid_size + spline_order)**:创建一个形状为 (out_features, in_features, grid_size + spline_order) 的未初始化张量,用于存储分段多项式的权重。这些权重将用于 B 样条或其他类似方法的计算。
  • **torch.nn.Parameter**:同样地,将这个张量封装成 torch.nn.Parameter,使其成为模型的可训练参数。
  • **self.spline_weight**:这个属性存储的是与分段多项式相关的权重。这些权重决定了如何将输入特征映射到输出特征,特别是在使用 B 样条等非线性激活函数时。
为什么 spline_weight 的形状是 (out_features, in_features, grid_size + spline_order)
  • **out_features****in_features**:与 base_weight 类似,表示输出和输入的特征数量。
  • **grid_size + spline_order**:这个维度表示在 B 样条或其他分段多项式方法中,每个输入特征需要使用的基函数的数量。通过这些基函数的线性组合,可以生成灵活的非线性激活。

3. if enable_standalone_scale_spline:

  • 这个条件语句检查 enable_standalone_scale_spline 是否为 True。如果为 True,则会为每个分段多项式激活函数引入一个独立的缩放参数。

4. self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w s ) (w_s) (ws)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的张量,用于存储独立的分段多项式缩放参数。
  • **torch.nn.Parameter**:将张量封装成 torch.nn.Parameter,使其成为可训练参数。
  • **self.spline_scaler**:这个属性存储的是分段多项式的缩放参数。每个 spline_weight 都有一个对应的缩放参数,可以单独调整其幅度,从而提供更大的灵活性。

其他实例属性

        self.scale_noise = scale_noise # 保存缩放噪声、基础缩放、分段多项式的缩放、是否启用独立的分段多项式缩放、基础激活函数和网格范围的容差self.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()  # 重置参数

Kaiming初始化权重(reset_parameters)

def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)# 使用 Kaiming 均匀初始化基础权重with torch.no_grad():noise = (# 生成缩放噪声(torch.rand(self.grid_size + 1, self

相关文章:

Efficient-KAN 源码详解

Efficient-KAN源码链接 Efficient-KAN (GitHub) 改进细节 1.内存效率提升 KAN网络的原始实现的性能问题主要在于它需要扩展所有中间变量以执行不同的激活函数。对于具有in_features个输入和out_features个输出的层,原始实现需要将输入扩展为shape为(batch_size, out_featur…...

Jlink commander使用方法(附指令大全)

Jlinkcmd它可以方便用户在非仿真的情况下,hold内核、单步、全速、设置断点、查看内核和外设寄存器、读取flash代码等等,方便大家拥有最高的权限查看在运行中的MCU情况,查找非IDE仿真情况下,MCU运行异常的原因。 目录 驱动安装 …...

Java SpringBoot实现PDF转图片

不是单页图片,是多页PDF转成一张图片的逻辑。 我这里的场景是PDF转成图片之后返回给前端,前端再在图片上实现签字,并且可拖拽的逻辑,就是签订合同的场景。 但是这里只写后端多页PDF转图片的逻辑。 先说逻辑,后面直接…...

elasticsearch SQL:在Elasticsearch中启用和使用SQL功能

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…...

Java 并发编程:线程变量 ThreadLocal

大家好&#xff0c;我是栗筝i&#xff0c;这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 029 篇文章&#xff0c;在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验&#xff0c;并希望进…...

【OpenHarmony4.1 之 U-Boot 2024.07源码深度解析】018 - init_sequence_f 各函数源码分析(二)

【OpenHarmony4.1 之 U-Boot 2024.07源码深度解析】018 - init_sequence_f 各函数源码分析(二) 一、arch_cpu_init二、arch_cpu_init系列文章汇总:《【OpenHarmony4.1 之 U-Boot 源码深度解析】000 - 文章链接汇总》 本文链接:《【OpenHarmony4.1 之 U-Boot 2024.07源码深度…...

LVS原理——详细介绍

目录 介绍 lvs简介 LVS作用 LVS 的优势与不足 LVS概念与相关术语 LVS的3种工作模式 LVS调度算法 LVS-dr模式 LVS-tun模式 ipvsadm工具使用 实验 nat模式集群部署 实验环境 webserver1配置 webserver2配置 lvs配置 dr模式集群部署 实验环境 router 效果呈现…...

MYSQL 5.7.36 等保 建设记录

文章目录 前言一、开启审计日志1.1 查看当前状态1.2 开启方式1.3 查看开启后状态 二、密码有效期2.1 查看当前状态2.2 开启方式2.3 查看开启后状态 三、密码复杂度3.1 查看当前状态3.2 开启方式3.3 查看开启后状态 四、连接控制4.1 查看当前状态4.2 开启方式4.3 查看开启后状态…...

fatal: unable to access ‘https://github.com/xxxxx

ubuntu中git克隆项目异常 git clone https://github.com/xxx Cloning into ‘xxx’… fatal: unable to access ‘https://github.com/xxx/xx.git/’: Could not resolve host: github.com 解决办法使用命令&#xff1a; git config --global http.proxy git config --global…...

从零开始的CPP(38)——递归与动态规划

leetcode46 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2&#xff1a; 输入&#…...

从战略到系统架构:信息系统设计的全面解析

在当今数字化时代&#xff0c;信息系统已成为企业运营、管理和创新的核心驱动力。信息系统设计的重要性不仅关乎企业的技术实现&#xff0c;更直接影响到企业的战略执行和市场竞争能力。本文将从战略视角出发&#xff0c;深入探讨信息系统设计的全过程&#xff0c;包括从战略制…...

GEE调用中国(China Land Cover Dataset,简称CLCD)1990-2022年30米分辨率的土地分类数据

博客推荐 GEE土地分类&#xff1a;中国30米年度土地覆盖产品annual China Land Cover Dataset, CLCD&#xff08;面积提取&#xff09;_30米土地利用数据gee-CSDN博客 简介 中国陆地覆盖数据集&#xff08;China Land Cover Dataset&#xff0c;简称CLCD&#xff09;是一个用…...

三十八、大数据技术之Kafka(1)

&#x1f33b;&#x1f33b; 目录 一、Kafka 概述1.1 定义1.2 消息队列1.2.1 消息队列内部实现原理1.2.2 传统消息队列的应用场景1.2.3 消息队列的两种模式 1.3 Kafka 基础架构 二、 Kafka 快速入门2.1 安装前的准备2.2 安装部署2.2.1 集群规划2.2.2 单节点或集群部署2.2.3 集群…...

将 Tcpdump 输出内容重定向到 Wireshark

在 Linux 系统中使用 Tcpdump 抓包后分析数据包不是很方便。 通常 Wireshark 比 tcpdump 更容易分析应用层协议。 一般的做法是在远程主机上先使用 tcpdump 抓取数据并写入文件&#xff0c;然后再将文件拷贝到本地工作站上用 Wireshark 分析。 还有一种更高效的方法&#xf…...

【Python蓝屏程序(管理员)】

说明&#xff1a;该程序为临摹(&#x1f600;)作品&#xff0c;源地址C蓝屏程序(非管理员) 我试图使用Python调用 NtRaiseHardError API &#xff0c;实现类似的蓝屏效果。可惜我发现Python在普通权限下&#xff0c;直接调用 NtRaiseHardError API 是不被允许的&#xff0c;因为…...

OpenGL ES->GLSurfaceView绘制图形的流程

自定义View代码 class MyGLSurfaceView(context: Context, attrs: AttributeSet) : GLSurfaceView(context, attrs), GLSurfaceView.Renderer {var mProgrem 0init {// 设置 OpenGL ES 3.0 版本setEGLContextClientVersion(3)// 设置当前类为渲染器, 注册回调接口的实现类set…...

Linux OOM Killer详解

Linux OOM Killer详解 一、概述二、OOM Killer的技术原理1. 内存区域划分2. 内存耗尽与OOM Killer触发3. 选择被杀进程的策略4. 内存回收机制5. 内存分配策略 三、OOM Killer的工作机制1. 内存压力监测2. 触发条件3. 选择被杀进程4. 终止进程 四、实际场景举例场景一&#xff1…...

2024rk(案例二)

试题二(25分) 阅读以下关于数据库缓存的叙述,在答题纸上回答问题1至问题3。 【说明】 某大型电商平台建立了一个在线 B2B 商店系统,并在全国多地建设了货物仓储中心,通过提前备货的方式来提高货物的运送效率。但是在运营过程中,发现会出现很多跨仓储中心调货从而延误货物…...

小红书爆文秘籍:ChatGPT助你从0到1创造热门内容!

在小红书打造爆款文案的策略中&#xff0c;以下是一些调整和同义词替换的建议&#xff0c;以便达到文章去重的要求&#xff1a; 了解目标受众&#xff1a; 在撰写文案前&#xff0c;先深入分析目标读者的属性&#xff0c;如年龄层次、性别、爱好和购买行为。通过ChatGPT, 你能迅…...

django快速实现个人博客(附源码)

文章目录 一、工程目录组织结构二、模型及管理实现1、模型2、admin管理 三、博客展现实现1、视图实现2、模板实现 四、部署及效果五、源代码 Django作为一款成熟的Python Web开发框架提供了丰富的内置功能&#xff0c;如ORM&#xff08;对象关系映射&#xff09;、Admin管理界面…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动

一、前言说明 在2011版本的gb28181协议中&#xff0c;拉取视频流只要求udp方式&#xff0c;从2016开始要求新增支持tcp被动和tcp主动两种方式&#xff0c;udp理论上会丢包的&#xff0c;所以实际使用过程可能会出现画面花屏的情况&#xff0c;而tcp肯定不丢包&#xff0c;起码…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

Python网页自动化Selenium中文文档

1. 安装 1.1. 安装 Selenium Python bindings 提供了一个简单的API&#xff0c;让你使用Selenium WebDriver来编写功能/校验测试。 通过Selenium Python的API&#xff0c;你可以非常直观的使用Selenium WebDriver的所有功能。 Selenium Python bindings 使用非常简洁方便的A…...