当前位置: 首页 > news >正文

TensorFlow高阶API和低阶API

TensorFlow提供了众多的API,简单地可以分类为高阶API和低阶API. API太多太乱也是TensorFlow被诟病的重点之一,可能因为Google的工程师太多了,社区太活跃了~当然后来Google也意识到这个问题,在TensorFlow 2.0中有了很大的改善。本文就简要介绍一下TensorFlow的高阶API和低阶API使用,提供推荐的使用方式。

高阶API(For beginners)

The best place to start is with the user-friendly Keras sequential API. Build models by plugging together building blocks.

TensorFlow推荐使用Keras的sequence函数作为高阶API的入口进行模型的构建,就像堆积木一样:

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf  
from tensorflow.keras.layers import Flatten, Dense, Dropout# 加载并准备好MNIST数据集
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将样本从0~255的整数转换为0~1的浮点数x_train, x_test = x_train / 255.0, x_test / 255.0# 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型
model = tf.keras.models.Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dropout(0.5),Dense(10, activation='softmax')])# 为训练选择优化器和损失函数model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 训练并验证模型
model.fit(x_train, y_train, epochs=5)model.evaluate(x_test,  y_test, verbose=2)

输出的日志:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 4s 72us/sample - loss: 0.2919 - accuracy: 0.9156
Epoch 2/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1439 - accuracy: 0.9568
Epoch 3/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1080 - accuracy: 0.9671
Epoch 4/5
60000/60000 [==============================] - 4s 59us/sample - loss: 0.0875 - accuracy: 0.9731
Epoch 5/5
60000/60000 [==============================] - 3s 58us/sample - loss: 0.0744 - accuracy: 0.9766
10000/1 - 1s - loss: 0.0383 - accuracy: 0.9765
[0.07581, 0.9765]

日志的最后一行有两个数 [0.07581, 0.9765],0.07581是最终的loss值,也就是交叉熵;0.9765是测试集的accuracy结果,这个数字手写体模型的精度已经将近98%.

低阶API(For experts)

The Keras functional and subclassing APIs provide a define-by-run interface for customization and advanced research. Build your model, then write the forward and backward pass. Create custom layers, activations, and training loops.

说到TensorFlow低阶API,最先想到的肯定是tf.Session和著名的sess.run,但随着TensorFlow的发展,tf.Session最后出现在TensorFlow 1.15中,TensorFlow 2.0已经取消了这个API,如果非要使用的话只能使用兼容版本的tf.compat.v1.Session. 当然,还是推荐使用新版的API,这里也是用Keras,但是用的是subclass的相关API以及GradientTape. 下面会详细介绍。

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2Dfrom tensorflow.keras import Model# 加载并准备好MNIST数据集mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将样本从0~255的整数转换为0~1的浮点数x_train, x_test = x_train / 255.0, x_test / 255.0
# 使用 tf.data 来将数据集切分为 batch 以及混淆数据集batch_size = 32
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
# 使用 Keras 模型子类化(model subclassing) API 构建 tf.keras 模型
class MyModel(Model):def __init__(self):super(MyModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.dropout = Dropout(0.5)self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)x = self.dropout(x)return self.d2(x)model = MyModel()# 为训练选择优化器和损失函数loss_object = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam()# 选择衡量指标来度量模型的损失值(loss)和准确率(accuracy)。这些指标在 epoch 上累积值,然后打印出整体结果train_loss = tf.keras.metrics.Mean(name='train_loss')train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')# 使用 tf.GradientTape 来训练模型@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss(loss)train_accuracy(labels, predictions)# 使用 tf.GradientTape 来训练模型@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss(loss)train_accuracy(labels, predictions)# 测试模型
@tf.functiondef test_step(images, labels):predictions = model(images)t_loss = loss_object(labels, predictions)test_loss(t_loss)test_accuracy(labels, predictions)EPOCHS = 5for epoch in range(EPOCHS):for images, labels in train_ds:train_step(images, labels)for test_images, test_labels in test_ds:test_step(test_images, test_labels)template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'print (template.format(epoch+1,train_loss.result(),train_accuracy.result()*100,test_loss.result(),test_accuracy.result()*100))

输出:

Epoch 1, Loss: 0.13822732865810394, Accuracy: 95.84833526611328, Test Loss: 0.07067110389471054, Test Accuracy: 97.75
Epoch 2, Loss: 0.09080979228019714, Accuracy: 97.25, Test Loss: 0.06446609646081924, Test Accuracy: 97.95999908447266
Epoch 3, Loss: 0.06777264922857285, Accuracy: 97.93944549560547, Test Loss: 0.06325332075357437, Test Accuracy: 98.04000091552734
Epoch 4, Loss: 0.054447807371616364, Accuracy: 98.33999633789062, Test Loss: 0.06611879169940948, Test Accuracy: 98.00749969482422
Epoch 5, Loss: 0.04556874558329582, Accuracy: 98.60433197021484, Test Loss: 0.06510476022958755, Test Accuracy: 98.10400390625

可以看出,低阶API把整个训练的过程都暴露出来了,包括数据的shuffle(每个epoch重新排序数据使得训练数据随机化,避免周期性重复带来的影响)及组成训练batch,组建模型的数据通路,具体定义各种评估指标(loss, accuracy),计算梯度,更新梯度(这两步尤为重要)。如果用户需要对梯度或者中间过程做处理,甚至打印等,使用低阶API可以完全进行完全的控制。

如何选择

从上面的标题也可以看出,对于初学者来说,建议使用高阶API,简单清晰,可以迅速入门。对于专家学者们,建议使用低阶API,可以随心所欲地对具体细节进行改造和加工。

 

相关文章:

TensorFlow高阶API和低阶API

TensorFlow提供了众多的API,简单地可以分类为高阶API和低阶API. API太多太乱也是TensorFlow被诟病的重点之一,可能因为Google的工程师太多了,社区太活跃了~当然后来Google也意识到这个问题,在TensorFlow 2.0中有了很大的改善。本文…...

强训之【参数解析和跳石板】

目录 1.参数解析1.1题目描述1.2思路1.3代码 2.跳石板2.1题目2.2思路2.3代码 3.选择题 1.参数解析 1.1题目描述 在命令行输入如下命令: xcopy /s c:\ d:\e, 各个参数如下: 参数1:命令字xcopy 参数2:字符串/s 参数…...

Redis队列Stream、Redis多线程详解(三)

Redis中的线程和IO模型 什么是Reactor模式 ? “反应”器名字中”反应“的由来: “反应”即“倒置”,“控制逆转”,具体事件处理程序不调用反应器,而向反应器注册一个事件处理器,表示自己对某些事件感兴趣&#xff0…...

MySQL统计函数count详解

count()概述 count() 是一个聚合函数,返回指定匹配条件的行数。开发中常用来统计表中数据,全部数据,不为null数据,或者去重数据 count(1)和count()和count(列名)的区别 1.函数说明 count(1):统计所有的记录&#xff0…...

实验04:图像压缩(DP算法)

1.实验目的: 掌握动态规划算法的基本思想以及用它解决问题的一般技巧。运用所熟悉的编程工具,运用动态规划的思想来求解图像压缩问题。 2.实验内容: 给定一幅图像,求解最佳压缩,使得压缩后的文件最小。 3.实验要求…...

4.19--面试系列之真题版本--redis出现大key怎么解决?Redis 大 Key 对持久化有什么影响?

对于redis出现大key的情况,可以通过以下几种方式来解决: 1.分布式存储:将大key拆分成多个小的key,分别存储在不同的节点上。 2.数据过期:对于大key中不经常使用的数据,可以使用redis自带的过期特性&#xf…...

新手在家做自媒体要如何起步?

不少人都想做自媒体来增加自己的收入或者创业,但没有人带领,自己像是无头苍蝇一样,不知道往哪里走。 今天这期内容大周就来给粉丝们分享一点干货,如果对你有所帮助,记得点赞支持一下大周。 1、注册账号 如果你连一个…...

易基因:禾本科植物群落的病毒组丰度/组成与人为管理/植物多样性变化的相关性 | 宏病毒组

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 现代农业通过简化生态系统、引入新宿主物种和减少作物遗传多样性来影响植物病毒的出现。因此,更好理解农业生态中种植和未种植群落中的病毒分布,以及它们之间的病…...

华为OD机试——对称美学(通过率只有8.51%???)

用java写的这道题,两个样例都可以通过,但是提交之后最终的通过率只有8.51%???自己搞了半天一直都是这个通过率,然后用网上说的100%通过率的代码也是一样的结果,最后时间到了还是没有拿到满分&am…...

【三十天精通Vue 3】第十六天 Vue 3 的虚拟 DOM 原理详解

引言 Vue 3 的虚拟 DOM 是一种用于优化 Vue 应用程序性能的技术。它通过将组件实例转换为虚拟 DOM,并在组件更新时递归地更新虚拟 DOM,以达到高效的渲染性能。在 Vue 3 中,虚拟 DOM 树由 VNode 组成,VNode 是虚拟 DOM 的基本单元…...

Arduino ESP8266通过udp获取时间以及同步本地时间方法

Arduino ESP8266通过udp获取时间以及同步本地时间 ✨通过udp获取NTP服务器上的时间戳,然后经过转换,得到当前具体的时间。转换相对复杂,对于获取时间还是相对比较准确。📝通过udp获取时间实现代码 #include <ESP8266WiFi.h> #include <WiFiUdp.h>//填写 WiFi…...

c/c++:char*定义常量字符串,strcmp()函数,strcpy()函数,寻找指定字符,字符串去空格

c/c&#xff1a;char*定义常量字符串&#xff0c;strcmp()函数&#xff0c;strcpy()函数&#xff0c;寻找指定字符&#xff0c;字符串去空格 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;此时学会c的话&#xff0c; 我所…...

2023年6月DAMA-CDGA/CDGP数据治理认证考试可报名地区公布

2023年4月23日&#xff0c;据DAMA中国官方信息&#xff0c;目前6月DAMA-CDGA/CDGP数据治理认证考试开放报名地区有&#xff1a;北京、上海、广州、深圳、长沙、呼和浩特。目前南京、济南、西安、杭州等地区还在接近开考人数中&#xff0c;打算6月考试的朋友们可以抓紧时间报名啦…...

UDS的0x19服务介绍

什么是 UDS&#xff1f; UEI (Unified Diagnostic Services&#xff0c;统一诊断服务) 是一种在车辆电子控制单元 (ECU) 之间交换诊断信息的标准通信协议&#xff0c;它是OBD-II的某些扩展。利用 UDS 协议&#xff0c;诊断工程师可以访问车辆的各种功能&#xff0c;如读取故障…...

QinQ技术与Portal技术

QinQ 802.1Q-in-802.1Q&#xff0c;是一种扩展VLAN标签技术。在城域网中&#xff0c;需要大量的VLAN来隔离区分不同的用户&#xff0c;但是原有的802.1Q只有12个比特&#xff0c;仅能标识4096个VLANQinQ即在802.1Q的基础上&#xff0c;再增加一层外层标签。使得可以标识4096*40…...

Vue-自定义表单验证(rule,value,callback)详细使用

前言 最近在实际开发中遇到需要验证合同编号是否在数据库已经存在&#xff0c;自定义表单验证。 的表单验证大家都知道form绑定rules&#xff0c;prop绑定值与form.值一样&#xff0c;必填&#xff0c;失去焦点触发 提示信息。 今天我们讲一讲自定义验证规则具体使用场景和它…...

港联证券|TMT板块全线退潮,这些个股获主力逆市抢筹

计算机、电子、传媒、通讯职业流出规模居前。 今天沪深两市主力资金净流出709.92亿元&#xff0c;其中创业板净流出218.36亿元&#xff0c;沪深300成份股净流出187.92亿元。 资金流向上&#xff0c;今天申万一级职业普跌&#xff0c;除了国防军工职业小幅上涨&#xff0c;获主…...

WPF学习

一、了解WPF的框架结构 &#xff08;第一小节随便看下就可以&#xff0c;简单练习就行&#xff09; 1、新建WPF项目 xmlns&#xff1a;XML的命名空间 Margin外边距&#xff1a;左上右下 HorizontalAlignment&#xff1a;水平位置 VerticalAlignment&#xff1a;垂直位置 2…...

C#使用WebDriver模拟浏览器操作WEB页面

有时候需要模拟访问页面触发某个功能&#xff0c;可以使用WebDriver来实现这一功能&#xff0c;驱动打开浏览器&#xff0c;并对页面重定向以及对页面写入脚本等操作。 安装Selenium.Chrome&#xff0c;Selenium.Support.UI&#xff0c;Selenium 引入 using OpenQA.Selenium.…...

正则表达式 - 简单模式匹配

目录 一、测试数据 二、简单模式匹配 1. 匹配字面值 2. 匹配数字和非数字字符 3. 匹配单词与非单词字符 4. 匹配空白字符 5. 匹配任意字符 6. 匹配单词边界 7. 匹配零个或多个字符 8. 单行模式与多行模式 一、测试数据 这里所用文本是《学习正则表达式》这本书带的&a…...

银行数字化转型导师坚鹏:银行数字化转型培训方案

目录 一、银行数字化转型培训背景 二、银行数字化转型模型 三、银行数字化转型课程设计思路 四、 银行数字化转型课程基本介绍 五、 银行数字化转型课程设置 六、银行数字化转型课程大纲 七、培训方案实施流程 一、银行数字化转型培训背景 2020年1月3日&#xff…...

多维时序 | MATLAB实现BO-CNN-LSTM贝叶斯优化卷积神经网络-长短期记忆网络多变量时间序列预测

多维时序 | MATLAB实现BO-CNN-LSTM贝叶斯优化卷积神经网络-长短期记忆网络多变量时间序列预测 目录 多维时序 | MATLAB实现BO-CNN-LSTM贝叶斯优化卷积神经网络-长短期记忆网络多变量时间序列预测效果一览基本介绍模型搭建程序设计参考资料 效果一览 基本介绍 MATLAB实现BO-CNN-…...

Shell知识点(一)

1.echo 命令 echo命令的作用是在屏幕输入一行文本&#xff0c;可以降该命令的参数原样输出。 $ echo hello world hello world 如果想要输出的是多行文本&#xff0c;包含换行符&#xff0c;这时就需要把多行文本放在引号里面 $ echo "<HTML><HEAD><TITLE…...

mysql 索引失效、联合索引失效场景和举例

索引失效 假设有一张user 表&#xff0c;表中包含索引 (id); (name); (birthday); (name,age); 对索引字段进行函数操作 select name from user where year(birthday) 2000;使用模糊查询&#xff0c;查询中使用通配符 select name from user where name like %益达%;使用i…...

快速将PDF转换为图片:使用在线转换器的步骤

PDF文件是一种常见的文档格式&#xff0c;但在某些情况下需要将其转换为图片格式&#xff0c;例如将PDF文件插入PPT演示文稿中。此时&#xff0c;使用在线PDF转换器是一种快速且简便的方法。 本文将介绍如何使用在线转换器将PDF文件转换为图片格式。 步骤1&#xff1a;选择合…...

什么是gpt一4-如何用上gpt-4

怎么使用gpt-4 目前GPT-4还未正式发布或公开&#xff0c;因此也没有详细的对接说明。但是我们可以根据GPT-4的前身GPT-3的应用经验&#xff0c;以及GPT-4的预期功能推测一些可能的使用步骤&#xff1a; 选择适合的GPT-4实现技术&#xff1a;GPT-4可能有不同的实现技术&#xff…...

Docker 相关概念

1、Docker是什么&#xff1f; 如何确保应用能够在这些环境中运行和通过质量检测&#xff1f;并且在部署过程中不出现令人头疼的版本、配置问题&#xff0c;也无需重新编写代码和进行故障修复&#xff1f; 答案就是使用容器。Docker之所以发展如此迅速&#xff0c;也是因为它对…...

STM32平衡小车 TB6612电机驱动学习

TB6612FNG简介 单片机引脚的电流一般只有几十个毫安&#xff0c;无法驱动电机&#xff0c;因此一般是通过单片机控制电机驱动芯片进而控制电机。TB6612是比较常用的电机驱动芯片之一。 TB6612FNG可以同时控制两个电机&#xff0c;工作电流1.2A&#xff0c;最大电流3.2A。 VM电…...

动态加载 JS 文件

动态加载JS文件是指在网页运行过程中通过JavaScript代码向页面中动态添加外部JS文件&#xff0c;这种方式能够提高页面加载速度和用户体验&#xff0c;并且可以帮助网站实现更多的功能和特效。 本文将详细介绍动态加载JS文件的基本原理、优势、注意事项以及具体实现方法&#…...

14、lldb调试指令

LLDB LLDB(Low Lever Debug): 默认内置于Xcode中的动态调试工具.标准的lldb提供了一组广泛的命令,旨在与老版本的GDB命令兼容.除了使用标准配置外,还可以很容易地自定义lldb以满足实际需要. 1.1 lldb语法: <command> [<subcommand> [<subcommand>...]] &l…...