huggingface NLP-微调一个预训练模型
微调一个预训练模型

1 预处理数据
1.1 处理数据
1.1.1 fine-tune
使用tokenizer后的token 进行训练
batch = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")# This is new
batch["labels"] = torch.tensor([1, 1])optimizer = AdamW(model.parameters())
loss = model(**batch).loss
loss.backward()
optimizer.step()
1.2 从模型中心(Hub)加载数据集
1.2.1 数据集
DatasetDict对象,其中包含训练集、验证集和测试集
。每一个集合都包含几个列(sentence1, sentence2, label, and idx)以及一个代表行数的变量,即每个集合中的行的个数
下载数据集并缓存到 ~/.cache/huggingface/datasets. 回想一下第2章,您可以通过设置HF_HOME环境变量来自定义缓存的文件夹。
1.3 预处理数据集
1.3.1 预处理数据集,我们需要将文本转换为模型能够理解的数字
1.3.2 类型标记ID(token_type_ids)的作用就是告诉模型输入的哪一部分是第一句,哪一部分是第二句
1.3.3 不一定具有类型标记ID(token_type_ids
1.3.4 将数据保存为数据集,我们将使用Dataset.map()
调用map时使用了batch =True,这样函数就可以同时应用到数据集的多个元素上,而不是分别应用到每个元素上。这将使我们的预处理快许多
1.3.5 省略padding参数
在标记的时候将所有样本填充到最大长度的效率不高
一个更好的做法:在构建批处理时填充样本更好,因为这样我们只需要填充到该批处理中的最大长度,而不是整个数据集的最大长度。当输入长度变化很大时,这可以节省大量时间和处理能力!
1.3.6 将所有示例填充到最长元素的长度——我们称之为动态填充
为了解决句子长度统一的问题,我们必须定义一个collate函数,该函数会将每个batch句子填充到正确的长度
transformer库通过DataCollatorWithPadding为我们提供了这样一个函数
It’s when you pad your inputs when the batch is created, to the maximum length of the sentences inside that batch.
1.3.7 数据集是以Apache Arrow文件存储在磁盘上
1.4 benefits
1.4.1 The results of the function are cached, so it won’t take any time if we re-execute the code.
1.4.2 It can apply multiprocessing to go faster than applying the function on each element of the dataset.
1.4.3 It does not load the whole dataset into memory, saving the results as soon as one element is processed.
2 使用 Trainer API 微调模型(非并行批量模式)
2.1 TrainingArguments 类
2.1.1 它将包含 Trainer用于训练和评估的所有超参数
2.1.2 可以只调整部分默认参数进行微调
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
2.2 Training
2.2.1 简单training
trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,tokenizer=tokenizer,
)
trainer.train()
2.3 评估
2.3.1 使用 Trainer.predict() 命令来使用我们的模型进行预测
predict() 的输出结果是具有三个字段的命名元组: predictions , label_ids , 和 metrics
metrics 字段将只包含传递的数据集的loss,以及一些运行时间(预测所需的总时间和平均时间)
要将我们的预测的可以与真正的标签进行比较,我们需要在第二个轴上取最大值的索引
preds = np.argmax(predictions.predictions, axis=-1)
def compute_metrics(eval_preds):metric = evaluate.load("glue", "mrpc")logits, labels = eval_predspredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)
如何使用compute_metrics()函数定义一个新的 Trainer
2.4 AutoModelForSequenceClassification
2.4.1 when we used AutoModelForSequenceClassification with bert-base-uncased, we got warnings when instantiating the model. The pretrained head is not used for the sequence classification task, so it’s discarded and a new head is instantiated(实例化) with random weights.
3 完整的训练,使用Accelerator和scheduler
3.1 训练前的数据准备
3.1.1 删除与模型不期望的值相对应的列(如sentence1和sentence2列)。
3.1.2 将列名label重命名为labels(因为模型期望参数是labels)。
3.1.3 设置数据集的格式,使其返回 PyTorch 张量而不是列表。
3.1.4 代码
tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names
3.2 data loader
3.2.1 from torch.utils.data import DataLoader
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)
3.3 优化器和学习率调度器
3.3.1 optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler("linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,
)
3.4 训练循环
3.4.1
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
3.5 使用accelerator
3.5.1
train_dl, eval_dl, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs = 3
num_training_steps = num_epochs * len(train_dl)
3.5.2 分布式
accelerate config
accelerate launch train.py
3.5.3 With 🤗Accelerate, your training loops will work for multiple GPUs and TPUs.
4 注意train只会对参数有调整
4.1 超参数(Hyperparameter),是机器学习算法中的调优参数,用于控制模型的学习过程和结构。 与模型参数(Model Parameter)不同,模型参数是在训练过程中通过数据学习得到的,而超参数是在训练之前由开发者或实践者直接设定的,并且在训练过程中保持不变。
4.2 需要自己设定,不是机器自己找出来的,称为超参数(hyperparameter)。
4.2.1 需要人工设置: 超参数的值不是通过训练过程自动学习得到的,而是需要训练者根据经验或实验来设定。
4.2.2 影响模型性能: 超参数的选择会直接影响模型的训练过程和最终性能。
4.2.3 需要优化: 为了获得更好的模型性能,通常需要对超参数进行优化,选择最优的超参数组合。
4.3 validated 数据集作用
4.3.1 验证集,用于挑选超参数的数据子集。

相关文章:
huggingface NLP-微调一个预训练模型
微调一个预训练模型 1 预处理数据 1.1 处理数据 1.1.1 fine-tune 使用tokenizer后的token 进行训练 batch tokenizer(sequences, paddingTrue, truncationTrue, return_tensors"pt")# This is new batch["labels"] torch.tensor([1, 1])optimizer A…...
【BUG记录】Apifox 参数传入 + 号变成空格的 BUG
文章目录 1. 问题描述2. 原因2.1 编码2.2 解码 3. 解决方法 1. 问题描述 之前写了一个接口,用 Apifox 请求,参数传入一个 86 的电话,结果到服务器 就变成空格了。 Java 接收请求的接口: 2. 原因 2.1 编码 进行 URL 请求的…...
Spring AI API 介绍
目录: Spring AI 框架介绍 Spring AI API 核心API简介 Spring AI 提供了很多便利的功能,主要如下: AI Model API “Model API” 提供了聊天、文本转图像、音频转录、文本转语音、嵌入等功能,且不局限于某个固定的大模型提供商…...
【MySQL】Linux使用C语言连接安装
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...
2024年第十五届蓝桥杯青少组C++国赛—割点
割点 题目描述 一张棋盘由n行 m 列的网格矩阵组成,每个网格中最多放一颗棋子。当前棋盘上已有若干棋子。所有水平方向或竖直方向上相邻的棋子属于同一连通块。 现给定棋盘上所有棋子的位置,如果要使棋盘上出现两个及以上的棋子连通块,请问…...
【软件开发】做出技术决策
文章目录 专注于核心业务除非绝对必要,不要重写代码保持技术栈简单尽量减少依赖避免范围蔓延按照业务实际情况确定优先级在做出高风险决策前构建原型跨职能团队协作信任你的团队在过去的二十年里,我曾在多家初创企业担任软件开发人员、技术负责人以及首席技术官(包括创办自己…...
Airborne使用教程
1.安装环境 前提条件:系统已安装Ruby 打开终端输入如下命令 gem install airborne 或者在Gemfile添加 gem airborne 然后运行bundle install 2.编写脚本 在项目中新建api_tests_spec.rb文件 以GET接口"https://www.thunderclient.com/welcome"为…...
WPF实现曲线数据展示【案例:震动数据分析】
wpf实现曲线数据展示,函数曲线展示,实例:震动数据分析为例。 如上图所示,如果你想实现上图中的效果,请详细参考我的内容,创作不易,给个赞吧。 一共有两种方式来实现,一种是使用第三…...
EasyExcel 动态设置表格的背景颜色和排列
项目中使用EasyExcel把数据以excel格式导出,其中设置某一行、某一列单元格的背景颜色、排列方式十分常用,记录下来方便以后查阅。 1. 导入maven依赖: <dependency><groupId>com.alibaba</groupId><artifactId>easy…...
【 C++11 】类的新功能
C类的新功能 一、默认成员函数二、类成员变量初始化三、default关键字四、delete关键字六、final关键字七、override关键字 一、默认成员函数 八个默认成员函数 在C11之前,一个类中有如下六个默认成员函数: 构造函数。析构函数。拷贝构造函数。拷贝赋值…...
防止SQL注入:PHP安全最佳实践
防止SQL注入:PHP安全最佳实践 SQL注入是一种常见的网络攻击方式,攻击者通过向应用程序的SQL查询中插入恶意代码,来获取、操控或破坏数据库中的数据。为了保护PHP应用免受SQL注入攻击,开发者需要遵循一系列安全最佳实践。本文将介…...
自动化生产或质量检测准备工作杂记
自动化生产或质量检测一个流程是: 上料位上料: “上料位”指的是物料被放置以供机器或设备处理的位置。“上料”指的是将物料从存储位置移动到加工或检测位置的过程。移动到对位相机位置: “对位相机”是一种高精度相机,用于精确…...
张志辰医生
在医学领域,北京中医药大学东方医院的张志辰副主任医师宛如一颗璀璨的明星。自 2011 年于北京中医药大学获取博士学位后,他便扎根临床一线,以精湛医术和仁心仁术,为众多患者排忧解难 张志辰曾先后前往北京天坛医院、广东中山医院…...
CodeMirror 如何动态更新definemode
CodeMirror 如何动态更新definemode 问题描述:解决方法: 问题描述: 项目中有一部分用到了CodeMirror组件,其高亮显示的内容需要根据最新的json动态的更新,需要使用definemode自定义高亮内容。 想要的效果如下…...
舵机SG90详解
舵机,也叫伺服电机,在嵌入式开发中,舵机作为一种常见的运动控制组件,具有广泛的应用。其中,SG90 舵机以其高效、稳定的性能特点,成为了许多工程师和爱好者的首选,无论是航模、云台、机器人、智能…...
程序设计考题汇总(四:SQL练习)
文章目录 查询结果限制返回行数 查询结果限制返回行数 select device_id from user_profile LIMIT 2;...
明达IOT平台助力工业废水运维智能化
背景简介 相较于生活污水,工业废水的处理挑战性更高,原因在于其源于多样化的工业生产流程,成分复杂且多变,可能包含重金属、有毒化学…...
深入理解 Ansible Playbook:组件与实战
目录 1 playbook介绍 2 YAML语言 2.1语法简介 2.2数据类型 3 Playbook核心组件 3.1 hosts组件 3.2 remote_user组件 3.3 task列表和action组件 3.4 handlers 3.5 tags组件 3.6 其他组件说明 1 playbook介绍 playbook 剧本是由一个或多个"play"组成的列表。…...
JavaEE初阶——多线程(线程安全-锁)
复习上节内容(部分-掌握程度不够的) 加锁,解决线程安全问题。 synchronized关键字,对锁对象进行加锁。 锁对象,可以是随便一个Object对象(或者其子类的对象),需要关注的是ÿ…...
Stable Diffusion 提示词语法
1.提示词基础 1.提示词之间用英文逗号,分隔 2.提示词之间是可以换行的 3.权重默认为1,越靠前权重越高 4.数量控制在75个单位以内 2.提示词各种符号的意义 2.1 ()、[]、{}符号 权重值()小括号[]中括号{}大括号默认1111层()1.1[]0.9{}1.052层(()) 1.121.21[[]]0.920.81{{}}1.…...
【杂谈】-递归进化:人工智能的自我改进与监管挑战
递归进化:人工智能的自我改进与监管挑战 文章目录 递归进化:人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管?3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...
select、poll、epoll 与 Reactor 模式
在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。 一、I…...
html-<abbr> 缩写或首字母缩略词
定义与作用 <abbr> 标签用于表示缩写或首字母缩略词,它可以帮助用户更好地理解缩写的含义,尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时,会显示一个提示框。 示例&#x…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信
文章目录 Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket(服务端和客户端都要)2. 绑定本地地址和端口&#x…...
怎么让Comfyui导出的图像不包含工作流信息,
为了数据安全,让Comfyui导出的图像不包含工作流信息,导出的图像就不会拖到comfyui中加载出来工作流。 ComfyUI的目录下node.py 直接移除 pnginfo(推荐) 在 save_images 方法中,删除或注释掉所有与 metadata …...
第7篇:中间件全链路监控与 SQL 性能分析实践
7.1 章节导读 在构建数据库中间件的过程中,可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中,必须做到: 🔍 追踪每一条 SQL 的生命周期(从入口到数据库执行)&#…...
毫米波雷达基础理论(3D+4D)
3D、4D毫米波雷达基础知识及厂商选型 PreView : https://mp.weixin.qq.com/s/bQkju4r6med7I3TBGJI_bQ 1. FMCW毫米波雷达基础知识 主要参考博文: 一文入门汽车毫米波雷达基本原理 :https://mp.weixin.qq.com/s/_EN7A5lKcz2Eh8dLnjE19w 毫米波雷达基础…...
