huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录
- 前言
- 一、huggingface的trainer的self.state与self.control初始化调用
- 二、TrainerState源码解读(self.state)
- 1、huggingface中self.state初始化参数
- 2、TrainerState类的Demo
- 三、TrainerControl源码解读(self.control)
- 总结
前言
在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。
一、huggingface的trainer的self.state与self.control初始化调用
trainer函数初始化调用代码如下:
# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)
在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:
class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...
二、TrainerState源码解读(self.state)
1、huggingface中self.state初始化参数
这里多解读一点huggingface的self.state初始化调用参数方法,
self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)
而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。
def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0
self.args.local_process_index与self.args.process_index来源self.args
2、TrainerState类的Demo
介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0
等内容,也进行了默认参数赋值,其Demo如下:
from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')
我这里使用state = TrainerState()
方法对TrainerState()类实例化,使用state.save_to_json('state.json')
进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')
方式载入会得到新的state_new实例化。
三、TrainerControl源码解读(self.control)
该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False
内容,也进行了默认参数赋值,其源码如下:
@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False
总结
本文主要介绍huggingface的trainer中的self.control与self.state的来源。
相关文章:

huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录 前言一、huggingface的trainer的self.state与self.control初始化调用二、TrainerState源码解读(self.state)1、huggingface中self.state初始化参数2、TrainerState类的Demo 三、TrainerControl源码解读(self.control)总结 前言 在 Hugging Face 中,self.s…...

30【Aseprite 作图】桌子——拆解
1 桌子只要画左上方,竖着5,斜着3个1,斜着两个2,斜着2个3,斜着一个5,斜着一个很长的 然后左右翻转 再上下翻转 在桌子腿部分,竖着三个直线,左右都是斜线;这是横着水平线不…...

C++设计模式-单例模式,反汇编
文章目录 25. 单例模式25.1. 饿汉式单例模式25.2. 懒汉式单例模式25.2.1. 解决方案125.2.2. 解决方案2 (推荐写法) 运行在VS2022,x86,Debug下。 25. 单例模式 单例即该类只能有一个实例。 应用:如在游戏开发中&#x…...

Django 做migrations时出错,解决方案
在做migrations的时候,偶尔会出现出错。 在已有数据的表中新增字段时,会弹出下面的信息 运行这个命令时 python manage.py makemigrationsTracking file by folder pattern: migrations It is impossible to add a non-nullable field ‘example’ to …...

QT::QNetworkReply类readAll()读取不到数据的可能原因
程序中,当发送请求时,并没有加锁,而是在响应函数中加了锁,导致可能某个请求的finished信号影响到其他请求响应数据的读取 connect(reply,&QNetworkReply::finished,this,&Display::replyFinished);参考这篇文章ÿ…...

vxe-form-design 表单设计器的使用
vxe-form-design 在 vue3 中表单设计器的使用 查看官网 https://vxeui.com 安装 npm install vxe-pc-ui // ... import VxeUI from vxe-pc-ui import vxe-pc-ui/lib/style.css // ...// ... createApp(App).use(VxeUI).mount(#app) // ...使用 github vxe-form-design 用…...

【Linux】TCP协议【上】{协议段属性:源端口号/目的端口号/序号/确认序号/窗口大小/紧急指针/标记位}
文章目录 1.引入2.协议段格式4位首部长度16位窗口大小32位序号思考三个问题【demo】标记位URG: 紧急指针是否有效提升某报文被处理优先级【0表示不设置1表示设置】ACK: 确认号是否有效PSH: 提示接收端应用程序立刻从TCP缓冲区把数据读走RST: 对方要求重新建立连接; 我们把携带R…...

php之sql代码审计
1 SQL注入代码审计流程 1.1 反向查找流程 通过可控变量(输入点)回溯危险函数 查找危险函数确定可控变量 传递的过程中触发漏洞 1.2 反向查找流程特点 暴力:全局搜索危险函数 简单:无需过多理解目标网站功能与架构 快速:适用于自动化代码审…...

【Java用法】java中计算两个时间差
java中计算两个时间差 不多说,直接上代码,可自行查看示例 package org.example.calc;import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit;public class MinusTest {public static void…...

tinymce富文本编辑器使用
安卓富文本编辑器:npm i tinymce/tinymce-vue 当前项目中富文本是放在一个dialog中,因此部分样式会有层叠问题,该组件样式部分不添加scope。这里图片上传只是前端静态数据展示收集。 <template><div class"desc-editor"…...

Java——接口后续
1.Comparable 接口 在Java中,我们对一个元素是数字的数组可以使用sort方法进行排序,如果要对一个元素是对象的数组按某种规则排序,就会用到Comparable接口 当实现Comparable接口后,sort会自动调用Comparable接口里的compareTo 方法…...

最新上市公司控制变量大全(1413+指标)1990-2023年
数据介绍:根据2023年上市公司年报数据进行更新,包括基本信息、财务指标、环境、社会与治理、数字化转型、企业发展、全要素生产率等1413指标。数据范围:A股上市公司数据年份:1990-2023年指标数目:1413个指标࿰…...

jmeter多用户并发登录教程
有时候为了模拟更真实的场景,在项目中需要多用户登录操作,大致参考如下 jmx脚本:百度网盘链接 提取码:0000 一: 单用户登录 先使用1个用户登录(先把1个请求调试通过) 发送一个登录请求&…...

【高频】redis快的原因
相关问题: 1.为什么Redis能够如此快速地进行数据存储和检索? 2.Redis作为内存数据库,其内存存储有什么优势吗? 3.Redis的网络模型有何特点,如何帮助提升性能? 一、问题回答 Redis使用了内存数据结构,例如字符串、哈希表、列表、集合、有…...

hive3从入门到精通(一)
Hive3入门至精通(基础、部署、理论、SQL、函数、运算以及性能优化)1-14章 第1章:数据仓库基础理论 1-1.数据仓库概念 数据仓库(英语:Data Warehouse,简称数仓、DW),是一个用于存储、分析、报告的数据系统。 数据仓库的目的是构…...

c++编程(15)——list的模拟实现
欢迎来到博主的专栏——c编程 博主ID:代码小豪 文章目录 前言list的数据结构list的默认构造尾插与尾删iterator插入和删除构造、析构、赋值copy构造initializer_list构造operator 析构函数 前言 受限于博主当前的技术水平,暂时还不能模拟实现出STL当中用…...

【深度学习】吸烟行为检测软件系统
往期文章列表: 【YOLO深度学习系列】图像分类、物体检测、实例分割、物体追踪、姿态估计、定向边框检测演示系统【含源码】【深度学习】YOLOV8数据标注及模型训练方法整体流程介绍及演示【深度学习】行人跌倒行为检测软件系统【深度学习】火灾检测软件系统【深度学…...

你见过哪些不过度设计的优秀APP?
优联前端https://ufrontend.com/ 提供一站式企业前端解决方案 “每日故宫”是一款以故宫博物院丰富的藏品为基础,结合日历形式展示每日精选藏品的移动应用。通过这款应用,用户可以随时随地欣赏到故宫的珍贵藏品,感受中华五千年文化的魅力。…...

全栈:session用户会话信息,用户浏览记录实例
PHP中的session是一种存储机制,它允许您存储和跟踪用户在访问Web应用程序时的信息。会话通常用于存储用户特定的数据,如用户ID、购物车内容、用户偏好设置等,这些数据需要在多个页面请求之间保持不变。 session详解 1. 会话是如何工作的 会…...

设计模式--》 装饰模式的应用
装饰模式的定义: 装饰模式(Decorator Pattern)是一种结构型设计模式,它允许你动态地给一个对象添加一些额外的职责。就增加功能来说,装饰模式相比生成子类更为灵活。 何时应用装饰模式? 1.当需要动态地给…...

深入解析Web前端三大主流框架:Angular、React和Vue
Web前端三大主流框架分别是Angular、React和Vue。下面我将为您详细介绍这三大框架的特点和使用指南。 Angular 核心概念: 组件(Components): 组件是Angular应用的构建块,每个组件由一个带有装饰器的类、一个HTML模板、一个CSS样式表组成。组件通过输入(@Input)和输出(…...

ch3运输层--计算机网络期末复习(持续更新中)
运输层位于网络层之上 运输层协议提供的某些服务受到网络层协议的限制。比如,时限和带宽保证。 运输层也提供自己的特殊服务。比如,可靠数据传输服务,安全性服务。 网络层:两个主机之间的逻辑通信 运输层:两个进程之间的逻辑通信 网络地址:主机的标识(IP地址) 传输地址: …...

mysql中的内连接与外连接
在MySQL中,内连接和外连接是用于从多个表中检索数据的两种不同的连接方式。 内连接(INNER JOIN): 内连接返回两个表之间匹配的行。它只返回两个表中共同匹配的行,如果在一个表中没有匹配到对应的行,则不会显…...

0基础认识C语言(理论+实操 2)
小伙伴们大家好,今天也要撸起袖子加油干!万事开头难,越学到后面越轻松~ 话不多说,开始正题~ 前提回顾: 接上次博客,我们学到了转义字符,最后留下两个转义字符不知道大家有没有动手尝试了一遍&a…...

ChatGPT的基本原理是什么?又该如何提高其准确性?
在深入探索如何提升ChatGPT的准确性之前,让我们先来了解一下它的工作原理吧。ChatGPT是一种基于深度学习的自然语言生成模型,它通过预训练和微调两个关键步骤来学习和理解自然语言。 在预训练阶段,ChatGPT会接触到大规模的文本数据集&#x…...

云计算OpenStack基础
1.什么是虚拟化? •虚拟化是云计算的基础。 •虚拟化是指计算元件在虚拟的而不是真实的硬件基础上运行。 •虚拟化将物理资源转变为具有可管理性的逻辑资源,以消除物理结构之间的隔离,将物理资源融为一个整体。虚拟化是一种简化管理和优化…...

[10] CUDA程序性能的提升 与 流
CUDA程序性能的提升 与 流 1. CUDA程序性能的提升 在本节中,我们会看到用来遵循的基本的一些性能来提升准则,我们会逐一解释它们1.1 使用适当的块数量和线程数量 研究表明,如果块的数量是 GPU 的流多处理器数量的两倍,则会给出最佳性能,不过,块和线程的数量与具体的算法…...

TH方程学习(1)
一、背景介绍 根据CW方程的学习,CW方程的限制条件为圆轨道,不考虑摄动,二者距离相对较小。TH方程则可以将物体间的相对运动推广到椭圆轨道的二体运动模型,本部分将结合STK的仿真功能,联合考察TH方程的有用性ÿ…...

【九十七】【算法分析与设计】图论,迷宫,1207. 大臣的旅费,走出迷宫,石油采集,after与迷宫,逃离迷宫,3205. 最优配餐,路径之谜
1207. 大臣的旅费 - AcWing题库 很久以前,TT 王国空前繁荣。 为了更好地管理国家,王国修建了大量的快速路,用于连接首都和王国内的各大城市。 为节省经费,TT 国的大臣们经过思考,制定了一套优秀的修建方案,…...

【Tools】SpringBoot工程中,对于时间属性从后端返回到前端的格式问题
Catalog 时间属性格式问题一、需求二、怎么使用 时间属性格式问题 一、需求 对于表中时间字段,后端创建对应的实体类的时间属性需要设定格式(默认的格式不方便阅读),再返回给前端。 二、怎么使用 导入jackson相关的坐标&#x…...