当前位置: 首页 > news >正文

huggingface NLP-微调一个预训练模型

微调一个预训练模型
在这里插入图片描述

1 预处理数据

1.1 处理数据
1.1.1 fine-tune
使用tokenizer后的token 进行训练

batch = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")# This is new
batch["labels"] = torch.tensor([1, 1])optimizer = AdamW(model.parameters())
loss = model(**batch).loss
loss.backward()
optimizer.step()

1.2 从模型中心(Hub)加载数据集
1.2.1 数据集
DatasetDict对象,其中包含训练集、验证集和测试集
。每一个集合都包含几个列(sentence1, sentence2, label, and idx)以及一个代表行数的变量,即每个集合中的行的个数
下载数据集并缓存到 ~/.cache/huggingface/datasets. 回想一下第2章,您可以通过设置HF_HOME环境变量来自定义缓存的文件夹。
1.3 预处理数据集
1.3.1 预处理数据集,我们需要将文本转换为模型能够理解的数字
1.3.2 类型标记ID(token_type_ids)的作用就是告诉模型输入的哪一部分是第一句,哪一部分是第二句
1.3.3 不一定具有类型标记ID(token_type_ids
1.3.4 将数据保存为数据集,我们将使用Dataset.map()
调用map时使用了batch =True,这样函数就可以同时应用到数据集的多个元素上,而不是分别应用到每个元素上。这将使我们的预处理快许多
1.3.5 省略padding参数
在标记的时候将所有样本填充到最大长度的效率不高
一个更好的做法:在构建批处理时填充样本更好,因为这样我们只需要填充到该批处理中的最大长度,而不是整个数据集的最大长度。当输入长度变化很大时,这可以节省大量时间和处理能力!
1.3.6 将所有示例填充到最长元素的长度——我们称之为动态填充
为了解决句子长度统一的问题,我们必须定义一个collate函数,该函数会将每个batch句子填充到正确的长度
transformer库通过DataCollatorWithPadding为我们提供了这样一个函数
It’s when you pad your inputs when the batch is created, to the maximum length of the sentences inside that batch.
1.3.7 数据集是以Apache Arrow文件存储在磁盘上
1.4 benefits
1.4.1 The results of the function are cached, so it won’t take any time if we re-execute the code.
1.4.2 It can apply multiprocessing to go faster than applying the function on each element of the dataset.
1.4.3 It does not load the whole dataset into memory, saving the results as soon as one element is processed.

2 使用 Trainer API 微调模型(非并行批量模式)
2.1 TrainingArguments 类
2.1.1 它将包含 Trainer用于训练和评估的所有超参数
2.1.2 可以只调整部分默认参数进行微调

training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")

2.2 Training
2.2.1 简单training

trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,tokenizer=tokenizer,
)
trainer.train()

2.3 评估
2.3.1 使用 Trainer.predict() 命令来使用我们的模型进行预测
predict() 的输出结果是具有三个字段的命名元组: predictions , label_ids , 和 metrics
metrics 字段将只包含传递的数据集的loss,以及一些运行时间(预测所需的总时间和平均时间)
要将我们的预测的可以与真正的标签进行比较,我们需要在第二个轴上取最大值的索引

preds = np.argmax(predictions.predictions, axis=-1)
def compute_metrics(eval_preds):metric = evaluate.load("glue", "mrpc")logits, labels = eval_predspredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)

如何使用compute_metrics()函数定义一个新的 Trainer
2.4 AutoModelForSequenceClassification
2.4.1 when we used AutoModelForSequenceClassification with bert-base-uncased, we got warnings when instantiating the model. The pretrained head is not used for the sequence classification task, so it’s discarded and a new head is instantiated(实例化) with random weights.

3 完整的训练,使用Accelerator和scheduler

3.1 训练前的数据准备
3.1.1 删除与模型不期望的值相对应的列(如sentence1和sentence2列)。
3.1.2 将列名label重命名为labels(因为模型期望参数是labels)。
3.1.3 设置数据集的格式,使其返回 PyTorch 张量而不是列表。
3.1.4 代码

tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

3.2 data loader
3.2.1 from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

3.3 优化器和学习率调度器
3.3.1 optimizer = AdamW(model.parameters(), lr=5e-5)

lr_scheduler = get_scheduler("linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,
)

3.4 训练循环
3.4.1

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

3.5 使用accelerator
3.5.1

train_dl, eval_dl, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs = 3
num_training_steps = num_epochs * len(train_dl)

3.5.2 分布式
accelerate config
accelerate launch train.py
3.5.3 With 🤗Accelerate, your training loops will work for multiple GPUs and TPUs.

4 注意train只会对参数有调整

4.1 超参数(Hyperparameter),是机器学习算法中的调优参数,用于控制模型的学习过程和结构。 与模型参数(Model Parameter)不同,模型参数是在训练过程中通过数据学习得到的,而超参数是在训练之前由开发者或实践者直接设定的,并且在训练过程中保持不变。
4.2 需要自己设定,不是机器自己找出来的,称为超参数(hyperparameter)。
4.2.1 需要人工设置: 超参数的值不是通过训练过程自动学习得到的,而是需要训练者根据经验或实验来设定。
4.2.2 影响模型性能: 超参数的选择会直接影响模型的训练过程和最终性能。
4.2.3 需要优化: 为了获得更好的模型性能,通常需要对超参数进行优化,选择最优的超参数组合。
4.3 validated 数据集作用
4.3.1 验证集,用于挑选超参数的数据子集。
4.3.2

相关文章:

huggingface NLP-微调一个预训练模型

微调一个预训练模型 1 预处理数据 1.1 处理数据 1.1.1 fine-tune 使用tokenizer后的token 进行训练 batch tokenizer(sequences, paddingTrue, truncationTrue, return_tensors"pt")# This is new batch["labels"] torch.tensor([1, 1])optimizer A…...

【BUG记录】Apifox 参数传入 + 号变成空格的 BUG

文章目录 1. 问题描述2. 原因2.1 编码2.2 解码 3. 解决方法 1. 问题描述 之前写了一个接口,用 Apifox 请求,参数传入一个 86 的电话,结果到服务器 就变成空格了。 Java 接收请求的接口: 2. 原因 2.1 编码 进行 URL 请求的…...

Spring AI API 介绍

目录: Spring AI 框架介绍 Spring AI API 核心API简介 Spring AI 提供了很多便利的功能,主要如下: AI Model API “Model API” 提供了聊天、文本转图像、音频转录、文本转语音、嵌入等功能,且不局限于某个固定的大模型提供商…...

【MySQL】Linux使用C语言连接安装

📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...

2024年第十五届蓝桥杯青少组C++国赛—割点

割点 题目描述 一张棋盘由n行 m 列的网格矩阵组成,每个网格中最多放一颗棋子。当前棋盘上已有若干棋子。所有水平方向或竖直方向上相邻的棋子属于同一连通块。 现给定棋盘上所有棋子的位置,如果要使棋盘上出现两个及以上的棋子连通块,请问…...

【软件开发】做出技术决策

文章目录 专注于核心业务除非绝对必要,不要重写代码保持技术栈简单尽量减少依赖避免范围蔓延按照业务实际情况确定优先级在做出高风险决策前构建原型跨职能团队协作信任你的团队在过去的二十年里,我曾在多家初创企业担任软件开发人员、技术负责人以及首席技术官(包括创办自己…...

Airborne使用教程

1.安装环境 前提条件:系统已安装Ruby 打开终端输入如下命令 gem install airborne 或者在Gemfile添加 gem airborne 然后运行bundle install 2.编写脚本 在项目中新建api_tests_spec.rb文件 以GET接口"https://www.thunderclient.com/welcome"为…...

WPF实现曲线数据展示【案例:震动数据分析】

wpf实现曲线数据展示,函数曲线展示,实例:震动数据分析为例。 如上图所示,如果你想实现上图中的效果,请详细参考我的内容,创作不易,给个赞吧。 一共有两种方式来实现,一种是使用第三…...

EasyExcel 动态设置表格的背景颜色和排列

项目中使用EasyExcel把数据以excel格式导出&#xff0c;其中设置某一行、某一列单元格的背景颜色、排列方式十分常用&#xff0c;记录下来方便以后查阅。 1. 导入maven依赖&#xff1a; <dependency><groupId>com.alibaba</groupId><artifactId>easy…...

【 C++11 】类的新功能

C类的新功能 一、默认成员函数二、类成员变量初始化三、default关键字四、delete关键字六、final关键字七、override关键字 一、默认成员函数 八个默认成员函数 在C11之前&#xff0c;一个类中有如下六个默认成员函数&#xff1a; 构造函数。析构函数。拷贝构造函数。拷贝赋值…...

防止SQL注入:PHP安全最佳实践

防止SQL注入&#xff1a;PHP安全最佳实践 SQL注入是一种常见的网络攻击方式&#xff0c;攻击者通过向应用程序的SQL查询中插入恶意代码&#xff0c;来获取、操控或破坏数据库中的数据。为了保护PHP应用免受SQL注入攻击&#xff0c;开发者需要遵循一系列安全最佳实践。本文将介…...

自动化生产或质量检测准备工作杂记

自动化生产或质量检测一个流程是&#xff1a; 上料位上料&#xff1a; “上料位”指的是物料被放置以供机器或设备处理的位置。“上料”指的是将物料从存储位置移动到加工或检测位置的过程。移动到对位相机位置&#xff1a; “对位相机”是一种高精度相机&#xff0c;用于精确…...

张志辰医生

在医学领域&#xff0c;北京中医药大学东方医院的张志辰副主任医师宛如一颗璀璨的明星。自 2011 年于北京中医药大学获取博士学位后&#xff0c;他便扎根临床一线&#xff0c;以精湛医术和仁心仁术&#xff0c;为众多患者排忧解难 张志辰曾先后前往北京天坛医院、广东中山医院…...

CodeMirror 如何动态更新definemode

CodeMirror 如何动态更新definemode 问题描述&#xff1a;解决方法&#xff1a; 问题描述&#xff1a; 项目中有一部分用到了CodeMirror组件&#xff0c;其高亮显示的内容需要根据最新的json动态的更新&#xff0c;需要使用definemode自定义高亮内容。 想要的效果如下&#xf…...

舵机SG90详解

舵机&#xff0c;也叫伺服电机&#xff0c;在嵌入式开发中&#xff0c;舵机作为一种常见的运动控制组件&#xff0c;具有广泛的应用。其中&#xff0c;SG90 舵机以其高效、稳定的性能特点&#xff0c;成为了许多工程师和爱好者的首选&#xff0c;无论是航模、云台、机器人、智能…...

程序设计考题汇总(四:SQL练习)

文章目录 查询结果限制返回行数 查询结果限制返回行数 select device_id from user_profile LIMIT 2;...

明达IOT平台助力工业废水运维智能化

背景简介 相较于生活污水&#xff0c;工业废水的处理挑战性更高&#xff0c;原因在于其源于多样化的工业生产流程&#xff0c;成分复杂且多变&#xff0c;可能包含重金属、有毒化学…...

深入理解 Ansible Playbook:组件与实战

目录 1 playbook介绍 2 YAML语言 2.1语法简介 2.2数据类型 3 Playbook核心组件 3.1 hosts组件 3.2 remote_user组件 3.3 task列表和action组件 3.4 handlers 3.5 tags组件 3.6 其他组件说明 1 playbook介绍 playbook 剧本是由一个或多个"play"组成的列表。…...

JavaEE初阶——多线程(线程安全-锁)

复习上节内容&#xff08;部分-掌握程度不够的&#xff09; 加锁&#xff0c;解决线程安全问题。 synchronized关键字&#xff0c;对锁对象进行加锁。 锁对象&#xff0c;可以是随便一个Object对象&#xff08;或者其子类的对象&#xff09;&#xff0c;需要关注的是&#xff…...

Stable Diffusion 提示词语法

1.提示词基础 1.提示词之间用英文逗号,分隔 2.提示词之间是可以换行的 3.权重默认为1,越靠前权重越高 4.数量控制在75个单位以内 2.提示词各种符号的意义 2.1 ()、[]、{}符号 权重值()小括号[]中括号{}大括号默认1111层()1.1[]0.9{}1.052层(()) 1.121.21[[]]0.920.81{{}}1.…...

【功能安全】安全确认

目录 01 功能安全确认介绍 02 安全确认用例 03 安全确认模板 01 功能安全确认介绍 定义: 来源...

在pycharm2024.3.1中配置anaconda3-2024-06环境

version: anaconda3-2024.06-1 pycharm-community-2024.3.1 1、安装anaconda和pycharm 最新版最详细Anaconda新手安装配置环境创建教程_anaconda配置-CSDN博客 【2024最新版】超详细Pycharm安装保姆级教程&#xff0c;Pycharm环境配置和使用指南&#xff0c;看完这一篇就够了…...

linux不同发行版中的主要差异

一、初始化系统 Linux不同发行版中的系统初始化系统&#xff08;如 System V init、Upstart 或 systemd&#xff09; System V init&#xff1a; 历史&#xff1a;System V init 是最传统的 Linux 系统初始化系统&#xff0c;起源于 Unix System V 操作系统。运行级别&#xff…...

概率论得学习和整理29: 用EXCEL 描述二项分布

目录 1 关于二项分布的基本内容 2 二项分布的概率 2.1 核心要素 2.2 成功K次的概率&#xff0c;二项分布公式 2.3 期望和方差 2.4 具体试验 2.5 概率质量函数pmf 和cdf 3 二项分布的pmf图的改进 3.1 改进折线图 3.2 如何生成这种竖线图呢 4 不同的二项分布 4.1 p0.…...

C++打造局域网聊天室第九课: 客户端队列及其处理线程

文章目录 前言一、添加客户端队列的参数初始化二、相关函数总结 前言 C打造局域网聊天室第九课&#xff1a; 客户端队列及其处理线程 一、添加客户端队列的参数初始化 在Server.cpp的 ListenThreadFunc()函数内的其他操作处实现客户端队列的添加。 首先进行部分参数的初始化…...

请求go web后端接口 java安卓端播放视频

前端代码 添加gradle依赖 implementation com.squareup.retrofit2:retrofit:2.9.0 implementation com.squareup.retrofit2:converter-gson:2.9.0 添加访问网络权限 <uses-permission android:name"android.permission.INTERNET" />允许http 请求请求 andro…...

XML Schema 复合类型 - 混合内容

XML Schema 复合类型 - 混合内容 XML Schema 是一种用于定义 XML 文档结构和内容的语言。在 XML Schema 中&#xff0c;复合类型是一种包含其他元素和/或属性的复杂类型。混合内容&#xff08;Mixed Content&#xff09;是复合类型的一种特殊形式&#xff0c;它允许元素包含其…...

第8章 搬移特性

8.1 搬移函数 模块化是优秀软件设计的核心所在&#xff0c;好的模块化能够让我在修改程序时只需理解程序的一小部分。为了设计出高度模块化的程序&#xff0c;我得保证互相关联的软件要素都能集中到一块&#xff0c;并确保块与块之间的联系易于查找、直观易懂。同时&#xff0c…...

ARM/Linux嵌入式面经(五九):海尔

1.以后打算在哪里工作 问题回答: 1. 以后打算在哪里工作? 回答这个问题时,我首先会考虑我的个人目标、职业规划以及家庭和生活因素。从职业发展的角度来看,我希望能够在技术氛围浓厚、创新能力强、且能提供良好职业成长机会的地方工作。具体来说,我对以下几个方向特别感…...

java中的List、数组和set

在Java中&#xff0c;List、数组&#xff08;Array&#xff09;和Set 是三种常用的数据结构&#xff0c;它们各自有不同的特性、用途和实现方式。下面我们将深入探讨这三者的特点、区别以及它们在 Java 中的常见使用场景。 1. 数组&#xff08;Array&#xff09; 特性&#x…...