【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面
上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。
这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。
写在中间
学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂keras和tf.keras的区别,还是很有必要看看这篇文章!
- 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。
但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。
注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras就会报错。
1. 引包
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行
2. 数据集的读取与处理
这一步就老生常谈了,直接复制粘贴过来
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [-1, 28*28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y(x, y), (x_test, y_test) = datasets.mnist.load_data()# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)
3. 网络层的构建
封装创建
对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras 提供的网络容器 Sequential 将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build 方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()函数可以方便打印出网络结构和参数量
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28)) # None代表batch不定
network.summary()
就如我们会打印出以下信息:
-
Layer (type):层名称、层类型 -
Output Shape:输出形状 -
Param #:层的参数个数 -
Trainable params、Non-trainable params:可优化的参数、不可优化的参数
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 128) 32896 dense_2 (Dense) (None, 64) 8256 dense_3 (Dense) (None, 32) 2080 dense_4 (Dense) (None, 10) 330 =================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
4. 模型装配、训练、测试
在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。
在 tf.keras 中提供了 compile()和 fit()函数方便实现上述逻辑。首先通过compile 函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型装配完成后,即可通过 fit()函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)
如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)循环测试完 test_db数据集上所有样本,并打印出性能指标
network.evaluate(test_db)
5. 模型保存与加载
在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。
仅保存网络参数
这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。
print('模型参数自动保存...')
network.save_weights('weights.ckpt')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型的参数...')# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
保存模型及参数
这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)函数可以将模型的结构以及模型的参数保存到path文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)即可恢复网络结构和网络参数。
network.save('model.h5')
print('模型已自动保存...')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)
学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!
写在最后
👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!
相关文章:
【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面 上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。 这篇文章我们将了解keras高层…...
柔性数组(C语言)
也许你从来没有听说过柔性数组( flexible array )这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做柔性数组成员,但结 构中的柔性数组成员前面必须至少一个其他…...
判断推理 -- 图形推理 -- 属性规律
中心对称:取一个点,穿过中心能找到另一个对称点。把轴对称 中心对称标出来。五角星不是中心对称。 BD对称轴方向相同,但135自带对称轴,24没带,所以6应该不带对称轴。 百分号不是轴对称。 白色对称轴 平行 或者 夹角…...
【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)
问题发生场景: 在使用 IDEA 开发 SpringBoot 项目时,在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示,查看其他使用该注解的地方同样出现了警告提示。这是怎么回事?由于先去使用了SpringBoot并没有对Spring进行…...
Rust语法: 枚举,泛型,trait
这是我学习Rust的笔记,本文适合于有一定高级语言基础的开发者看不适合刚入门编程的人,对于一些概念像枚举,泛型等,不会再做解释,只写在Rust中怎么用。 文章目录 枚举枚举的定义与赋值枚举绑定方法和函数match匹配枚举…...
hivesql-dayofweek 函数
返回日期或时间戳的星期几。 此函数是 extract(DAYOFWEEK FROM expr) 的同义函数。 语法 dayofweek(expr) 参数 expr:一个 DATE 或 TIMESTAMP 表达式。 返回 一个 INTEGER,其中 1 Sunday 和 7 Saturday。 示例 > SELECT dayofweek(2009-07-30)…...
DIP:《Deep Image Prior》经典文献阅读总结与实现
文章目录 Deep Image Prior1. 方法原理1.1 研究动机1.2 方法 2. 实验验证2.1 去噪2.2 超分辨率2.3 图像修复2.4 消融实验 3. 总结 Deep Image Prior 1. 方法原理 1.1 研究动机 动机 深度神经网络在图像复原和生成领域有非常好的表现一般归功于神经网络学习到了图像的先验信息…...
LAXCUS如何通过技术创新管理数千台服务器
随着互联网技术的不断发展,服务器已经成为企业和个人获取信息、进行计算和存储的重要工具。然而,随着服务器数量的不断增加,传统的服务器管理和运维方式已经无法满足现代企业的需求。LAXCUS做为专注服务器集群的【数存算管】一体化平台&#…...
【Java】BF算法(串模式匹配算法)
☀️ 什么是BF算法 BF算法,即暴力算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个与模式串T的第一个字符串进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则…...
Vue:使用Promise.all()方法并行执行多个请求
在Vue中,可以使用Promise.all()方法来并行执行多个请求。当需要同时执行多个异步请求时,可以将这些请求封装为Promise对象并使用Promise.all()方法来执行它们。 示例1: 以下是一个示例代码,展示了如何通过Promise.all()方法并行…...
21.0 CSS 介绍
1. CSS层叠样式表 1.1 CSS简介 CSS(层叠样式表): 是一种用于描述网页上元素外观和布局的样式标记语言. 它可以与HTML结合使用, 通过为HTML元素添加样式来改变其外观. CSS使用选择器来选择需要应用样式的元素, 并使用属性-值对来定义这些样式.1.2 CSS版本 CSS有多个版本, 每个…...
下一代计算:嵌入AI的云/雾/边缘/量子计算
计算系统在过去几十年中推动了计算机科学的发展,现在已成为企业世界的核心,提供基于云计算、雾计算、边缘计算、无服务器计算和量子计算的服务。现代计算系统解决了现实世界中许多需要低延迟和低响应时间的问题。这有助于全球各地的青年才俊创办初创企业…...
Gitlab-第四天-CD到k8s集群的坑
一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…...
【Java基础】Java对象的生命周期
【Java基础】Java对象的生命周期 一、概述 一个类通过编译器将一个Java文件编译为Class字节码文件,然后通过JVM中的解释器编译成不同操作系统的机器码。虽然操作系统不同,但是基于解释器的虚拟机是相同的。java类的生命周期就是指一个class文件加载到类…...
【每日一题】88. 合并两个有序数组
【每日一题】88. 合并两个有序数组 88. 合并两个有序数组题目描述解题思路 88. 合并两个有序数组 题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 …...
Navicat Premium连接sqlserve数据库失败?你需要注意这几点看看配置对了么?
新建数据库连接的时候这么填的信息 报错 原因1:sqlserver数据库的端口和IP地址之间不是:连接而是用,连接 改成如下样式用逗号连接端口和IP地址就好了 原因2:在Navicat Premium中需要安装一个sqlserver的插件 找到安装路径的根目…...
207、仿真-51单片机脉搏心率与血氧报警Proteus仿真设计(程序+Proteus仿真+配套资料等)
毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…...
flutter 初识(开发体验,优缺点)
前言 最近有个跨平台桌面应用的需求,需要支持 windows/linux/mac 系统,要做个更新应用的小界面,主要功能就是下载更新文件并在本地进行替换,很简单的小功能。 花了几分钟构建没做 UI 优化的示例界面: 由于我们的客…...
校验vue prop的几种方式
校验vue prop的几种方式 vue 要求将传递给组件的任何数据显式声明为 props。此外,它还提供了强大的内置机制来验证该数据。这充当组件和父级组件之间的约定,并确保组件能按预期使用。 让我们看看怎么对props进行校验。它可以帮助我们在开发和调试过程中…...
vue+springboot 前后端分离 上传文件处理后再下载,并且传递参数
vue代码 <template><div><input type"file" ref"fileInput" accept".json"/><el-button click"upload">上传</el-button></div> </template><script> export default {name: "…...
uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖
在前面的练习中,每个页面需要使用ref,onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入,需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...
Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
PL0语法,分析器实现!
简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
AI,如何重构理解、匹配与决策?
AI 时代,我们如何理解消费? 作者|王彬 封面|Unplash 人们通过信息理解世界。 曾几何时,PC 与移动互联网重塑了人们的购物路径:信息变得唾手可得,商品决策变得高度依赖内容。 但 AI 时代的来…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...
云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...
人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...
