当前位置: 首页 > news >正文

政安晨:示例演绎TensorFlow的官方指南(二){Estimator}

咱们接着演绎TensorFlow官方指南,我的这个系列的上一篇文章为:

政安晨:示例演绎TensorFlow的官方指南(一){基础知识}icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136067030为什么要演绎官方指南,我在上一篇说过了,这次没有废话,直接开始。

Estimator介绍


政安晨:

咱们先看一下Estimator的背景。

TensorFlow的Estimator API是一种高级的机器学习API,用于简化模型的训练、评估和推理过程。它提供了一种更加高层次的抽象,使开发者能够更加专注于模型的架构和数据流水线的设计,而不需要太多地关注底层的实现细节。

Estimator API提供了一套统一的接口,可以用于各种机器学习任务,如分类、回归、聚类等。它具有以下几个主要特点:

  1. 封装了模型的训练、评估和推理过程,提供了一种简单且一致的方式来组织代码和配置模型。

  2. 支持分布式训练,可以轻松地在多个GPU或多台机器上进行训练,以加速模型的训练过程。

  3. 提供了一系列内置的模型,如线性模型、DNN模型、CNN模型等,可以根据任务的需求快速构建模型。

  4. 可以使用预定义的特征列(feature columns)来处理和预处理输入数据,简化了数据准备的过程。

  5. 可以使用高层的tf.data.Dataset API来读取和处理数据,使数据加载和预处理过程更加灵活和高效。

使用Estimator API时,需要定义一个Estimator对象,这个对象包含了模型的结构和参数。然后,通过调用Estimator对象的train()方法来训练模型,evaluate()方法来评估模型,predict()方法来进行预测。在训练模型时,可以通过tf.estimator.TrainSpec对象来指定训练数据的路径和其他相关参数。在评估模型时,可以通过tf.estimator.EvalSpec对象来指定评估数据的路径和其他相关参数。

总之,Estimator API提供了一种简单、灵活且高效的方式来构建、训练和评估机器学习模型,使开发者能够更加专注于模型的设计和业务逻辑。


这篇官方文档介绍了 tf.estimator,它是一种高级 TensorFlow API。Estimator 封装了以下操作:

  • 训练
  • 评估
  • 预测
  • 导出以供使用

您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator 类的类。

有关 API 设计概述,请参阅白皮书。

优势

与 tf.keras.Model 类似,estimator 是模型级别的抽象。tf.estimator 提供了一些目前仍在为 tf.keras 开发中的功能。包括:

  • 基于参数服务器的训练
  • 完整的 TFX 集成

政安晨:

为了后面的演绎,我们先设置一下环境:


Estimator 功能

Estimator 提供了以下优势:

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
  • Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
    • 加载数据
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 摘要

在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。

预制 Estimator

使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。

预制 Estimator 程序结构

依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:

1. 编写一个或多个数据集导入函数。

例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:

  • 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
  • 包含一个或多个标签的张量

例如,以下代码展示了输入函数的基本框架:

def input_fn(dataset):     ...  # manipulate dataset, extracting the feature dict and the label     return feature_dict, label

政安晨:

数据框架其实是这样的,不知为何官方文档中没有给出?

def train_input_fn():titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")titanic = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=32,label_name="survived")titanic_batches = (titanic.cache().repeat().shuffle(500).prefetch(tf.data.AUTOTUNE))return titanic_batches

执行如下:


2. 定义特征列。

每个 tf.feature_column 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:

# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column(   'median_education',   normalizer_fn=lambda x: x - global_education_mean)

3. 实例化相关预制 Estimator。

例如,下面是对名为 LinearClassifier 的预制 Estimator 进行实例化的示例:

# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier(   feature_columns=[population, crime_rate, median_education])

4. 调用训练、评估或推断方法。

例如,所有 Estimator 都会提供一个用于训练模型的 train 方法。

# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)

预制 Estimator 的优势

预制 Estimator 对最佳做法进行了编码,具有以下优势:

  • 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
  • 事件(摘要)编写和通用摘要的最佳做法。

如果不使用预制 Estimator,则您必须自己实现上述功能。

自定义 Estimator

每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。

推荐工作流

  1. 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
  2. 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
  3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
  4. 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

从 Keras 模型创建 Estimator

您可以使用 tf.keras.estimator.model_to_estimator 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。

实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = Falseestimator_model = tf.keras.Sequential([keras_mobilenet_v2,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(1)
])# Compile the model
estimator_model.compile(optimizer='adam',loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])

政安晨执行:

从已编译的 Keras 模型创建 Estimator。Keras 模型的初始模型状态会保留在已创建的 Estimator中:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)

您可以像对待任何其他 Estimator 一样对待派生的 Estimator

IMG_SIZE = 160  # All images will be resized to 160x160def preprocess(image, label):image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return image, label
def train_input_fn(batch_size):data = tfds.load('cats_vs_dogs', as_supervised=True)train_data = data['train']train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)return train_data

要进行训练,可调用 Estimator 的训练函数:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)

同样,要进行评估,可调用 Estimator 的评估函数:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)

有关详细信息,请参阅 tf.keras.estimator.model_to_estimator 文档。

写在最后

其实这一篇中官方指南并不详尽,尤其是最后的训练部分,咱们补充了一些,但仍然存在缺失,我将在后续的文章中以实际项目为例,详细演绎。

相关文章:

政安晨:示例演绎TensorFlow的官方指南(二){Estimator}

咱们接着演绎TensorFlow官方指南,我的这个系列的上一篇文章为: 政安晨:示例演绎TensorFlow的官方指南(一){基础知识}https://blog.csdn.net/snowdenkeke/article/details/136067030为什么要演绎官方指南,我…...

vue3:24—组件通信方式

目录 1、props 2、自定义事件 (emit) 3、mitt(任意组件的通讯) 4、v-model【封装ui组件库用的多,平时用的少。和vue2有点不同】 5、$attrs 6、$refs和$parent 7、provide和inject 8、pinia(即vue2中…...

WebGL+Three.js入门与实战——绘制水平移动的点、通过鼠标控制绘制(点击绘制、移动绘制、模拟画笔)

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…...

大数据环境搭建(一)-Hive

1 hive介绍 由Facebook开源的,用于解决海量结构化日志的数据统计的项目 本质上是将HQL转化为MapReduce、Tez、Spark等程序 Hive表的数据是HDFS上的目录和文件 Hive元数据 metastore,包含Hive表的数据库、表名、列、分区、表类型、表所在目录等。 根据Hive部署模…...

mac电脑上使用android studio创建flutter项目

mac电脑环境配置可以看这篇文章:https://xiaoshen.blog.csdn.net/article/details/136068650 配置玩环境之后,开始创建第一个flutter项目:点击new flutter project或者new project都可以 然后选择flutter: 并将sdk配置为解压后的…...

Excel——分类汇总

1.一级分类汇总 Q:请根据各销售地区统计销售额总数。 第一步:排序,我们需要根据销售地区汇总数据,我们就要对【销售地区】的内容进行排序。点击【销售地区】列中任意一个单元格,选择【数据】——【排序】&#xff0c…...

Backtrader 文档学习- Observers - Reference

Backtrader 文档学习- Observers - Reference 1.Benchmark class backtrader.observers.Benchmark() 观察器存储策略的回报和参考资产的回报,参考资产是传递给系统的数据之一。 参数: timeframe (default: None) ,如果None,则将…...

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Radio组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Radio组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Radio组件 单选框,提供相应的用户交互选择项。 子组件 无。 接口 …...

【go】结构体切片去重

场景 自定义结构体切片,去除切片中的重复元素(所有值完全相同) 代码 // 自定义struct去重 type AssetAppIntranets struct {ID string json:"id,omitempty"AppID string json:"app_id,omitempty"IP …...

百面嵌入式专栏(面试题)C语言面试题22道

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍C语言相关面试题 。 宏定义是在编译的哪个阶段被处理的?答案:宏定义是在编译预处理阶段被处理的。 解读:编译预处理:头文件包含、宏替换、条件编译、去除注释、添加行号。 写一个“标准”宏MIN,这个…...

Docker方式创建keepalived连接MGR集群

记录一下简单的搭建步骤以便后期查验 目录 前言步骤1. 安装环境2. 重新制作镜像3. 导入新镜像4. 创建容器 前言 假设已经搭建了MySQL8的MGR集群方式(一主两从)。 MGR本身有故障转移重新选举新的主节点功能,但是上游的应用程序需要自己手动修…...

Oracle PL/SQL Programming 第5章:Iterative Processing with Loops 读书笔记

总的目录和进度,请参见开始读 Oracle PL/SQL Programming 第6版 本章探讨 PL/SQL 的迭代控制结构(也称为循环),它允许您重复执行相同的代码。 PL/SQL 提供了三种不同类型的循环结构: 简单或无限循环FOR 循环&#x…...

C入门番外篇——C, Are you OK?

今日路上看到一个车牌,52U0K,感觉很有意思,如果做一下简单的翻译就是,“我爱你,好么?” 这样让我脑子中闪现了这样的一个画面:“一个男生追一个女生,看到女生不怎么搭理自己的样子&a…...

python-产品篇-游戏-象棋

文章目录 代码效果 代码 import pygame import time import constants from button import Button import pieces import computerclass MainGame():window NoneStart_X constants.Start_XStart_Y constants.Start_YLine_Span constants.Line_SpanMax_X Start_X 8 * Lin…...

用linux文件系统的链接功能实现文件缓存LRU

概述: 目前,随着家庭宽带网络、无线宽带技术,以及终端设备性能的不断发展,基于多媒体的应用越来越广泛,特别是互联网视频的应用更是成为推动这些技术发展的源动力。作为互联网视频VOD的应用,提高视频播放的流畅度是一个非常重要的指标之一。除了编解码技术,视频C…...

AI大模型开发架构设计(10)——AI大模型架构体系与典型应用场景

文章目录 AI大模型架构体系与典型应用场景1 AI大模型架构体系你了解多少?GPT 助手训练流程GPT 助手训练数据预处理2个训练案例分析 2 AI 大模型的典型应用场景以及应用架构剖析AI 大模型的典型应用场景AI 大模型应用架构 AI大模型架构体系与典型应用场景 1 AI大模型架构体系你…...

C# async/await的使用

C# 中的 async 和 await 关键字是用于实现异步编程的重要工具,它们简化了编写和维护非阻塞代码的过程。以下是对这两个关键字用法的简要说明: async 关键字 定义异步方法:在方法声明前使用 async 关键字,表示该方法是一个异步方…...

C语言之找单身狗

个人主页(找往期文章包括但不限于本期文章中不懂的知识点): 我要学编程(ಥ_ಥ)-CSDN博客 题目: 在一个整型数组中,只有一个数字出现一次,其他数组都是成对出现的,请找出那个只出现一次的数字。…...

读懂 FastChat 大模型部署源码所需的异步编程基础

原文:读懂 FastChat 大模型部署源码所需的异步编程基础 - 知乎 目录 0. 前言 1. 同步与异步的区别 2. 协程 3. 事件循环 4. await 5. 组合协程 6. 使用 Semaphore 限制并发数 7. 运行阻塞任务 8. 异步迭代器 async for 9. 异步上下文管理器 async with …...

【华为】GRE VPN 实验配置

【华为】GRE VPN 实验配置 前言报文格式 实验需求配置思路配置拓扑GRE配置步骤R1基础配置GRE 配置 ISP_R2基础配置 R3基础配置GRE 配置 PCPC1PC2 抓包检查OSPF建立GRE隧道建立 配置文档 前言 VPN :(Virtual Private Network),即“…...

创建一个VUE项目(vue2和vue3)

背景:电脑已经安装完vue2和vue3环境 一台Mac同时安装vue2和vue3 https://blog.csdn.net/c103363/article/details/136059783 创建vue2项目 vue init webpack "项目名称"创建vue3项目 vue create "项目名称"...

Android 10.0 动态壁纸 LiveWallpaper

前言 在 Android 中,壁纸分为动态与静态两种,但其实两者得本质都是一样。都以一个 Service 得形式在后台运行,在一个类型为 TYPE_WALLPAPER 的窗口上绘制内容。也可以这么去理解:静态壁纸是一种特殊的动态壁纸,它仅在…...

Linux内核与驱动面试经典“小”问题集锦(4)

接前一篇文章:Linux内核与驱动面试经典“小”问题集锦(3) 问题5 问:Linux内核中内存分配都有哪些方式?它们之间的使用场景都是什么? 备注:这个问题是笔者近期参加蔚来面试时遇到的一个问题。这…...

使用python实现:判断任意坐标点在STL几何模型的内部或外部

简介 在STL几何模型处理的过程中,经常需要判断一个点是否在模型的内部。网上给出的资料主要是使用C vtk的,而python vtk的很少。本文给出了一段精简版的python代码,实现判断任意坐标点在STL几何模型的内部或外部。 代码 首先定义三个函数 …...

leetcode(滑动窗口)483.找到字符中所有字母异位词(C++详细解释)DAY4

文章目录 1.题目示例提示 2.解答思路3.实现代码结果 4.总结 1.题目 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词 的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 异位词 指由相同字母重排列形成的字符串(包括相同的字符串&a…...

Leaf——美团点评分布式ID生成系统

0.普通算法生成id的缺点 1.Leaf-segment数据库方案 第一种Leaf-segment方案,在使用数据库的方案上,做了如下改变: - 原方案每次获取ID都得读写一次数据库,造成数据库压力大。改为利用proxy server批量获取,每次获取一…...

ProcessSlot构建流程分析

ProcessorSlot ProcessorSlot构建流程 // com.alibaba.csp.sentinel.CtSph#lookProcessChain private Entry entryWithPriority(ResourceWrapper resourceWrapper, int count, boolean prioritized, Object... args)throws BlockException {// 省略创建 Context 的代码// 黑盒…...

工业笔记本丨行业三防笔记本丨亿道加固笔记本定制丨极端温度优势

工业笔记本是专为在恶劣环境条件下工作而设计的高度耐用的计算机设备。与传统消费者级笔记本电脑相比,工业笔记本在极端温度下展现出了许多优势。本文将探讨工业笔记本在极端温度环境中的表现,并介绍其优势。 耐高温性能: 工业笔记本具有更高的耐高温性…...

游戏服务器多少钱一台?腾讯云32元,阿里云26元

游戏服务器租用多少钱一年?1个月游戏服务器费用多少?阿里云游戏服务器26元1个月、腾讯云游戏服务器32元,游戏服务器配置从4核16G、4核32G、8核32G、16核64G等配置可选,可以选择轻量应用服务器和云服务器,阿腾云atengyu…...

实战分享:SpringBoot在创新创业项目管理中的应用

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…...