从零学习大模型(七)-----LoRA(中)
自注意力层中的 LoRA 应用
Transformer 的自注意力机制是模型理解输入序列之间复杂关系的核心部分。自注意力层通常包含多个线性变换,包括键(Key)、查询(Query) 和 值(Value) 三个权重矩阵的线性映射,这些权重矩阵需要被训练来适应不同的任务。
LoRA 在自注意力层中的步骤:
冻结原始权重:将自注意力层中的键、查询、值矩阵的权重参数冻结,不进行训练。这些原始权重是通过预训练获得的,用于保留模型在大规模数据上学习到的通用知识。
添加低秩适配矩阵:对于键、查询、值矩阵的权重分别添加低秩适配矩阵 A A A 和 B B B。例如,如果键矩阵 W k ∈ R d k × d model W_k \in \mathbb{R}^{d_k \times d_{\text{model}}} Wk∈Rdk×dmodel,则引入两个适配矩阵:
A k ∈ R d k × r , B k ∈ R r × d model A_k \in \mathbb{R}^{d_k \times r}, \quad B_k \in \mathbb{R}^{r \times d_{\text{model}}} Ak∈Rdk×r,Bk∈Rr×dmodel
其中 r r r 是低秩值,通常比 d k d_k dk 和 d model d_{\text{model}} dmodel 小很多。
权重变换:在执行前向传播时,键、查询、值矩阵的线性变换将使用 W k ′ = W k + A k B k W_k' = W_k + A_k B_k Wk′=Wk+AkBk,作为新的权重矩阵,其中 W k W_k Wk 是冻结的原始权重, A k B k A_k B_k AkBk 是任务特定的微调部分。
训练适配矩阵:在特定任务的微调过程中,仅更新新增的适配矩阵 A k A_k Ak 和 B k B_k Bk 的参数,而不更新原始的键矩阵权重。这使得微调过程变得高效和轻量化。
前馈网络中的 LoRA 应用
Transformer 的前馈网络(FFN) 通常由两个线性层和一个非线性激活函数(如 ReLU)组成。在前馈网络中,参数数量非常庞大,因为它在每个位置上独立地对输入进行两次线性变换。
LoRA 在前馈网络中的步骤:
冻结全连接层权重:前馈网络中的两个全连接层的权重也被冻结。这些层用于对每个输入位置进行独立的变换,其预训练参数保持不变,用于保留预训练期间获得的通用知识。
添加适配矩阵:对于每个全连接层,添加低秩适配矩阵。例如,假设前馈网络中的第一层权重矩阵为 W f f ∈ R d f f × d model W_{ff} \in \mathbb{R}^{d_{ff} \times d_{\text{model}}} Wff∈Rdff×dmodel,则 LoRA 在这个矩阵上引入两个适配矩阵: A f f ∈ R d f f × r , B f f ∈ R r × d model A_{ff} \in \mathbb{R}^{d_{ff} \times r}, \quad B_{ff} \in \mathbb{R}^{r \times d_{\text{model}}} Aff∈Rdff×r,Bff∈Rr×dmodel
权重变换与前向传播:在前向传播中,使用权重矩阵 W f f ′ = W f f + A f f B f f W_{ff}' = W_{ff} + A_{ff} B_{ff} Wff′=Wff+AffBff来代替原始的全连接层权重矩阵。这样,前馈网络的线性变换就不仅包含原始的预训练权重,还包含适配矩阵的贡献,用于更好地适应特定任务。
训练适配矩阵:在训练过程中,仅更新适配矩阵 A f f A_{ff} Aff 和 B f f B_{ff} Bff,而冻结的原始权重 W f f W_{ff} Wff 不变,这大大减少了训练中需要更新的参数数量。
LoRA 在 Transformer 中的整体优势
参数效率:在自注意力层和前馈网络中引入低秩适配矩阵,可以大大减少需要训练的参数数量,同时保留模型在大规模数据上学到的知识。
任务适应性:LoRA 的方法允许对每个特定任务引入单独的低秩适配矩阵,而保持预训练模型的核心不变,这使得同一个预训练模型可以高效地适应多种任务。
降低计算与存储开销:由于适配矩阵的秩 rrr 通常远小于模型的维度,LoRA 显著降低了训练和存储的开销,使得大模型在计算资源受限的情况下能够得到有效应用。
代码实现LoRA微调Bert的过程
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel# 假设我们有一个预训练好的 Transformer 模型,例如 BERT
pretrained_model = BertModel.from_pretrained('bert-base-uncased')# 定义 LoRA 适配层
class LoRAAdapter(nn.Module):def __init__(self, input_dim, low_rank_dim):super(LoRAAdapter, self).__init__()self.A = nn.Linear(input_dim, low_rank_dim, bias=False) # 低秩矩阵 Aself.B = nn.Linear(low_rank_dim, input_dim, bias=False) # 低秩矩阵 Bdef forward(self, x):return self.B(self.A(x)) # 输出 B(A(x))# 定义一个新的 BERT 模型,包含 LoRA 适配器
class LoRABertModel(nn.Module):def __init__(self, pretrained_model, low_rank_dim):super(LoRABertModel, self).__init__()self.bert = pretrained_modelself.low_rank_dim = low_rank_dim# 为 BERT 的每一层添加 LoRA 适配器for layer in self.bert.encoder.layer:layer.attention.self.query_adapter = LoRAAdapter(layer.attention.self.query.in_features, low_rank_dim)layer.attention.self.key_adapter = LoRAAdapter(layer.attention.self.key.in_features, low_rank_dim)layer.attention.self.value_adapter = LoRAAdapter(layer.attention.self.value.in_features, low_rank_dim)# 冻结原始 BERT 模型的所有参数for param in self.bert.parameters():param.requires_grad = False# 只训练 LoRA 适配层for layer in self.bert.encoder.layer:for param in layer.attention.self.query_adapter.parameters():param.requires_grad = Truefor param in layer.attention.self.key_adapter.parameters():param.requires_grad = Truefor param in layer.attention.self.value_adapter.parameters():param.requires_grad = Truedef forward(self, input_ids, attention_mask=None, token_type_ids=None):# 对每一层应用 LoRA 适配器for layer in self.bert.encoder.layer:# 使用适配后的 Query、Key、Value 进行注意力计算query = layer.attention.self.query(input_ids) + layer.attention.self.query_adapter(input_ids)key = layer.attention.self.key(input_ids) + layer.attention.self.key_adapter(input_ids)value = layer.attention.self.value(input_ids) + layer.attention.self.value_adapter(input_ids)# 注意力计算attn_output, _ = layer.attention.self.attn(query, key, value)input_ids = attn_outputreturn self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 创建包含 LoRA 适配器的 BERT 模型
low_rank_dim = 16
lora_bert = LoRABertModel(pretrained_model, low_rank_dim)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, lora_bert.parameters()), lr=1e-4)# 生成一些随机数据用于训练
data = torch.randint(0, 30522, (8, 64)) # 随机输入数据,批量大小为 8,序列长度为 64
labels = torch.randint(0, 2, (8, 64)) # 随机标签# 训练循环
num_epochs = 5
for epoch in range(num_epochs):# 前向传播outputs = lora_bert(data).last_hidden_statelogits = outputs.view(-1, outputs.size(-1))labels = labels.view(-1)# 计算损失loss = criterion(logits, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 打印损失值print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
利用LoRA对LLAMA2进行微调的代码
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset# 1. 加载预训练的 LLaMA 2 7B 模型
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)# 2. 手动添加 LoRA 适配器
class LoRAAdapter(nn.Module):def __init__(self, input_dim, low_rank_dim):super(LoRAAdapter, self).__init__()self.A = nn.Linear(input_dim, low_rank_dim, bias=False) # 低秩矩阵 Aself.B = nn.Linear(low_rank_dim, input_dim, bias=False) # 低秩矩阵 Bdef forward(self, x):return self.B(self.A(x)) # 输出 B(A(x))# 添加 LoRA 适配器到模型的注意力层
low_rank_dim = 16
for name, module in model.named_modules():if "attn" in name.lower() and hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):# 为 Query, Key, Value 添加 LoRA 适配器input_dim = module.q_proj.in_featuresmodule.q_proj_lora = LoRAAdapter(input_dim, low_rank_dim)module.k_proj_lora = LoRAAdapter(input_dim, low_rank_dim)module.v_proj_lora = LoRAAdapter(input_dim, low_rank_dim)# 使用 hooks 修改前向传播逻辑def hook_fn_forward_q(module, input, output):return output + module.q_proj_lora(input[0])def hook_fn_forward_k(module, input, output):return output + module.k_proj_lora(input[0])def hook_fn_forward_v(module, input, output):return output + module.v_proj_lora(input[0])module.q_proj.register_forward_hook(hook_fn_forward_q)module.k_proj.register_forward_hook(hook_fn_forward_k)module.v_proj.register_forward_hook(hook_fn_forward_v)# 3. 定义训练参数和优化器
training_args = TrainingArguments(output_dir="./llama2_7b_lora_finetuned",num_train_epochs=3,per_device_train_batch_size=4,gradient_accumulation_steps=8,learning_rate=1e-4,logging_dir="./logs",logging_steps=10,save_total_limit=2,fp16=True # 采用混合精度加速训练
)optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)# 4. 加载训练数据集
train_dataset = load_dataset("daily_dialog", split="train")# 5. 定义 Trainer 对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,tokenizer=tokenizer,optimizers=(optimizer, None)
)# 6. 开始训练
trainer.train()# 7. 验证和保存模型
trainer.evaluate()
model.save_pretrained("./llama2_7b_lora_finetuned")相关文章:
从零学习大模型(七)-----LoRA(中)
自注意力层中的 LoRA 应用 Transformer 的自注意力机制是模型理解输入序列之间复杂关系的核心部分。自注意力层通常包含多个线性变换,包括键(Key)、查询(Query) 和 值(Value) 三个权重矩阵的线…...
Java知识巩固(十二)
I/O JavaIO流了解吗? IO 即 Input/Output,输入和输出。数据输入到计算机内存的过程即输入,反之输出到外部存储(比如数据库,文件,远程主机)的过程即输出。数据传输过程类似于水流,因…...
一家光伏企业终止,恐不具行业代表性,市占率仅为2.35%
海达光能终止原因如下:报告期内海达光能销售金额较所在行业第二名亚玛顿相差两倍以上,公司毛利率更是远低于行业龙头福莱特,恐难以说明公司行业代表性。在企业竞争上,公司2021年度的市场占有率约为2.35%,公司未来光伏玻…...
企业计算机监控软件是什么?6款电脑监控软件分享!提升企业管理效率,吐血推荐!
嘿,各位企业管理者和IT小伙伴们! 您是否曾担忧员工在工作时间内效率低下?是否对公司的数据安全感到不安? 别担心,今天我们就来聊聊企业计算机监控软件,它就像是企业的"超级侦探",帮…...
VisionPro —— CogOCRMaxTool工具详解
CogOCRMaxTool的作用: CogOCRMaxTool:是一个字符识别工具,主要用于字符识别,它能够根据已训练的字符样本读取灰度图像中的字符,并返回读取结果。 一:工具位置 二:添加图片 三:工具的初始页面 将识别框拖到需要识别处…...
网站安全问题都有哪些,分别详细说明
网站安全问题涉及多个方面,以下是一些常见的网站安全问题及其详细说明: 数据泄露 问题描述:数据泄露是指网站存储的用户敏感信息(如用户名、密码、信用卡信息等)被非法获取。黑客可能通过SQL注入、XSS攻击等手段窃取这…...
DiskGenius一键修复磁盘损坏
下午外接磁盘和U盘都出现扇区损坏,估计就是在开着电脑,可能是电脑运行的软件还在对磁盘进行读写,不小心按到笔记本关机键,重新开机读写磁盘分区变得异常卡顿,估摸就是这个原因导致扇区损坏。在进行读写时,整…...
Matlab实现鼠群优化算法优化回声状态网络模型 (ROS-ESN)(附源码)
目录 1.内容介绍 2.部分代码 3.实验结果 4.内容获取 1内容介绍 鼠群优化算法(Rat Swarm Optimization, ROS)是一种基于老鼠群体行为的群体智能优化算法。ROS通过模拟老鼠在寻找食物时的聚集、分散和跟随行为,来探索解空间并寻找最优解。该算…...
nfs作业
一、作业要求 1、开放/nfs/shared目录,供所有用户查询资料 2、开放/nfs/upload目录,为192.168.xxx.0/24网段主机可以上传目录, 并将所有用户及所属的组映射为nfs-upload,其UID和GID均为210 3、将/home/tom目录仅共享给192.168.xxx.xxx这台…...
Linux 基础io_理解文件系统_软硬链接_动静态库
一.磁盘 1.磁盘物理结构 盘片 磁盘可以有多个磁片,每个磁片有两个盘面,每个盘面都对应一个磁头,都可以存储数据。 磁道 扇区 磁道是指在盘面上,由磁头读写的数据环形轨道。每个磁道都是由一圈圈的圆形区域组成,数据…...
大语言模型参数传递、model 构建与tokenizer构建(基于llama3模型)
文章目录 前言一、传递参数构建1、构建模型参数2、构建数据参数3、构建训练参数4、类似parse方式解析数据、模型、训练参数五、构建tokenizer与model1、tokenizer与model调用代码2、tokenizer实现2、model实现前言 上一篇说到huggingface的参数传递理论方法,本篇文章应用与ll…...
使用 `screen` + `nohup` 实现高效日志记录和多环境任务管理
使用 screen nohup 实现高效日志记录和多环境任务管理 在深度学习模型训练中,特别是在服务器上运行长时间的任务时,有效的任务管理和日志记录至关重要。我们通常需要在后台运行多个任务,同时为每个任务配置不同的 conda 环境。通过结合使用…...
【探索数字孪生,引领未来技术】
在数字化浪潮的推动下,数字孪生技术正成为连接虚拟与现实的桥梁,它不仅是工业互联网的基石,更是智慧城市、智慧园区、智慧楼宇以及元宇宙构建的核心。为了帮助更多专业人士掌握这一前沿技术,我们荣幸地宣布,“新质技术…...
Tcp_Sever(线程池版本的 TCP 服务器)
Tcp_Sever(线程池版本的 TCP 服务器) 前言1. 功能介绍及展示1.1 服务端连接1.2 客户端连接(可多个用户同时在线连接服务端)1.3 功能服务1.3.1 defaultService(默认服务)1.3.2 transform(大小写转…...
第十一章 Vue生命周期及生命周期的四个阶段
目录 一、引言 1.1. Vue生命周期的具体阶段 1.2. 每个阶段的具体作用和常用场景 1.3. 生命周期钩子函数 二、代码示例 三、运行效果 一、引言 Vue生命周期是指Vue组件实例从创建到销毁的整个过程。在这个过程中,组件经历了一系列的阶段,每个阶段…...
展厅展会客流显示屏的客流统计功能如何实现
随着科技的发展,展厅和展会的管理越来越智能化。客流显示屏作为一种高效的管理工具,能够实时显示参观人数,帮助主办方更好地了解客流情况,优化资源配置。本文将详细介绍展厅展会客流显示屏的客流统计功能如何实现,分为…...
golang正则表达式的使用及举例
正则表达式很强大,在一些场合如抓包,爬虫等方面很有用。在 Go语言中,正则表达式通过标准库 regexp 提供支持。使用正则表达式可以进行字符串匹配、替换和分割等操作。 以下是正则表达式的基本使用方法及示例: 1. 导入 regexp 包 …...
Flutter杂学: iOS 上启用自动填充和关联域
下面是详细的配置和代码,以确保在 iOS 上启用自动填充和关联域(Associated Domains)功能。 配置步骤 1. 在 Apple Developer 控制台中启用 Associated Domains 登录 Apple Developer。导航至您的 App ID 设置页面。找到您要配置的 App ID&…...
接口自动化-框架搭建(Python+request+pytest+allure)
使用代码如何开展接口自动化测试。 一 选择自动化测试用例 业务流程优先,单接口靠后,功能稳定优先,变更频繁不选。 二 搭建自动化测试环境 (1)安装python编译器3.7版本以上--自行安装 (2)安…...
[论文阅读]Constrained Decision Transformer for Offline Safe Reinforcement Learning
Constrained Decision Transformer for Offline Safe Reinforcement Learning Proceedings of the 40th International Conference on Machine Learning (ICML), July 23-29, 2023 https://arxiv.org/abs/2302.07351 泛读只需要了解其核心思想即可。 安全强化学习(Safe Rei…...
多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...
docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...
Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...
SCAU期末笔记 - 数据分析与数据挖掘题库解析
这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...
select、poll、epoll 与 Reactor 模式
在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。 一、I…...
【Go语言基础【12】】指针:声明、取地址、解引用
文章目录 零、概述:指针 vs. 引用(类比其他语言)一、指针基础概念二、指针声明与初始化三、指针操作符1. &:取地址(拿到内存地址)2. *:解引用(拿到值) 四、空指针&am…...
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...
C语言中提供的第三方库之哈希表实现
一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...
SQL Server 触发器调用存储过程实现发送 HTTP 请求
文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...
GraphQL 实战篇:Apollo Client 配置与缓存
GraphQL 实战篇:Apollo Client 配置与缓存 上一篇:GraphQL 入门篇:基础查询语法 依旧和上一篇的笔记一样,主实操,没啥过多的细节讲解,代码具体在: https://github.com/GoldenaArcher/graphql…...
