当前位置: 首页 > news >正文

[线性RNN系列] Mamba: S4史诗级升级

 前言

iclr24终于可以在openreview上看预印本了

这篇(可能是颠覆之作)文风一眼c re组出品;效果实在太惊艳了,实验相当完善,忍不住写一篇解读分享分享。

TL;DR (overview)

Structured State-Space Model (SSM, S4) 是一个线性时不变系统 ( Linear Time Invariance, LTI), 其参数 (Δ,A,B,C) 是static的,与输入无关,i.e., data independent。 S4虽然在玩具数据集LRA上表现良好,但是在下游任务普遍拉垮。Attention机制的成功arguably可以认为是有data dependent的QKV矩阵来进行交互,这篇的核心思路是让这些参数data dependent,做出了如下的改动:

B: batch size, L: sentence length, D: input dimension, N: RNN hidden dimension

我们可以看到 B,C 的大小从原来的 (D,N) 变成了 (B,L,N) , Δ 的大小由原来的 D 变成了 (B,L,D) ,每个位置的 B,C,Δ 都不相同 (之前是在所有位置共享)。

虽然A没有data dependent, 但是通过state space model的离散化操作之后, (A¯,B¯) 会经过outer product 变成 (B,L,N,D) 的data dependent张量,以一种parameter efficient的方式来达到data dependent的目的。

其余主要改动/贡献如下(技术细节在文末):

(1) 由于SSM的参数data dependent, 此时失去了LTI的性质,不能像之前的S4一样通过FFT来训练了。本文提出了IO-aware的parallel scan(一种memory bounded算子)算法来进行高效训练,降低整体的读写量从而提高wall-time efficiency。上面提到的outer product的参数化方式也对降低整体读写量很有帮助(大致思路是 (A¯,B¯) 在SRAM里面on-the-fly算出来,避免materialization带来的读写开销)

(2) 如果用一个线性层参数化 Δ:R[B×L×D]→R[B×L×D]需要 D[2] 参数。本文提出了一种low-rank projection的参数化方式,可以通过很小的额外参数量来获得较大的提升。最后负责token mixing的SSM只需要很少的参数,绝大多数参数都分给channel mixing了。从MetaFormer的视角来看,token mixing相对channel mixing而言不是重要,所以从这个视角出发的话分配很少的参数是极其合理的。

(3) 以往的SSM经常需要一个output gate来达到很好的效果,如Gated SSM, 这个结构跟gated MLP很像。所以作者干脆把token mixing和channel mixing合二为一,提出了一个新的极简风的Mamba block。(Update: 这跟Gated Attention Unit挺像的)

如下图所示。

实验部分是最让人惊喜的:

Chinchilla scaling laws, 训练长度2048

其中Transformer++指的是带有Rope和SwiGLU的版本(i.e., LLaMa用的)。可以看到之前声称match Transformer performance的model基本上最多也就match一下vanilla transformer的结果 (i.e., 不带rope,如图绿线所示)(吐槽:Hyena是真的辣鸡)

Mamba在8192训练长度上也能match Transformer++的结果

下游任务evaluation,Mamba无情刷榜

技术细节

S4简介

Recommended Reading:
Structured State Spaces for Sequence Modeling (S4)
Simplifying S4

S4的连续微分方程形式(一般也用不着):

离散形式:

其中最常用到的离散化方法是zero-order hold (ZOH):

其中 A¯∈R[N×N],B¯∈R[N×1],C∈R[1×N],Δ∈R, N 是SSM hidden state的大小。 需要强调的是 S4用的是Single-input-single-output (SISO), 即对应于每一个输入的维度,都有一套独立的SSM参数 (传统的RNN是MIMO, multiple-input-multiple-output, 很容易混淆)

Parameter-efficient的data dependent参数化方式

上面的S4的参数都是静态的,这肯定不行()所以要弄成data dependent的动态的

这一套的思路由来已久,CV领域的dynamic convolutional,Transformers里面的QKV, LSTM里面的gating都是类似的思想

注意到,对于每个input dimension A只需要N个参数, 因为我们通常会对A做对角化

作者用

来将 B,C,Δ data dependent化, 其中  Linear d(X) 是把 D维的输入向量 X 经过一个线性层map到 d 维。这里的总参数量大概是 D∗N∗2+D∗D 。 N 即SSM的hidden dimension,一般设的比较小 (e.g., 16),所以 D∗N∗2 部分的参数量是少头,而参数化 sΔ 的 D∗D 是大头(一般至少都是几k维)

所以作者用了一个low-rank projection来降低参数量:

sΔ(X)=LinearD⁡(Linear1⁡(X))

这样总参数量就从 D∗D 降低到了 2D 。

最后作者选择把A设成了data independent,作者给出的解释是反正离散化之后 A¯=exp⁡(ΔA) , Δ 的data dependent能够让整体的 A¯ data dependent。

(PS: 这个解释理由感觉有点牵强,因为如果这样的话, B 也完全可以data independent,靠 Δ 让 B¯ data dependent)


理解参数的含义和功能

step size Δthat represents the resolution of the input
discretization of SSMs is the principled foundation of heuristic gating mechanisms.

这个量跟RNN里的gating有着深刻的联系[1] ,data dependent的 Δ 跟RNN的forget gate的功能类似

经典的RNN gating可以理解成SSM离散化的一个特例。

而 B和C 所起到的功能类似于写(进RNN的memory)和读(取RNN的memory)。所以data dependent的B/C的功能跟RNN的input/output gate类似。

A的作用其实有点尴尬,因为 Δ 已经有点遗忘门的意思了。但注意到对于每个input维度来说, Δ 只是一个标量,而 A∈R[N×1] ,也就是说对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因(i.e., forget gate是跟隐藏层维度相同的一个向量,而不仅仅是一个标量)

这篇文章所强调的selectivity无非就是传统门控RNN经典的思想。。。属于是文艺复兴/新瓶装旧酒
Recommended Reading:
十分推荐一篇鞭辟入里的文章
Written Memories: Understanding, Deriving and Extending the LSTM

IO-aware Parallel Scan

因为现在的参数都是data dependent了,所以不再是LTI,也就失去了卷积的性质,不能用FFT来进行高效训练了。

不过这也不是什么问题,之前的S5已经指出了data dependent的SSM可以用parallel scan来进行训练。不过parallel scan依然是memory bounded的操作,对于SSM这种每个input维度对应一个RNN的SISO模型来说,总共有效的RNN hidden state可以理解成 N∗D ,所以实现的不好的话很容易比较慢。S5为了避免这个问题,选择了MIMO的方式并且降低总体的维度。Mamba选择迎难而上,利用kernel fusion, recomputation的经典优化思想来硬上 (PS: 很好很c re组)

一般的实现会提前先把大小为 (B,L,D,N) 的 A¯,B¯ 先算出来,然后把它们从HBM (high-bandwith memory, or GPU memopry) 读到SRAM, 然后调用scan算子算出 (B,L,D,N) 的output,写到HBM里面。再开一个kernel把 (B,L,D,N) 的output以及 (B,L,N) 的C读进来,multiply and sum with C得到最后的 (B,L,N) output 。整个过程的读写是 O(BLDN) 。本文提出的方法:

  • 把 (Δ,A,B,C) 读到SRAM里面,总共大小是 O(SLN+DN)
  • 在SRAM里面做离散化,得到 (B,L,D,N) 的 A¯,B¯
  • 在SRAM里面做scan,得到 (B,L,D,N) 的 output
  • multiply and sum with C,得到最后的 (B,L,D) output 写入HBM

整个过程的总读写量是 O(BLN) ,比之前省了O(N)。 backward的时候就把 A¯,B¯ 重算一遍,类似于flashattn重算attention分数矩阵的思想。只要重算的时间比读 O(BLDN) 快就算胜利

We benchmark the speed of the SSM scan operation (N = 16), as well as the end-to-end inference throughput of Mamba, in Figure 8.  Our efficient SSM scan is faster than the best attention implementation that we know of (FlashAttention-2 (Dao, 2023)) beyond sequence length 2K, and up to 20-40× faster than a standard scan implementation in PyTorch.

IO-aware的实现比naive实现快很多倍;(flash)scan 在输入长度2k的时候就开始比flashattention快了, 之后越长越快。同时scan也比long convolution (w/ FFT)快,再次给long convolution模型敲上丧钟(本来long conv模型inference的时候就很笨了,训练还慢就更...

Token mixing+Channel Mixing合二为一

之前的SSM模型要work,都会加上output gating,之后再过个线性层channel mixing,如上图的最左边所示。这两个部分跟Gated MLP(上图中间)右边的支路和最上面的channel mixing是一样的。所以SSM层如果跟Gated MLP叠的话,难免会感觉有点冗余,所以作者干脆把两个合二为一,把token mixing层和channel mixing层合二为一 (PS: 估计会有很深远的影响),并且做work了。

现在的新的Mamba block有 3ED[2] 个参数(E是FFN扩展的倍数,一般transformer里面E是扩大四倍)。如果E=4,那么正好对应于一个 12D[2] 也就是一层transformer layer的总参数量。但可能是因为RNN比较吃层数(也很好形象理解,RNN是比较local的模型,所以需要叠深度来换一层attend到的广度),所以作者选择E=2,一层包含两个这样的Mamda block。

消融实验

对不同参数data dependent的敏感性

上文提到 Δ 的作用类似遗忘门,而遗忘门毫无疑问是LSTM里面最重要的门[2],所以这个消融实验结果发现 Δ data dependent带来的收益效果最大就一点都令人惊讶啦

A用实数还是虚数,以及A的参数化方式

这篇发现complex的decay rate不如real;跟rwkv作者的观点一致。之前的data independent的ssm模型发现虚数挺重要的;这里的实验现象相左的可能原因是因为data dependent的ssm表达能力本身就足够强了,不需要复数带来的额外表达能力;而之前data independent的ssm如果不用虚数来对角化A,表达能力相当受限

\Delta参数化时使用的low-rank的rank size

之前提到了参数化 Δ 的时候用low-rank来降低ssm部分的参数。其中一个可能的深意是 Metaformer框架认为token mixing远不如channel mixing重要,所以与其把参数分配给token mixing,不如把参数分配给channel mixing。最上面的那一行是data independent;rank=1的时候可以发现就已经有提升了,证明了data dependent的有效性;之后接着加参数也有提升 (但不确定如果多出来的参数加到channel mixing里面会不会更好)

SSM hidden size的影响,上面是data independent, 下面是data dependent

我们可以看到data independent的时候,增大SSM hidden state size的帮助很小,反而增大了很多计算量;而data dependent的时候,增大SSM hidden state size的收益大得多,体现了selectivity的优势

这个表体现了把token mixing和channel mixing合二为一成一个单独的Mamba层的好处 (PS: 似乎只有对这个模型有效,对其他模型反向提升)。

总结

把经典LSTM选择性的思想引入了SSM,极致的implementation优化,solid的全方位的实验,惊艳的实验效果,可能彻底打破大家对RNN的印象

参考

  1. ^https://arxiv.org/abs/1804.11188
  2. ^https://arxiv.org/abs/1804.04849

附赠

【一】上千篇CVPR、ICCV顶会论文
【二】动手学习深度学习、花书、西瓜书等AI必读书籍
【三】机器学习算法+深度学习神经网络基础教程
【四】OpenCV、Pytorch、YOLO等主流框架算法实战教程

➤ 添加助理自取:

➤ 还可咨询论文辅导❤【毕业论文、SCI、CCF、中文核心、El会议】评职称、研博升学、本升海外学府!

相关文章:

[线性RNN系列] Mamba: S4史诗级升级

前言 iclr24终于可以在openreview上看预印本了 这篇(可能是颠覆之作)文风一眼c re组出品;效果实在太惊艳了,实验相当完善,忍不住写一篇解读分享分享。 TL;DR (overview) Structured State-Sp…...

【鸿蒙学习笔记】元服务

官方文档:元服务规格 目录标题 什么是元服务特征第一个元服务-案例介绍创建项目源码启动模拟器启动entry创建卡片出发元服务 什么是元服务 特征 免安装分包预加载老化和更新机制 第一个元服务-案例介绍 创建项目 源码 Entry Component struct WidgetCard {buil…...

LIS+找规律,CF 582B - Once Again...

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 582B - Once Again... 二、解题报告 1、思路分析 考虑朴素做法对T *n的数组求LIS 但是T * n可达1e9 思考一下,最优解无非就是几个循环节拼接,我们最差情况下对sqrt(T)个a[]求LIS即…...

数据赋能(145)——开发:数据拆分——实施过程、应用特点

实施过程 数据拆分的实施过程通常涉及以下几个关键步骤: 确定拆分目标和需求: 明确数据拆分的目的和需求,例如是为了减少数据处理的复杂性、提高查询效率还是为了满足特定的业务需求。根据需求确定拆分后的数据结构和拆分规则。选择拆分方法…...

【漏洞复现】Splunk Enterprise for Windows 任意文件读取漏洞 CVE-2024-36991

声明:本文档或演示材料仅用于教育和教学目的。如果任何个人或组织利用本文档中的信息进行非法活动,将与本文档的作者或发布者无关。 一、漏洞描述 Splunk Enterprise 是一款强大的机器数据管理和分析平台,广泛应用于企业中,用于实…...

FastAPI -- 第一弹

Hello World 经典的 Hello World 安装 pip install fastapi pip install "uvicorn[standard]"main.py from typing import Unionfrom fastapi import FastAPIapp FastAPI()app.get("/") def read_root():return {"Hello": "World"}…...

C++入门基础篇(1)

欢迎大家来到海盗猫鸥的博客—— 断更许久,让我们继续好好学习吧! 目录 1.namespace命名空间 命名空间的存在价值: 命名空间的定义: 命名空间的使用: 2.C输入输出函数 使用: 3.缺省参数 4.函数重载…...

基于html开发的在线网址导航在线工具箱源码

基于html开发的在线网址导航在线工具箱源码,将全部文件复制到服务器,入口文件是index.html 如需修改网址,可修改index.html 如需修改关于页面,可修改about里面的index页面 源码下载:https://download.csdn.net/down…...

【密码学】大整数分解问题和离散对数问题

公钥密码体制的主要思想是通过一种非对称性,即正向计算简单,逆向计算复杂的加密算法设计,来解决安全通信。本文介绍两种在密码学领域内最为人所熟知、应用最为广泛的数学难题——大整数分解问题与离散对数问题 一、大整数分解问题 &#xf…...

解析 pdfminer layout.py LAParams类及其应用实例

解析 pdfminer layout.py LAParams类及其应用实例 引言类的定义1. line_overlap2. char_margin3. word_margin4. line_margin5. boxes_flow6. detect_vertical7. all_texts 类的初始化参数验证类的表示总结 引言 在这篇文章中,我们将解析一个叫做 LAParams 的类。这…...

Redis官方可视化管理工具

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl RedisInsight是一个Redis可视化工具,提供设计、开发和优化 Redis 应用程序的功能。RedisInsight分为免费的社区版和一个付费的企业版,免费版具有基本…...

android 固定图片大小

在Android中,固定图片大小可以通过多种方法实现,这些方法主要涉及到ImageView控件的使用、Bitmap类的操作,以及第三方库(如Glide)的辅助。以下是几种常见的方法: 1. 使用ImageView控件 在Android的布局文…...

操作系统——内存管理(面试准备)

虚拟内存 单片机没有操作系统,每次写完代码,都需要借助工具把程序烧录进去,这样程序才能跑起来。 另外,单片机的CPU是直接操作内存的物理地址。 在这种情况下,想在内存中同时运行两个程序是不可能的,如果第…...

vue3实现vuedraggable实现拖拽到垃圾桶图标位置进行删除

当使用Vue 3和vuedraggable库时,你可以按照以下方式实现拖拽到垃圾桶图标位置进行删除的功能: 首先,确保你已经安装了vuedraggable库。如果没有安装,可以通过以下命令进行安装: vuedraggable 和vue-draggable-plus使…...

MySQL向自增列插入0失败问题

问题 在一次上线时,发现通过脚本添加的状态表中,待提交的状态不正确,本来应该是0,线上是101。 原因 默认情况下,MySQL对应自增列,认为0和null等价(因为mysql认为0不是最佳实践不推荐使用&…...

Python:Python基础知识(注释、命名、数据类型、运算符)

.注释 Python有两种注释方法:单行注释和多行注释。单行注释以#开头,多行注释以三个单引号 或三个双引号 """ 开头和结尾。 2.命名规则 命名规则: 大小写字母、数字、下划线和汉字等字符及组合; 注意事项: 大小写敏感、首…...

Protobuf: 大数据开发中的高效数据传输利器

作为一名大数据开发者,我经常需要处理海量的数据传输和存储。在这个过程中,选择一个高效、可靠的数据序列化工具至关重要。今天,我想和大家分享一下我在项目中使用 Protobuf 的经历。 目录 故事背景Protobuf 简介优点: 实战案例示…...

MySQL 面试相关问题

写在前面: 不喜勿喷,暴躁作者又不求你给钱【没办法,遇见的狗喷子太多了🐶】欢迎大家在评论区留言,指正文章中的信息错误有一些其他相关的问题,可以直接评论区留言,作者看到会及时更新到文章末尾…...

java org.aeonbits.owner库介绍

org.aeonbits.owner 是一个用于简化Java应用程序配置管理的库。它通过使用接口和注解来定义和读取配置,使得配置管理更加简洁和类型安全。以下是对这个库的一些主要特性和功能的介绍: 主要特性 类型安全的配置: OWNER 库允许开发者使用接口定义配置,从而提供了编译时的类型…...

YOLOv10改进 | 添加注意力机制篇 | 添加LSKAttention大核注意力机制助力极限涨点

一、本文介绍 在这篇文章中,我们将讲解如何将LSKAttention大核注意力机制应用于YOLOv10,以实现显著的性能提升。首先,我们介绍LSKAttention机制的基本原理,它主要通过将深度卷积层的2D卷积核分解为水平和垂直1D卷积核&#xff0…...

2025年能源电力系统与流体力学国际会议 (EPSFD 2025)

2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

unix/linux,sudo,其发展历程详细时间线、由来、历史背景

sudo 的诞生和演化,本身就是一部 Unix/Linux 系统管理哲学变迁的微缩史。来,让我们拨开时间的迷雾,一同探寻 sudo 那波澜壮阔(也颇为实用主义)的发展历程。 历史背景:su的时代与困境 ( 20 世纪 70 年代 - 80 年代初) 在 sudo 出现之前,Unix 系统管理员和需要特权操作的…...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...

基于TurtleBot3在Gazebo地图实现机器人远程控制

1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...

基于 TAPD 进行项目管理

起因 自己写了个小工具,仓库用的Github。之前在用markdown进行需求管理,现在随着功能的增加,感觉有点难以管理了,所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD,需要提供一个企业名新建一个项目&#…...

【Go语言基础【12】】指针:声明、取地址、解引用

文章目录 零、概述:指针 vs. 引用(类比其他语言)一、指针基础概念二、指针声明与初始化三、指针操作符1. &:取地址(拿到内存地址)2. *:解引用(拿到值) 四、空指针&am…...