Token Statistics Transformer:线性注意力革命,重新定义Transformer效率天花板
“TOKEN STATISTICS TRANSFORMER: LINEAR-TIME ATTENTION VIA VARIATIONAL RATE REDUCTION” 由Ziyang Wu等人撰写。文章提出一种新型Transformer注意力算子,通过对最大编码率降低( M C R 2 MCR^{2} MCR2)目标的变分形式进行展开优化得到,其计算复杂度与令牌数量呈线性关系,在保证性能的同时显著提高计算效率。
也因此 ToST 也作为 Spotlight 论文,入选了 ICLR 2025 大会。
下面是对本篇论文的重点总结。
-
研究背景
-
Transformer的问题:Transformer架构凭借注意力机制在多领域表现出色,但传统自注意力算子计算复杂度与令牌数量呈二次方关系,限制了模型扩展。
-
现有优化方法:已有方法如分块计算、滑动窗口注意力、低秩投影和Nystrom扩展等,试图解决自注意力的计算效率问题。
-
本文方法:基于“白盒”架构设计,从全新视角推导高效注意力算子,避免计算令牌间成对相似性。
-
-
理论基础
-
最大编码率降低表示学习:通过 M C R 2 MCR^{2} MCR2目标寻找合适的数据表示,其由扩展项和压缩项组成,分别促进特征扩展和组内压缩。
-
白盒深度网络构建:通过算法展开设计网络架构,将网络层操作视为优化目标函数的增量更新步骤。
-
-
Token Statistics Transformer(TOST)
-
编码率的新变分形式:提出基于矩阵谱的凹函数的 M C R 2 MCR^{2} MCR2目标变分形式,可通过计算矩阵乘积对角线的标量函数来上界大矩阵的函数值。
-
通过变分形式展开的高效架构:对变分目标进行梯度下降得到Token Statistics Self - Attention(TSSA)算子,其基于输入令牌特征的经验二阶矩统计进行低秩投影,而非计算令牌间成对相似性,计算和内存复杂度为线性。
-
实际实现考虑因素:实际中不强制U矩阵正交,通过寻找低维正交基降低其列数,并基于高斯混合模型估计组隶属矩阵Π,用TSSA算子构建TOST架构。
-
-
实验结果
-
TSSA算子的逐层分析:TOST注意力层能优化设计目标,且成员分配矩阵Π可对前景图像补丁聚类。
-
真实视觉数据集评估:TOST在ImageNet - 1k和迁移学习任务上性能与其他架构相当,但效率更高、参数更少,注意力图可视化显示其能自主学习分割和聚类。
-
语言和长序列任务评估:在长序列建模任务中,TOST性能优于多数基于Transformer的方法;在因果语言建模任务中,性能随模型规模提升,且计算效率更高。
-
-
结论与展望:提出的TOST架构通过新的注意力算子实现线性时间复杂度,性能与传统Transformer相当。未来需在大规模应用中验证其准确性,并设计更有效的MLP块替代方案。
总结完毕,下面我们一起来探究这篇论文所研究的 ToST 到底是怎么回事,为什么说它是线性注意力革命,能重新定义 Transformer 效率天花板,我们先来看一张图。
这张图表示 ToST 架构对比传统Transformer,可以在图中看出:
ToST在4096 token长度下,内存消耗仅为ViT的1/20(图源:论文)
一、注意力机制:从暴力美学到数学之美
2017年,Transformer以自注意力机制横扫NLP领域。其核心逻辑简单粗暴:让每个token与其他所有token对话。这种"全连接"式的设计虽然强大,却埋下了一个定时炸弹——当处理4096个token时,传统Transformer需要计算1600万次相似度!
ToST的突破在于发现了一个数学真理:无需两两对话,统计特征足以刻画全局关系。这就像从逐一采访每个公民转向分析人口普查数据,效率实现质的飞跃。
# 传统注意力计算(O(n²)复杂度)
def standard_attention(Q, K, V):scores = torch.matmul(Q, K.transpose(-2, -1)) # 两两相似度矩阵attn = torch.softmax(scores, dim=-1)return torch.matmul(attn, V)# ToST的统计注意力(O(n)复杂度)
def TSSA(X, heads=8):b, n, d = X.shapeproj = nn.Linear(d, heads*d)(X) # 投影到多头空间proj = proj.view(b, n, heads, d//heads).transpose(1,2)# 统计量计算(核心创新)stats = proj.pow(2).mean(dim=1, keepdim=True) # 二阶矩统计gate = 1 / (1 + stats) # 基于统计量的门控return (proj * gate).transpose(1,2).reshape(b, n, d)
二、ToST核心原理:用数学公式重塑注意力
1. 最大编码率缩减(MCR²)目标
Δ R = 1 2 log det ( I + d ϵ 2 Z Z ⊤ ) − 1 2 ∑ k = 1 K n k n log det ( I + d ϵ 2 Z k Z k ⊤ ) \Delta R = \frac{1}{2}\log\det(\mathbf{I}+\frac{d}{\epsilon^2}\mathbf{Z}\mathbf{Z}^\top) - \frac{1}{2}\sum_{k=1}^K \frac{n_k}{n}\log\det(\mathbf{I}+\frac{d}{\epsilon^2}\mathbf{Z}_k\mathbf{Z}_k^\top) ΔR=21logdet(I+ϵ2dZZ⊤)−21k=1∑Knnklogdet(I+ϵ2dZkZk⊤)
这个看似复杂的公式其实在做两件事:
- 全局扩张:让所有token特征尽可能分散(第一项最大化)
- 局部压缩:让同类token特征聚集(第二项最小化)
2. 变分编码率缩减(VRR)
通过引入正交投影矩阵 U k \mathbf{U}_k Uk,将原问题转化为:
R var = ∑ k = 1 K ∑ i = 1 d f ( ( U k ⊤ Z k ) i i 2 ) R^{\text{var}} = \sum_{k=1}^K \sum_{i=1}^d f\left( (\mathbf{U}_k^\top \mathbf{Z}_k)_{ii}^2 \right) Rvar=k=1∑Ki=1∑df((Uk⊤Zk)ii2)
其中 f ( x ) = log ( 1 + x ) f(x)=\log(1+x) f(x)=log(1+x)。这使得每个注意力头只需维护一个低维统计量。
三步实现线性复杂度:
- 特征投影:将d维特征映射到p维子空间(p << d)
- 统计门控:计算投影特征的二阶矩,生成抑制门控
- 残差连接:通过门控筛选重要特征方向
class TSSA(nn.Module):def __init__(self, dim, heads=8, dim_head=64):super().__init__()self.heads = headsself.scale = dim_head ** -0.5# 投影矩阵(学习不同统计视角)self.to_qkv = nn.Linear(dim, dim_head * heads * 3) # 动态门控生成self.gate = nn.Sequential(nn.Linear(dim_head, 1),nn.Sigmoid())def forward(self, x):b, n, _ = x.shapeqkv = self.to_qkv(x).chunk(3, dim=-1)# 多头投影q, k, v = map(lambda t: t.view(b, n, self.heads, -1).transpose(1,2), qkv)# 统计量计算(核心创新)stats = torch.einsum('bhid,bhjd->bhij', q, k).mean(dim=-1) # O(n)gate = self.gate(stats) # 基于统计量的动态门控# 门控特征聚合out = torch.einsum('bhij,bhjd->bhid', gate, v)return out.transpose(1,2).reshape(b, n, -1)
四、性能实测:效率与精度的双杀
1. 计算效率对比
模型 | 序列长度 | 内存占用(MB) | 推理时间(ms) |
---|---|---|---|
Transformer | 4096 | 12.8 | 342 |
ToST | 4096 | 0.6 | 28 |
2. 视觉任务表现
在ImageNet-1k上,ToST-Small以22.6M参数达到77.9% Top-1准确率,媲美ViT-Base(86.6M参数,79.8%),但计算量减少90%。
3. 长序列建模
在Long-Range Arena基准测试中,ToST在Path-X任务(16k长度)上以69.4%准确率超越Performer(77.0%),显存占用仅为1/10。
五、实战:用ToST构建高效语言模型
from torch import nn
import torchclass ToSTBlock(nn.Module):def __init__(self, dim, heads=8):super().__init__()self.attn = TSSA(dim, heads=heads)self.mlp = nn.Sequential(nn.Linear(dim, 4*dim),nn.GELU(),nn.Linear(4*dim, dim))self.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return xclass ToST(nn.Module):def __init__(self, num_layers=12, dim=768, heads=12):super().__init__()self.layers = nn.ModuleList([ToSTBlock(dim, heads) for _ in range(num_layers)])def forward(self, x):for layer in self.layers:x = layer(x)return x# 示例:处理512 token的文本序列
model = ToST()
x = torch.randn(1, 512, 768) # (batch, seq_len, dim)
print(model(x).shape) # torch.Size([1, 512, 768])
关键优化技巧:
- 动态门控量化:将统计门控转换为8位整数计算
- 内存复用:在投影阶段共享中间结果
- 混合精度训练:使用FP16存储统计矩阵
六、ToST的蝴蝶效应:AI未来的五大变革
-
大模型平民化
7B参数的ToST在单张3090显卡上可处理32k长度文本,成本降低10倍 -
实时视频理解
处理1080P视频(每帧产生2304个token)时,延迟从3.2秒降至0.3秒 -
科学计算革命
在蛋白质结构预测中,对10k氨基酸序列的处理时间从小时级缩短到分钟级 -
边缘智能爆发
在Jetson Nano等嵌入式设备上实现实时多模态推理 -
理论突破
为理解神经网络中的信息压缩提供了新的数学框架
七、挑战与展望
尽管ToST展现了巨大潜力,仍需解决:
- 统计偏差累积:长序列中统计误差的传播问题
- 多模态适配:如何统一视觉与语言的统计特征
- 动态序列处理:流式输入下的增量统计计算
马毅教授团队表示,下一步将探索:
# 伪代码:动态统计量更新
class StreamingTSSA:def update(self, new_token):self.stats = self.momentum * self.stats + (1 - self.momentum) * new_token**2self.gate = 1 / (1 + self.stats)return self.gate * new_token
这场由ToST引发的效率革命才刚刚开始。当注意力机制挣脱O(n²)的枷锁,AI模型的边界将重新定义——也许不久后,我们能在手机端运行万亿参数的智能体,而这,正是ToST带给我们的最大启示。
绑定的资源为本篇论文的原文,当然你也可以通过以下网站了解更多关于 ToST 的故事。
论文标题:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction论文地址:https://arxiv.org/abs/2412.17810
项目主页:https://robinwu218.github.io/ToST/
开源地址:https://github.com/RobinWu218/ToST
相关文章:

Token Statistics Transformer:线性注意力革命,重新定义Transformer效率天花板
“TOKEN STATISTICS TRANSFORMER: LINEAR-TIME ATTENTION VIA VARIATIONAL RATE REDUCTION” 由Ziyang Wu等人撰写。文章提出一种新型Transformer注意力算子,通过对最大编码率降低( M C R 2 MCR^{2} MCR2)目标的变分形式进行展开优化得到&…...
Django 5实用指南(二)项目结构与管理
2.1 Django5项目结构概述 当你创建一个新的 Django 项目时,Django 会自动生成一个默认的项目结构。这个结构是根据 Django 的最佳实践来设计的,以便开发者能够清晰地管理和维护项目中的各种组件。理解并管理好这些文件和目录结构是 Django 开发的基础。…...
JAVA监听器(学习自用)
一、什么是监听器 servlet监听器是一种特殊的接口,用于监听特定的事件(如请求创建和销毁、会话创建和销毁、上下文的初始化和销毁)。 当Web应用程序中反生特定事件时,Servlet容器就会自动调用监听器中相应的方法来处理这些事件。…...

Ubuntu下mysql主从复制搭建
本文介绍mysql 8.4主从集群的搭建,从单个机器安装到集群的配置,整体走了一遍,希望对大家有帮助。mysql 8.4和之前的版本命令上有些变化,大家用来参考。 0、环境 ubuntu: 22.04mysql:8.4 1、安装mysql 1…...

VirtualBox 中使用 桥接网卡 并设置 MAC 地址
在 VirtualBox 中使用 桥接网卡 并设置 MAC 地址,可以按照以下步骤操作: 步骤 1:设置桥接网卡 打开 VirtualBox,选择你的虚拟机,点击 “设置” (Settings)。进入 “网络” (Network) 选项卡。在 “适配器 1” (Adapt…...

Ubuntu 20 掉显卡驱动的解决办法
目录 问题背景解决办法Step1:首先查看当前linux内核Step2:重启Step3:进入ubuntu advanced (即高级选项)Step4:查看有哪些linux内核Step5:如果滚回老板kernel还是没有驱动,就找到驱动…...

EasyPoi系列之框架集成及基础使用
EasyPoi系列之框架集成及基础使用 1 EasyPoi1.1 gitee仓库地址 2 EasyPoi集成至SpringBoot2.1 maven引入jar包 3 EasyPoi Excel导出3.1 基于实体对象导出3.1.1 Excel 注解3.1.2 编写实体3.1.3 编写导出方法3.1.4 导出效果 3.2 基于模板导出3.2.1 编写模板文件3.2.2 编写导出方法…...

Web后端 Tomcat服务器
一 Tomcat Web 服务器 介绍: Tomcat是一个开源的Java Servlet容器和Web服务器,由Apache软件基金会开发。它实现了Java Servlet和JavaServer Pages (JSP) 技术,用于运行Java Web应用程序。Tomcat轻量、易于配置,常作为开发和部署…...
【RK3588嵌入式图形编程】-SDL2-构建模块化UI
构建模块化UI 文章目录 构建模块化UI1、概述2、创建UI管理器3、嵌套组件4、继承5、多态子组件6、总结在本文中,将介绍如何使用C++和SDL创建一个灵活且可扩展的UI系统,重点关注组件层次结构和多态性。 1、概述 在前面的文章中,我们介绍了应用程序循环和事件循环,这为我们的…...

面向机器学习的Java库与平台简介、适用场景、官方网站、社区网址
Java机器学习的库与平台 最近听到有的人说要做机器学习就一定要学Python,我想他们掌握的知道还不够系统全面。本文作者给大家介绍几种常用Java实现的机器学习库,快快收藏加关注吧~ Java机器学习库表格 Java机器学习库整理库/平台概念适合场…...

基于YOLO11深度学习的心脏超声图像间隔壁检测分割与分析系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标分割、人工智能
《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…...
ubuntu24基于虚拟机无法从主机拖拽文件夹
以下是解决问题的精简步骤: 安装 open-vm-tools-desktop: bash复制 sudo apt-get install open-vm-tools-desktop 重启虚拟机后,文字复制粘贴功能可正常工作。 禁用 Wayland: 编辑 /etc/gdm3/custom.conf 文件: bash复…...
常用Webpack Loader汇总介绍
引言 在前端项目开发中,Webpack 作为强大的模块打包工具,能够将各种资源进行打包处理。而其中的 Loader 则是 Webpack 处理不同类型文件的关键,它允许 Webpack 不仅仅局限于处理 JavaScript 文件,还能处理 CSS、图片、字体等多种…...

剑指 Offer II 023. 两个链表的第一个重合节点
comments: true edit_url: https://github.com/doocs/leetcode/edit/main/lcof2/%E5%89%91%E6%8C%87%20Offer%20II%20023.%20%E4%B8%A4%E4%B8%AA%E9%93%BE%E8%A1%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E9%87%8D%E5%90%88%E8%8A%82%E7%82%B9/README.md 剑指 Offer II 023. 两…...
个人搭建CDN加速服务 特网科技
在互联网快速发展的今天,网站的加载速度对用户体验有着至关重要的影响,传统的网页加载方式依赖于服务器的性能和网络环境,这使得某些网站的页面加载时间过长,用户体验不佳,为了解决这个问题,许多企业开始采…...

用deepseek学大模型08-卷积神经网络(CNN)
yuanbao.tencent.com 从入门到精通卷积神经网络(CNN),着重介绍的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据,预测结果的可视化展示, 模型应用场景和优缺点…...

蓝桥杯单片机基础部分——6、555定时器
前言 NE555是一个纯硬件的设计,旦硬件电路确定了,其功能也确定了,没有可编程的部分,也没什么好去理解的地方,如果理解不了就直接背代码,这里也不是很常考,大家了解一下就可以了,知道…...
Python学习心得函数
一、函数的定义及调用 1.函数的定义: 函数的定义:函数是将一段能实现某种特定功能的代码,使用函数名进行封装,并通过函数名称进行调用。从而达到一次编写,多次调用的目的。 2.函数类型分为两类: &#…...

神经网络实验——MLP
目录 1 目的 2 方法 3 源代码 4 结果 1 目的 ①熟悉 Python 的输入输出流; ②学会使用 matplotlib进行图像可视化; ③掌握神经网络的基本原理,学会使用 sklearn 库中的 MLPClassifier 函数构建基础的多层感知机神经网络分类器; ④学会使用网格查找进行超参数优…...

配置Api自动生成
我的飞书:https://rvg7rs2jk1g.feishu.cn/docx/TVlJdMgYLoDJrsxAwMgcCE14nxt 使用Springfox Swagger生成API,并导入Postman,完成API单元测试 Swagger: 是一套API定义的规范,按照这套规范的要求去定义接口及接口相关信息,再通过可…...

【Python】 -- 趣味代码 - 小恐龙游戏
文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...
【位运算】消失的两个数字(hard)
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...
大语言模型如何处理长文本?常用文本分割技术详解
为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...
将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?
Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
反射获取方法和属性
Java反射获取方法 在Java中,反射(Reflection)是一种强大的机制,允许程序在运行时访问和操作类的内部属性和方法。通过反射,可以动态地创建对象、调用方法、改变属性值,这在很多Java框架中如Spring和Hiberna…...

Linux-07 ubuntu 的 chrome 启动不了
文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了,报错如下四、启动不了,解决如下 总结 问题原因 在应用中可以看到chrome,但是打不开(说明:原来的ubuntu系统出问题了,这个是备用的硬盘&a…...
Rust 异步编程
Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台
🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...