【Fine-Tuning】大模型微调理论及方法, PytorchHuggingFace微调实战
Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战
文章目录
- Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战
- 1. 什么是微调
- (1) 为什么要进行微调
- (2) 经典简单例子:情感分析
- 任务
- 背景
- 微调
- (3) 为什么微调work, 理论解释下
- 2. 详细介绍微调的流程
- (1) 准备数据, 预处理
- (2) 微调策略
- **前三种都差不多的逻辑, 古早**
- 1. 冻结, 逐层微调
- 2. 部分参数微调
- 3. 全参数微调
- 4. LoRA(低秩适应)
- 5. Prompt Tuning
- 6. RLHF(基于人类反馈的强化学习)
- 7. Prefix Tuning
- 8. Adapter微调
- (3) 设置微调超参数
- (4) 训练, 评估
- 3. 具体怎么做
- 常用的微调框架
- HuggingFace版
- Pytorch版
- Pytorch vs HuggingFace
- 易用性:
- 灵活性
- 性能
1. 什么是微调
大模型微调是指在预训练的大型模型基础上,使用特定数据集进行进一步训练,以适应特定任务或领域。
(1) 为什么要进行微调
- 大模型虽然知识丰富(由于其极大批量的预训练任务),但在特定领域可能不够准确。微调能让模型更好地理解特定任务。
- 相比从头开始训练一个新模型,微调节省了大量时间和计算资源(站在前人的肩膀上), 只需少量的数据就能有效提升模型在特定领域的性能。
(2) 经典简单例子:情感分析
任务
训练一个情感分析模型
背景
硬件很烂, 不可能从头训练一个情感分析大模型
但已经有预训练的语言模型比如BERT,已经在大量文本上进行过训练(这叫预训练)。
微调
BERT本身没有直接判断情感的能力, 但由于其在大量文本上进行的预训练任务, 其具有很多自然语言领域的 知识(预训练的权重), 通过少量的情感分析数据, 和合适的微调策略, 就能低成本的(数据, 算力)来微调出一个能进行情感分析的BERT
(3) 为什么微调work, 理论解释下
- 迁移学习: 深度学习模型有分层学习特征的特点, 底层学习通用特征, 高层学习任务相关特征, 将通用特征的知识迁移到相关的特定领域, 合理
- 统计学: 预训练可以看作为参数分布的先验估计, 微调就是在已有先验知识的基础上结合新数据
2. 详细介绍微调的流程
(1) 准备数据, 预处理
首先收集数据, 分成训练验证测试, 老生常谈, 都2024年了就不多说了
预处理: 每种大模型都有特定的输入格式, 要把原始数据转换成预训练大模型认识的数据输入
(2) 微调策略
策略有很多, 也有很多新冒出来的策略, 说一些常见的
前三种都差不多的逻辑, 古早
1. 冻结, 逐层微调
冻结就是权重固定, 不会再反向传播调整了
在这种策略中,模型的一部分参数被冻结,仅对特定层进行微调。逐层解冻的方法允许从顶层开始逐步释放冻结状态,以平衡预训练知识与新任务学习之间的关系.
2. 部分参数微调
和逐层微调本质上类似, 仅选择性地更新模型中的某些权重,通常是顶层或最后几层,而保持底层的大部分权重不变(冻结).
3. 全参数微调
全部参数都会反向传播, 这种方法资源消耗很大, 对数据要求也很高, 而且容易导致灾难性遗忘
灾难性遗忘(Catastrophic Forgetting): 微调模型在学习新任务时,突然或彻底忘记其预训练所学到的知识
4. LoRA(低秩适应)
LoRA通过在模型的每一层引入可训练的低秩矩阵来进行微调, 自适应的调整部分参数.
5. Prompt Tuning
轻量级的微调方法,不改变模型的主参数(全部冻结),通过为特定任务设计可学习的提示(prompt)来引导模型生成期望的输出。
6. RLHF(基于人类反馈的强化学习)
利用人类的反馈来纠正模型, 生成符合期望的结果
7. Prefix Tuning
在输入的前面前拼一些可训练的参数,使得模型在处理任务时能够更好地理解输入意图
8. Adapter微调
模型层之间插入小型可训练模块,这些模块可以适应新任务,而不影响原始模型的参数
(3) 设置微调超参数
设置/调整 学习率, BatchSize等参数, 让模型能收敛和防止拟合不好, 后面介绍
(4) 训练, 评估
用现成的框架训练, 验证, 测试, 后面介绍
3. 具体怎么做
由于深度学习技术的不断成熟, 各种稳定易用的框架逐渐出现, 让微调过程仅需要少许代码就能实现, 下面看看例子
常用的微调框架
-
Hugging Face Transformer
-
Pytorch
HuggingFace版
用HuggingFace对GraphCodeBERT进行微调
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments# 加载预训练模型和tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
model = RobertaForSequenceClassification.from_pretrained("microsoft/graphcodebert-base")# 准备数据, 数据的预处理一般比较复杂
train_data = [...] # 训练数据
train_encodings = tokenizer(train_data, truncation=True, padding=True)# 定义训练参数
training_args = TrainingArguments(output_dir='./results',num_train_epochs=3,per_device_train_batch_size=16,save_steps=10_000,save_total_limit=2,
)# 创建Trainer实例并开始训练
trainer = Trainer(model=model,args=training_args,train_dataset=train_encodings,
)trainer.train()
实际肯定不止这么简单, 细节比较多, 比如数据的预处理, 和自定义的训练和评估.
由于各种下游任务的多样性, 不同任务的数据/标签差异非常大,这里没办法根据每种任务详细介绍预处理流程, 故在此略过. 我们一般需要自己写很多数据预处理的代码, 构造数据, 使得预训练模型能够接受数据输入.
再或是训练和评估, 由于使用者对模型的需求不同, 训练和评估过程也不一定相同, 自定义的流程往往需要写一些代码, 但是基本的训练和评估流程是封装好的. 代码中给出来了
一些基本的东西在HuggingFace中都有稳定的接口, 比如微调的策略, 参数定义, 基本的训练评估流程, 都是即插即用的
Pytorch版
用pytorch对BERT 使用RLHF策略 在情感分析任务上进行微调
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset# 假设我们有一个简单的数据集
class CustomDataset(Dataset):def __init__(self, texts, labels):self.texts = textsself.labels = labelsself.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def __len__(self):return len(self.texts)def __getitem__(self, idx):encoding = self.tokenizer(self.texts[idx], padding='max_length', truncation=True, return_tensors='pt', max_length=128)return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(self.labels[idx], dtype=torch.long)}# 1. 监督微调(SFT)
def supervised_fine_tuning(model, dataloader):model.train()optimizer = AdamW(model.parameters(), lr=5e-5)for epoch in range(3): # 假设训练3个epochfor batch in dataloader:optimizer.zero_grad()input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['labels']outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()print(f"Loss: {loss.item()}")# 2. 奖励模型训练
def train_reward_model(model, reward_data):# 假设reward_data包含文本和对应的奖励分数model.train()optimizer = AdamW(model.parameters(), lr=5e-5)for epoch in range(3): # 假设训练3个epochfor text, reward in reward_data:inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)optimizer.zero_grad()outputs = model(**inputs)reward_loss = nn.MSELoss()(outputs.logits.squeeze(), torch.tensor(reward, dtype=torch.float32))reward_loss.backward()optimizer.step()print(f"Reward Loss: {reward_loss.item()}")# 3. RLHF训练
def rl_training(actor_model, critic_model, dataloader):actor_model.train()critic_model.eval() # 奖励模型在评估模式for epoch in range(3): # 假设训练3个epochfor batch in dataloader:input_ids = batch['input_ids']attention_mask = batch['attention_mask']# 使用actor模型生成输出actor_outputs = actor_model(input_ids=input_ids, attention_mask=attention_mask)# 使用critic模型评估输出的奖励with torch.no_grad():critic_outputs = critic_model(input_ids=input_ids, attention_mask=attention_mask)# 根据奖励调整actor模型的参数(PPO等算法可在此实现)# 此处省略具体的PPO实现,需根据具体需求添加# 示例数据集和模型初始化
texts = ["I love this!", "This is terrible."]
labels = [1, 0] # 假设1为正面,0为负面dataset = CustomDataset(texts, labels)
dataloader = DataLoader(dataset, batch_size=2)model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)# 执行微调流程
supervised_fine_tuning(model, dataloader)# 假设我们有一些奖励数据用于训练奖励模型
reward_data = [("I love this!", 1.0), ("This is terrible.", 0.0)]
train_reward_model(model, reward_data)# 最后,进行RLHF训练(需实现具体的PPO算法)
rl_training(model, model) # 此处使用同一模型作为示例
Pytorch vs HuggingFace
易用性:
HuggingFace的API非常简洁, 并且有丰富的涵盖多个领域的预训练模型库, 集成了多种常用的微调策略, 比如上面提到的LoRA等, 还有活跃的社区和丰富的文档
Pytorch缺乏高层封装, 在比如数据处理, 模型保存上需要用户手动实现更多的功能, 学习曲线陡峭
灵活性
HuggingFace灵活性不如Pytorch, 在高度自定义场景下, Pytorch表现更佳
性能
在一些情况下, Pytorch在计算上设计了专门的优化, HuggingFace的高层API不如Pytorch的性能优化高效
相关文章:

【Fine-Tuning】大模型微调理论及方法, PytorchHuggingFace微调实战
Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战 文章目录 Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战1. 什么是微调(1) 为什么要进行微调(2) 经典简单例子:情感分析任务背景微调 (3) 为什么微调work, 理论解释下 2…...

清华系“仓颉”来袭:图形起源:用AI颠覆字体设计,推动大模型商业化落地
大模型如何落地?又该如何实现商业化?这一议题已成为今年科技领域的焦点话题。 在一个鲜为人知的字体设计赛道上,清华创业公司“图形起源”悄然实现了商业变现:他们帮助字体公司将成本降低了80%,生产速度提升了10倍以上…...

分布式一致性协议的深度解析:Paxos与Raft
分布式系统的复杂性源于节点失效、网络分区、消息丢失等诸多不确定性。在这种背景下,分布式一致性问题应运而生,成为解决这些问题的核心。本文将从理论到实践,深入探讨两种经典的一致性协议:Paxos与Raft。文章适合有一定分布式系统…...

ai写作,五款软件助你快速写作!
在这个信息爆炸的时代,内容创作成为了连接用户、传递价值的桥梁。然而,面对日益增长的创作需求,如何在保证质量的同时提升效率,成为了每位创作者面临的难题。幸运的是,随着人工智能技术的飞速发展,AI写作软…...

解决JavaScript 数学运算精度丢失的问题
JavaScript 中执行浮点数运算时可能会遇到精度丢失的问题。这通常是因为浮点数的表示遵循IEEE 754标准,而这种表示法只能精确地表示有限的数字。对于大多数程序员来说,这不是一个问题,因为它允许计算机处理超出精度范围之外的数字。然而&…...

mysql学习教程,从入门到精通,SQL窗口函数(38)
1、SQL窗口函数 SQL窗口函数(Window Functions)是一种强大的数据分析工具,它们允许你在结果集的行上执行计算,而不需要将这些行分组到单独的输出行中。窗口函数通常与OVER()子句一起使用,该子句定义了窗口或分区&…...

gbase8s数据库实现黑白名单的几种方案
1、借用操作系统的黑白名单 2、使用数据库 TRUSTED CONTEXT 机制 CREATE TRUSTED CONTEXT tcx1USER rootATTRIBUTES (ADDRESS 172.16.39.162)ATTRIBUTES (ADDRESS 172.16.39.163)ENABLEWITH USE FOR wangyx WITHOUT AUTHENTICATION; 如上创建 可信任上下文对象 tcx1 在 jdb…...

Qt-窗口布局按钮输入类
1. 窗口布局 Qt 提供了很多摆放控件的辅助工具(又称布局管理器或者布局控件),它们可以完成两件事: 自动调整控件的位置,包括控件之间的间距、对齐等; 当用户调整窗口大小时,位于布局管理器内的…...

Apache DolphinScheduler社区9月进展记录
各位热爱 Apache DolphinScheduler 的小伙伴们,社区 9 月月报更新啦!这里将记录 Apache DolphinScheduler 社区每月的重要更新,欢迎关注! 月度 Merge Star 感谢以下小伙伴上个月为 Apache DolphinScheduler 做的精彩贡献&#x…...

在docker中安装并运行mysql8.0.31
第一步:命令行拉取mysql镜像 docker pull mysql:8.0.31查看是否拉取成功 docker images mysql:latest第二步:运行mysql镜像,启动mysql实例 docker run -p 3307:3307 -e MYSQL_ROOT_PASSWORD"123456" -d mysql:8.0.313307:3307前…...

C++ | Leetcode C++题解之第458题可怜的小猪
题目: 题解: class Solution { public:int poorPigs(int buckets, int minutesToDie, int minutesToTest) {if (buckets 1) {return 0;}vector<vector<int>> combinations(buckets 1,vector<int>(buckets 1));combinations[0][0] …...

【万字长文】Word2Vec计算详解(三)分层Softmax与负采样
【万字长文】Word2Vec计算详解(三)分层Softmax与负采样 写在前面 第三部分介绍Word2Vec模型的两种优化方案。 【万字长文】Word2Vec计算详解(一)CBOW模型 markdown行 9000 【万字长文】Word2Vec计算详解(二࿰…...

【分布式微服务云原生】探索Dubbo:接口定义语言的多样性与选择
目录 探索Dubbo:接口定义语言的多样性与选择引言Dubbo的接口定义语言(IDL)1. Java接口2. XML配置3. 注解4. Protobuf IDL 流程图:Dubbo服务定义流程表格:Dubbo IDL方式比较结论呼吁行动Excel表格:Dubbo IDL…...

SAP将假脱机(Spool requests)内容转换为PDF文档[RSTXPDFT4]
将假脱机(Spool requests)内容转换为PDF文档[RSTXPDFT4] 有时需要将Spool中的内容导出成PDF文件,sap提供了一个标准程序RSTXPDFT4可以实现此功能。 1, Tcode:SP01, 进入spool requests list 2, SE38 运行程序RSTXPDFT4 输入spool reqeust号码18680,然后…...

DNS能加速游戏吗?
在游戏玩家追求极致游戏体验的今天,任何可能提升游戏性能的因素都备受关注,DNS(域名系统)便是其中一个被探讨的对象。那么,DNS能加速游戏吗? 首先,我们需要了解DNS的基本功能。DNS就像是互联网…...

Raspberry Pi3B+之C/C++开发环境搭建
Raspberry Pi3B之C/C开发环境搭建 1. 源由2. 环境搭建2.1 搭建C语言开发环境2.2 工程目录结构2.3 Makefile2.4 Demo (main.c) 3. 测试工程3.1 编译3.2 运行 4. 总结5. 参考资料 1. 源由 为了配合《Ardupilot开源飞控之FollowMe验证平台搭建》,以及VINS-Fusion对于图…...

[笔记] 仿射变换性质的代数证明
Title: [笔记] 仿射变换性质的代数证明 文章目录 I. 仿射变换的代数表示II. 仿射变换的性质III. 同素性的代数证明1. 点变换为点2. 直线变换为直线 IV. 结合性的代数证明1. 直线上一点映射为直线上一点2. 直线外一点映射为直线外一点 V. 保持单比的代数证明VI. 平行性的代数证明…...

遥感影像-语义分割数据集:sar水体数据集详细介绍及训练样本处理流程
原始数据集详情 简介:该数据集由WHU-OPT-SAR数据集整理而来,覆盖面积51448.56公里,分辨率为5米。据我们所知,WHU-OPT-SAR是第一个也是最大的土地利用分类数据集,它融合了高分辨率光学和SAR图像,并进行了充…...

极狐GitLab 发布安全补丁版本 17.4.1、17.3.4、17.2.8
GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料: 极狐GitLab 官网极狐…...

汽车管理系统中使用函数
目录 setupUisetEnabledcurrentText()setTextsetFocus()query.exec(...)addWidgetconnect setupUi setupUi() 是 ui 对象的一个成员函数,它的作用是根据 .ui 文件中的设计,将设计好的组件(如按钮、文本框、布局等)添加到当前的窗…...

大数据分析入门概述
大数据分析入门概述 本文旨在为有意向学习数据分析、数据开发等大数据方向的初学者提供一个学习指南,当然如果你希望通过视频课程的方式快速入门,B站UP主戴戴戴师兄的课程质量很高,并且适合初学者快速入门。本文的目的旨在为想要了解大数据但…...

提示工程、微调和 RAG
自众多大型语言模型(LLM)和高级对话模型发布以来,人们已经运用了各种技术来从这些 AI 系统中提取所需的输出。其中一些方法会改变模型的行为来更好地贴近我们的期望,而另一些方法则侧重于增强我们查询 LLM 的方式,以提…...

自动化测试中如何高效进行元素定位!
前言 在自动化测试中,元素定位是一项非常重要的工作。良好的元素定位可以帮助测试人员处理大量的测试用例,加快测试进度,降低工作负担。但是在实际的测试工作中,我们常常遇到各种各样的定位问题,比如元素定位失败、元…...

UE5数字人制作平台使用及3D模型生成
第10章 数字人制作平台使用及3D模型生成 在数字娱乐、虚拟现实(VR)、增强现实(AR)等领域,高质量的3D模型是数字内容创作的核心。本章将引导你了解如何使用UE5(Unreal Engine 5)虚幻引擎这一强大…...

Linux进程被占用如何杀死进程
文章目录 前言一、根据名称进行查找程序所占用的端口号二、杀死进程总结 前言 由于Linux中,校园网登录的时候容易出现端口被占用,如何快速查找程序所占用的端口号。 提示:以下是本篇文章正文内容,下面案例可供参考 一、根据名称…...

详解Xilinx JESD204B PHY层端口信号含义及动态切换线速率(JESD204B五)
点击进入高速收发器系列文章导航界面 Xilinx官方提供了两个用于开发JESD204B的IP,其中一个完成PHY层设计,另一个完成传输层的逻辑,两个IP必须一起使用才能正常工作。 7系列FPGA只能使用最多12通道的JESD204B协议,线速率为1.0至12.…...

Java面试——场景题
1.如何分批处理数据? 1.使用LIMIT和OFFSET子句: 这是最常用的分批查询方法。例如,你可以使用以下SQL语句来分批查询数据: SELECT * FROM your_table LIMIT 1000 OFFSET 0; 分批查询到的数据在后端进行处理,达到分批…...

xss-labs靶场第一关测试报告
目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、注入点寻找 2、使用hackbar进行payload测试 3、绕过结果 四、源代码分析 五、结论 一、测试环境 1、系统环境 渗透机:本机(127.0.0.1) 靶 机:本机(127.0.0.…...

微软PowerBI认证!数据分析师入门级证书备考攻略来啦
#微软PowerBI认证!数据分析师入门级证书! 😃Power BI是一种强大的数据可视化和分析工具,学习Power BI,能提高数据的分析能力,将数据转化为有意义的见解,并支持数据驱动的决策制定。 ㅤ ✨微软P…...

上海AI Lab视频生成大模型书生.筑梦环境搭建推理测试
引子 最近视频生成大模型层出不穷,上海AI Lab推出新一代视频生成大模型 “书生・筑梦 2.0”(Vchitect 2.0)。根据官方介绍,书生・筑梦 2.0 是集文生视频、图生视频、插帧超分、训练系统一体化的视频生成大模型。OK,那就让我们开始吧。 一、模…...