当前位置: 首页 > article >正文

用PyTorch手把手教你实现LoRA:从Linear到ConvLoRA的完整代码解析

用PyTorch手把手教你实现LoRA从Linear到ConvLoRA的完整代码解析在深度学习模型微调领域LoRALow-Rank Adaptation技术正逐渐成为资源敏感型场景下的首选方案。不同于传统微调需要更新整个庞大模型的参数LoRA通过引入轻量级的低秩矩阵来捕获任务特定的知识既保留了预训练模型的核心能力又大幅降低了计算开销。本文将带您从零实现LoRA的核心组件涵盖全连接层到卷积层的完整适配过程。1. LoRA技术原理与实现基础LoRA的核心思想建立在矩阵低秩分解的数学基础上。假设原始权重矩阵W∈R^(d×k)其更新量ΔW可以分解为两个小矩阵的乘积ΔWBA其中B∈R^(d×r)A∈R^(r×k)且秩r≪min(d,k)。这种分解使得参数量从d×k减少到r×(dk)当r8时通常可减少98%以上的可训练参数。在PyTorch中实现基础LoRA层需要解决三个关键问题参数冻结保持原始权重不可训练低秩适配构建可训练的A/B矩阵权重合并训练/推理模式的切换逻辑让我们先看一个最简单的Linear层LoRA实现框架import torch import torch.nn as nn import torch.nn.functional as F class LoRA_Linear(nn.Module): def __init__(self, in_features, out_features, rank8): super().__init__() # 原始线性层参数冻结 self.linear nn.Linear(in_features, out_features) self.linear.weight.requires_grad False # 低秩适配矩阵 self.lora_A nn.Parameter(torch.zeros(rank, in_features)) self.lora_B nn.Parameter(torch.zeros(out_features, rank)) # 初始化策略 nn.init.kaiming_uniform_(self.lora_A, amath.sqrt(5)) nn.init.zeros_(self.lora_B) self.rank rank self.scaling 1.0 / rank # 缩放因子 self.merged False # 权重合并状态标志这个基础框架已经包含了LoRA的核心组件。在实际应用中我们还需要实现训练/推理模式切换时的权重合并与分离逻辑这是LoRA能够无缝集成到现有模型中的关键。2. 完整Linear层LoRA实现扩展基础框架我们需要完善以下功能训练/推理模式的自动切换Dropout正则化支持权重合并与分离的数学正确性前向传播的完整计算流程下面是完整的Linear层LoRA实现class LoRA_Linear(nn.Linear): def __init__(self, in_features, out_features, rank8, lora_alpha1.0, lora_dropout0.0, **kwargs): nn.Linear.__init__(self, in_features, out_features, **kwargs) # LoRA配置参数 self.rank rank self.lora_alpha lora_alpha self.scaling lora_alpha / rank # 正则化设置 if lora_dropout 0.: self.lora_dropout nn.Dropout(plora_dropout) else: self.lora_dropout lambda x: x # 冻结原始权重 self.weight.requires_grad False # 初始化低秩矩阵 self.lora_A nn.Parameter(torch.zeros(rank, in_features)) self.lora_B nn.Parameter(torch.zeros(out_features, rank)) self.reset_parameters() self.merged False def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, lora_A): nn.init.kaiming_uniform_(self.lora_A, amath.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, modeTrue): nn.Linear.train(self, mode) if mode: if self.merged: # 从合并权重中分离 self.weight.data - (self.lora_B self.lora_A) * self.scaling self.merged False else: if not self.merged: # 合并到原始权重 self.weight.data (self.lora_B self.lora_A) * self.scaling self.merged True def forward(self, x): if not self.merged: # 原始线性变换 result F.linear(x, self.weight, self.bias) # LoRA分支 lora_output (self.lora_dropout(x) self.lora_A.T self.lora_B.T) * self.scaling return result lora_output else: return F.linear(x, self.weight, self.bias)这个实现完整展示了LoRA在Linear层的应用关键点包括权重合并机制在eval模式下自动合并参数保持推理效率梯度隔离原始权重始终冻结仅训练低秩矩阵数值稳定性通过scaling因子控制更新幅度实际使用时只需将模型中的nn.Linear替换为我们的LoRA_Linear即可# 传统线性层 # layer nn.Linear(1024, 1024) # LoRA版本 layer LoRA_Linear(1024, 1024, rank8)3. ConvLoRA卷积层的低秩适配将LoRA思想扩展到卷积层面临新的挑战。卷积核是4D张量out_channels, in_channels, kH, kW直接应用低秩分解需要考虑空间维度。ConvLoRA的解决方案是将卷积核视为二维矩阵out_channels, in_channels×kH×kW然后应用类似的低秩分解。以下是Conv2d层的LoRA实现class LoRA_Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, rank8, lora_alpha1.0, **kwargs): nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) # 参数设置 self.rank rank self.lora_alpha lora_alpha self.scaling lora_alpha / rank # 计算展开后的维度 self.kernel_size kernel_size if isinstance(kernel_size, tuple) \ else (kernel_size, kernel_size) self.unfold_dim in_channels * self.kernel_size[0] * self.kernel_size[1] # 初始化低秩矩阵 self.lora_A nn.Parameter( torch.zeros(rank * self.kernel_size[0], self.unfold_dim) ) self.lora_B nn.Parameter( torch.zeros(out_channels, rank * self.kernel_size[0]) ) self.reset_parameters() # 冻结原始权重 self.weight.requires_grad False self.merged False def reset_parameters(self): nn.Conv2d.reset_parameters(self) if hasattr(self, lora_A): nn.init.kaiming_uniform_(self.lora_A, amath.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, modeTrue): nn.Conv2d.train(self, mode) if mode: if self.merged: # 分离低秩更新 delta_w (self.lora_B self.lora_A).view(self.weight.shape) self.weight.data - delta_w * self.scaling self.merged False else: if not self.merged: # 合并更新到权重 delta_w (self.lora_B self.lora_A).view(self.weight.shape) self.weight.data delta_w * self.scaling self.merged True def forward(self, x): if not self.merged: # 计算低秩更新 delta_w (self.lora_B self.lora_A).view(self.weight.shape) effective_weight self.weight delta_w * self.scaling return F.conv2d( x, effective_weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: return super().forward(x)ConvLoRA的实现有几个关键技术点张量展开将4D卷积核展开为2D矩阵进行处理空间维度保留在低秩分解中保持kernel的空间结构权重视图转换确保合并后的权重恢复原始形状使用方式与Linear层类似# 传统卷积层 # conv nn.Conv2d(3, 64, kernel_size3) # LoRA版本 conv LoRA_Conv2d(3, 64, kernel_size3, rank8)4. 实战将LoRA集成到Transformer模型让我们以常见的Transformer架构为例展示如何将LoRA应用到实际模型中。我们将修改一个标准的BERT模型将其中的关键线性层替换为LoRA版本。首先定义LoRA化的MLP模块class LoRA_MLP(nn.Module): def __init__(self, hidden_size, intermediate_size, rank8): super().__init__() self.dense_in LoRA_Linear(hidden_size, intermediate_size, rankrank) self.dense_out LoRA_Linear(intermediate_size, hidden_size, rankrank) self.activation nn.GELU() def forward(self, x): x self.dense_in(x) x self.activation(x) return self.dense_out(x)然后实现LoRA化的Attention层class LoRA_Attention(nn.Module): def __init__(self, hidden_size, num_heads, rank8): super().__init__() self.num_heads num_heads self.head_dim hidden_size // num_heads # 使用MergedLinear处理qkv投影 self.qkv LoRA_Linear( hidden_size, 3 * hidden_size, rankrank ) self.proj LoRA_Linear(hidden_size, hidden_size, rankrank) def forward(self, x): B, L, D x.shape # qkv投影 qkv self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim) q, k, v qkv.unbind(2) # 注意力计算 attn (q k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn attn.softmax(dim-1) # 输出投影 out (attn v).transpose(1, 2).reshape(B, L, D) return self.proj(out)最后组装完整的Transformer Blockclass LoRA_TransformerBlock(nn.Module): def __init__(self, hidden_size, num_heads, intermediate_size, rank8): super().__init__() self.attention LoRA_Attention(hidden_size, num_heads, rank) self.mlp LoRA_MLP(hidden_size, intermediate_size, rank) self.norm1 nn.LayerNorm(hidden_size) self.norm2 nn.LayerNorm(hidden_size) def forward(self, x): # 注意力分支 attn_out self.attention(self.norm1(x)) x x attn_out # MLP分支 mlp_out self.mlp(self.norm2(x)) return x mlp_out在实际应用中我们可以选择性地只对部分层进行LoRA化。例如在大型语言模型中通常只对注意力机制的投影矩阵应用LoRAdef convert_model_to_lora(model, rank8): for name, module in model.named_children(): if isinstance(module, nn.Linear): # 替换特定的线性层 if query in name or key in name or value in name: new_module LoRA_Linear( module.in_features, module.out_features, rankrank ) new_module.load_state_dict(module.state_dict(), strictFalse) setattr(model, name, new_module) else: convert_model_to_lora(module, rank)这种选择性转换可以在保持性能的同时最大化参数效率。实验表明仅对注意力层的QKV投影应用LoRArank8就能达到全参数微调90%以上的效果而可训练参数通常不到原模型的0.5%。5. 训练技巧与最佳实践成功应用LoRA需要一些实践技巧以下是我们在多个项目中总结的经验1. 秩的选择策略一般从rank8开始尝试对于关键层如注意力输出投影可适当增加使用以下公式作为初始估计rank min(64, max(4, int(0.01 * min(d_in, d_out))))2. 初始化方法对比初始化方案适用场景优点缺点KaimingA/B默认选择稳定收敛需要适当缩放全零初始化B保守微调初始状态等同原模型早期学习较慢正交初始化低秩约束强保持矩阵性质计算开销略大3. 学习率设置通常比全参数微调大5-10倍推荐使用分层学习率optimizer AdamW([ {params: model.lora_A.parameters(), lr: 5e-4}, {params: model.lora_B.parameters(), lr: 1e-3}, {params: other_params, lr: 1e-5} ])4. 混合精度训练LoRA特别适合与AMP自动混合精度配合使用scaler GradScaler() with autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5. 参数保存与加载LoRA模型的保存需要特殊处理# 保存原始模型参数可选 torch.save(model.state_dict(), base_model.pth) # 仅保存LoRA参数 lora_params {n: p for n, p in model.named_parameters() if lora_ in n} torch.save(lora_params, lora_params.pth) # 加载时先加载基础模型再加载LoRA参数 model.load_state_dict(torch.load(base_model.pth), strictFalse) model.load_state_dict(torch.load(lora_params.pth), strictFalse)6. 梯度检查点对于极大模型可以结合梯度检查点技术from torch.utils.checkpoint import checkpoint class LoRA_TransformerBlock(nn.Module): def forward(self, x): return checkpoint(self._forward, x) def _forward(self, x): # 原来的前向计算 ...在实际项目中我们发现这些技巧的组合使用可以使LoRA的训练效率提升2-3倍同时保持模型性能。特别是在资源受限的场景下合理配置的LoRA方案往往能够达到与全参数微调相当的效果。

相关文章:

用PyTorch手把手教你实现LoRA:从Linear到ConvLoRA的完整代码解析

用PyTorch手把手教你实现LoRA:从Linear到ConvLoRA的完整代码解析 在深度学习模型微调领域,LoRA(Low-Rank Adaptation)技术正逐渐成为资源敏感型场景下的首选方案。不同于传统微调需要更新整个庞大模型的参数,LoRA通过引…...

Android Studio 升级后编译报错?手把手教你解决 minCompileSdk 版本冲突(以 appcompat 1.4.1 为例)

Android Studio升级后的minCompileSdk版本冲突全解析:从快速定位到长效预防 每次Android Studio或Gradle插件升级后,总有些"惊喜"等着我们。最近不少开发者反馈,项目在毫无改动的情况下突然编译失败,报出令人困惑的minC…...

从工行笔试到录用:一份‘科技菁英’岗的完整备考清单与时间线复盘(2022版)

从工行笔试到录用:一份‘科技菁英’岗的完整备考清单与时间线复盘(2022版) 银行科技岗的竞争向来激烈,尤其是工商银行这类国有大行的"科技菁英"计划,每年吸引数以万计的计算机相关专业学子投递。作为2022年成…...

别再重复造轮子了!Power Apps组件库保姆级教程,从创建到团队共享一次搞定

Power Apps组件库实战指南:从零构建到团队高效协作 在多人协作的Power Apps开发项目中,你是否遇到过这样的困扰:每个页面都需要重复设计相同的导航栏,当UI风格调整时不得不逐个修改几十个页面;团队成员各自开发的按钮样…...

Mac本地运行多模态大模型:mlx-vlm环境搭建与性能优化指南

1. 项目概述:在Mac上本地运行多模态大模型的利器如果你是一名Mac用户,同时又对当前火热的视觉语言大模型(VLM)感兴趣,那么你很可能面临一个尴尬的局面:网上那些炫酷的图片理解、视频分析、多轮对话演示&…...

避坑指南:微调chinese-roberta-wwm-ext做情感分析时,这5个参数调优细节千万别忽略

微调chinese-roberta-wwm-ext进行情感分析的五大调优实战技巧 当你第一次成功运行chinese-roberta-wwm-ext模型进行情感分析时,那种成就感确实令人振奋。但很快你会发现,从"能跑通"到"效果好"之间,还有一条充满陷阱的调优…...

考研数学救命稻草:一阶和二阶微分方程的通解公式,我帮你整理好了(附880/660真题解法)

考研数学微分方程通关手册:从公式推导到880/660真题实战拆解 微分方程作为考研数学(数一/数二/数三)的必考核心章节,每年在真题中至少占据10-15分权重。但面对纷繁复杂的方程类型和变化多端的题目条件,许多考生常陷入&…...

为Alexa注入ChatGPT灵魂:智能语音助手开发实战指南

1. 项目概述:为你的Alexa注入ChatGPT的灵魂 如果你和我一样,家里摆着个Alexa智能音箱,除了让它定个闹钟、播个天气,总觉得它那点“智能”有点不够看。官方技能商店里的东西要么是收费的,要么功能死板,想让…...

AI编码助手安全技能集成:vt、gakido等工具实战指南

1. 项目概述:为AI编码助手注入安全测试“超能力” 如果你是一名安全研究员、渗透测试工程师,或者正在学习网络安全,那么你肯定对“Happy Hacking Space”这个开源安全工具集不陌生。他们推出的工具,比如一键部署漏洞靶场的 vt …...

Obsidian BMO Chatbot:在笔记软件中集成AI助手的配置与实战指南

1. 项目概述:在笔记软件里塞进一个AI大脑如果你和我一样,是个重度Obsidian用户,同时又对各种大语言模型(LLM)爱不释手,那你肯定也经历过这种“精神分裂”般的体验:一边在Obsidian里奋笔疾书记录…...

【前端(十三)】JavaScript 数组与字符串笔记

文章目录JavaScript 数组与字符串笔记一、数组(Array)1.1 定义1.2 特点1.3 查询与索引访问1.4 修改与赋值1.5 length 属性与 empty1.6 删除元素1.7 常用方法精讲📌 添加元素📌 截取与合并📌 查找元素📌 遍历…...

【边缘AI场景Docker调优白皮书】:基于Raspberry Pi 5/JeVois-Bin/NVIDIA Jetson实测数据的12项关键参数配置清单

更多请点击: https://intelliparadigm.com 第一章:边缘AI场景下Docker容器化部署的独特挑战 在资源受限、网络不稳、硬件异构的边缘设备上运行AI推理服务,Docker虽提供标准化封装能力,却暴露出一系列深层矛盾。传统云原生容器设计…...

PX4 Autopilot系统调用架构:从实时通信到智能控制的深度解析

PX4 Autopilot系统调用架构:从实时通信到智能控制的深度解析 【免费下载链接】PX4-Autopilot PX4 Autopilot Software 项目地址: https://gitcode.com/gh_mirrors/px/PX4-Autopilot 在无人机开发领域,开发人员常常面临一个核心挑战:如…...

MXFP4量化技术提升LLM推理性能与精度

1. 项目背景与核心价值在大型语言模型(LLM)部署的实际场景中,模型量化技术一直是平衡计算资源消耗与推理性能的关键手段。传统FP4(4位浮点)量化虽然能显著减少模型体积,但在处理复杂语义任务时经常出现精度…...

别再死记硬背了!用Multisim仿真带你直观理解运放负反馈的三大魔法(增益、带宽、阻抗)

别再死记硬背了!用Multisim仿真带你直观理解运放负反馈的三大魔法(增益、带宽、阻抗) 第一次接触运算放大器负反馈时,我盯着课本上那些晦涩的公式和抽象的理论推导,感觉就像在看天书。"增益灵敏度降低"、&qu…...

程序化噪声在游戏开发中的应用:从Perlin到Shader实战

1. 项目概述:当游戏世界开始“呼吸”如果你是一位游戏开发者,或者对计算机图形学有浓厚兴趣,那么“噪声”这个词对你来说一定不陌生。它绝不仅仅是屏幕上恼人的雪花点,恰恰相反,它是构建数字世界“生命力”与“真实感”…...

从实践中提炼的架构设计与工程规范

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》、《前端求职突破计划》 🍚 蓝桥云课签约作者、…...

告别Diskpart恐惧症:手把手教你用命令行安全合并U盘分区(附完整命令清单)

命令行艺术:彻底掌握Diskpart合并U盘分区的底层逻辑 你是否遇到过这样的场景——插入U盘后系统提示需要格式化,打开磁盘管理工具却发现原本单一的存储空间被分割成多个陌生分区?这种"分区幽灵"现象往往让普通用户手足无措&#xff…...

从Vaadin 14到Vaadin 24的迁移:解决内存泄漏问题

引言 在现代Web应用开发中,迁移到新的版本是常见的需求。最近,我们将一个基于Spring Boot的Vaadin应用从版本14升级到了版本24,同时也保留了之前使用的Keycloak和OAuth2登录功能。然而,在这个迁移过程中,我们遇到了一个令人头疼的问题——内存泄漏。特别是在应用程序启动…...

3分钟快速上手:DamaiHelper大麦网抢票脚本完整指南

3分钟快速上手:DamaiHelper大麦网抢票脚本完整指南 【免费下载链接】DamaiHelper 大麦网演唱会演出抢票脚本。 项目地址: https://gitcode.com/gh_mirrors/dama/DamaiHelper 想要告别演唱会陪跑,轻松抢到心仪的门票吗?DamaiHelper大麦…...

终极PC多人游戏解决方案:Nucleus Co-Op分屏工具完全指南

终极PC多人游戏解决方案:Nucleus Co-Op分屏工具完全指南 【免费下载链接】nucleuscoop Starts multiple instances of a game for split-screen multiplayer gaming! 项目地址: https://gitcode.com/gh_mirrors/nu/nucleuscoop 你是否曾梦想过与好友在同一台…...

如何在 MATLAB 中调用 Taotoken 聚合的大模型 API 接口

如何在 MATLAB 中调用 Taotoken 聚合的大模型 API 接口 1. 准备工作 在 MATLAB 中调用 Taotoken 的大模型 API 接口前,需要确保具备以下条件: 有效的 Taotoken API Key,可在 Taotoken 控制台中创建。目标模型 ID,可在 Taotoken…...

解决iOS Safari上的SVG动画问题

引言 在移动设备上实现交互式SVG动画时,常常会遇到一些特定的挑战,尤其是对于iOS的Safari浏览器。本文将探讨如何解决在iOS Safari中SVG元素点击时无法触发淡入动画的问题,并提供一个实用的JavaScript解决方案。 背景介绍 最近我遇到一个问题,当在iOS Safari中点击SVG元…...

2025终极解决方案:八大网盘直链下载助手完整使用指南

2025终极解决方案:八大网盘直链下载助手完整使用指南 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼云…...

深度解构:如何基于PX4-Autopilot构建高性能无人机控制系统

深度解构:如何基于PX4-Autopilot构建高性能无人机控制系统 【免费下载链接】PX4-Autopilot PX4 Autopilot Software 项目地址: https://gitcode.com/gh_mirrors/px/PX4-Autopilot 在无人机系统开发中,实时性、可靠性和扩展性一直是开发团队面临的…...

基于容器与Seccomp的代码沙盒安全实践:以dify-sandbox为例

1. 项目概述:构建一个安全的代码沙盒环境在构建一个多租户的AI应用平台或在线代码评测系统时,一个核心且棘手的问题是如何安全地执行用户提交的、不可信的代码。直接在生产服务器上运行这些代码无异于敞开大门,恶意代码可以轻易地耗尽系统资源…...

开发者如何利用 Taotoken 快速切换模型以应对不同场景需求

开发者如何利用 Taotoken 快速切换模型以应对不同场景需求 1. 多模型统一接入的价值 在构建多功能 AI 应用时,开发者常面临模型选型与接入的复杂性。不同场景对模型能力的需求各异:对话交互可能需要更强的上下文理解,代码生成需要编程语言的…...

初次使用 Taotoken 模型广场进行模型选型与对比的体验

初次使用 Taotoken 模型广场进行模型选型与对比的体验 1. 模型广场概览 登录 Taotoken 控制台后,左侧导航栏的"模型广场"入口非常醒目。页面加载后,首先看到的是按热门程度排序的模型列表,每个卡片展示了模型名称、提供商、简要描…...

正点原子IMX6ULL SR04模块+Qt使用

本篇文章用于记录在使用正点原子开发板进行自主开发时使用SR04模块完成倒车雷达辅助功能遇到的问题及延伸问题,文章重点在于记录!问题还待解决问题背景:想要实现sr04的模块驱动且配合Qt应用程序完成倒车雷达辅助功能但是在过程中发现 1.当前系…...

保姆级避坑指南:用PX4 v1.12.3 + Gazebo搞定Offboard模式,解决‘Vehicle armed’失败问题

PX4 v1.12.3与Gazebo仿真环境深度调优:从Offboard模式解锁到轨迹飞行的全流程实战 去年夏天,当我第一次尝试用PX4的Offboard模式控制Gazebo中的无人机时,遇到了一个令人抓狂的问题——终端不断显示"Offboard enabled",但…...