当前位置: 首页 > article >正文

tf.Keras (tf-1.15)使用记录4-model.fit方法及其callbacks参数

model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。
其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API

1 model.fit() 方法的参数及使用

必需参数

  • x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生成器或 keras.utils.Sequence 实例。
  • y: 训练数据的目标(标签)。与输入 x 相对应,应该是 NumPy 数组或 TensorFlow tf.data.Dataset。当 xtf.data.Dataset、生成器或 Sequence 实例时,y 应该不被提供,因为 x 已经包含了输入和目标。

常用可选参数

  • batch_size: 整数,指定进行梯度更新时每个批次的样本数。默认值为 32。注意,当使用 tf.data.Dataset、生成器或 Sequence 作为输入时,不应指定 batch_size,因为这些数据结构已经定义了批次大小。
  • epochs: 整数,训练模型的轮数,即整个数据集的前向和反向传播次数。
  • verbose: 整数,日志显示模式。0 = 不在标准输出流中输出日志信息,1 = 进度条(默认),2 = 每轮一行。
  • callbacks: keras.callbacks.Callback 实例的列表。一系列在训练过程中会被调用的回调函数,用于查看训练过程中内部状态和统计信息。
  • validation_split: 浮点数,0 到 1 之间,用来指定一定比例的训练数据作为验证数据的比例。模型会在这些数据上评估损失和任何模型指标,但这些数据不会用于训练。
  • validation_data: 用作验证的数据。格式可以是 (X_val, y_val) 的元组,或者是 tf.data.Dataset。如果提供此参数,则不会根据 validation_split 从训练数据中分割验证数据。
  • shuffle: 布尔值或字符串,表示是否在每轮训练前打乱数据。默认为 True。当设置为 False 时,不会打乱数据。当输入为 tf.data.Dataset、生成器或 Sequence 实例时,此参数无效,因为这些数据结构可能已经定义了自己的打乱数据的方式。
  • initial_epoch: 用于恢复之前的训练。从该轮次开始训练,之前的轮次被视为已经训练过。

高级参数

  • steps_per_epoch: 整数,当使用生成器或 Sequence 实例作为输入时定义一个 epoch 完成并开始下一个 epoch 的总步数(批次数)。通常,应该等于数据集的样本数除以批次大小。
  • validation_steps: 当 validation_data 是生成器或 Sequence 实例时,此参数指定在停止前验证集的总步数(批次数)。
  • validation_batch_size: 整数,仅当 validation_data 是 NumPy 数组时有效。指定验证批次的大小。
  • validation_freq: 指定验证的频率。可以是整数,也可以是 'epoch' 或列表。如果是整数,则表示每多少个 epoch 验证一次。如果是列表,则列表中的元素指定了需要进行验证的 epoch。

使用示例

基本用法:

model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

使用验证数据:

model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

使用回调函数:

from tensorflow.keras.callbacks import EarlyStoppingearly_stopping = EarlyStopping(monitor='val_loss', patience=3)
model.fit(x_train, y_train, epochs=10, validation_split=0.2, callbacks=[early_stopping])

使用 tf.data.Dataset

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)

model.fit() 方法提供了灵活的方式来训练模型,通过合理设置参数,可以有效地控制训练过程和评估模型性能。

2 callbacks参数使用

callbacks 参数是 model.fit() 方法中一个重要参数,属于keras的高级用法,它允许在训练的不同阶段(如训练开始、训练结束、每个 epoch 开始/结束时等)执行特定的操作。

callbacks 是一个 tf.keras.callbacks.Callback 实例的列表,每个实例都能够访问到模型的内部状态和统计信息。TensorFlow Keras 提供了多种内置的回调函数,同时也支持自定义回调。
以下是callbacks类的全部方法类(https://keras.io/api/callbacks/):
在这里插入图片描述

  1. ModelCheckpoint: 在训练过程中保存模型或模型权重。

    • filepath: 保存模型的路径。
    • monitor: 被监视的数据。
    • verbose: 详细信息模式。
    • save_best_only: 若为 True,则只保存在验证集上性能最好的模型。
    • save_weights_only: 若为 True,则只保存模型的权重。
    • mode: {auto, min, max} 中的一个。决定监视的数据是应该最大化还是最小化。
    • save_freq: 保存模型的频率。
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min')
    
  2. EarlyStopping: 当被监视的数据不再提升,则停止训练。

    • monitor: 被监视的数据。
    • min_delta: 改进的最小变化量,小于这个量的改进将被忽略。
    • patience: 没有进步的训练轮数,在这之后训练将被停止。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
    
  3. ReduceLROnPlateau: 当学习停滞时,减少学习率。

    • monitor: 被监视的数据。
    • factor: 学习率将以这个因子减少。新的学习率 = 学习率 * 因子。
    • patience: 没有进步的训练轮数,在这之后学习率将被减少。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    • min_lr: 学习率的下限。
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
    
  4. TensorBoard: 为 TensorFlow 提供的可视化工具。

    • log_dir: 用来保存日志文件的路径,TensorBoard 将读取这个路径下的日志。
    • histogram_freq: 对于模型层的激活和权重直方图的计算频率(每个 epoch)。
    • write_graph: 是否在 TensorBoard 中可视化图形。如果 write_graph 被打开,日志文件会变得非常大。
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./logs')
    

使用示例

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10),tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min'),tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001),tf.keras.callbacks.TensorBoard(log_dir='./logs')
]model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=callbacks)

自定义回调

也可以通过继承 tf.keras.callbacks.Callback 类来创建自定义回调,允许在训练的不同阶段执行自定义的逻辑。

class CustomCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):# 每个 epoch 结束时执行keys = list(logs.keys())print(f"结束 epoch {epoch},损失 = {logs['loss']}, 验证损失 = {logs['val_loss']}")model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=[CustomCallback()])

回调提供了一种灵活的方式来嵌入训练过程,使得你可以在不改变模型代码的情况下,监控模型的训练、保存模型、调整学习率等。

相关文章:

tf.Keras (tf-1.15)使用记录4-model.fit方法及其callbacks参数

model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。 其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API 1 model.fit() 方法的参数及使用 必需参数 x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生…...

Easy系列PLC尺寸测量功能块ST代码(激光微距仪应用)

激光微距仪可以测量短距离内的产品尺寸,产品规格书的测量 精度可以到0.001mm。具体需要看不同的型号。 1、激光微距仪 2、尺寸测量应用 下面我们以测量高度为例子,设计一个高度测量功能块,同时给出测量数据和合格不合格指标。 3、高度测量功能块 4、复位完成信号 5、功能…...

996引擎 -地图-添加安全区

996引擎 -地图-添加安全区 文件位置配置 cfg_startpoint.xls特效效果1345参考资料文件位置 文件位置服务端D:\996M2-lua\MirServer-lua\Mir200客户端D:\996M2-lua\996M2_debug\dev配置 cfg_startpoint.xls 服务端\Mir200\Envir\DATA\cfg_startpoint.xls 填歪了也有可能只画一…...

Node.js 全局对象

Node.js 全局对象 引言 在Node.js中,全局对象是JavaScript环境中的一部分,它提供了对Node.js运行时环境的访问。全局对象在Node.js中扮演着重要的角色,它使得开发者能够访问和操作Node.js的许多核心功能。本文将详细介绍Node.js的全局对象,包括其特点、常用方法和应用场景…...

[Collection与数据结构] B树与B+树

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…...

redex快速体验

第一步: 2.回调函数在每次state发生变化时候自动执行...

【VM】VirtualBox安装CentOS8虚拟机

阅读本文前,请先根据 VirtualBox软件安装教程 安装VirtualBox虚拟机软件。 1. 下载centos8系统iso镜像 可以去两个地方下载,推荐跟随本文的操作用阿里云的镜像 centos官网:https://www.centos.org/download/阿里云镜像:http://…...

电子电气架构 --- 汽车电子拓扑架构的演进过程

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活…...

自动驾驶---苏箐对智驾产品的思考

1 前言 对于更高级别的自动驾驶,很多人都有不同的思考,方案也好,产品也罢。最近在圈内一位知名的自动驾驶专家苏箐发表了他自己对于自动驾驶未来的思考。 苏箐是地平线的副总裁兼首席架构师,同时也是高阶智能驾驶解决方案SuperDri…...

手写call函数、手写apply函数、手写bind函数

文章目录 1 手写call函数2 手写apply函数3 手写bind函数 1 手写call函数 call函数的实现步骤: 判断调用对象是否为函数。判断传入上下文对象是否存在,如果不存在,则设置为window。处理传入的参数,截取第一个参数后的所有参数。将…...

Python | Pytorch | Tensor知识点总结

如是我闻: Tensor 是我们接触Pytorch了解到的第一个概念,这里是一个关于 PyTorch Tensor 主题的知识点总结,涵盖了 Tensor 的基本概念、创建方式、运算操作、梯度计算和 GPU 加速等内容。 1. Tensor 基本概念 Tensor 是 PyTorch 的核心数据结…...

90,【6】攻防世界 WEB Web_php_unserialize

进入靶场 进入靶场 <?php // 定义一个名为 Demo 的类 class Demo { // 定义一个私有属性 $file&#xff0c;默认值为 index.phpprivate $file index.php;// 构造函数&#xff0c;当创建类的实例时会自动调用// 接收一个参数 $file&#xff0c;用于初始化对象的 $file 属…...

快速提升网站收录:如何设置网站标签?

本文转自&#xff1a;百万收录网 原文链接&#xff1a;https://www.baiwanshoulu.com/45.html 为了快速提升网站的收录&#xff0c;合理设置网站标签是至关重要的。网站标签主要包括标题标签&#xff08;TitleTag&#xff09;、描述标签&#xff08;DescriptionTag&#xff09…...

【数据分析】案例04:豆瓣电影Top250的数据分析与Web网页可视化(numpy+pandas+matplotlib+flask)

豆瓣电影Top250的数据分析与Web网页可视化(numpy+pandas+matplotlib+flask) 豆瓣电影Top250官网:https://movie.douban.com/top250写在前面 实验目的:实现豆瓣电影Top250详情的数据分析与Web网页可视化。电脑系统:Windows使用软件:PyCharm、NavicatPython版本:Python 3.…...

Banana JS,一个严格子集 JavaScript 的解释器

项目地址&#xff1a;https://github.com/shajunxing/banana-js 特色 我的目标是剔除我在实践中总结的JavaScript语言的没用的和模棱两可的部分&#xff0c;只保留我喜欢和需要的&#xff0c;创建一个最小的语法解释器。只支持 JSON 兼容的数据类型和函数&#xff0c;函数是第…...

2025.2.1——四、php_rce RCE漏洞|PHP框架

题目来源&#xff1a;攻防世界 php_rce 目录 一、打开靶机&#xff0c;整理信息 二、解题思路 step 1&#xff1a;PHP框架漏洞以及RCE漏洞信息 1.PHP常用框架 2.RCE远程命令执行 step 2&#xff1a;根据靶机提示&#xff0c;寻找版本漏洞 step 3&#xff1a;进行攻击…...

从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架(绘图设备封装)

目录 图像层的底层抽象——绘图设备抽象 如何抽象一个绘图设备&#xff1f; 桥接绘图设备&#xff0c;特化为OLED设备 题外话&#xff1a;设备的属性&#xff0c;与设计一个相似函数化简的通用办法 使用函数指针来操作设备 总结一下 图像层的底层抽象——绘图设备抽象 在…...

对比DeepSeek、ChatGPT和Kimi的学术写作撰写引言能力

引言 引言部分引入研究主题&#xff0c;明确研究背景、问题陈述&#xff0c;并提出研究的目的和重要性&#xff0c;最后&#xff0c;概述研究方法和论文结构。 下面我们使用DeepSeek、ChatGPT4以及Kimi辅助引言撰写。 提示词&#xff1a; 你现在是一名[计算机理论专家]&#…...

【C++篇】哈希表

目录 一&#xff0c;哈希概念 1.1&#xff0c;直接定址法 1.2&#xff0c;哈希冲突 1.3&#xff0c;负载因子 二&#xff0c;哈希函数 2.1&#xff0c;除法散列法 /除留余数法 2.2&#xff0c;乘法散列法 2.3&#xff0c;全域散列法 三&#xff0c;处理哈希冲突 3.1&…...

TVM调度原语完全指南:从入门到微架构级优化

调度原语 在TVM的抽象体系中&#xff0c;调度&#xff08;Schedule&#xff09;是对计算过程的时空重塑。每一个原语都是改变计算次序、数据流向或并行策略的手术刀。其核心作用可归纳为&#xff1a; 优化目标 max ⁡ ( 计算密度 内存延迟 指令开销 ) \text{优化目标} \max…...

《解锁AI黑科技:数据分类聚类与可视化》

在当今数字化时代&#xff0c;数据如潮水般涌来&#xff0c;如何从海量数据中提取有价值的信息&#xff0c;成为了众多领域面临的关键挑战。人工智能&#xff08;AI&#xff09;技术的崛起&#xff0c;为解决这一难题提供了强大的工具。其中&#xff0c;能够实现数据分类与聚类…...

[MySQL]事务的隔离级别原理与底层实现

目录 1.为什么要有隔离性 2.事务的隔离级别 读未提交 读提交 可重复读 串行化 3.演示事务隔离级别的操作 查看与设置事务的隔离级别 演示读提交操作 演示可重复读操作 1.为什么要有隔离性 在真正的业务场景下&#xff0c;MySQL服务在同一时间一定会有大量的客户端进程…...

数据密码解锁之DeepSeek 和其他 AI 大模型对比的神秘面纱

本篇将揭露DeepSeek 和其他 AI 大模型差异所在。 目录 ​编辑 一本篇背景&#xff1a; 二性能对比&#xff1a; 2.1训练效率&#xff1a; 2.2推理速度&#xff1a; 三语言理解与生成能力对比&#xff1a; 3.1语言理解&#xff1a; 3.2语言生成&#xff1a; 四本篇小结…...

知识管理系统推动企业知识创新与人才培养的有效途径分析

内容概要 本文旨在深入探讨知识管理系统在现代企业中的应用及其对于知识创新与人才培养的重要性。通过分析知识管理系统的概念&#xff0c;企业可以认识到它不仅仅是信息管理的一种工具&#xff0c;更是提升整体创新能力的战略性资产。知识管理系统通过集成企业内部信息资源&a…...

【数据结构与算法】动态规划

目录 动态规划 1. 基本概念 2. 基本步骤 3. 经典应用场景 4. 优点和局限性 最长递增子序列&#xff08;中等&#xff09; 最大子数组和&#xff08;中等&#xff09; 动态规划 动态规划是一种用于解决多阶段决策问题的算法思想&#xff0c;它将复杂问题分解为一系列相对…...

ASP.NET Core 中使用依赖注入 (DI) 容器获取并执行自定义服务

目录 一、ASP.NET Core 中使用依赖注入 (DI) 容器获取并执行自定义服务 1. app.Services 2. GetRequiredService() 3. Init() 二、应用场景 三、依赖注入使用拓展 1、使用场景 2、使用步骤 1. 定义服务接口和实现类 2. 注册服务到依赖注入容器 3. 使用依赖注入获取并…...

Nginx知识

nginx 精简的配置文件 worker_processes 1; # 可以理解为一个内核一个worker # 开多了可能性能不好events {worker_connections 1024; } # 一个 worker 可以创建的连接数 # 1024 代表默认一般不用改http {include mime.types;# 代表引入的配置文件# mime.types 在 ngi…...

CSES Missing Coin Sum

思路是对数组排序 设 S [ i ] S[i] S[i] 是数组的前缀和 R [ i ] R[i] R[i] 是递增排序后的数组 遍历数组&#xff0c;如果出现 S [ i − 1 ] 1 < R [ i ] S[i - 1] 1 < R[i] S[i−1]1<R[i]&#xff0c;就代表S[i - 1] 1是不能被合成出来的数字 因为&#xff1a…...

nth_element函数——C++快速选择函数

目录 1. 函数原型 2. 功能描述 3. 算法原理 4. 时间复杂度 5. 空间复杂度 6. 使用示例 8. 注意事项 9. 自定义比较函数 11. 总结 nth_element 是 C 标准库中提供的一个算法&#xff0c;位于 <algorithm> 头文件中&#xff0c;用于部分排序序列。它的主要功能是将…...

Hot100之双指针

283移动零 题目 思路解析 那我们就把不为0的数字都放在数组前面&#xff0c;然后数组后面的数字都为0就行了 代码 class Solution {public void moveZeroes(int[] nums) {int left 0;for (int num : nums) {if (num ! 0) {nums[left] num;// left最后会变成数组中不为0的数…...