当前位置: 首页 > news >正文

政安晨:示例演绎TensorFlow的官方指南(二){Estimator}

咱们接着演绎TensorFlow官方指南,我的这个系列的上一篇文章为:

政安晨:示例演绎TensorFlow的官方指南(一){基础知识}icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136067030为什么要演绎官方指南,我在上一篇说过了,这次没有废话,直接开始。

Estimator介绍


政安晨:

咱们先看一下Estimator的背景。

TensorFlow的Estimator API是一种高级的机器学习API,用于简化模型的训练、评估和推理过程。它提供了一种更加高层次的抽象,使开发者能够更加专注于模型的架构和数据流水线的设计,而不需要太多地关注底层的实现细节。

Estimator API提供了一套统一的接口,可以用于各种机器学习任务,如分类、回归、聚类等。它具有以下几个主要特点:

  1. 封装了模型的训练、评估和推理过程,提供了一种简单且一致的方式来组织代码和配置模型。

  2. 支持分布式训练,可以轻松地在多个GPU或多台机器上进行训练,以加速模型的训练过程。

  3. 提供了一系列内置的模型,如线性模型、DNN模型、CNN模型等,可以根据任务的需求快速构建模型。

  4. 可以使用预定义的特征列(feature columns)来处理和预处理输入数据,简化了数据准备的过程。

  5. 可以使用高层的tf.data.Dataset API来读取和处理数据,使数据加载和预处理过程更加灵活和高效。

使用Estimator API时,需要定义一个Estimator对象,这个对象包含了模型的结构和参数。然后,通过调用Estimator对象的train()方法来训练模型,evaluate()方法来评估模型,predict()方法来进行预测。在训练模型时,可以通过tf.estimator.TrainSpec对象来指定训练数据的路径和其他相关参数。在评估模型时,可以通过tf.estimator.EvalSpec对象来指定评估数据的路径和其他相关参数。

总之,Estimator API提供了一种简单、灵活且高效的方式来构建、训练和评估机器学习模型,使开发者能够更加专注于模型的设计和业务逻辑。


这篇官方文档介绍了 tf.estimator,它是一种高级 TensorFlow API。Estimator 封装了以下操作:

  • 训练
  • 评估
  • 预测
  • 导出以供使用

您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator 类的类。

有关 API 设计概述,请参阅白皮书。

优势

与 tf.keras.Model 类似,estimator 是模型级别的抽象。tf.estimator 提供了一些目前仍在为 tf.keras 开发中的功能。包括:

  • 基于参数服务器的训练
  • 完整的 TFX 集成

政安晨:

为了后面的演绎,我们先设置一下环境:


Estimator 功能

Estimator 提供了以下优势:

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
  • Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
    • 加载数据
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 摘要

在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。

预制 Estimator

使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。

预制 Estimator 程序结构

依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:

1. 编写一个或多个数据集导入函数。

例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:

  • 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
  • 包含一个或多个标签的张量

例如,以下代码展示了输入函数的基本框架:

def input_fn(dataset):     ...  # manipulate dataset, extracting the feature dict and the label     return feature_dict, label

政安晨:

数据框架其实是这样的,不知为何官方文档中没有给出?

def train_input_fn():titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")titanic = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=32,label_name="survived")titanic_batches = (titanic.cache().repeat().shuffle(500).prefetch(tf.data.AUTOTUNE))return titanic_batches

执行如下:


2. 定义特征列。

每个 tf.feature_column 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:

# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column(   'median_education',   normalizer_fn=lambda x: x - global_education_mean)

3. 实例化相关预制 Estimator。

例如,下面是对名为 LinearClassifier 的预制 Estimator 进行实例化的示例:

# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier(   feature_columns=[population, crime_rate, median_education])

4. 调用训练、评估或推断方法。

例如,所有 Estimator 都会提供一个用于训练模型的 train 方法。

# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)

预制 Estimator 的优势

预制 Estimator 对最佳做法进行了编码,具有以下优势:

  • 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
  • 事件(摘要)编写和通用摘要的最佳做法。

如果不使用预制 Estimator,则您必须自己实现上述功能。

自定义 Estimator

每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。

推荐工作流

  1. 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
  2. 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
  3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
  4. 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

从 Keras 模型创建 Estimator

您可以使用 tf.keras.estimator.model_to_estimator 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。

实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = Falseestimator_model = tf.keras.Sequential([keras_mobilenet_v2,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(1)
])# Compile the model
estimator_model.compile(optimizer='adam',loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])

政安晨执行:

从已编译的 Keras 模型创建 Estimator。Keras 模型的初始模型状态会保留在已创建的 Estimator中:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)

您可以像对待任何其他 Estimator 一样对待派生的 Estimator

IMG_SIZE = 160  # All images will be resized to 160x160def preprocess(image, label):image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return image, label
def train_input_fn(batch_size):data = tfds.load('cats_vs_dogs', as_supervised=True)train_data = data['train']train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)return train_data

要进行训练,可调用 Estimator 的训练函数:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)

同样,要进行评估,可调用 Estimator 的评估函数:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)

有关详细信息,请参阅 tf.keras.estimator.model_to_estimator 文档。

写在最后

其实这一篇中官方指南并不详尽,尤其是最后的训练部分,咱们补充了一些,但仍然存在缺失,我将在后续的文章中以实际项目为例,详细演绎。

相关文章:

政安晨:示例演绎TensorFlow的官方指南(二){Estimator}

咱们接着演绎TensorFlow官方指南,我的这个系列的上一篇文章为: 政安晨:示例演绎TensorFlow的官方指南(一){基础知识}https://blog.csdn.net/snowdenkeke/article/details/136067030为什么要演绎官方指南,我…...

vue3:24—组件通信方式

目录 1、props 2、自定义事件 (emit) 3、mitt(任意组件的通讯) 4、v-model【封装ui组件库用的多,平时用的少。和vue2有点不同】 5、$attrs 6、$refs和$parent 7、provide和inject 8、pinia(即vue2中…...

WebGL+Three.js入门与实战——绘制水平移动的点、通过鼠标控制绘制(点击绘制、移动绘制、模拟画笔)

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…...

大数据环境搭建(一)-Hive

1 hive介绍 由Facebook开源的,用于解决海量结构化日志的数据统计的项目 本质上是将HQL转化为MapReduce、Tez、Spark等程序 Hive表的数据是HDFS上的目录和文件 Hive元数据 metastore,包含Hive表的数据库、表名、列、分区、表类型、表所在目录等。 根据Hive部署模…...

mac电脑上使用android studio创建flutter项目

mac电脑环境配置可以看这篇文章:https://xiaoshen.blog.csdn.net/article/details/136068650 配置玩环境之后,开始创建第一个flutter项目:点击new flutter project或者new project都可以 然后选择flutter: 并将sdk配置为解压后的…...

Excel——分类汇总

1.一级分类汇总 Q:请根据各销售地区统计销售额总数。 第一步:排序,我们需要根据销售地区汇总数据,我们就要对【销售地区】的内容进行排序。点击【销售地区】列中任意一个单元格,选择【数据】——【排序】&#xff0c…...

Backtrader 文档学习- Observers - Reference

Backtrader 文档学习- Observers - Reference 1.Benchmark class backtrader.observers.Benchmark() 观察器存储策略的回报和参考资产的回报,参考资产是传递给系统的数据之一。 参数: timeframe (default: None) ,如果None,则将…...

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Radio组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Radio组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Radio组件 单选框,提供相应的用户交互选择项。 子组件 无。 接口 …...

【go】结构体切片去重

场景 自定义结构体切片,去除切片中的重复元素(所有值完全相同) 代码 // 自定义struct去重 type AssetAppIntranets struct {ID string json:"id,omitempty"AppID string json:"app_id,omitempty"IP …...

百面嵌入式专栏(面试题)C语言面试题22道

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍C语言相关面试题 。 宏定义是在编译的哪个阶段被处理的?答案:宏定义是在编译预处理阶段被处理的。 解读:编译预处理:头文件包含、宏替换、条件编译、去除注释、添加行号。 写一个“标准”宏MIN,这个…...

Docker方式创建keepalived连接MGR集群

记录一下简单的搭建步骤以便后期查验 目录 前言步骤1. 安装环境2. 重新制作镜像3. 导入新镜像4. 创建容器 前言 假设已经搭建了MySQL8的MGR集群方式(一主两从)。 MGR本身有故障转移重新选举新的主节点功能,但是上游的应用程序需要自己手动修…...

Oracle PL/SQL Programming 第5章:Iterative Processing with Loops 读书笔记

总的目录和进度,请参见开始读 Oracle PL/SQL Programming 第6版 本章探讨 PL/SQL 的迭代控制结构(也称为循环),它允许您重复执行相同的代码。 PL/SQL 提供了三种不同类型的循环结构: 简单或无限循环FOR 循环&#x…...

C入门番外篇——C, Are you OK?

今日路上看到一个车牌,52U0K,感觉很有意思,如果做一下简单的翻译就是,“我爱你,好么?” 这样让我脑子中闪现了这样的一个画面:“一个男生追一个女生,看到女生不怎么搭理自己的样子&a…...

python-产品篇-游戏-象棋

文章目录 代码效果 代码 import pygame import time import constants from button import Button import pieces import computerclass MainGame():window NoneStart_X constants.Start_XStart_Y constants.Start_YLine_Span constants.Line_SpanMax_X Start_X 8 * Lin…...

用linux文件系统的链接功能实现文件缓存LRU

概述: 目前,随着家庭宽带网络、无线宽带技术,以及终端设备性能的不断发展,基于多媒体的应用越来越广泛,特别是互联网视频的应用更是成为推动这些技术发展的源动力。作为互联网视频VOD的应用,提高视频播放的流畅度是一个非常重要的指标之一。除了编解码技术,视频C…...

AI大模型开发架构设计(10)——AI大模型架构体系与典型应用场景

文章目录 AI大模型架构体系与典型应用场景1 AI大模型架构体系你了解多少?GPT 助手训练流程GPT 助手训练数据预处理2个训练案例分析 2 AI 大模型的典型应用场景以及应用架构剖析AI 大模型的典型应用场景AI 大模型应用架构 AI大模型架构体系与典型应用场景 1 AI大模型架构体系你…...

C# async/await的使用

C# 中的 async 和 await 关键字是用于实现异步编程的重要工具,它们简化了编写和维护非阻塞代码的过程。以下是对这两个关键字用法的简要说明: async 关键字 定义异步方法:在方法声明前使用 async 关键字,表示该方法是一个异步方…...

C语言之找单身狗

个人主页(找往期文章包括但不限于本期文章中不懂的知识点): 我要学编程(ಥ_ಥ)-CSDN博客 题目: 在一个整型数组中,只有一个数字出现一次,其他数组都是成对出现的,请找出那个只出现一次的数字。…...

读懂 FastChat 大模型部署源码所需的异步编程基础

原文:读懂 FastChat 大模型部署源码所需的异步编程基础 - 知乎 目录 0. 前言 1. 同步与异步的区别 2. 协程 3. 事件循环 4. await 5. 组合协程 6. 使用 Semaphore 限制并发数 7. 运行阻塞任务 8. 异步迭代器 async for 9. 异步上下文管理器 async with …...

【华为】GRE VPN 实验配置

【华为】GRE VPN 实验配置 前言报文格式 实验需求配置思路配置拓扑GRE配置步骤R1基础配置GRE 配置 ISP_R2基础配置 R3基础配置GRE 配置 PCPC1PC2 抓包检查OSPF建立GRE隧道建立 配置文档 前言 VPN :(Virtual Private Network),即“…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用:实现组件通用属性的渐变过渡效果,提升用户体验。支持属性:width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项: 布局类属性(如宽高)变化时&#…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言:为什么 Eureka 依然是存量系统的核心? 尽管 Nacos 等新注册中心崛起,但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制,是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...

用docker来安装部署freeswitch记录

今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

AI书签管理工具开发全记录(十九):嵌入资源处理

1.前言 📝 在上一篇文章中,我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源,方便后续将资源打包到一个可执行文件中。 2.embed介绍 🎯 Go 1.16 引入了革命性的 embed 包,彻底改变了静态资源管理的…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...