huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录
- 前言
- 一、huggingface的trainer的self.state与self.control初始化调用
- 二、TrainerState源码解读(self.state)
- 1、huggingface中self.state初始化参数
- 2、TrainerState类的Demo
- 三、TrainerControl源码解读(self.control)
- 总结
前言
在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。
一、huggingface的trainer的self.state与self.control初始化调用
trainer函数初始化调用代码如下:
# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)
在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:
class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...
二、TrainerState源码解读(self.state)
1、huggingface中self.state初始化参数
这里多解读一点huggingface的self.state初始化调用参数方法,
self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)
而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。
def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0
self.args.local_process_index与self.args.process_index来源self.args
2、TrainerState类的Demo
介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0等内容,也进行了默认参数赋值,其Demo如下:
from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')
我这里使用state = TrainerState()方法对TrainerState()类实例化,使用state.save_to_json('state.json')进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')方式载入会得到新的state_new实例化。

三、TrainerControl源码解读(self.control)
该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False内容,也进行了默认参数赋值,其源码如下:
@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False
总结
本文主要介绍huggingface的trainer中的self.control与self.state的来源。
相关文章:
huggingface的self.state与self.control来源(TrainerState与TrainerControl)
文章目录 前言一、huggingface的trainer的self.state与self.control初始化调用二、TrainerState源码解读(self.state)1、huggingface中self.state初始化参数2、TrainerState类的Demo 三、TrainerControl源码解读(self.control)总结 前言 在 Hugging Face 中,self.s…...
30【Aseprite 作图】桌子——拆解
1 桌子只要画左上方,竖着5,斜着3个1,斜着两个2,斜着2个3,斜着一个5,斜着一个很长的 然后左右翻转 再上下翻转 在桌子腿部分,竖着三个直线,左右都是斜线;这是横着水平线不…...
C++设计模式-单例模式,反汇编
文章目录 25. 单例模式25.1. 饿汉式单例模式25.2. 懒汉式单例模式25.2.1. 解决方案125.2.2. 解决方案2 (推荐写法) 运行在VS2022,x86,Debug下。 25. 单例模式 单例即该类只能有一个实例。 应用:如在游戏开发中&#x…...
Django 做migrations时出错,解决方案
在做migrations的时候,偶尔会出现出错。 在已有数据的表中新增字段时,会弹出下面的信息 运行这个命令时 python manage.py makemigrationsTracking file by folder pattern: migrations It is impossible to add a non-nullable field ‘example’ to …...
QT::QNetworkReply类readAll()读取不到数据的可能原因
程序中,当发送请求时,并没有加锁,而是在响应函数中加了锁,导致可能某个请求的finished信号影响到其他请求响应数据的读取 connect(reply,&QNetworkReply::finished,this,&Display::replyFinished);参考这篇文章ÿ…...
vxe-form-design 表单设计器的使用
vxe-form-design 在 vue3 中表单设计器的使用 查看官网 https://vxeui.com 安装 npm install vxe-pc-ui // ... import VxeUI from vxe-pc-ui import vxe-pc-ui/lib/style.css // ...// ... createApp(App).use(VxeUI).mount(#app) // ...使用 github vxe-form-design 用…...
【Linux】TCP协议【上】{协议段属性:源端口号/目的端口号/序号/确认序号/窗口大小/紧急指针/标记位}
文章目录 1.引入2.协议段格式4位首部长度16位窗口大小32位序号思考三个问题【demo】标记位URG: 紧急指针是否有效提升某报文被处理优先级【0表示不设置1表示设置】ACK: 确认号是否有效PSH: 提示接收端应用程序立刻从TCP缓冲区把数据读走RST: 对方要求重新建立连接; 我们把携带R…...
php之sql代码审计
1 SQL注入代码审计流程 1.1 反向查找流程 通过可控变量(输入点)回溯危险函数 查找危险函数确定可控变量 传递的过程中触发漏洞 1.2 反向查找流程特点 暴力:全局搜索危险函数 简单:无需过多理解目标网站功能与架构 快速:适用于自动化代码审…...
【Java用法】java中计算两个时间差
java中计算两个时间差 不多说,直接上代码,可自行查看示例 package org.example.calc;import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit;public class MinusTest {public static void…...
tinymce富文本编辑器使用
安卓富文本编辑器:npm i tinymce/tinymce-vue 当前项目中富文本是放在一个dialog中,因此部分样式会有层叠问题,该组件样式部分不添加scope。这里图片上传只是前端静态数据展示收集。 <template><div class"desc-editor"…...
Java——接口后续
1.Comparable 接口 在Java中,我们对一个元素是数字的数组可以使用sort方法进行排序,如果要对一个元素是对象的数组按某种规则排序,就会用到Comparable接口 当实现Comparable接口后,sort会自动调用Comparable接口里的compareTo 方法…...
最新上市公司控制变量大全(1413+指标)1990-2023年
数据介绍:根据2023年上市公司年报数据进行更新,包括基本信息、财务指标、环境、社会与治理、数字化转型、企业发展、全要素生产率等1413指标。数据范围:A股上市公司数据年份:1990-2023年指标数目:1413个指标࿰…...
jmeter多用户并发登录教程
有时候为了模拟更真实的场景,在项目中需要多用户登录操作,大致参考如下 jmx脚本:百度网盘链接 提取码:0000 一: 单用户登录 先使用1个用户登录(先把1个请求调试通过) 发送一个登录请求&…...
【高频】redis快的原因
相关问题: 1.为什么Redis能够如此快速地进行数据存储和检索? 2.Redis作为内存数据库,其内存存储有什么优势吗? 3.Redis的网络模型有何特点,如何帮助提升性能? 一、问题回答 Redis使用了内存数据结构,例如字符串、哈希表、列表、集合、有…...
hive3从入门到精通(一)
Hive3入门至精通(基础、部署、理论、SQL、函数、运算以及性能优化)1-14章 第1章:数据仓库基础理论 1-1.数据仓库概念 数据仓库(英语:Data Warehouse,简称数仓、DW),是一个用于存储、分析、报告的数据系统。 数据仓库的目的是构…...
c++编程(15)——list的模拟实现
欢迎来到博主的专栏——c编程 博主ID:代码小豪 文章目录 前言list的数据结构list的默认构造尾插与尾删iterator插入和删除构造、析构、赋值copy构造initializer_list构造operator 析构函数 前言 受限于博主当前的技术水平,暂时还不能模拟实现出STL当中用…...
【深度学习】吸烟行为检测软件系统
往期文章列表: 【YOLO深度学习系列】图像分类、物体检测、实例分割、物体追踪、姿态估计、定向边框检测演示系统【含源码】【深度学习】YOLOV8数据标注及模型训练方法整体流程介绍及演示【深度学习】行人跌倒行为检测软件系统【深度学习】火灾检测软件系统【深度学…...
你见过哪些不过度设计的优秀APP?
优联前端https://ufrontend.com/ 提供一站式企业前端解决方案 “每日故宫”是一款以故宫博物院丰富的藏品为基础,结合日历形式展示每日精选藏品的移动应用。通过这款应用,用户可以随时随地欣赏到故宫的珍贵藏品,感受中华五千年文化的魅力。…...
全栈:session用户会话信息,用户浏览记录实例
PHP中的session是一种存储机制,它允许您存储和跟踪用户在访问Web应用程序时的信息。会话通常用于存储用户特定的数据,如用户ID、购物车内容、用户偏好设置等,这些数据需要在多个页面请求之间保持不变。 session详解 1. 会话是如何工作的 会…...
设计模式--》 装饰模式的应用
装饰模式的定义: 装饰模式(Decorator Pattern)是一种结构型设计模式,它允许你动态地给一个对象添加一些额外的职责。就增加功能来说,装饰模式相比生成子类更为灵活。 何时应用装饰模式? 1.当需要动态地给…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...
Opencv中的addweighted函数
一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...
SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
Matlab | matlab常用命令总结
常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...
学习STC51单片机32(芯片为STC89C52RCRC)OLED显示屏2
每日一言 今天的每一份坚持,都是在为未来积攒底气。 案例:OLED显示一个A 这边观察到一个点,怎么雪花了就是都是乱七八糟的占满了屏幕。。 解释 : 如果代码里信号切换太快(比如 SDA 刚变,SCL 立刻变&#…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度
文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...
