当前位置: 首页 > news >正文

Mamba环境配置教程【自用】

1. 新建一个Conda虚拟环境

conda create -n mamba python=3.10

在这里插入图片描述

2. 进入该环境

conda activate mamba

在这里插入图片描述

3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio
直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio
CSDN资源
在这里插入图片描述

下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可

pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl

4. 安装triton和transformers库

pip install triton==2.3.1
pip install transformers==4.43.3

5. 安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

6. 最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本
causal-conv1d —— 1.4.0
在这里插入图片描述
mamba-ssm —— 2.2.2
在这里插入图片描述
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装

pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

7. 安装好环境后,验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行

# Copyright (c) 2024, Tri Dao, Albert Gu.import math
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom einops import rearrange, repeattry:from causal_conv1d import causal_conv1d_fn
except ImportError:causal_conv1d_fn = Nonetry:from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:RMSNormGated, LayerNorm = None, Nonefrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combinedclass Mamba2Simple(nn.Module):def __init__(self,d_model,d_state=128,d_conv=4,conv_init=None,expand=2,headdim=64,ngroups=1,A_init_range=(1, 16),dt_min=0.001,dt_max=0.1,dt_init_floor=1e-4,dt_limit=(0.0, float("inf")),learnable_init_states=False,activation="swish",bias=False,conv_bias=True,# Fused kernel and sharding optionschunk_size=256,use_mem_eff_path=True,layer_idx=None,  # Absorb kwarg for general moduledevice=None,dtype=None,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_modelself.d_state = d_stateself.d_conv = d_convself.conv_init = conv_initself.expand = expandself.d_inner = self.expand * self.d_modelself.headdim = headdimself.ngroups = ngroupsassert self.d_inner % self.headdim == 0self.nheads = self.d_inner // self.headdimself.dt_limit = dt_limitself.learnable_init_states = learnable_init_statesself.activation = activationself.chunk_size = chunk_sizeself.use_mem_eff_path = use_mem_eff_pathself.layer_idx = layer_idx# Order: [z, x, B, C, dt]d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheadsself.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)conv_dim = self.d_inner + 2 * self.ngroups * self.d_stateself.conv1d = nn.Conv1d(in_channels=conv_dim,out_channels=conv_dim,bias=conv_bias,kernel_size=d_conv,groups=conv_dim,padding=d_conv - 1,**factory_kwargs,)if self.conv_init is not None:nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)# self.conv1d.weight._no_weight_decay = Trueif self.learnable_init_states:self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))self.init_states._no_weight_decay = Trueself.act = nn.SiLU()# Initialize log dt biasdt = torch.exp(torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min))dt = torch.clamp(dt, min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))self.dt_bias = nn.Parameter(inv_dt)# Just to be explicit. Without this we already don't put wd on dt_bias because of the check# name.endswith("bias") in param_grouping.pyself.dt_bias._no_weight_decay = True# A parameterassert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)A_log = torch.log(A).to(dtype=dtype)self.A_log = nn.Parameter(A_log)# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)self.A_log._no_weight_decay = True# D "skip" parameterself.D = nn.Parameter(torch.ones(self.nheads, device=device))self.D._no_weight_decay = True# Extra normalization layer right before output projectionassert RMSNormGated is not Noneself.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)def forward(self, u, seq_idx=None):"""u: (B, L, D)Returns: same shape as u"""batch, seqlen, dim = u.shapezxbcdt = self.in_proj(u)  # (B, L, d_in_proj)A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else Nonedt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)if self.use_mem_eff_path:# Fully fused pathout = mamba_split_conv1d_scan_combined(zxbcdt,rearrange(self.conv1d.weight, "d 1 w -> d w"),self.conv1d.bias,self.dt_bias,A,D=self.D,chunk_size=self.chunk_size,seq_idx=seq_idx,activation=self.activation,rmsnorm_weight=self.norm.weight,rmsnorm_eps=self.norm.eps,outproj_weight=self.out_proj.weight,outproj_bias=self.out_proj.bias,headdim=self.headdim,ngroups=self.ngroups,norm_before_gate=False,initial_states=initial_states,**dt_limit_kwargs,)else:z, xBC, dt = torch.split(zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1)dt = F.softplus(dt + self.dt_bias)  # (B, L, nheads)assert self.activation in ["silu", "swish"]# 1D Convolutionif causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2))  # (B, L, self.d_inner + 2 * ngroups * d_state)xBC = xBC[:, :seqlen, :]else:xBC = causal_conv1d_fn(x=xBC.transpose(1, 2),weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),bias=self.conv1d.bias,activation=self.activation,).transpose(1, 2)# Split into 3 main branches: X, B, C# These correspond to V, K, Q respectively in the SSM/attention dualityx, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)y = mamba_chunk_scan_combined(rearrange(x, "b l (h p) -> b l h p", p=self.headdim),dt,A,rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),chunk_size=self.chunk_size,D=self.D,z=None,seq_idx=seq_idx,initial_states=initial_states,**dt_limit_kwargs,)y = rearrange(y, "b l h p -> b l (h p)")# Multiply "gate" branch and apply extra normalization layery = self.norm(y, z)out = self.out_proj(y)return outif __name__ == '__main__':model = Mamba2Simple(256).cuda()inputs = torch.randn(2, 128, 256).cuda()pred = model(inputs)print(pred.size())      

在这里插入图片描述

参考文献

相关文章:

Mamba环境配置教程【自用】

1. 新建一个Conda虚拟环境 conda create -n mamba python3.102. 进入该环境 conda activate mamba3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio 直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、…...

2021 年 6 月青少年软编等考 C 语言二级真题解析

目录 T1. 数字放大思路分析 T2. 统一文件名思路分析 T3. 内部元素之和思路分析 T4. 整数排序思路分析 T5. 计算好数思路分析 T1. 数字放大 给定一个整数序列以及放大倍数 x x x,将序列中每个整数放大 x x x 倍后输出。 时间限制:1 s 内存限制&#x…...

2024网络安全、应用软件系统开发决赛技术文件

用软件系统开发技术方案 一、竞赛项目 2024 年全国电子信息行业第二届职工技能竞赛四川省应用 软件系统开发选拔赛分理论比赛和实际操作两个部分。理论比赛 成绩占30%,实际操作成绩占70%。 二、理论比赛 1、理论比赛范围 ①计算机系统基础知识: …...

CSP-J初赛每日题目2(答案)

二进制数 00100100和 00010100 的和是( )。 A.00101000 B.01100111 C.01000100 D.00111000 正确答案: D \color{green}{正确答案: D} 正确答案:D 解析: \color{red}{解析:} 解析: 00100100 36 \color{r…...

为什么Node.js不适合CPU密集型应用?

Node.js不适合CPU密集型应用的原因主要基于其设计理念和核心特性,具体可以归纳为以下几点: 单线程模型 Node.js采用单线程模型来处理用户请求和异步I/O操作。虽然这种模型在处理高并发I/O密集型任务时非常高效,因为它避免了传统多线程模型中的…...

数模原理精解【12】

文章目录 广义线性模型多元回归中的 R 2 R^2 R2(也称为决定系数)一、定义二、性质三、计算四、例子五、例题 偏相关系数一、定义二、计算三、性质四、例子 多元回归相关定义性质假设检验定义计算性质检验方法例子和例题例子例题例子 参考文献 广义线性模…...

steamdeck执行exe文件

命令行安装: sudo pacman xxxx //"xxxx"为软件名 ,或者搜索“arch linux 软件安装命令” 安装wine及wineZGUI 命令行输入: sudo pacman -S wine 后面需要输入密码,deck设置的用户密码即可(输入无反应是正…...

三、集合原理-3.2、HashMap(下)

3.2、HashMap(下) 3.2.2、单线程下的HashMap的工作原理(底层逻辑)是什么? 答: HashMap的源码位于Java的标准库中,你可以在java.util包中找到它。 以下是HashMap的简化源码示例,用于说明其实现逻辑&#…...

【激活函数】Activation Function——在卷积神经网络中的激活函数是一个什么样的角色??

【激活函数】Activation Function——在卷积神经网络中的激活函数是一个什么样的角色?? Activation Function——在卷积神经网络中的激活函数是一个什么样的角色?? 文章目录 【激活函数】Activation Function——在卷积神经网络中…...

重生之我在Java世界------学单例设计模式

什么是单例设计模式? 单例模式是面向对象编程中最简单却又最常用的设计模式之一。它的核心思想是确保一个类只有一个实例,并提供一个全局访问点。本文将深入探讨单例模式的原理、常见实现方法、优缺点,以及在使用过程中可能遇到的陷阱。 单…...

快速提升Python Pandas处理速度的秘诀

大家好,Python的Pandas库为数据处理和分析提供了丰富的功能,但当处理大规模数据时,性能问题往往成为瓶颈。本文将介绍一些在Pandas中进行性能优化的方法与技巧,帮助有效提升数据处理速度,优化代码运行效率。 1.数据类…...

在基于线程的环境中运行 MATLAB 函数

MATLAB 和其他工具箱中的数百个函数可以在基于线程的环境中运行。可以使用 backgroundPool 或 parpool("threads") 在基于线程的环境中运行代码。 ​要在后台运行函数,请使用 parfeval 和 backgroundPool。​ ​具体信息可以参考Choose Between Thread-B…...

黑神话悟空+云技术,游戏新体验!

近期,一款名为黑神话悟空的游戏因其独特的艺术风格和创新的技术实现在玩家中产生了不小的影响。 而云桌面技术作为一种新兴的解决方案,正在改变人们的游戏体验方式,使得高性能游戏可以在更多设备上流畅运行。 那么,黑神话悟空如…...

【Android 13源码分析】WindowContainer窗口层级-3-实例分析

在安卓源码的设计中,将将屏幕分为了37层,不同的窗口将在不同的层级中显示。 对这一块的概念以及相关源码做了详细分析,整理出以下几篇。 【Android 13源码分析】WindowContainer窗口层级-1-初识窗口层级树 【Android 13源码分析】WindowCon…...

Redis常用操作及springboot整合redis

1. Redis和Mysql的区别 数据模型:二者都是数据库,但是不同的是mysql是进行存储到磁盘当中,而Redis是进行存储到内存中. 数据模型 : mysql的存储的形式是二维表而Redis是通过key-value键值对的形式进行存储数据. 实际的应用的场景: Redis适合于需要快速读写的场景&…...

动态规划day34|背包理论基础(1)(2)、46.携带研究材料(纯粹的01背包)、416. 分割等和子集(01背包的应用)

动态规划day34|背包理论基础(1)(2)、46.携带研究材料、416. 分割等和子集 背包理论基础(1)——二维背包理论基础(2)——一维46.携带研究材料(卡码网 01背包)1. 二维背包2. 一维背包 …...

pytorch优化器

在反向传播计算完所有参数的梯度后,还需要使用优化方法更新网络的权重和参数。例如,随机梯度下降法(SGD)的更新策略如下: weight weight - learning_rate * gradient 手动实现如下: learning_rate 0.01 …...

必备工具,AI生成证件照,再也不用麻烦他人,电子驾驶证等多种证件照一键生成

最近有一个生成证件照的开源项目很火,今天我们来学习一下。之前我生成证件照都是线下去拍照,线上使用也是各种限制,需要付费或看广告,而且效果也不是很理想, 今天要分享的这个 AI 证件照生成工具可以一键可以生成一寸…...

深度解析 MintRich 独特的价格曲线机制玩法

随着 Meme 币赛道的迅速崛起,NFT 市场也迎来了新的变革。作为一个创新的 NFT 发行平台,Mint.Rich 正掀起一场全民参与的 NFT 热潮。其简易的操作界面和独特的价格曲线设计,让任何人都能以极低的门槛发行和交易自己的 NFT,从而参与…...

实时数仓3.0DWD层

实时数仓3.0DWD层 DWD层设计要点:9.1 流量域未经加工的事务事实表9.1.1 主要任务9.1.2 思路9.1.3 图解9.1.4 代码 9.2 流量域独立访客事务事实表9.2.1 主要任务9.2.2 思路分析9.2.3 图解9.2.4 代码 9.3 流量域用户跳出事务事实表9.3.1 主要任务9.3.2 思路分析9.3.3 …...

Java应用等保三级合规改造:3天完成代码层、配置层、运维层全栈优化(附Checklist)

第一章:Java应用等保三级合规改造全景图等保三级是国家网络安全等级保护制度中面向重要信息系统的核心要求,对Java应用而言,合规改造不是单一技术点的修补,而是一套覆盖开发、运行、运维全生命周期的安全治理工程。其核心目标在于…...

重构求职效率:boss_batch_push批量投递工具的颠覆性价值

重构求职效率:boss_batch_push批量投递工具的颠覆性价值 【免费下载链接】boss_batch_push Boss直聘批量投简历,解放双手 项目地址: https://gitcode.com/gh_mirrors/bo/boss_batch_push boss_batch_push是一款专为Boss直聘平台设计的开源自动化投…...

新手入门指南:基于快马生成的代码理解设备配对功能实现

今天想和大家分享一个特别适合新手学习的设备配对功能实现案例。这个例子用最基础的HTML、CSS和原生JavaScript就能完成,特别适合刚接触前端开发的朋友理解交互逻辑。 项目结构设计 整个项目分为三个部分:两个模拟设备(用不同图标表示&#x…...

L1-064 估值一亿的ai核心代码 (分数20)字符串处理

•无论用户说什么,首先把对方说的话在一行中原样打印出来;•消除原文中多余空格:把相邻单词间的多个空格换成 1 个空格,把行首尾的空格全部删掉,把标点符号前面的空格删掉; •把原文中所有大写英文字母变成…...

2026最权威的AI写作神器解析与推荐

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 在学术研究范畴之内,人工智能技术的深度交融催生出了多种具备专业性的学术辅助平…...

Jimeng LoRA环境部署教程:Python+Torch+CUDA兼容性避坑与版本匹配指南

Jimeng LoRA环境部署教程:PythonTorchCUDA兼容性避坑与版本匹配指南 1. 项目简介 Jimeng LoRA(即梦LoRA)是一个专门为LoRA模型测试设计的轻量级文本生成图像系统。这个项目的核心价值在于它能让你只用加载一次基础模型,然后快速…...

企业微信考勤自动化解决方案:基于EasyWeChat的实战指南

企业微信考勤自动化解决方案:基于EasyWeChat的实战指南 【免费下载链接】easywechat 📦 一个 PHP 微信 SDK 项目地址: https://gitcode.com/gh_mirrors/ea/easywechat 在数字化办公普及的今天,企业考勤管理面临着数据采集繁琐、统计分…...

快速掌握C#语言基础知识点(16.访问修饰符)

关注我的动态 namespace _16.访问修饰符 {internal class Program {//私有内部类,被嵌套定义,能被直接外部类访问,外部类之外无法访问private class Class_Private{//公有public int a { get; set; }//私有private int b { get; set; }//受保…...

Phi-4-mini-reasoning 128K上下文应用创新:法律条文交叉引用推理案例

Phi-4-mini-reasoning 128K上下文应用创新:法律条文交叉引用推理案例 1. 模型简介与核心能力 Phi-4-mini-reasoning 是一个轻量级开源模型,专注于高质量推理任务。作为Phi-4模型家族成员,它通过合成数据训练和微调,特别擅长处理…...

别再只盯着数据了!用Arduino+GP2Y1014AU传感器,手把手教你做个能“看见”空气的PM2.5监测仪

用Arduino打造智能PM2.5监测仪:从硬件连接到可视化交互 在空气质量日益受到关注的今天,拥有一个实时监测PM2.5浓度的设备不仅能提升生活品质,还能为健康保驾护航。不同于市面上千篇一律的商用监测仪,自己动手打造一个兼具实用性和…...