huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录
- 前言
- 一、huggingface的trainer的self.state与self.control初始化调用
- 二、TrainerState源码解读(self.state)
- 1、huggingface中self.state初始化参数
- 2、TrainerState类的Demo
- 三、TrainerControl源码解读(self.control)
- 总结
前言
在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。
一、huggingface的trainer的self.state与self.control初始化调用
trainer函数初始化调用代码如下:
# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)
在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:
class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...
二、TrainerState源码解读(self.state)
1、huggingface中self.state初始化参数
这里多解读一点huggingface的self.state初始化调用参数方法,
self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)
而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。
def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0
self.args.local_process_index与self.args.process_index来源self.args
2、TrainerState类的Demo
介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0
等内容,也进行了默认参数赋值,其Demo如下:
from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')
我这里使用state = TrainerState()
方法对TrainerState()类实例化,使用state.save_to_json('state.json')
进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')
方式载入会得到新的state_new实例化。
三、TrainerControl源码解读(self.control)
该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False
内容,也进行了默认参数赋值,其源码如下:
@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False
总结
本文主要介绍huggingface的trainer中的self.control与self.state的来源。
相关文章:

huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录 前言一、huggingface的trainer的self.state与self.control初始化调用二、TrainerState源码解读(self.state)1、huggingface中self.state初始化参数2、TrainerState类的Demo 三、TrainerControl源码解读(self.control)总结 前言 在 Hugging Face 中,self.s…...

30【Aseprite 作图】桌子——拆解
1 桌子只要画左上方,竖着5,斜着3个1,斜着两个2,斜着2个3,斜着一个5,斜着一个很长的 然后左右翻转 再上下翻转 在桌子腿部分,竖着三个直线,左右都是斜线;这是横着水平线不…...

C++设计模式-单例模式,反汇编
文章目录 25. 单例模式25.1. 饿汉式单例模式25.2. 懒汉式单例模式25.2.1. 解决方案125.2.2. 解决方案2 (推荐写法) 运行在VS2022,x86,Debug下。 25. 单例模式 单例即该类只能有一个实例。 应用:如在游戏开发中&#x…...

Django 做migrations时出错,解决方案
在做migrations的时候,偶尔会出现出错。 在已有数据的表中新增字段时,会弹出下面的信息 运行这个命令时 python manage.py makemigrationsTracking file by folder pattern: migrations It is impossible to add a non-nullable field ‘example’ to …...

QT::QNetworkReply类readAll()读取不到数据的可能原因
程序中,当发送请求时,并没有加锁,而是在响应函数中加了锁,导致可能某个请求的finished信号影响到其他请求响应数据的读取 connect(reply,&QNetworkReply::finished,this,&Display::replyFinished);参考这篇文章ÿ…...

vxe-form-design 表单设计器的使用
vxe-form-design 在 vue3 中表单设计器的使用 查看官网 https://vxeui.com 安装 npm install vxe-pc-ui // ... import VxeUI from vxe-pc-ui import vxe-pc-ui/lib/style.css // ...// ... createApp(App).use(VxeUI).mount(#app) // ...使用 github vxe-form-design 用…...

【Linux】TCP协议【上】{协议段属性:源端口号/目的端口号/序号/确认序号/窗口大小/紧急指针/标记位}
文章目录 1.引入2.协议段格式4位首部长度16位窗口大小32位序号思考三个问题【demo】标记位URG: 紧急指针是否有效提升某报文被处理优先级【0表示不设置1表示设置】ACK: 确认号是否有效PSH: 提示接收端应用程序立刻从TCP缓冲区把数据读走RST: 对方要求重新建立连接; 我们把携带R…...

php之sql代码审计
1 SQL注入代码审计流程 1.1 反向查找流程 通过可控变量(输入点)回溯危险函数 查找危险函数确定可控变量 传递的过程中触发漏洞 1.2 反向查找流程特点 暴力:全局搜索危险函数 简单:无需过多理解目标网站功能与架构 快速:适用于自动化代码审…...

【Java用法】java中计算两个时间差
java中计算两个时间差 不多说,直接上代码,可自行查看示例 package org.example.calc;import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit;public class MinusTest {public static void…...

tinymce富文本编辑器使用
安卓富文本编辑器:npm i tinymce/tinymce-vue 当前项目中富文本是放在一个dialog中,因此部分样式会有层叠问题,该组件样式部分不添加scope。这里图片上传只是前端静态数据展示收集。 <template><div class"desc-editor"…...

Java——接口后续
1.Comparable 接口 在Java中,我们对一个元素是数字的数组可以使用sort方法进行排序,如果要对一个元素是对象的数组按某种规则排序,就会用到Comparable接口 当实现Comparable接口后,sort会自动调用Comparable接口里的compareTo 方法…...

最新上市公司控制变量大全(1413+指标)1990-2023年
数据介绍:根据2023年上市公司年报数据进行更新,包括基本信息、财务指标、环境、社会与治理、数字化转型、企业发展、全要素生产率等1413指标。数据范围:A股上市公司数据年份:1990-2023年指标数目:1413个指标࿰…...

jmeter多用户并发登录教程
有时候为了模拟更真实的场景,在项目中需要多用户登录操作,大致参考如下 jmx脚本:百度网盘链接 提取码:0000 一: 单用户登录 先使用1个用户登录(先把1个请求调试通过) 发送一个登录请求&…...

【高频】redis快的原因
相关问题: 1.为什么Redis能够如此快速地进行数据存储和检索? 2.Redis作为内存数据库,其内存存储有什么优势吗? 3.Redis的网络模型有何特点,如何帮助提升性能? 一、问题回答 Redis使用了内存数据结构,例如字符串、哈希表、列表、集合、有…...

hive3从入门到精通(一)
Hive3入门至精通(基础、部署、理论、SQL、函数、运算以及性能优化)1-14章 第1章:数据仓库基础理论 1-1.数据仓库概念 数据仓库(英语:Data Warehouse,简称数仓、DW),是一个用于存储、分析、报告的数据系统。 数据仓库的目的是构…...

c++编程(15)——list的模拟实现
欢迎来到博主的专栏——c编程 博主ID:代码小豪 文章目录 前言list的数据结构list的默认构造尾插与尾删iterator插入和删除构造、析构、赋值copy构造initializer_list构造operator 析构函数 前言 受限于博主当前的技术水平,暂时还不能模拟实现出STL当中用…...

【深度学习】吸烟行为检测软件系统
往期文章列表: 【YOLO深度学习系列】图像分类、物体检测、实例分割、物体追踪、姿态估计、定向边框检测演示系统【含源码】【深度学习】YOLOV8数据标注及模型训练方法整体流程介绍及演示【深度学习】行人跌倒行为检测软件系统【深度学习】火灾检测软件系统【深度学…...

你见过哪些不过度设计的优秀APP?
优联前端https://ufrontend.com/ 提供一站式企业前端解决方案 “每日故宫”是一款以故宫博物院丰富的藏品为基础,结合日历形式展示每日精选藏品的移动应用。通过这款应用,用户可以随时随地欣赏到故宫的珍贵藏品,感受中华五千年文化的魅力。…...
全栈:session用户会话信息,用户浏览记录实例
PHP中的session是一种存储机制,它允许您存储和跟踪用户在访问Web应用程序时的信息。会话通常用于存储用户特定的数据,如用户ID、购物车内容、用户偏好设置等,这些数据需要在多个页面请求之间保持不变。 session详解 1. 会话是如何工作的 会…...
设计模式--》 装饰模式的应用
装饰模式的定义: 装饰模式(Decorator Pattern)是一种结构型设计模式,它允许你动态地给一个对象添加一些额外的职责。就增加功能来说,装饰模式相比生成子类更为灵活。 何时应用装饰模式? 1.当需要动态地给…...
基于算法竞赛的c++编程(28)结构体的进阶应用
结构体的嵌套与复杂数据组织 在C中,结构体可以嵌套使用,形成更复杂的数据结构。例如,可以通过嵌套结构体描述多层级数据关系: struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...

超短脉冲激光自聚焦效应
前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放
简介 前面两期文章我们介绍了I2S的读取和写入,一个是通过INMP441麦克风模块采集音频,一个是通过PCM5102A模块播放音频,那如果我们将两者结合起来,将麦克风采集到的音频通过PCM5102A播放,是不是就可以做一个扩音器了呢…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

用docker来安装部署freeswitch记录
今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...