Mamba环境配置教程【自用】
1. 新建一个Conda虚拟环境
conda create -n mamba python=3.10
2. 进入该环境
conda activate mamba
3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio
直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio
CSDN资源
下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可
pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
4. 安装triton和transformers库
pip install triton==2.3.1
pip install transformers==4.43.3
5. 安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
6. 最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本
causal-conv1d —— 1.4.0
mamba-ssm —— 2.2.2
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装
pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
7. 安装好环境后,验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行
# Copyright (c) 2024, Tri Dao, Albert Gu.import math
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom einops import rearrange, repeattry:from causal_conv1d import causal_conv1d_fn
except ImportError:causal_conv1d_fn = Nonetry:from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:RMSNormGated, LayerNorm = None, Nonefrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combinedclass Mamba2Simple(nn.Module):def __init__(self,d_model,d_state=128,d_conv=4,conv_init=None,expand=2,headdim=64,ngroups=1,A_init_range=(1, 16),dt_min=0.001,dt_max=0.1,dt_init_floor=1e-4,dt_limit=(0.0, float("inf")),learnable_init_states=False,activation="swish",bias=False,conv_bias=True,# Fused kernel and sharding optionschunk_size=256,use_mem_eff_path=True,layer_idx=None, # Absorb kwarg for general moduledevice=None,dtype=None,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_modelself.d_state = d_stateself.d_conv = d_convself.conv_init = conv_initself.expand = expandself.d_inner = self.expand * self.d_modelself.headdim = headdimself.ngroups = ngroupsassert self.d_inner % self.headdim == 0self.nheads = self.d_inner // self.headdimself.dt_limit = dt_limitself.learnable_init_states = learnable_init_statesself.activation = activationself.chunk_size = chunk_sizeself.use_mem_eff_path = use_mem_eff_pathself.layer_idx = layer_idx# Order: [z, x, B, C, dt]d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheadsself.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)conv_dim = self.d_inner + 2 * self.ngroups * self.d_stateself.conv1d = nn.Conv1d(in_channels=conv_dim,out_channels=conv_dim,bias=conv_bias,kernel_size=d_conv,groups=conv_dim,padding=d_conv - 1,**factory_kwargs,)if self.conv_init is not None:nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)# self.conv1d.weight._no_weight_decay = Trueif self.learnable_init_states:self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))self.init_states._no_weight_decay = Trueself.act = nn.SiLU()# Initialize log dt biasdt = torch.exp(torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min))dt = torch.clamp(dt, min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))self.dt_bias = nn.Parameter(inv_dt)# Just to be explicit. Without this we already don't put wd on dt_bias because of the check# name.endswith("bias") in param_grouping.pyself.dt_bias._no_weight_decay = True# A parameterassert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)A_log = torch.log(A).to(dtype=dtype)self.A_log = nn.Parameter(A_log)# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)self.A_log._no_weight_decay = True# D "skip" parameterself.D = nn.Parameter(torch.ones(self.nheads, device=device))self.D._no_weight_decay = True# Extra normalization layer right before output projectionassert RMSNormGated is not Noneself.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)def forward(self, u, seq_idx=None):"""u: (B, L, D)Returns: same shape as u"""batch, seqlen, dim = u.shapezxbcdt = self.in_proj(u) # (B, L, d_in_proj)A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else Nonedt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)if self.use_mem_eff_path:# Fully fused pathout = mamba_split_conv1d_scan_combined(zxbcdt,rearrange(self.conv1d.weight, "d 1 w -> d w"),self.conv1d.bias,self.dt_bias,A,D=self.D,chunk_size=self.chunk_size,seq_idx=seq_idx,activation=self.activation,rmsnorm_weight=self.norm.weight,rmsnorm_eps=self.norm.eps,outproj_weight=self.out_proj.weight,outproj_bias=self.out_proj.bias,headdim=self.headdim,ngroups=self.ngroups,norm_before_gate=False,initial_states=initial_states,**dt_limit_kwargs,)else:z, xBC, dt = torch.split(zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1)dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)assert self.activation in ["silu", "swish"]# 1D Convolutionif causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * ngroups * d_state)xBC = xBC[:, :seqlen, :]else:xBC = causal_conv1d_fn(x=xBC.transpose(1, 2),weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),bias=self.conv1d.bias,activation=self.activation,).transpose(1, 2)# Split into 3 main branches: X, B, C# These correspond to V, K, Q respectively in the SSM/attention dualityx, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)y = mamba_chunk_scan_combined(rearrange(x, "b l (h p) -> b l h p", p=self.headdim),dt,A,rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),chunk_size=self.chunk_size,D=self.D,z=None,seq_idx=seq_idx,initial_states=initial_states,**dt_limit_kwargs,)y = rearrange(y, "b l h p -> b l (h p)")# Multiply "gate" branch and apply extra normalization layery = self.norm(y, z)out = self.out_proj(y)return outif __name__ == '__main__':model = Mamba2Simple(256).cuda()inputs = torch.randn(2, 128, 256).cuda()pred = model(inputs)print(pred.size())
参考文献
相关文章:

Mamba环境配置教程【自用】
1. 新建一个Conda虚拟环境 conda create -n mamba python3.102. 进入该环境 conda activate mamba3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio 直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、…...
2021 年 6 月青少年软编等考 C 语言二级真题解析
目录 T1. 数字放大思路分析 T2. 统一文件名思路分析 T3. 内部元素之和思路分析 T4. 整数排序思路分析 T5. 计算好数思路分析 T1. 数字放大 给定一个整数序列以及放大倍数 x x x,将序列中每个整数放大 x x x 倍后输出。 时间限制:1 s 内存限制&#x…...
2024网络安全、应用软件系统开发决赛技术文件
用软件系统开发技术方案 一、竞赛项目 2024 年全国电子信息行业第二届职工技能竞赛四川省应用 软件系统开发选拔赛分理论比赛和实际操作两个部分。理论比赛 成绩占30%,实际操作成绩占70%。 二、理论比赛 1、理论比赛范围 ①计算机系统基础知识: …...
CSP-J初赛每日题目2(答案)
二进制数 00100100和 00010100 的和是( )。 A.00101000 B.01100111 C.01000100 D.00111000 正确答案: D \color{green}{正确答案: D} 正确答案:D 解析: \color{red}{解析:} 解析: 00100100 36 \color{r…...
为什么Node.js不适合CPU密集型应用?
Node.js不适合CPU密集型应用的原因主要基于其设计理念和核心特性,具体可以归纳为以下几点: 单线程模型 Node.js采用单线程模型来处理用户请求和异步I/O操作。虽然这种模型在处理高并发I/O密集型任务时非常高效,因为它避免了传统多线程模型中的…...
数模原理精解【12】
文章目录 广义线性模型多元回归中的 R 2 R^2 R2(也称为决定系数)一、定义二、性质三、计算四、例子五、例题 偏相关系数一、定义二、计算三、性质四、例子 多元回归相关定义性质假设检验定义计算性质检验方法例子和例题例子例题例子 参考文献 广义线性模…...
steamdeck执行exe文件
命令行安装: sudo pacman xxxx //"xxxx"为软件名 ,或者搜索“arch linux 软件安装命令” 安装wine及wineZGUI 命令行输入: sudo pacman -S wine 后面需要输入密码,deck设置的用户密码即可(输入无反应是正…...
三、集合原理-3.2、HashMap(下)
3.2、HashMap(下) 3.2.2、单线程下的HashMap的工作原理(底层逻辑)是什么? 答: HashMap的源码位于Java的标准库中,你可以在java.util包中找到它。 以下是HashMap的简化源码示例,用于说明其实现逻辑&#…...
【激活函数】Activation Function——在卷积神经网络中的激活函数是一个什么样的角色??
【激活函数】Activation Function——在卷积神经网络中的激活函数是一个什么样的角色?? Activation Function——在卷积神经网络中的激活函数是一个什么样的角色?? 文章目录 【激活函数】Activation Function——在卷积神经网络中…...
重生之我在Java世界------学单例设计模式
什么是单例设计模式? 单例模式是面向对象编程中最简单却又最常用的设计模式之一。它的核心思想是确保一个类只有一个实例,并提供一个全局访问点。本文将深入探讨单例模式的原理、常见实现方法、优缺点,以及在使用过程中可能遇到的陷阱。 单…...
快速提升Python Pandas处理速度的秘诀
大家好,Python的Pandas库为数据处理和分析提供了丰富的功能,但当处理大规模数据时,性能问题往往成为瓶颈。本文将介绍一些在Pandas中进行性能优化的方法与技巧,帮助有效提升数据处理速度,优化代码运行效率。 1.数据类…...
在基于线程的环境中运行 MATLAB 函数
MATLAB 和其他工具箱中的数百个函数可以在基于线程的环境中运行。可以使用 backgroundPool 或 parpool("threads") 在基于线程的环境中运行代码。 要在后台运行函数,请使用 parfeval 和 backgroundPool。 具体信息可以参考Choose Between Thread-B…...

黑神话悟空+云技术,游戏新体验!
近期,一款名为黑神话悟空的游戏因其独特的艺术风格和创新的技术实现在玩家中产生了不小的影响。 而云桌面技术作为一种新兴的解决方案,正在改变人们的游戏体验方式,使得高性能游戏可以在更多设备上流畅运行。 那么,黑神话悟空如…...

【Android 13源码分析】WindowContainer窗口层级-3-实例分析
在安卓源码的设计中,将将屏幕分为了37层,不同的窗口将在不同的层级中显示。 对这一块的概念以及相关源码做了详细分析,整理出以下几篇。 【Android 13源码分析】WindowContainer窗口层级-1-初识窗口层级树 【Android 13源码分析】WindowCon…...

Redis常用操作及springboot整合redis
1. Redis和Mysql的区别 数据模型:二者都是数据库,但是不同的是mysql是进行存储到磁盘当中,而Redis是进行存储到内存中. 数据模型 : mysql的存储的形式是二维表而Redis是通过key-value键值对的形式进行存储数据. 实际的应用的场景: Redis适合于需要快速读写的场景&…...
动态规划day34|背包理论基础(1)(2)、46.携带研究材料(纯粹的01背包)、416. 分割等和子集(01背包的应用)
动态规划day34|背包理论基础(1)(2)、46.携带研究材料、416. 分割等和子集 背包理论基础(1)——二维背包理论基础(2)——一维46.携带研究材料(卡码网 01背包)1. 二维背包2. 一维背包 …...
pytorch优化器
在反向传播计算完所有参数的梯度后,还需要使用优化方法更新网络的权重和参数。例如,随机梯度下降法(SGD)的更新策略如下: weight weight - learning_rate * gradient 手动实现如下: learning_rate 0.01 …...

必备工具,AI生成证件照,再也不用麻烦他人,电子驾驶证等多种证件照一键生成
最近有一个生成证件照的开源项目很火,今天我们来学习一下。之前我生成证件照都是线下去拍照,线上使用也是各种限制,需要付费或看广告,而且效果也不是很理想, 今天要分享的这个 AI 证件照生成工具可以一键可以生成一寸…...

深度解析 MintRich 独特的价格曲线机制玩法
随着 Meme 币赛道的迅速崛起,NFT 市场也迎来了新的变革。作为一个创新的 NFT 发行平台,Mint.Rich 正掀起一场全民参与的 NFT 热潮。其简易的操作界面和独特的价格曲线设计,让任何人都能以极低的门槛发行和交易自己的 NFT,从而参与…...

实时数仓3.0DWD层
实时数仓3.0DWD层 DWD层设计要点:9.1 流量域未经加工的事务事实表9.1.1 主要任务9.1.2 思路9.1.3 图解9.1.4 代码 9.2 流量域独立访客事务事实表9.2.1 主要任务9.2.2 思路分析9.2.3 图解9.2.4 代码 9.3 流量域用户跳出事务事实表9.3.1 主要任务9.3.2 思路分析9.3.3 …...

从WWDC看苹果产品发展的规律
WWDC 是苹果公司一年一度面向全球开发者的盛会,其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具,对过去十年 WWDC 主题演讲内容进行了系统化分析,形成了这份…...

Linux --进程控制
本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块,…...

通过 Ansible 在 Windows 2022 上安装 IIS Web 服务器
拓扑结构 这是一个用于通过 Ansible 部署 IIS Web 服务器的实验室拓扑。 前提条件: 在被管理的节点上安装WinRm 准备一张自签名的证书 开放防火墙入站tcp 5985 5986端口 准备自签名证书 PS C:\Users\azureuser> $cert New-SelfSignedCertificate -DnsName &…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现指南针功能
指南针功能是许多位置服务应用的基础功能之一。下面我将详细介绍如何在HarmonyOS 5中使用DevEco Studio实现指南针功能。 1. 开发环境准备 确保已安装DevEco Studio 3.1或更高版本确保项目使用的是HarmonyOS 5.0 SDK在项目的module.json5中配置必要的权限 2. 权限配置 在mo…...
Spring Boot + MyBatis 集成支付宝支付流程
Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例(电脑网站支付) 1. 添加依赖 <!…...
用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章
用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章 摘要: 操作系统内核的安全性、稳定性至关重要。传统 Linux 内核模块开发长期依赖于 C 语言,受限于 C 语言本身的内存安全和并发安全问题,开发复杂模块极易引入难以…...

新版NANO下载烧录过程
一、序言 搭建 Jetson 系列产品烧录系统的环境需要在电脑主机上安装 Ubuntu 系统。此处使用 18.04 LTS。 二、环境搭建 1、安装库 $ sudo apt-get install qemu-user-static$ sudo apt-get install python 搭建环境的过程需要这个应用库来将某些 NVIDIA 软件组件安装到 Je…...
Electron简介(附电子书学习资料)
一、什么是Electron? Electron 是一个由 GitHub 开发的 开源框架,允许开发者使用 Web技术(HTML、CSS、JavaScript) 构建跨平台的桌面应用程序(Windows、macOS、Linux)。它将 Chromium浏览器内核 和 Node.j…...