huggingface的self.state与self.control来源(TrainerState与TrainerControl)

news/2024/7/22 1:23:36/文章来源:https://blog.csdn.net/weixin_38252409/article/details/139299941

文章目录

  • 前言
  • 一、huggingface的trainer的self.state与self.control初始化调用
  • 二、TrainerState源码解读(self.state)
    • 1、huggingface中self.state初始化参数
    • 2、TrainerState类的Demo
  • 三、TrainerControl源码解读(self.control)
  • 总结


前言

在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。


一、huggingface的trainer的self.state与self.control初始化调用

trainer函数初始化调用代码如下:

# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)

在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:

class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...

二、TrainerState源码解读(self.state)

1、huggingface中self.state初始化参数

这里多解读一点huggingface的self.state初始化调用参数方法,

 self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)

而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。

 def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0

self.args.local_process_index与self.args.process_index来源self.args

2、TrainerState类的Demo

介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0等内容,也进行了默认参数赋值,其Demo如下:

from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')

我这里使用state = TrainerState()方法对TrainerState()类实例化,使用state.save_to_json('state.json')进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')方式载入会得到新的state_new实例化。
在这里插入图片描述

三、TrainerControl源码解读(self.control)

该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False内容,也进行了默认参数赋值,其源码如下:

@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False

总结

本文主要介绍huggingface的trainer中的self.control与self.state的来源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ldbm.cn/p/430622.html

如若内容造成侵权/违法违规/事实不符,请联系编程新知网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Matlab|基于PMU相量测量单元进行电力系统电压幅值和相角状态估计

主要内容 程序采用三种方法对14节点和30节点电力系统状态进行评估: ①PMU同步相量测量单元结合加权最小二乘法(WLS)分析电力系统的电压幅值和相角状态; ②并采用牛顿-拉夫逊方法进行系统潮流计算,结果作为理论分…

【Linux进程篇】Linux进程管理——进程创建与终止

W...Y的主页 😊 代码仓库分享💕 目录 进程创建 fork函数初识 写时拷贝 fork常规用法 fork调用失败的原因 进程终止 进程退出场景 _exit函数 exit函数 return退出 进程创建 fork函数初识 在linux中fork函数时非常重要的函数,它从已…

【已解决】使用token登录机制,token获取不到,blog_list.html界面加载不出来

Bug产生 今天使用token完成用户登录信息的存储的时候被卡了大半天。 因为登录的功能写的已经很多了,所以今天就没有写一点验一点,而是在写完获取博客列表功功能,验证完它的后端后,了解完令牌的基本使用以及Jwt的基本使用方式——…

基于文本来推荐相似酒店

基于文本来推荐相似酒店 查看数据集基本信息 import pandas as pd import numpy as np from nltk.corpus import stopwords from sklearn.metrics.pairwise import linear_kernel from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extrac…

C语言 指针——指针变量的定义、初始化及解引用

目录 指针 内存如何编址? 如何对变量进行寻址? 用什么类型的变量来存放变量的地址? 如何显示变量的地址?​编辑 使用未初始化的指针会怎样? NULL是什么? 如何访问指针变量指向的存储单元中的数据? 指针变量的…

电脑录屏怎么录?7个电脑录屏软件免费版强势来袭,赶快收藏!

电脑录屏怎么录?相信很多小伙伴们都不知道怎么在Windows电脑上录屏吧?在当今社会,随着互联网的快速发展,越来越多的小伙伴们开始通过制作视频内容来分享知识、展示技能或者记录生活。电脑录屏成为了一种简单高效的方式&#xff0c…

Python OCR 文字识别使用模型:读光-文字识别-行识别模型-中英-通用领域

介绍 什么是OCR? OCR是“Optical Character Recognition”的缩写,中文意为“光学字符识别”。它是一种技术,可以识别和转换打印在纸张或图像上的文字和字符为机器可处理的格式,如计算机文本文件。通过使用OCR技术,可…

2024.05.29学习记录

1、css面经复习 2、代码随想录二刷 3、rosebush upload组件初步完成

0基础认识C语言(理论+实操 2)

小伙伴们大家好,今天也要撸起袖子加油干!万事开头难,越学到后面越轻松~ 话不多说,开始正题~ 前提回顾: 接上次博客,我们学到了转义字符,最后留下两个转义字符不知道大家有没有动手尝试了一遍&a…

MySQL8找不到my.ini配置文件以及报sql_mode=only_full_group_by解决方案

一、找不到my.ini配置文件 MySQL 8 安装或启动过程中,如果系统找不到my.ini文件,通常意味着 MySQL服务器没有找到其配置文件。在Windows系统上,MySQL 8 预期使用my.ini作为配置文件,而不是在某些情况下用到的my.cnf文件。 通过 …

自定义注解+AOP切面实现日志记录

自定义注解: Target(ElementType.METHOD)// 作用在方法上 Retention(RetentionPolicy.RUNTIME) Documented Inherited // 子类可以继承此注解 public interface OperationLog { } aop切面: Slf4j Aspect Component public class OperationAspect {Au…

格式化字符串

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 格式化字符串是指先制定一个模板,在这个模板中预留几个空位,然后再根据需要填上相应的内容。这些空位需要通过指定的符号标记…

tomcat学习--部署java项目

主流开发项目,springboot框架下,jar部署java传统的tomcat发布war包 一 什么是tomcat? 是一个用于运行java程序的软件,发布的时候:开发将源码使用maven打包,生产war包 二 安装tomcat tomcat是java写的&a…

基于GO 写的一款 GUI 工具,M3u8视频下载播放器-飞鸟视频助手

M3u8视频下载播放器-飞鸟视频助手 M3u8视频飞鸟视频助手使用m3u8下载m3u8 本地播放 软件下载地址m3u8嗅探 M3u8视频 M3u8视频格式是为网络视频播放设计,视频网站多数采用 m3u8格式。如腾讯,爱奇艺等网站。 m3u8和 mp4的区别: 一个 mp4是一个…

网络模型—BIO、NIO、IO多路复用、信号驱动IO、异步IO

一、用户空间和内核空间 以Linux系统为例,ubuntu和CentOS是Linux的两种比较常见的发行版,任何Linux发行版,其系统内核都是Linux。我们在发行版上操作应用,如Redis、Mysql等其实是无法直接执行访问计算机硬件(如cpu,内存…

STL库--stack

目录 stack的定义 stack容器内元素的访问 stack常用函数实例解析 stack的常见用途 stack的定义 其定义的写法和其他STL容器相同&#xff0c;typename可以任意基本类型或容器&#xff1a; stack<typename> name; stack容器内元素的访问 由于栈本身就是一种后进先出…

Facebook开户 | 如何检查公共主页的状态

想要了解你的Facebook公共主页的状态吗&#xff1f; Facebook公共主页是让广告主与粉丝互动、传播信息的绝佳平台&#xff0c;但是大家知道如何检查并维护自己的主页状态吗&#xff1f;别担心&#xff0c;Facebook提供了一系列简单易用的工具来帮助大家实现这一目标。 *Page Q…

如何恢复被盗的加密货币?

本世纪&#xff0c;网络犯罪的首要目标是加密货币。 这要归功于加密货币的日益普及和价值&#xff0c;网络犯罪分子已经认识到经济收益的潜力&#xff0c;并将重点转向利用这种数字资产中的漏洞。 在今天的文章中&#xff0c;我们将讨论加密货币恢复和被盗加密货币恢复。 我们…

爬虫案例-亚马逊反爬分析-验证码突破(x-amz-captcha)

总体概览&#xff1a;核心主要是需要突破该网站的验证码&#xff0c;成功后会返回我们需要的参数后再去请求一个中间页&#xff08;类似在后台注册一个session&#xff09;&#xff0c;最后需要注意一下 IP 是不能随意切换的 主要难点&#xff1a; 1、梳理整体反爬流程 2、验证…

深入解析Web前端三大主流框架:Angular、React和Vue

Web前端三大主流框架分别是Angular、React和Vue。下面我将为您详细介绍这三大框架的特点和使用指南。 Angular 核心概念: 组件(Components): 组件是Angular应用的构建块,每个组件由一个带有装饰器的类、一个HTML模板、一个CSS样式表组成。组件通过输入(@Input)和输出(…