Huggingface训练Transformer
在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客
Huggingface提供了一个TRL的扩展库,可以对transformer模型进行强化学习,SFT是其中的一个训练步骤,为此我也测试一下如何用Huggingface来进行SFT训练,和Pytorch的训练方式做一个比较。
训练数据
首先是获取训练数据,这里同样是采用Huggingface的chatbot_instruction_prompts的数据集,这个数据集涵盖了不同类型的问答,可以用作我们的模型优化之用。
from datasets import load_datasetds = load_dataset("alespalla/chatbot_instruction_prompts")
train_ds = ds['train']
eval_dataset = ds['test']
eval_dataset = eval_dataset.select(range(1024))
训练集总共包括了258042条问答数据,对于验证集我只选取了头1024条记录,因为总的数据集太长,如果在训练过程中全部验证的话耗时太长。
加载GPT2模型
Huggingface提供了很多大模型的训练好的参数,这里我们可以直接加载一个已经训练好的GPT2模型来做优化
from transformers import AutoModelForCausalLM, AutoTokenizermodel = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
然后我们可以定义TRL提供的SFTTrainer来进行训练,首先需要对训练数据处理一下,因为训练数据包括了两列,分别是prompt和response,我们需要把两列的文本合为一起,通过格式化字符来区分,如以下格式化函数:
def formatting_func(example):text = f"### Prompt: {example['prompt']}\n ### Response: {example['response']}"return text
定义SFTTrainer的训练参数,具体每个参数的含义可见官网的文档:
args = TrainingArguments(output_dir='checkpoints_hf_sft',overwrite_output_dir=True, per_device_train_batch_size=4,per_device_eval_batch_size=4,fp16=True,torch_compile=True,evaluation_strategy='steps',prediction_loss_only=True,eval_accumulation_steps=1,learning_rate=0.00006,weight_decay=0.01,adam_beta1=0.9,adam_beta2=0.95,warmup_steps=1000,eval_steps=4000,save_steps=4000,save_total_limit=4,dataloader_num_workers=4,max_steps=12000,optim='adamw_torch_fused')
最后就可以定义一个trainer来训练了
trainer = SFTTrainer(model,args = args,train_dataset=dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,packing=True,formatting_func=formatting_func,max_seq_length=1024
)trainer.train(resume_from_checkpoint=False)
因为是第一次训练,我设置了resume_from_checkpoint=False,如果是继续训练,把这个参数设为True即可从上次checkpoint目录自动加载最新的checkpoint来训练。
训练结果如下:
[12000/12000 45:05, Epoch 0/1]
Step | Training Loss | Validation Loss |
---|---|---|
4000 | 2.137900 | 2.262321 |
8000 | 2.187800 | 2.232235 |
12000 | 2.218500 | 2.210413 |
总共耗时45分钟,比我在pytorch上的训练要快一些(快了10分钟多一些),但是这个训练集的Loss随着Step的增加反而增加了,Validation Loss就减少了,有些奇怪。在Pytorch上我同样训练12000个迭代,最后training loss是去到1.8556的。可能还要再调整一下trainer的参数看看。
测试
最后我们把SFT训练完成的模型,通过huggingface的pipeline就可加载进行测试了。
from transformers import pipelinemodel = AutoModelForCausalLM.from_pretrained('checkpoints_hf_sft/checkpoint-12000/')
pipe = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=0)pipe('### Prompt: Who is the current president of USA?')
[{'generated_text': '### Prompt: Who is the current president of USA?\n ### Response: Harry K. Busby is currently President of the United States.'}]
回答的语法没问题,不过内容是错的。
再测试另一个问题
### Prompt: How to make a cup of coffee?
[{'generated_text': '### Prompt: How to make a cup of coffee?\n ### Response: 1. Preheat the oven to 350°F (175°C).\n\n2. Boil the coffee beans according to package instructions.\n\n3. In a'}]
这个回答就正确了。
总结
通过用Huggingface可以很方便的对大模型进行强化学习的训练,不过也正因为太方便了,很多训练的细节被包装了,所以训练的结果不太容易优化,不像在Pytorch里面控制的自由度更高一些。当然可能我对huggingface的trainer参数的细节还不太了解,这个有待后续继续了解。
另外我还发现huggingface的一个小的bug,就是模型从头训练的时候没有问题,但是当我从之前的checkpoint继续训练时,会报CUDA OOM的错误,从nvidia-smi命令看到的显存占用率来看,好像trainer定义模型和装载Checkpoint会重复占用了显存,因此同样的batch_size,在继续训练时就报内存不够了,这个也有待后溪继续了解。
相关文章:
Huggingface训练Transformer
在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客 Huggingface提供了一个TRL的扩展库,可以对tra…...

IA-YOLO项目中DIP模块的初级解读
IA-YOLO项目源自论文Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions,其提出端到端方式联合学习CNN-PP和YOLOv3,这确保了CNN-PP可以学习适当的DIP,以弱监督的方式增强图像检测。IA-YOLO方法可以自适应地处理正常和不…...

MathType7.4mac最新版本数学公式编辑器安装教程
MathType7.4中文版是一款功能强大且易于使用的公式编辑器。该软件可与word软件配合使用,有效提高了教学人员的工作效率,避免了一些数学符号和公式无法在word中输入的麻烦。新版MathType7.4启用了全新的LOGO,带来了更多对数学符号和公式的支持…...

为Claude的分析内容做准备:提取PDF页面内容的简易应用程序
由于Claude虽然可以分析整个文件,但是对文件的大小以及字数是有限制的,为了将pdf文件分批传入Claude人工智能分析和总结文章内容,才有了这篇博客: 在本篇博客中,我们将介绍一个基于 wxPython 和 PyMuPDF 库编写的简易的…...

js中作用域的理解?
1.作用域 作用域,即变量(变量作用域又称上下文)和函数生效(能被访问)的区域或集合 换句话说,作用域决定了代码区块中变量和其他资源的可见性 举个例子 function myFunction() {let inVariable "函数内部变量"; } myFunction();//要先执行这…...

机器学习基础之《分类算法(4)—案例:预测facebook签到位置》
一、背景 1、说明 2、数据集 row_id:签到行为的编码 x y:坐标系,人所在的位置 accuracy:定位的准确率 time:时间戳 place_id:预测用户将要签到的位置 3、数据集下载 https://www.kaggle.com/navoshta/gr…...
【Java】反射 之 调用方法
调用方法 我们已经能通过Class实例获取所有Field对象,同样的,可以通过Class实例获取所有Method信息。Class类提供了以下几个方法来获取Method: Method getMethod(name, Class...):获取某个public的Method(包括父类&a…...

Java——单例设计模式
什么是设计模式? 设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格、以及解决问题的思考方式。设计模式免去我们自己再思考和摸索。就像是经典的棋谱,不同的棋局,我们用不同的棋谱、“套路”。 经典的设计模式共有23种。…...

Java实现excel表数据的批量存储(结合easyexcel插件)
场景:加哥最近在做项目时,苦于系统自身并未提供数据批量导入的功能还不能自行添加上该功能,且自身不想手动一条一条将数据录入系统。随后,自己使用JDBC连接数据库、使用EasyExcel插件读取表格并将数据按照业务逻辑批量插入数据库完…...

Config:客户端连接服务器访问远程
springcloud-config: springcloud-config push pom <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocatio…...

【KMP算法-代码随想录】
目录 1.什么是KMP2.什么是next数组3.什么是前缀表(1)前后缀含义(2)最长公共前后缀(3)前缀表的必要性 4.计算前缀表5.前缀表与next数组(1)使用next数组来匹配 6.构造next数组…...

【手写promise——基本功能、链式调用、promise.all、promise.race】
文章目录 前言一、前置知识二、实现基本功能二、实现链式调用三、实现Promise.all四、实现Promise.race总结 前言 关于动机,无论是在工作还是面试中,都会遇到Promise的相关使用和原理,手写Promise也有助于学习设计模式以及代码设计。 本文主…...

计算机网络-笔记-第二章-物理层
目录 二、第二章——物理层 1、物理层的基本概念 2、物理层下面的传输媒体 (1)光纤、同轴电缆、双绞线、电力线【导引型】 (2)无线电波、微波、红外线、可见光【非导引型】 (3)无线电【频谱的使用】 …...
前端开发中的单伪标签清除和双伪标签清除
引言 在前端开发中,我们经常会遇到一些样式上的问题,其中之一就是伪元素造成的布局问题。为了解决这个问题,我们可以使用伪标签清除技术。本篇博客将介绍单伪标签清除和双伪标签清除的概念、用法和示例代码,并详细解释它们的原理…...

云计算中的数据安全与隐私保护策略
文章目录 1. 云计算中的数据安全挑战1.1 数据泄露和数据风险1.2 多租户环境下的隔离问题 2. 隐私保护策略2.1 数据加密2.2 访问控制和身份验证 3. 应对方法与技术3.1 零知识证明(Zero-Knowledge Proofs)3.2 同态加密(Homomorphic Encryption&…...

MacOS软件安装包分享(附安装教程)
目录 一、软件简介 二、软件下载 一、软件简介 MacOS是一种由苹果公司开发的操作系统,专门用于苹果公司的计算机硬件。它被广泛用于创意和专业应用程序,如图像设计、音频和视频编辑等。以下是关于MacOS的详细介绍。 1、MacOS的历史和演变 MacOS最初于…...

【linux进程概念】
目录: 冯诺依曼体系结构操作系统进程 基本概念描述进程-PCBtask_struct-PCB的一种task_ struct内容分类组织进程查看进程 fork()函数 冯诺依曼体系结构 我们常见的计算机,如笔记本。我们不常见的计算机,如服务器,大部分都遵守冯诺…...

直击成都国际车展:远航汽车多款车型登陆车展,打造完美驾乘体验
随着市场渗透率日益高涨,新能源汽车成为今年成都国际车展的关注焦点。在本届车展上,新能源品牌占比再创新高,覆盖两个展馆,印证了当下新能源汽车市场的火爆。作为大运集团重磅打造的高端品牌,远航汽车深度洞察高端智能…...
android nv21 转 yuv420sp
上面两个函数的目标都是将NV21格式的数据转换为YUV420P格式,但是它们在处理U和V分量的方式上有所不同。 在第一个函数NV21toYUV420P_1中,U和V分量的处理方式是这样的:对于U分量,它从NV21数据的Y分量之后的每个奇数位置取数据&…...

使用Nacos与Spring Boot实现配置管理
🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...

C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具
作者:来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗?了解下一期 Elasticsearch Engineer 培训的时间吧! Elasticsearch 拥有众多新功能,助你为自己…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...
音视频——I2S 协议详解
I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议,专门用于在数字音频设备之间传输数字音频数据。它由飞利浦(Philips)公司开发,以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...

STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...
uniapp 实现腾讯云IM群文件上传下载功能
UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中,群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS,在uniapp中实现: 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...
Python实现简单音频数据压缩与解压算法
Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中,压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言,提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...