在自定义数据集上微调Alpaca和LLaMA
本文将介绍使用LoRa在本地机器上微调Alpaca和LLaMA,我们将介绍在特定数据集上对Alpaca LoRa进行微调的整个过程,本文将涵盖数据处理、模型训练和使用流行的自然语言处理库(如Transformers和hugs Face)进行评估。此外还将介绍如何使用grado应用程序部署和测试模型。
配置
首先,alpaca-lora1 GitHub存储库提供了一个脚本(finetune.py)来训练模型。在本文中,我们将利用这些代码并使其在Google Colab环境中无缝地工作。
首先安装必要的依赖:
!pip install -U pip!pip install accelerate==0.18.0!pip install appdirs==1.4.4!pip install bitsandbytes==0.37.2!pip install datasets==2.10.1!pip install fire==0.5.0!pip install git+https://github.com/huggingface/peft.git!pip install git+https://github.com/huggingface/transformers.git!pip install torch==2.0.0!pip install sentencepiece==0.1.97!pip install tensorboardX==2.6!pip install gradio==3.23.0
安装完依赖项后,继续导入所有必要的库,并为matplotlib绘图配置设置:
import transformersimport textwrapfrom transformers import LlamaTokenizer, LlamaForCausalLMimport osimport sysfrom typing import Listfrom peft import (LoraConfig,get_peft_model,get_peft_model_state_dict,prepare_model_for_int8_training,)import fireimport torchfrom datasets import load_datasetimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplimport seaborn as snsfrom pylab import rcParams%matplotlib inlinesns.set(rc={'figure.figsize':(10, 7)})sns.set(rc={'figure.dpi':100})sns.set(style='white', palette='muted', font_scale=1.2)DEVICE = "cuda" if torch.cuda.is_available() else "cpu"DEVICE
数据
我们这里使用BTC Tweets Sentiment dataset4,该数据可在Kaggle上获得,包含大约50,000条与比特币相关的tweet。为了清理数据,删除了所有以“转发”开头或包含链接的推文。
使用Pandas来加载CSV:
df = pd.read_csv("bitcoin-sentiment-tweets.csv")df.head()
通过清理的数据集有大约1900条推文。
情绪标签用数字表示,其中-1表示消极情绪,0表示中性情绪,1表示积极情绪。让我们看看它们的分布:
df.sentiment.value_counts()# 0.0 860# 1.0 779# -1.0 258# Name: sentiment, dtype: int64
数据量差不多,虽然负面评论较少,但是可以简单的当成平衡数据来对待:
df.sentiment.value_counts().plot(kind='bar');
构建JSON数据集
原始Alpaca存储库中的dataset5格式由一个JSON文件组成,该文件具有具有指令、输入和输出字符串的对象列表。
让我们将Pandas的DF转换为一个JSON文件,该文件遵循原始Alpaca存储库中的格式:
def sentiment_score_to_name(score: float):if score > 0:return "Positive"elif score < 0:return "Negative"return "Neutral"dataset_data = [{"instruction": "Detect the sentiment of the tweet.","input": row_dict["tweet"],"output": sentiment_score_to_name(row_dict["sentiment"])}for row_dict in df.to_dict(orient="records")]dataset_data[0]
结果如下:
{"instruction": "Detect the sentiment of the tweet.","input": "@p0nd3ea Bitcoin wasn't built to live on exchanges.","output": "Positive"}
然后就是保存生成的JSON文件,以便稍后使用它来训练模型:
import jsonwith open("alpaca-bitcoin-sentiment-dataset.json", "w") as f:json.dump(dataset_data, f)
模型权重
虽然原始的Llama模型权重不可用,但它们被泄露并随后被改编用于HuggingFace Transformers库。我们将使用decapoda-research6:
BASE_MODEL = "decapoda-research/llama-7b-hf"model = LlamaForCausalLM.from_pretrained(BASE_MODEL,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",)tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)tokenizer.pad_token_id = (0 # unk. we want this to be different from the eos token)tokenizer.padding_side = "left"
这段代码使用来自Transformers库的LlamaForCausalLM类加载预训练的Llama 模型。load_in_8bit=True参数使用8位量化加载模型,以减少内存使用并提高推理速度。
代码还使用LlamaTokenizer类为同一个Llama模型加载标记器,并为填充标记设置一些附加属性。具体来说,它将pad_token_id设置为0以表示未知的令牌,并将padding_side设置为“left”以填充左侧的序列。
数据集加载
现在我们已经加载了模型和标记器,下一步就是加载之前保存的JSON文件,使用HuggingFace数据集库中的load_dataset()函数:
data = load_dataset("json", data_files="alpaca-bitcoin-sentiment-dataset.json")data["train"]
结果如下:
Dataset({features: ['instruction', 'input', 'output'],num_rows: 1897})
接下来,我们需要从加载的数据集中创建提示并标记它们:
def generate_prompt(data_point):return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501### Instruction:{data_point["instruction"]}### Input:{data_point["input"]}### Response:{data_point["output"]}"""def tokenize(prompt, add_eos_token=True):result = tokenizer(prompt,truncation=True,max_length=CUTOFF_LEN,padding=False,return_tensors=None,)if (result["input_ids"][-1] != tokenizer.eos_token_idand len(result["input_ids"]) < CUTOFF_LENand add_eos_token):result["input_ids"].append(tokenizer.eos_token_id)result["attention_mask"].append(1)result["labels"] = result["input_ids"].copy()return resultdef generate_and_tokenize_prompt(data_point):full_prompt = generate_prompt(data_point)tokenized_full_prompt = tokenize(full_prompt)return tokenized_full_prompt
第一个函数generate_prompt从数据集中获取一个数据点,并通过组合指令、输入和输出值来生成提示。第二个函数tokenize接收生成的提示,并使用前面定义的标记器对其进行标记。它还向输入序列添加序列结束标记,并将标签设置为与输入序列相同。第三个函数generate_and_tokenize_prompt结合了前两个函数,生成并标记提示。
数据准备的最后一步是将数据集分成单独的训练集和验证集:
train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)train_data = (train_val["train"].map(generate_and_tokenize_prompt))val_data = (train_val["test"].map(generate_and_tokenize_prompt))
我们还需要数据进行打乱,并且获取200个样本作为验证集。generate_and_tokenize_prompt()函数应用于训练和验证集中的每个示例,生成标记化的提示。
训练
训练过程需要几个参数,这些参数主要来自原始存储库中的微调脚本:
LORA_R = 8LORA_ALPHA = 16LORA_DROPOUT= 0.05LORA_TARGET_MODULES = ["q_proj","v_proj",]BATCH_SIZE = 128MICRO_BATCH_SIZE = 4GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZELEARNING_RATE = 3e-4TRAIN_STEPS = 300OUTPUT_DIR = "experiments"
下面就可以为训练准备模型了:
model = prepare_model_for_int8_training(model)config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,target_modules=LORA_TARGET_MODULES,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM",)model = get_peft_model(model, config)model.print_trainable_parameters()#trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199
我们使用LORA算法初始化并准备模型进行训练,通过量化可以减少模型大小和内存使用,而不会显着降低准确性。
LoraConfig7是一个为LORA算法指定超参数的类,例如正则化强度(lora_alpha)、dropout概率(lora_dropout)和要压缩的目标模块(target_modules)。
然后就可以直接使用Transformers库进行训练:
training_arguments = transformers.TrainingArguments(per_device_train_batch_size=MICRO_BATCH_SIZE,gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,warmup_steps=100,max_steps=TRAIN_STEPS,learning_rate=LEARNING_RATE,fp16=True,logging_steps=10,optim="adamw_torch",evaluation_strategy="steps",save_strategy="steps",eval_steps=50,save_steps=50,output_dir=OUTPUT_DIR,save_total_limit=3,load_best_model_at_end=True,report_to="tensorboard")
这段代码创建了一个TrainingArguments对象,该对象指定用于训练模型的各种设置和超参数。这些包括:
- gradient_accumulation_steps:在执行向后/更新之前累积梯度的更新步数。
- warmup_steps:优化器的预热步数。
- max_steps:要执行的训练总数。
- learning_rate:学习率。
- fp16:使用16位精度进行训练。
DataCollatorForSeq2Seq是transformer库中的一个类,它为序列到序列(seq2seq)模型创建一批输入/输出序列。在这段代码中,DataCollatorForSeq2Seq对象用以下参数实例化:
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
pad_to_multiple_of:表示最大序列长度的整数,四舍五入到最接近该值的倍数。
padding:一个布尔值,指示是否将序列填充到指定的最大长度。
以上就是训练的所有代码准备,下面就是训练了
trainer = transformers.Trainer(model=model,train_dataset=train_data,eval_dataset=val_data,args=training_arguments,data_collator=data_collator)model.config.use_cache = Falseold_state_dict = model.state_dictmodel.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(model, type(model))model = torch.compile(model)trainer.train()model.save_pretrained(OUTPUT_DIR)
在实例化训练器之后,代码在模型的配置中将use_cache设置为False,并使用get_peft_model_state_dict()函数为模型创建一个state_dict,该函数为使用低精度算法进行训练的模型做准备。
然后在模型上调用torch.compile()函数,该函数编译模型的计算图并准备使用PyTorch 2进行训练。
训练过程在A100上持续了大约2个小时。我们看一下Tensorboard上的结果:
训练损失和评估损失呈稳步下降趋势。看来我们的微调是有效的。
如果你想将模型上传到Hugging Face上,可以使用下面代码,
from huggingface_hub import notebook_loginnotebook_login()model.push_to_hub("curiousily/alpaca-bitcoin-tweets-sentiment", use_auth_token=True)
推理
我们可以使用generate.py脚本来测试模型:
!git clone https://github.com/tloen/alpaca-lora.git%cd alpaca-lora!git checkout a48d947
我们的脚本启动的gradio应用程序
!python generate.py \--load_8bit \--base_model 'decapoda-research/llama-7b-hf' \--lora_weights 'curiousily/alpaca-bitcoin-tweets-sentiment' \--share_gradio
简单的界面如下:
总结
我们已经成功地使用LoRa方法对Llama 模型进行了微调,还演示了如何在Gradio应用程序中使用它。
如果你对本文感兴趣,请看原文:
https://avoid.overfit.cn/post/34b6eaf7097a4929b9aab7809f3cfeaa
相关文章:

在自定义数据集上微调Alpaca和LLaMA
本文将介绍使用LoRa在本地机器上微调Alpaca和LLaMA,我们将介绍在特定数据集上对Alpaca LoRa进行微调的整个过程,本文将涵盖数据处理、模型训练和使用流行的自然语言处理库(如Transformers和hugs Face)进行评估。此外还将介绍如何使用grado应用程序部署和…...

Python 实现接口类的两种方式+邮件提醒+动态导入模块+反射(参考Django中间件源码)
实现抽象类的两种方式 方式一 from abc import ABCMeta from abc import abstractmethodclass BaseMessage(metaclassABCMeta):abstractmethoddef send(self,subject,body,to,name):pass 方式二 class BaseMessage(object):def send(self, subject, body, to, name):raise …...

Solr原理剖析
一、简介 Solr是一个高性能、基于Lucene的全文检索服务器。Solr对Lucene进行了扩展,提供了比Lucene更为丰富的查询语言,并实现了强大的全文检索功能、高亮显示、动态集群,具有高度的可扩展性。同时从Solr 4.0版本开始,支持SolrCl…...

解决 “无法将 ‘npm‘ 项识别为 cmdlet、函数、脚本文件或可运行程序的名称“ 错误的方法
系列文章目录 文章目录 系列文章目录前言一、错误原因:二、解决方法:三、注意事项:总结 前言 在使用 npm 进行前端项目开发时,有时会遇到错误信息 “无法将 ‘npm’ 项识别为 cmdlet、函数、脚本文件或可运行程序的名称”&#x…...

Python 电商API 开发最佳实践
一、简介 当你打卡了一家北京最具有地中海特色的餐厅,当我们在餐厅点餐时,服务员会给我们一份菜单,菜单上列出了所有可供选择的菜品和饮料。我们可以在菜单上选择我们想要的食物和饮料,然后告诉服务员我们的选择。服务员会根据我…...

JAVA基础-集合(List与Map)
目录 引言 一,Collection集合 1.1,List接口 1.1.1,ArrayList 1.1.1.1,ArrayList的add()添加方法 1.1.1.2,ArrayList的remove()删除方法 1.1.1.3,ArrayList的contai…...

19 QListWidget控件
Tips: 对于列表式数据可以使用QStringList进行左移一块输入。 代码: //listWidget使用 // QListWidgetItem * item new QListWidgetItem("锄禾日当午"); // QListWidgetItem * item2 new QListWidgetItem("汗滴禾下土"); // ui->…...

手动安装docsify
安装docsify详见:docsify 1、下载 wget https://codeload.github.com/docsifyjs/docsify/zip/refs/heads/master -o docsify-master.zip 2、解压 unzip docsify-master.zip 3、移动文件到nginx的html所在目录【略】 4、配置nginx,示例如下 locati…...

yaml语法详解
#kv #对空格的严格要求十分高 #注入到我们的配置类中 #普通的keyvalue name: qinjiang#对象 student:name: qingjiangage: 3#行内写法 student1: {name: qinjiang,age: 3}#数组 pets:- cat- dog- pigpet: [cat,dog,pig]yaml可以给实体类赋值 person:name: kuangshenage: 19happ…...

ubuntu下tmux安装
目录 0. 前言1. Tmux介绍2. 安装3. 验证安装 0. 前言 本节安装tmux终端复用工具,在Ubuntu中运行一些服务或脚本的时候往往不能退出终端,需要一直挂着。在有图形界面的linux中你还可以新开一个终端去做别的事,但是在无界面linux中,…...

ssh打开远程vscode
如果想要远程打开其他终端的vscode,首先要知道远程终端的ip地址和用户名称以及用户密码 1、打开本地vscode 2、点击左下角蓝色区域 3、页面上部出现如下图,点击ssh,我这里已经连接,所以是connect to host 4、选择Add New SSH Host…...

Socket发送数据---winsock库和boost库
一个是通过winsock库提供的api实现,一个是boost库实现,两个方法都可以,因为项目是vc++6.0实现的,不支持boost库,只能使用winsock库,vc++6.0太老,局限性大。 通过Winsock库提供的API 通过UDP #include<winsock2.h> #include<vector> #include<WS2tcpip.h…...

Qt Core学习日记——第七天QMetaObject(上)
每一个声明Q_OBJECT的类都具有QMetaObject对象 Q_OBJECT宏源代码: #define Q_OBJECT \ public: \ QT_WARNING_PUSH \ Q_OBJECT_NO_OVERRIDE_WARNING \ static const QMetaObject staticMetaObject; \ virtual const QMetaObject *metaObject() const; \ vir…...

100、用简洁的语言描述一下:TCP的三次握手和四次挥手(不需要长篇大论)
TCP的三次握手和四次挥手 TCP协议是7层网络协议中的传输层协议,负责数据的可靠传输。 1、三次握手 在建立TCP连接时,需要通过三次握手来建立,过程是: 客户端向服务端发送一个SYN服务端接收到SYN后,给客户端发送一个SYN_ACK客户…...

中南大学硕士论文latex版本全指导
要毕业了,闲下点时间写的东西。之前一直收益与师兄师姐流传下来的latex版本,用起来很舒服,希望后面的学弟学妹也能完美用上。latex功能很强大,不需要自己排版,只管内容即可,但是安装流程会多一丢丢。 目录 …...

RFC8470在HTTP中使用早期数据
摘要 使用TLS早期数据会暴露出重放攻击的可能性。本文定义了允许客户端与服务器就早期数据中发送的HTTP请求进行通信的机制。描述了使用这些机制来减轻重放风险的技术。 1. 介绍 TLS 1.3[TLS13]引入了早期数据(也称为零往返时间(0-RTT)数…...

macOS Big Sur 11.7.9 (20G1426) 正式版 ISO、PKG、DMG、IPSW 下载
macOS Big Sur 11.7.9 (20G1426) 正式版 ISO、PKG、DMG、IPSW 下载 本站下载的 macOS 软件包,既可以拖拽到 Applications(应用程序)下直接安装,也可以制作启动 U 盘安装,或者在虚拟机中启动安装。另外也支持在 Window…...

【LeetCode】62.不同路径
题目 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” )。 问总共有多少条不同的路径? …...

使用序列化和反序列化函数archivedDataWithRootObject和unarchivedObjectOfClasses的使用和遇到问题及解决方案
为何archiveRootObject和unarchiveObjectWithFile正常,而archivedDataWithRootObject和unarchivedObjectOfClasses一直报错。 [NSKeyedArchiver archiveRootObject:account toFile:path];和c PPAccountModel *account [NSKeyedUnarchiver unarchiveObjectWithFile:…...

python获取鼠标出颜色
import pyautogui as pg import keyboarddef rgb2hex(r, g, b):return #{:02x}{:02x}{:02x}.format(r, g, b)try:width, height pg.size()print(f"Display resolution: {width} * {height}\n") # 打印屏幕分辨率print(按下shift键打印出鼠标所指位置的颜色......)w…...

Github Flow工作流简单介绍(以部署为中心的开发模式)
前言 这篇文章主要介绍Github Flow的理念,以下内容来源于《Github入门与实践》。 Github Flow是以部署为中心的开发模式,通过简单的规则,持续高速且安全地进行部署。而Gitflow则是以发布为中心的分支管理模型,它提供了一种更灵活…...

selenium浏览器驱动下载
Chrome谷歌浏览器 下载地址:http://chromedriver.storage.googleapis.com/index.html 不同的Chrome的版本对应的chromedriver.exe 版本也不一样,下载时不要搞错了。 如果是最新的Chrome, 下载最新的chromedriver.exe 就可以了。 Firefox火狐浏览器 驱…...

go学习 模块与包 - Init函数 - 如何导入第三方包 - 切片与数组的数据传递方式 - go中文件的读写
目录 包(package)是组织和复用代码的基本单元。 包的种类: 包的导入 包的组成 如下两个文件中定义了A变量和 sc_num变量,他们的首字母开头分别为大写和小写,因此可以说明A变量是公有变量,而sc_num是私…...

2023第五届全国生物资源提取与应用创新论坛即将举办
01、会议背景 为进一步加强生物资源提取行业交流与合作,促进业“产学研用”融合,提升行业科技创新水平,增强行业国际竞争力,中国生物发酵产业协会、浙江科技学院、浙江工业职业技术学院、浙江省农业生物资源生化制造协同创新中心&…...

Socks5代理在爬虫与HTTP应用中的重要性
IP代理的类型及原理常见的IP代理类型有HTTP代理、Socks代理等,本文重点关注Socks5代理。Socks5代理是一种网络协议,可以实现传输层的数据转发,使客户端在不直接连接服务器的情况下与其进行通信。其原理在于接收客户端的请求,然后将…...

二叉树详解
这里写目录标题 前言树型结构(了解)树常见的概念树的表示形式(了解)树的应用 二叉树概念两种特殊的二叉树二叉树的性质(重要)二叉树的存储二叉树的基本操作 前言 本篇博客讲述了以下几个知识点 树的基本概念二叉树概念及特性二叉树的基本操作 树型结构…...

Git的核心概念:探索Git中的提交、分支、合并、标签等核心概念,深入理解其作用和使用方法
🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~ἳ…...

JAVA设计模式——23种设计模式详解
一、什么是设计模式🍉 设计模式(Design pattern) 是解决软件开发某些特定问题而提出的一些解决方案也可以理解成解决问题的一些思路。通过设计模式可以帮助我们增强代码的可重用性、可扩充性、 可维护性、灵活性好。我们使用设计模式最终的目…...

Oracle输出文本平面(CSV、XML)文本数据详细过程
此过程是提供给前端,调用的接口,为报表提供”下载“功能。以下是本人在测试环境的测试,有什么不足的地方,请留言指教,谢谢。 1、测试表 分别对测试表输出csv、xml两种格式文件数据。前期的准备工作。 --在服务器端创建directory,用管理员用户 create or replace directo…...

基于C++的QT基础教程学习笔记
文章目录: 来源 教程社区 一:QT下载安装 二:注意事项 1.在哪里写程序 2.如何看手册 3.技巧 三:常用函数 1.窗口 2.相关 3.按钮 4.信号与槽函数 5.常用栏 菜单栏 工具栏 状态栏 6.铆接部件 7.文本编辑 8…...