huggingface NLP-微调一个预训练模型
微调一个预训练模型
1 预处理数据
1.1 处理数据
1.1.1 fine-tune
使用tokenizer后的token 进行训练
batch = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")# This is new
batch["labels"] = torch.tensor([1, 1])optimizer = AdamW(model.parameters())
loss = model(**batch).loss
loss.backward()
optimizer.step()
1.2 从模型中心(Hub)加载数据集
1.2.1 数据集
DatasetDict对象,其中包含训练集、验证集和测试集
。每一个集合都包含几个列(sentence1, sentence2, label, and idx)以及一个代表行数的变量,即每个集合中的行的个数
下载数据集并缓存到 ~/.cache/huggingface/datasets. 回想一下第2章,您可以通过设置HF_HOME环境变量来自定义缓存的文件夹。
1.3 预处理数据集
1.3.1 预处理数据集,我们需要将文本转换为模型能够理解的数字
1.3.2 类型标记ID(token_type_ids)的作用就是告诉模型输入的哪一部分是第一句,哪一部分是第二句
1.3.3 不一定具有类型标记ID(token_type_ids
1.3.4 将数据保存为数据集,我们将使用Dataset.map()
调用map时使用了batch =True,这样函数就可以同时应用到数据集的多个元素上,而不是分别应用到每个元素上。这将使我们的预处理快许多
1.3.5 省略padding参数
在标记的时候将所有样本填充到最大长度的效率不高
一个更好的做法:在构建批处理时填充样本更好,因为这样我们只需要填充到该批处理中的最大长度,而不是整个数据集的最大长度。当输入长度变化很大时,这可以节省大量时间和处理能力!
1.3.6 将所有示例填充到最长元素的长度——我们称之为动态填充
为了解决句子长度统一的问题,我们必须定义一个collate函数,该函数会将每个batch句子填充到正确的长度
transformer库通过DataCollatorWithPadding为我们提供了这样一个函数
It’s when you pad your inputs when the batch is created, to the maximum length of the sentences inside that batch.
1.3.7 数据集是以Apache Arrow文件存储在磁盘上
1.4 benefits
1.4.1 The results of the function are cached, so it won’t take any time if we re-execute the code.
1.4.2 It can apply multiprocessing to go faster than applying the function on each element of the dataset.
1.4.3 It does not load the whole dataset into memory, saving the results as soon as one element is processed.
2 使用 Trainer API 微调模型(非并行批量模式)
2.1 TrainingArguments 类
2.1.1 它将包含 Trainer用于训练和评估的所有超参数
2.1.2 可以只调整部分默认参数进行微调
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
2.2 Training
2.2.1 简单training
trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,tokenizer=tokenizer,
)
trainer.train()
2.3 评估
2.3.1 使用 Trainer.predict() 命令来使用我们的模型进行预测
predict() 的输出结果是具有三个字段的命名元组: predictions , label_ids , 和 metrics
metrics 字段将只包含传递的数据集的loss,以及一些运行时间(预测所需的总时间和平均时间)
要将我们的预测的可以与真正的标签进行比较,我们需要在第二个轴上取最大值的索引
preds = np.argmax(predictions.predictions, axis=-1)
def compute_metrics(eval_preds):metric = evaluate.load("glue", "mrpc")logits, labels = eval_predspredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)
如何使用compute_metrics()函数定义一个新的 Trainer
2.4 AutoModelForSequenceClassification
2.4.1 when we used AutoModelForSequenceClassification with bert-base-uncased, we got warnings when instantiating the model. The pretrained head is not used for the sequence classification task, so it’s discarded and a new head is instantiated(实例化) with random weights.
3 完整的训练,使用Accelerator和scheduler
3.1 训练前的数据准备
3.1.1 删除与模型不期望的值相对应的列(如sentence1和sentence2列)。
3.1.2 将列名label重命名为labels(因为模型期望参数是labels)。
3.1.3 设置数据集的格式,使其返回 PyTorch 张量而不是列表。
3.1.4 代码
tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names
3.2 data loader
3.2.1 from torch.utils.data import DataLoader
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)
3.3 优化器和学习率调度器
3.3.1 optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler("linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,
)
3.4 训练循环
3.4.1
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
3.5 使用accelerator
3.5.1
train_dl, eval_dl, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs = 3
num_training_steps = num_epochs * len(train_dl)
3.5.2 分布式
accelerate config
accelerate launch train.py
3.5.3 With 🤗Accelerate, your training loops will work for multiple GPUs and TPUs.
4 注意train只会对参数有调整
4.1 超参数(Hyperparameter),是机器学习算法中的调优参数,用于控制模型的学习过程和结构。 与模型参数(Model Parameter)不同,模型参数是在训练过程中通过数据学习得到的,而超参数是在训练之前由开发者或实践者直接设定的,并且在训练过程中保持不变。
4.2 需要自己设定,不是机器自己找出来的,称为超参数(hyperparameter)。
4.2.1 需要人工设置: 超参数的值不是通过训练过程自动学习得到的,而是需要训练者根据经验或实验来设定。
4.2.2 影响模型性能: 超参数的选择会直接影响模型的训练过程和最终性能。
4.2.3 需要优化: 为了获得更好的模型性能,通常需要对超参数进行优化,选择最优的超参数组合。
4.3 validated 数据集作用
4.3.1 验证集,用于挑选超参数的数据子集。
相关文章:

huggingface NLP-微调一个预训练模型
微调一个预训练模型 1 预处理数据 1.1 处理数据 1.1.1 fine-tune 使用tokenizer后的token 进行训练 batch tokenizer(sequences, paddingTrue, truncationTrue, return_tensors"pt")# This is new batch["labels"] torch.tensor([1, 1])optimizer A…...

【BUG记录】Apifox 参数传入 + 号变成空格的 BUG
文章目录 1. 问题描述2. 原因2.1 编码2.2 解码 3. 解决方法 1. 问题描述 之前写了一个接口,用 Apifox 请求,参数传入一个 86 的电话,结果到服务器 就变成空格了。 Java 接收请求的接口: 2. 原因 2.1 编码 进行 URL 请求的…...

Spring AI API 介绍
目录: Spring AI 框架介绍 Spring AI API 核心API简介 Spring AI 提供了很多便利的功能,主要如下: AI Model API “Model API” 提供了聊天、文本转图像、音频转录、文本转语音、嵌入等功能,且不局限于某个固定的大模型提供商…...

【MySQL】Linux使用C语言连接安装
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...

2024年第十五届蓝桥杯青少组C++国赛—割点
割点 题目描述 一张棋盘由n行 m 列的网格矩阵组成,每个网格中最多放一颗棋子。当前棋盘上已有若干棋子。所有水平方向或竖直方向上相邻的棋子属于同一连通块。 现给定棋盘上所有棋子的位置,如果要使棋盘上出现两个及以上的棋子连通块,请问…...
【软件开发】做出技术决策
文章目录 专注于核心业务除非绝对必要,不要重写代码保持技术栈简单尽量减少依赖避免范围蔓延按照业务实际情况确定优先级在做出高风险决策前构建原型跨职能团队协作信任你的团队在过去的二十年里,我曾在多家初创企业担任软件开发人员、技术负责人以及首席技术官(包括创办自己…...

Airborne使用教程
1.安装环境 前提条件:系统已安装Ruby 打开终端输入如下命令 gem install airborne 或者在Gemfile添加 gem airborne 然后运行bundle install 2.编写脚本 在项目中新建api_tests_spec.rb文件 以GET接口"https://www.thunderclient.com/welcome"为…...

WPF实现曲线数据展示【案例:震动数据分析】
wpf实现曲线数据展示,函数曲线展示,实例:震动数据分析为例。 如上图所示,如果你想实现上图中的效果,请详细参考我的内容,创作不易,给个赞吧。 一共有两种方式来实现,一种是使用第三…...

EasyExcel 动态设置表格的背景颜色和排列
项目中使用EasyExcel把数据以excel格式导出,其中设置某一行、某一列单元格的背景颜色、排列方式十分常用,记录下来方便以后查阅。 1. 导入maven依赖: <dependency><groupId>com.alibaba</groupId><artifactId>easy…...
【 C++11 】类的新功能
C类的新功能 一、默认成员函数二、类成员变量初始化三、default关键字四、delete关键字六、final关键字七、override关键字 一、默认成员函数 八个默认成员函数 在C11之前,一个类中有如下六个默认成员函数: 构造函数。析构函数。拷贝构造函数。拷贝赋值…...
防止SQL注入:PHP安全最佳实践
防止SQL注入:PHP安全最佳实践 SQL注入是一种常见的网络攻击方式,攻击者通过向应用程序的SQL查询中插入恶意代码,来获取、操控或破坏数据库中的数据。为了保护PHP应用免受SQL注入攻击,开发者需要遵循一系列安全最佳实践。本文将介…...
自动化生产或质量检测准备工作杂记
自动化生产或质量检测一个流程是: 上料位上料: “上料位”指的是物料被放置以供机器或设备处理的位置。“上料”指的是将物料从存储位置移动到加工或检测位置的过程。移动到对位相机位置: “对位相机”是一种高精度相机,用于精确…...
张志辰医生
在医学领域,北京中医药大学东方医院的张志辰副主任医师宛如一颗璀璨的明星。自 2011 年于北京中医药大学获取博士学位后,他便扎根临床一线,以精湛医术和仁心仁术,为众多患者排忧解难 张志辰曾先后前往北京天坛医院、广东中山医院…...

CodeMirror 如何动态更新definemode
CodeMirror 如何动态更新definemode 问题描述:解决方法: 问题描述: 项目中有一部分用到了CodeMirror组件,其高亮显示的内容需要根据最新的json动态的更新,需要使用definemode自定义高亮内容。 想要的效果如下…...

舵机SG90详解
舵机,也叫伺服电机,在嵌入式开发中,舵机作为一种常见的运动控制组件,具有广泛的应用。其中,SG90 舵机以其高效、稳定的性能特点,成为了许多工程师和爱好者的首选,无论是航模、云台、机器人、智能…...

程序设计考题汇总(四:SQL练习)
文章目录 查询结果限制返回行数 查询结果限制返回行数 select device_id from user_profile LIMIT 2;...

明达IOT平台助力工业废水运维智能化
背景简介 相较于生活污水,工业废水的处理挑战性更高,原因在于其源于多样化的工业生产流程,成分复杂且多变,可能包含重金属、有毒化学…...
深入理解 Ansible Playbook:组件与实战
目录 1 playbook介绍 2 YAML语言 2.1语法简介 2.2数据类型 3 Playbook核心组件 3.1 hosts组件 3.2 remote_user组件 3.3 task列表和action组件 3.4 handlers 3.5 tags组件 3.6 其他组件说明 1 playbook介绍 playbook 剧本是由一个或多个"play"组成的列表。…...

JavaEE初阶——多线程(线程安全-锁)
复习上节内容(部分-掌握程度不够的) 加锁,解决线程安全问题。 synchronized关键字,对锁对象进行加锁。 锁对象,可以是随便一个Object对象(或者其子类的对象),需要关注的是ÿ…...

Stable Diffusion 提示词语法
1.提示词基础 1.提示词之间用英文逗号,分隔 2.提示词之间是可以换行的 3.权重默认为1,越靠前权重越高 4.数量控制在75个单位以内 2.提示词各种符号的意义 2.1 ()、[]、{}符号 权重值()小括号[]中括号{}大括号默认1111层()1.1[]0.9{}1.052层(()) 1.121.21[[]]0.920.81{{}}1.…...

Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...

对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...

React19源码系列之 事件插件系统
事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...
【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】
1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件(System Property Definition File),用于声明和管理 Bluetooth 模块相…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
【JavaSE】绘图与事件入门学习笔记
-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角,以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向,距离坐标原点x个像素;第二个是y坐标,表示当前位置为垂直方向,距离坐标原点y个像素。 坐标体系-像素 …...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...