当前位置: 首页 > news >正文

TensorFlow高阶API和低阶API

TensorFlow提供了众多的API,简单地可以分类为高阶API和低阶API. API太多太乱也是TensorFlow被诟病的重点之一,可能因为Google的工程师太多了,社区太活跃了~当然后来Google也意识到这个问题,在TensorFlow 2.0中有了很大的改善。本文就简要介绍一下TensorFlow的高阶API和低阶API使用,提供推荐的使用方式。

高阶API(For beginners)

The best place to start is with the user-friendly Keras sequential API. Build models by plugging together building blocks.

TensorFlow推荐使用Keras的sequence函数作为高阶API的入口进行模型的构建,就像堆积木一样:

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf  
from tensorflow.keras.layers import Flatten, Dense, Dropout# 加载并准备好MNIST数据集
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将样本从0~255的整数转换为0~1的浮点数x_train, x_test = x_train / 255.0, x_test / 255.0# 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型
model = tf.keras.models.Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dropout(0.5),Dense(10, activation='softmax')])# 为训练选择优化器和损失函数model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 训练并验证模型
model.fit(x_train, y_train, epochs=5)model.evaluate(x_test,  y_test, verbose=2)

输出的日志:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 4s 72us/sample - loss: 0.2919 - accuracy: 0.9156
Epoch 2/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1439 - accuracy: 0.9568
Epoch 3/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1080 - accuracy: 0.9671
Epoch 4/5
60000/60000 [==============================] - 4s 59us/sample - loss: 0.0875 - accuracy: 0.9731
Epoch 5/5
60000/60000 [==============================] - 3s 58us/sample - loss: 0.0744 - accuracy: 0.9766
10000/1 - 1s - loss: 0.0383 - accuracy: 0.9765
[0.07581, 0.9765]

日志的最后一行有两个数 [0.07581, 0.9765],0.07581是最终的loss值,也就是交叉熵;0.9765是测试集的accuracy结果,这个数字手写体模型的精度已经将近98%.

低阶API(For experts)

The Keras functional and subclassing APIs provide a define-by-run interface for customization and advanced research. Build your model, then write the forward and backward pass. Create custom layers, activations, and training loops.

说到TensorFlow低阶API,最先想到的肯定是tf.Session和著名的sess.run,但随着TensorFlow的发展,tf.Session最后出现在TensorFlow 1.15中,TensorFlow 2.0已经取消了这个API,如果非要使用的话只能使用兼容版本的tf.compat.v1.Session. 当然,还是推荐使用新版的API,这里也是用Keras,但是用的是subclass的相关API以及GradientTape. 下面会详细介绍。

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2Dfrom tensorflow.keras import Model# 加载并准备好MNIST数据集mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将样本从0~255的整数转换为0~1的浮点数x_train, x_test = x_train / 255.0, x_test / 255.0
# 使用 tf.data 来将数据集切分为 batch 以及混淆数据集batch_size = 32
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
# 使用 Keras 模型子类化(model subclassing) API 构建 tf.keras 模型
class MyModel(Model):def __init__(self):super(MyModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.dropout = Dropout(0.5)self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)x = self.dropout(x)return self.d2(x)model = MyModel()# 为训练选择优化器和损失函数loss_object = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam()# 选择衡量指标来度量模型的损失值(loss)和准确率(accuracy)。这些指标在 epoch 上累积值,然后打印出整体结果train_loss = tf.keras.metrics.Mean(name='train_loss')train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')# 使用 tf.GradientTape 来训练模型@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss(loss)train_accuracy(labels, predictions)# 使用 tf.GradientTape 来训练模型@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss(loss)train_accuracy(labels, predictions)# 测试模型
@tf.functiondef test_step(images, labels):predictions = model(images)t_loss = loss_object(labels, predictions)test_loss(t_loss)test_accuracy(labels, predictions)EPOCHS = 5for epoch in range(EPOCHS):for images, labels in train_ds:train_step(images, labels)for test_images, test_labels in test_ds:test_step(test_images, test_labels)template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'print (template.format(epoch+1,train_loss.result(),train_accuracy.result()*100,test_loss.result(),test_accuracy.result()*100))

输出:

Epoch 1, Loss: 0.13822732865810394, Accuracy: 95.84833526611328, Test Loss: 0.07067110389471054, Test Accuracy: 97.75
Epoch 2, Loss: 0.09080979228019714, Accuracy: 97.25, Test Loss: 0.06446609646081924, Test Accuracy: 97.95999908447266
Epoch 3, Loss: 0.06777264922857285, Accuracy: 97.93944549560547, Test Loss: 0.06325332075357437, Test Accuracy: 98.04000091552734
Epoch 4, Loss: 0.054447807371616364, Accuracy: 98.33999633789062, Test Loss: 0.06611879169940948, Test Accuracy: 98.00749969482422
Epoch 5, Loss: 0.04556874558329582, Accuracy: 98.60433197021484, Test Loss: 0.06510476022958755, Test Accuracy: 98.10400390625

可以看出,低阶API把整个训练的过程都暴露出来了,包括数据的shuffle(每个epoch重新排序数据使得训练数据随机化,避免周期性重复带来的影响)及组成训练batch,组建模型的数据通路,具体定义各种评估指标(loss, accuracy),计算梯度,更新梯度(这两步尤为重要)。如果用户需要对梯度或者中间过程做处理,甚至打印等,使用低阶API可以完全进行完全的控制。

如何选择

从上面的标题也可以看出,对于初学者来说,建议使用高阶API,简单清晰,可以迅速入门。对于专家学者们,建议使用低阶API,可以随心所欲地对具体细节进行改造和加工。

 

相关文章:

TensorFlow高阶API和低阶API

TensorFlow提供了众多的API,简单地可以分类为高阶API和低阶API. API太多太乱也是TensorFlow被诟病的重点之一,可能因为Google的工程师太多了,社区太活跃了~当然后来Google也意识到这个问题,在TensorFlow 2.0中有了很大的改善。本文…...

强训之【参数解析和跳石板】

目录 1.参数解析1.1题目描述1.2思路1.3代码 2.跳石板2.1题目2.2思路2.3代码 3.选择题 1.参数解析 1.1题目描述 在命令行输入如下命令: xcopy /s c:\ d:\e, 各个参数如下: 参数1:命令字xcopy 参数2:字符串/s 参数…...

Redis队列Stream、Redis多线程详解(三)

Redis中的线程和IO模型 什么是Reactor模式 ? “反应”器名字中”反应“的由来: “反应”即“倒置”,“控制逆转”,具体事件处理程序不调用反应器,而向反应器注册一个事件处理器,表示自己对某些事件感兴趣&#xff0…...

MySQL统计函数count详解

count()概述 count() 是一个聚合函数,返回指定匹配条件的行数。开发中常用来统计表中数据,全部数据,不为null数据,或者去重数据 count(1)和count()和count(列名)的区别 1.函数说明 count(1):统计所有的记录&#xff0…...

实验04:图像压缩(DP算法)

1.实验目的: 掌握动态规划算法的基本思想以及用它解决问题的一般技巧。运用所熟悉的编程工具,运用动态规划的思想来求解图像压缩问题。 2.实验内容: 给定一幅图像,求解最佳压缩,使得压缩后的文件最小。 3.实验要求…...

4.19--面试系列之真题版本--redis出现大key怎么解决?Redis 大 Key 对持久化有什么影响?

对于redis出现大key的情况,可以通过以下几种方式来解决: 1.分布式存储:将大key拆分成多个小的key,分别存储在不同的节点上。 2.数据过期:对于大key中不经常使用的数据,可以使用redis自带的过期特性&#xf…...

新手在家做自媒体要如何起步?

不少人都想做自媒体来增加自己的收入或者创业,但没有人带领,自己像是无头苍蝇一样,不知道往哪里走。 今天这期内容大周就来给粉丝们分享一点干货,如果对你有所帮助,记得点赞支持一下大周。 1、注册账号 如果你连一个…...

易基因:禾本科植物群落的病毒组丰度/组成与人为管理/植物多样性变化的相关性 | 宏病毒组

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 现代农业通过简化生态系统、引入新宿主物种和减少作物遗传多样性来影响植物病毒的出现。因此,更好理解农业生态中种植和未种植群落中的病毒分布,以及它们之间的病…...

华为OD机试——对称美学(通过率只有8.51%???)

用java写的这道题,两个样例都可以通过,但是提交之后最终的通过率只有8.51%???自己搞了半天一直都是这个通过率,然后用网上说的100%通过率的代码也是一样的结果,最后时间到了还是没有拿到满分&am…...

【三十天精通Vue 3】第十六天 Vue 3 的虚拟 DOM 原理详解

引言 Vue 3 的虚拟 DOM 是一种用于优化 Vue 应用程序性能的技术。它通过将组件实例转换为虚拟 DOM,并在组件更新时递归地更新虚拟 DOM,以达到高效的渲染性能。在 Vue 3 中,虚拟 DOM 树由 VNode 组成,VNode 是虚拟 DOM 的基本单元…...

Arduino ESP8266通过udp获取时间以及同步本地时间方法

Arduino ESP8266通过udp获取时间以及同步本地时间 ✨通过udp获取NTP服务器上的时间戳,然后经过转换,得到当前具体的时间。转换相对复杂,对于获取时间还是相对比较准确。📝通过udp获取时间实现代码 #include <ESP8266WiFi.h> #include <WiFiUdp.h>//填写 WiFi…...

c/c++:char*定义常量字符串,strcmp()函数,strcpy()函数,寻找指定字符,字符串去空格

c/c&#xff1a;char*定义常量字符串&#xff0c;strcmp()函数&#xff0c;strcpy()函数&#xff0c;寻找指定字符&#xff0c;字符串去空格 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;此时学会c的话&#xff0c; 我所…...

2023年6月DAMA-CDGA/CDGP数据治理认证考试可报名地区公布

2023年4月23日&#xff0c;据DAMA中国官方信息&#xff0c;目前6月DAMA-CDGA/CDGP数据治理认证考试开放报名地区有&#xff1a;北京、上海、广州、深圳、长沙、呼和浩特。目前南京、济南、西安、杭州等地区还在接近开考人数中&#xff0c;打算6月考试的朋友们可以抓紧时间报名啦…...

UDS的0x19服务介绍

什么是 UDS&#xff1f; UEI (Unified Diagnostic Services&#xff0c;统一诊断服务) 是一种在车辆电子控制单元 (ECU) 之间交换诊断信息的标准通信协议&#xff0c;它是OBD-II的某些扩展。利用 UDS 协议&#xff0c;诊断工程师可以访问车辆的各种功能&#xff0c;如读取故障…...

QinQ技术与Portal技术

QinQ 802.1Q-in-802.1Q&#xff0c;是一种扩展VLAN标签技术。在城域网中&#xff0c;需要大量的VLAN来隔离区分不同的用户&#xff0c;但是原有的802.1Q只有12个比特&#xff0c;仅能标识4096个VLANQinQ即在802.1Q的基础上&#xff0c;再增加一层外层标签。使得可以标识4096*40…...

Vue-自定义表单验证(rule,value,callback)详细使用

前言 最近在实际开发中遇到需要验证合同编号是否在数据库已经存在&#xff0c;自定义表单验证。 的表单验证大家都知道form绑定rules&#xff0c;prop绑定值与form.值一样&#xff0c;必填&#xff0c;失去焦点触发 提示信息。 今天我们讲一讲自定义验证规则具体使用场景和它…...

港联证券|TMT板块全线退潮,这些个股获主力逆市抢筹

计算机、电子、传媒、通讯职业流出规模居前。 今天沪深两市主力资金净流出709.92亿元&#xff0c;其中创业板净流出218.36亿元&#xff0c;沪深300成份股净流出187.92亿元。 资金流向上&#xff0c;今天申万一级职业普跌&#xff0c;除了国防军工职业小幅上涨&#xff0c;获主…...

WPF学习

一、了解WPF的框架结构 &#xff08;第一小节随便看下就可以&#xff0c;简单练习就行&#xff09; 1、新建WPF项目 xmlns&#xff1a;XML的命名空间 Margin外边距&#xff1a;左上右下 HorizontalAlignment&#xff1a;水平位置 VerticalAlignment&#xff1a;垂直位置 2…...

C#使用WebDriver模拟浏览器操作WEB页面

有时候需要模拟访问页面触发某个功能&#xff0c;可以使用WebDriver来实现这一功能&#xff0c;驱动打开浏览器&#xff0c;并对页面重定向以及对页面写入脚本等操作。 安装Selenium.Chrome&#xff0c;Selenium.Support.UI&#xff0c;Selenium 引入 using OpenQA.Selenium.…...

正则表达式 - 简单模式匹配

目录 一、测试数据 二、简单模式匹配 1. 匹配字面值 2. 匹配数字和非数字字符 3. 匹配单词与非单词字符 4. 匹配空白字符 5. 匹配任意字符 6. 匹配单词边界 7. 匹配零个或多个字符 8. 单行模式与多行模式 一、测试数据 这里所用文本是《学习正则表达式》这本书带的&a…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例&#xff1a;使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例&#xff1a;使用OpenAI GPT-3进…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store&#xff1a; 我们在使用异步的时候理应是要使用中间件的&#xff0c;但是configureStore 已经自动集成了 redux-thunk&#xff0c;注意action里面要返回函数 import { configureS…...

安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖

在Vuzix M400 AR智能眼镜的助力下&#xff0c;卢森堡罗伯特舒曼医院&#xff08;the Robert Schuman Hospitals, HRS&#xff09;凭借在无菌制剂生产流程中引入增强现实技术&#xff08;AR&#xff09;创新项目&#xff0c;荣获了2024年6月7日由卢森堡医院药剂师协会&#xff0…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的&#xff0c;可以通过集中管理和高效资源的分配&#xff0c;来支持多个独立的网站同时运行&#xff0c;让每一个网站都可以分配到独立的IP地址&#xff0c;避免出现IP关联的风险&#xff0c;用户还可以通过控制面板进行管理功…...

Kafka主题运维全指南:从基础配置到故障处理

#作者&#xff1a;张桐瑞 文章目录 主题日常管理1. 修改主题分区。2. 修改主题级别参数。3. 变更副本数。4. 修改主题限速。5.主题分区迁移。6. 常见主题错误处理常见错误1&#xff1a;主题删除失败。常见错误2&#xff1a;__consumer_offsets占用太多的磁盘。 主题日常管理 …...