当前位置: 首页 > news >正文

Huggingface训练Transformer

在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客

Huggingface提供了一个TRL的扩展库,可以对transformer模型进行强化学习,SFT是其中的一个训练步骤,为此我也测试一下如何用Huggingface来进行SFT训练,和Pytorch的训练方式做一个比较。

训练数据

首先是获取训练数据,这里同样是采用Huggingface的chatbot_instruction_prompts的数据集,这个数据集涵盖了不同类型的问答,可以用作我们的模型优化之用。

from datasets import load_datasetds = load_dataset("alespalla/chatbot_instruction_prompts")
train_ds = ds['train']
eval_dataset = ds['test']
eval_dataset = eval_dataset.select(range(1024))

训练集总共包括了258042条问答数据,对于验证集我只选取了头1024条记录,因为总的数据集太长,如果在训练过程中全部验证的话耗时太长。

加载GPT2模型

Huggingface提供了很多大模型的训练好的参数,这里我们可以直接加载一个已经训练好的GPT2模型来做优化

from transformers import AutoModelForCausalLM, AutoTokenizermodel = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

然后我们可以定义TRL提供的SFTTrainer来进行训练,首先需要对训练数据处理一下,因为训练数据包括了两列,分别是prompt和response,我们需要把两列的文本合为一起,通过格式化字符来区分,如以下格式化函数:

def formatting_func(example):text = f"### Prompt: {example['prompt']}\n ### Response: {example['response']}"return text

定义SFTTrainer的训练参数,具体每个参数的含义可见官网的文档:

args = TrainingArguments(output_dir='checkpoints_hf_sft',overwrite_output_dir=True, per_device_train_batch_size=4,per_device_eval_batch_size=4,fp16=True,torch_compile=True,evaluation_strategy='steps',prediction_loss_only=True,eval_accumulation_steps=1,learning_rate=0.00006,weight_decay=0.01,adam_beta1=0.9,adam_beta2=0.95,warmup_steps=1000,eval_steps=4000,save_steps=4000,save_total_limit=4,dataloader_num_workers=4,max_steps=12000,optim='adamw_torch_fused')

最后就可以定义一个trainer来训练了

trainer = SFTTrainer(model,args = args,train_dataset=dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,packing=True,formatting_func=formatting_func,max_seq_length=1024
)trainer.train(resume_from_checkpoint=False)

因为是第一次训练,我设置了resume_from_checkpoint=False,如果是继续训练,把这个参数设为True即可从上次checkpoint目录自动加载最新的checkpoint来训练。

训练结果如下:

 [12000/12000 45:05, Epoch 0/1]

StepTraining LossValidation Loss
40002.1379002.262321
80002.1878002.232235
120002.2185002.210413

总共耗时45分钟,比我在pytorch上的训练要快一些(快了10分钟多一些),但是这个训练集的Loss随着Step的增加反而增加了,Validation Loss就减少了,有些奇怪。在Pytorch上我同样训练12000个迭代,最后training loss是去到1.8556的。可能还要再调整一下trainer的参数看看。

测试

最后我们把SFT训练完成的模型,通过huggingface的pipeline就可加载进行测试了。

from transformers import pipelinemodel = AutoModelForCausalLM.from_pretrained('checkpoints_hf_sft/checkpoint-12000/')
pipe = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=0)pipe('### Prompt: Who is the current president of USA?')
[{'generated_text': '### Prompt: Who is the current president of USA?\n ### Response: Harry K. Busby is currently President of the United States.'}]

回答的语法没问题,不过内容是错的。

再测试另一个问题

### Prompt: How to make a cup of coffee?

[{'generated_text': '### Prompt: How to make a cup of coffee?\n ### Response: 1. Preheat the oven to 350°F (175°C).\n\n2. Boil the coffee beans according to package instructions.\n\n3. In a'}]

这个回答就正确了。

总结

通过用Huggingface可以很方便的对大模型进行强化学习的训练,不过也正因为太方便了,很多训练的细节被包装了,所以训练的结果不太容易优化,不像在Pytorch里面控制的自由度更高一些。当然可能我对huggingface的trainer参数的细节还不太了解,这个有待后续继续了解。

另外我还发现huggingface的一个小的bug,就是模型从头训练的时候没有问题,但是当我从之前的checkpoint继续训练时,会报CUDA OOM的错误,从nvidia-smi命令看到的显存占用率来看,好像trainer定义模型和装载Checkpoint会重复占用了显存,因此同样的batch_size,在继续训练时就报内存不够了,这个也有待后溪继续了解。

相关文章:

Huggingface训练Transformer

在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客 Huggingface提供了一个TRL的扩展库,可以对tra…...

IA-YOLO项目中DIP模块的初级解读

IA-YOLO项目源自论文Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions,其提出端到端方式联合学习CNN-PP和YOLOv3,这确保了CNN-PP可以学习适当的DIP,以弱监督的方式增强图像检测。IA-YOLO方法可以自适应地处理正常和不…...

MathType7.4mac最新版本数学公式编辑器安装教程

MathType7.4中文版是一款功能强大且易于使用的公式编辑器。该软件可与word软件配合使用,有效提高了教学人员的工作效率,避免了一些数学符号和公式无法在word中输入的麻烦。新版MathType7.4启用了全新的LOGO,带来了更多对数学符号和公式的支持…...

为Claude的分析内容做准备:提取PDF页面内容的简易应用程序

由于Claude虽然可以分析整个文件,但是对文件的大小以及字数是有限制的,为了将pdf文件分批传入Claude人工智能分析和总结文章内容,才有了这篇博客: 在本篇博客中,我们将介绍一个基于 wxPython 和 PyMuPDF 库编写的简易的…...

js中作用域的理解?

1.作用域 作用域,即变量(变量作用域又称上下文)和函数生效(能被访问)的区域或集合 换句话说,作用域决定了代码区块中变量和其他资源的可见性 举个例子 function myFunction() {let inVariable "函数内部变量"; } myFunction();//要先执行这…...

机器学习基础之《分类算法(4)—案例:预测facebook签到位置》

一、背景 1、说明 2、数据集 row_id:签到行为的编码 x y:坐标系,人所在的位置 accuracy:定位的准确率 time:时间戳 place_id:预测用户将要签到的位置 3、数据集下载 https://www.kaggle.com/navoshta/gr…...

【Java】反射 之 调用方法

调用方法 我们已经能通过Class实例获取所有Field对象,同样的,可以通过Class实例获取所有Method信息。Class类提供了以下几个方法来获取Method: Method getMethod(name, Class...):获取某个public的Method(包括父类&a…...

Java——单例设计模式

什么是设计模式? 设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格、以及解决问题的思考方式。设计模式免去我们自己再思考和摸索。就像是经典的棋谱,不同的棋局,我们用不同的棋谱、“套路”。 经典的设计模式共有23种。…...

Java实现excel表数据的批量存储(结合easyexcel插件)

场景:加哥最近在做项目时,苦于系统自身并未提供数据批量导入的功能还不能自行添加上该功能,且自身不想手动一条一条将数据录入系统。随后,自己使用JDBC连接数据库、使用EasyExcel插件读取表格并将数据按照业务逻辑批量插入数据库完…...

Config:客户端连接服务器访问远程

springcloud-config: springcloud-config push pom <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocatio…...

【KMP算法-代码随想录】

目录 1.什么是KMP2.什么是next数组3.什么是前缀表&#xff08;1&#xff09;前后缀含义&#xff08;2&#xff09;最长公共前后缀&#xff08;3&#xff09;前缀表的必要性 4.计算前缀表5.前缀表与next数组&#xff08;1&#xff09;使用next数组来匹配 6.构造next数组&#xf…...

【手写promise——基本功能、链式调用、promise.all、promise.race】

文章目录 前言一、前置知识二、实现基本功能二、实现链式调用三、实现Promise.all四、实现Promise.race总结 前言 关于动机&#xff0c;无论是在工作还是面试中&#xff0c;都会遇到Promise的相关使用和原理&#xff0c;手写Promise也有助于学习设计模式以及代码设计。 本文主…...

计算机网络-笔记-第二章-物理层

目录 二、第二章——物理层 1、物理层的基本概念 2、物理层下面的传输媒体 &#xff08;1&#xff09;光纤、同轴电缆、双绞线、电力线【导引型】 &#xff08;2&#xff09;无线电波、微波、红外线、可见光【非导引型】 &#xff08;3&#xff09;无线电【频谱的使用】 …...

前端开发中的单伪标签清除和双伪标签清除

引言 在前端开发中&#xff0c;我们经常会遇到一些样式上的问题&#xff0c;其中之一就是伪元素造成的布局问题。为了解决这个问题&#xff0c;我们可以使用伪标签清除技术。本篇博客将介绍单伪标签清除和双伪标签清除的概念、用法和示例代码&#xff0c;并详细解释它们的原理…...

云计算中的数据安全与隐私保护策略

文章目录 1. 云计算中的数据安全挑战1.1 数据泄露和数据风险1.2 多租户环境下的隔离问题 2. 隐私保护策略2.1 数据加密2.2 访问控制和身份验证 3. 应对方法与技术3.1 零知识证明&#xff08;Zero-Knowledge Proofs&#xff09;3.2 同态加密&#xff08;Homomorphic Encryption&…...

MacOS软件安装包分享(附安装教程)

目录 一、软件简介 二、软件下载 一、软件简介 MacOS是一种由苹果公司开发的操作系统&#xff0c;专门用于苹果公司的计算机硬件。它被广泛用于创意和专业应用程序&#xff0c;如图像设计、音频和视频编辑等。以下是关于MacOS的详细介绍。 1、MacOS的历史和演变 MacOS最初于…...

【linux进程概念】

目录&#xff1a; 冯诺依曼体系结构操作系统进程 基本概念描述进程-PCBtask_struct-PCB的一种task_ struct内容分类组织进程查看进程 fork()函数 冯诺依曼体系结构 我们常见的计算机&#xff0c;如笔记本。我们不常见的计算机&#xff0c;如服务器&#xff0c;大部分都遵守冯诺…...

直击成都国际车展:远航汽车多款车型登陆车展,打造完美驾乘体验

随着市场渗透率日益高涨&#xff0c;新能源汽车成为今年成都国际车展的关注焦点。在本届车展上&#xff0c;新能源品牌占比再创新高&#xff0c;覆盖两个展馆&#xff0c;印证了当下新能源汽车市场的火爆。作为大运集团重磅打造的高端品牌&#xff0c;远航汽车深度洞察高端智能…...

android nv21 转 yuv420sp

上面两个函数的目标都是将NV21格式的数据转换为YUV420P格式&#xff0c;但是它们在处理U和V分量的方式上有所不同。 在第一个函数NV21toYUV420P_1中&#xff0c;U和V分量的处理方式是这样的&#xff1a;对于U分量&#xff0c;它从NV21数据的Y分量之后的每个奇数位置取数据&…...

使用Nacos与Spring Boot实现配置管理

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…...

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式

一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明&#xff1a;假设每台服务器已…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

抽象类和接口(全)

一、抽象类 1.概念&#xff1a;如果⼀个类中没有包含⾜够的信息来描绘⼀个具体的对象&#xff0c;这样的类就是抽象类。 像是没有实际⼯作的⽅法,我们可以把它设计成⼀个抽象⽅法&#xff0c;包含抽象⽅法的类我们称为抽象类。 2.语法 在Java中&#xff0c;⼀个类如果被 abs…...

c# 局部函数 定义、功能与示例

C# 局部函数&#xff1a;定义、功能与示例 1. 定义与功能 局部函数&#xff08;Local Function&#xff09;是嵌套在另一个方法内部的私有方法&#xff0c;仅在包含它的方法内可见。 • 作用&#xff1a;封装仅用于当前方法的逻辑&#xff0c;避免污染类作用域&#xff0c;提升…...

Xela矩阵三轴触觉传感器的工作原理解析与应用场景

Xela矩阵三轴触觉传感器通过先进技术模拟人类触觉感知&#xff0c;帮助设备实现精确的力测量与位移监测。其核心功能基于磁性三维力测量与空间位移测量&#xff0c;能够捕捉多维触觉信息。该传感器的设计不仅提升了触觉感知的精度&#xff0c;还为机器人、医疗设备和制造业的智…...