当前位置: 首页 > news >正文

Huggingface训练Transformer

在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客

Huggingface提供了一个TRL的扩展库,可以对transformer模型进行强化学习,SFT是其中的一个训练步骤,为此我也测试一下如何用Huggingface来进行SFT训练,和Pytorch的训练方式做一个比较。

训练数据

首先是获取训练数据,这里同样是采用Huggingface的chatbot_instruction_prompts的数据集,这个数据集涵盖了不同类型的问答,可以用作我们的模型优化之用。

from datasets import load_datasetds = load_dataset("alespalla/chatbot_instruction_prompts")
train_ds = ds['train']
eval_dataset = ds['test']
eval_dataset = eval_dataset.select(range(1024))

训练集总共包括了258042条问答数据,对于验证集我只选取了头1024条记录,因为总的数据集太长,如果在训练过程中全部验证的话耗时太长。

加载GPT2模型

Huggingface提供了很多大模型的训练好的参数,这里我们可以直接加载一个已经训练好的GPT2模型来做优化

from transformers import AutoModelForCausalLM, AutoTokenizermodel = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

然后我们可以定义TRL提供的SFTTrainer来进行训练,首先需要对训练数据处理一下,因为训练数据包括了两列,分别是prompt和response,我们需要把两列的文本合为一起,通过格式化字符来区分,如以下格式化函数:

def formatting_func(example):text = f"### Prompt: {example['prompt']}\n ### Response: {example['response']}"return text

定义SFTTrainer的训练参数,具体每个参数的含义可见官网的文档:

args = TrainingArguments(output_dir='checkpoints_hf_sft',overwrite_output_dir=True, per_device_train_batch_size=4,per_device_eval_batch_size=4,fp16=True,torch_compile=True,evaluation_strategy='steps',prediction_loss_only=True,eval_accumulation_steps=1,learning_rate=0.00006,weight_decay=0.01,adam_beta1=0.9,adam_beta2=0.95,warmup_steps=1000,eval_steps=4000,save_steps=4000,save_total_limit=4,dataloader_num_workers=4,max_steps=12000,optim='adamw_torch_fused')

最后就可以定义一个trainer来训练了

trainer = SFTTrainer(model,args = args,train_dataset=dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,packing=True,formatting_func=formatting_func,max_seq_length=1024
)trainer.train(resume_from_checkpoint=False)

因为是第一次训练,我设置了resume_from_checkpoint=False,如果是继续训练,把这个参数设为True即可从上次checkpoint目录自动加载最新的checkpoint来训练。

训练结果如下:

 [12000/12000 45:05, Epoch 0/1]

StepTraining LossValidation Loss
40002.1379002.262321
80002.1878002.232235
120002.2185002.210413

总共耗时45分钟,比我在pytorch上的训练要快一些(快了10分钟多一些),但是这个训练集的Loss随着Step的增加反而增加了,Validation Loss就减少了,有些奇怪。在Pytorch上我同样训练12000个迭代,最后training loss是去到1.8556的。可能还要再调整一下trainer的参数看看。

测试

最后我们把SFT训练完成的模型,通过huggingface的pipeline就可加载进行测试了。

from transformers import pipelinemodel = AutoModelForCausalLM.from_pretrained('checkpoints_hf_sft/checkpoint-12000/')
pipe = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=0)pipe('### Prompt: Who is the current president of USA?')
[{'generated_text': '### Prompt: Who is the current president of USA?\n ### Response: Harry K. Busby is currently President of the United States.'}]

回答的语法没问题,不过内容是错的。

再测试另一个问题

### Prompt: How to make a cup of coffee?

[{'generated_text': '### Prompt: How to make a cup of coffee?\n ### Response: 1. Preheat the oven to 350°F (175°C).\n\n2. Boil the coffee beans according to package instructions.\n\n3. In a'}]

这个回答就正确了。

总结

通过用Huggingface可以很方便的对大模型进行强化学习的训练,不过也正因为太方便了,很多训练的细节被包装了,所以训练的结果不太容易优化,不像在Pytorch里面控制的自由度更高一些。当然可能我对huggingface的trainer参数的细节还不太了解,这个有待后续继续了解。

另外我还发现huggingface的一个小的bug,就是模型从头训练的时候没有问题,但是当我从之前的checkpoint继续训练时,会报CUDA OOM的错误,从nvidia-smi命令看到的显存占用率来看,好像trainer定义模型和装载Checkpoint会重复占用了显存,因此同样的batch_size,在继续训练时就报内存不够了,这个也有待后溪继续了解。

相关文章:

Huggingface训练Transformer

在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客 Huggingface提供了一个TRL的扩展库,可以对tra…...

IA-YOLO项目中DIP模块的初级解读

IA-YOLO项目源自论文Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions,其提出端到端方式联合学习CNN-PP和YOLOv3,这确保了CNN-PP可以学习适当的DIP,以弱监督的方式增强图像检测。IA-YOLO方法可以自适应地处理正常和不…...

MathType7.4mac最新版本数学公式编辑器安装教程

MathType7.4中文版是一款功能强大且易于使用的公式编辑器。该软件可与word软件配合使用,有效提高了教学人员的工作效率,避免了一些数学符号和公式无法在word中输入的麻烦。新版MathType7.4启用了全新的LOGO,带来了更多对数学符号和公式的支持…...

为Claude的分析内容做准备:提取PDF页面内容的简易应用程序

由于Claude虽然可以分析整个文件,但是对文件的大小以及字数是有限制的,为了将pdf文件分批传入Claude人工智能分析和总结文章内容,才有了这篇博客: 在本篇博客中,我们将介绍一个基于 wxPython 和 PyMuPDF 库编写的简易的…...

js中作用域的理解?

1.作用域 作用域,即变量(变量作用域又称上下文)和函数生效(能被访问)的区域或集合 换句话说,作用域决定了代码区块中变量和其他资源的可见性 举个例子 function myFunction() {let inVariable "函数内部变量"; } myFunction();//要先执行这…...

机器学习基础之《分类算法(4)—案例:预测facebook签到位置》

一、背景 1、说明 2、数据集 row_id:签到行为的编码 x y:坐标系,人所在的位置 accuracy:定位的准确率 time:时间戳 place_id:预测用户将要签到的位置 3、数据集下载 https://www.kaggle.com/navoshta/gr…...

【Java】反射 之 调用方法

调用方法 我们已经能通过Class实例获取所有Field对象,同样的,可以通过Class实例获取所有Method信息。Class类提供了以下几个方法来获取Method: Method getMethod(name, Class...):获取某个public的Method(包括父类&a…...

Java——单例设计模式

什么是设计模式? 设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格、以及解决问题的思考方式。设计模式免去我们自己再思考和摸索。就像是经典的棋谱,不同的棋局,我们用不同的棋谱、“套路”。 经典的设计模式共有23种。…...

Java实现excel表数据的批量存储(结合easyexcel插件)

场景:加哥最近在做项目时,苦于系统自身并未提供数据批量导入的功能还不能自行添加上该功能,且自身不想手动一条一条将数据录入系统。随后,自己使用JDBC连接数据库、使用EasyExcel插件读取表格并将数据按照业务逻辑批量插入数据库完…...

Config:客户端连接服务器访问远程

springcloud-config: springcloud-config push pom <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocatio…...

【KMP算法-代码随想录】

目录 1.什么是KMP2.什么是next数组3.什么是前缀表&#xff08;1&#xff09;前后缀含义&#xff08;2&#xff09;最长公共前后缀&#xff08;3&#xff09;前缀表的必要性 4.计算前缀表5.前缀表与next数组&#xff08;1&#xff09;使用next数组来匹配 6.构造next数组&#xf…...

【手写promise——基本功能、链式调用、promise.all、promise.race】

文章目录 前言一、前置知识二、实现基本功能二、实现链式调用三、实现Promise.all四、实现Promise.race总结 前言 关于动机&#xff0c;无论是在工作还是面试中&#xff0c;都会遇到Promise的相关使用和原理&#xff0c;手写Promise也有助于学习设计模式以及代码设计。 本文主…...

计算机网络-笔记-第二章-物理层

目录 二、第二章——物理层 1、物理层的基本概念 2、物理层下面的传输媒体 &#xff08;1&#xff09;光纤、同轴电缆、双绞线、电力线【导引型】 &#xff08;2&#xff09;无线电波、微波、红外线、可见光【非导引型】 &#xff08;3&#xff09;无线电【频谱的使用】 …...

前端开发中的单伪标签清除和双伪标签清除

引言 在前端开发中&#xff0c;我们经常会遇到一些样式上的问题&#xff0c;其中之一就是伪元素造成的布局问题。为了解决这个问题&#xff0c;我们可以使用伪标签清除技术。本篇博客将介绍单伪标签清除和双伪标签清除的概念、用法和示例代码&#xff0c;并详细解释它们的原理…...

云计算中的数据安全与隐私保护策略

文章目录 1. 云计算中的数据安全挑战1.1 数据泄露和数据风险1.2 多租户环境下的隔离问题 2. 隐私保护策略2.1 数据加密2.2 访问控制和身份验证 3. 应对方法与技术3.1 零知识证明&#xff08;Zero-Knowledge Proofs&#xff09;3.2 同态加密&#xff08;Homomorphic Encryption&…...

MacOS软件安装包分享(附安装教程)

目录 一、软件简介 二、软件下载 一、软件简介 MacOS是一种由苹果公司开发的操作系统&#xff0c;专门用于苹果公司的计算机硬件。它被广泛用于创意和专业应用程序&#xff0c;如图像设计、音频和视频编辑等。以下是关于MacOS的详细介绍。 1、MacOS的历史和演变 MacOS最初于…...

【linux进程概念】

目录&#xff1a; 冯诺依曼体系结构操作系统进程 基本概念描述进程-PCBtask_struct-PCB的一种task_ struct内容分类组织进程查看进程 fork()函数 冯诺依曼体系结构 我们常见的计算机&#xff0c;如笔记本。我们不常见的计算机&#xff0c;如服务器&#xff0c;大部分都遵守冯诺…...

直击成都国际车展:远航汽车多款车型登陆车展,打造完美驾乘体验

随着市场渗透率日益高涨&#xff0c;新能源汽车成为今年成都国际车展的关注焦点。在本届车展上&#xff0c;新能源品牌占比再创新高&#xff0c;覆盖两个展馆&#xff0c;印证了当下新能源汽车市场的火爆。作为大运集团重磅打造的高端品牌&#xff0c;远航汽车深度洞察高端智能…...

android nv21 转 yuv420sp

上面两个函数的目标都是将NV21格式的数据转换为YUV420P格式&#xff0c;但是它们在处理U和V分量的方式上有所不同。 在第一个函数NV21toYUV420P_1中&#xff0c;U和V分量的处理方式是这样的&#xff1a;对于U分量&#xff0c;它从NV21数据的Y分量之后的每个奇数位置取数据&…...

使用Nacos与Spring Boot实现配置管理

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…...

初识【类和对象】

目录 1.面向过程和面向对象初步认识 2.类的引入 3.类的定义 4.类的访问限定符及封装 5.类的作用域 6.类的实例化 7.类的对象大小的计算 8.类成员函数的this指针 1.面向过程和面向对象初步认识 C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的…...

软考高级系统架构设计师系列论文八十六:论企业应用集成

软考高级系统架构设计师系列论文八十六:论企业应用集成 一、企业应用集成相关知识点二、摘要三、正文四、总结一、企业应用集成相关知识点 软考高级系统架构设计师系列之:企业集成平台技术的应用和架构设计二、摘要 2022年10月,我参加了***车站综合信息平台项目的开发,承…...

HarmonyOS ArkUI 属性动画入门详解

HarmonyOS ArkUI 属性动画入门详解 前言属性动画是什么&#xff1f;我们借助官方的话来说&#xff0c;我们自己简单归纳下 参数解释举个例子旋转动画 位移动画组合动画总结 前言 鸿蒙OS最近吹的很凶&#xff0c;赶紧卷一下。学习过程中发现很多人吐槽官方属性动画这一章比较敷…...

基于XGBoots预测A股大盘《上证指数》(代码+数据+一键可运行)

对AI炒股感兴趣的小伙伴可加WX&#xff1a;caihaihua057200&#xff08;备注&#xff1a;学校/公司名字方向&#xff09; 另外我还有些AI的应用可以一起研究&#xff08;我一直开源代码&#xff09; 1、引言 在这期内容中&#xff0c;我们回到AI预测股票&#xff0c;转而探索…...

5G NR:PRACH频域资源

PRACH在频域位置由IE RACH-ConfigGeneric中参数msg1-FrequencyStart和msg1-FDM所指示&#xff0c;其中&#xff0c; msg1-FrequencyStart确定PRACH occasion 0的RB其实位置相对于上行公共BWP的频域其实位置(即BWP 0)的偏移&#xff0c;即确定PRACH的频域起始位置msg1-FDM的取值…...

设计模式——组合模式

什么是组合模式 组合模式(Composite Pattern)&#xff1a;组合多个对象形成树形结构以表示具有“整体—部分”关系的层次结构。组合模式对单个对象&#xff08;即叶子对象&#xff09;和组合对象&#xff08;即容器对象&#xff09;的使用具有一致性&#xff0c;组合模式又可以…...

get属性是什么?有什么用?在什么场景用?get会被Json序列化?

在JavaScript中&#xff0c;对象的属性不仅可以是数据属性&#xff08;即常规的键值对&#xff09;&#xff0c;还可以是访问器属性&#xff08;accessor properties&#xff09;。访问器属性不包含实际的数据值&#xff0c;而是定义了如何获取&#xff08;get&#xff09;和设…...

这可能是你看过最详细的 [八大排序算法]

排序算法 前置知识 [排序稳定性]一、直接插入排序二、希尔排序三、直接选择排序四、堆排序五、冒泡排序六、快速排序七、归并排序八、计数排序&#xff08;非比较排序&#xff09;排序复杂度和稳定性总结 前置知识 [排序稳定性] 假定在待排序的记录序列中&#xff0c;存在多个…...

docker的安装

CentOS7 安装 Docker 安装需要的软件包&#xff0c; yum-util 提供yum-config-manager功能&#xff0c;另两个是devicemapper驱动依赖 yum install -y yum-utils device-mapper-persistent-data lvm2 添加下载源 yum-config-manager --add-repo http://mirrors.aliyun.com/…...

【业务功能篇75】微服务项目环境搭建docker-mysql-redisSpringCloudAlibaba

项目环境准备 1.虚拟机环境 我们可以通过VMWare来安装&#xff0c;但是通过VMWare安装大家经常会碰到网络ip连接问题&#xff0c;为了减少额外的环境因素影响&#xff0c;Docker内容的讲解我们会通过VirtualBox结合Vagrant来安装虚拟机。 VirtualBox官网&#xff1a;https:/…...