【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面
上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。
这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。
写在中间
学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂keras
和tf.keras
的区别,还是很有必要看看这篇文章!
- 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。
但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。
注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras
就会报错。
1. 引包
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行
2. 数据集的读取与处理
这一步就老生常谈了,直接复制粘贴过来
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [-1, 28*28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y(x, y), (x_test, y_test) = datasets.mnist.load_data()# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)
3. 网络层的构建
封装创建
对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras
提供的网络容器 Sequential
将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build
方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()
函数可以方便打印出网络结构和参数量
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28)) # None代表batch不定
network.summary()
就如我们会打印出以下信息:
-
Layer (type)
:层名称、层类型 -
Output Shape
:输出形状 -
Param #
:层的参数个数 -
Trainable params、Non-trainable params
:可优化的参数、不可优化的参数
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 128) 32896 dense_2 (Dense) (None, 64) 8256 dense_3 (Dense) (None, 32) 2080 dense_4 (Dense) (None, 10) 330 =================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
4. 模型装配、训练、测试
在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。
在 tf.keras
中提供了 compile()
和 fit()
函数方便实现上述逻辑。首先通过compile
函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型装配完成后,即可通过 fit()
函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)
如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)
循环测试完 test_db数据集上所有样本,并打印出性能指标
network.evaluate(test_db)
5. 模型保存与加载
在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。
仅保存网络参数
这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。
print('模型参数自动保存...')
network.save_weights('weights.ckpt')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型的参数...')# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
保存模型及参数
这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)
函数可以将模型的结构以及模型的参数保存到path
文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)
即可恢复网络结构和网络参数。
network.save('model.h5')
print('模型已自动保存...')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)
学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!
写在最后
👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!
相关文章:
【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面 上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。 这篇文章我们将了解keras高层…...
柔性数组(C语言)
也许你从来没有听说过柔性数组( flexible array )这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做柔性数组成员,但结 构中的柔性数组成员前面必须至少一个其他…...

判断推理 -- 图形推理 -- 属性规律
中心对称:取一个点,穿过中心能找到另一个对称点。把轴对称 中心对称标出来。五角星不是中心对称。 BD对称轴方向相同,但135自带对称轴,24没带,所以6应该不带对称轴。 百分号不是轴对称。 白色对称轴 平行 或者 夹角…...

【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)
问题发生场景: 在使用 IDEA 开发 SpringBoot 项目时,在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示,查看其他使用该注解的地方同样出现了警告提示。这是怎么回事?由于先去使用了SpringBoot并没有对Spring进行…...
Rust语法: 枚举,泛型,trait
这是我学习Rust的笔记,本文适合于有一定高级语言基础的开发者看不适合刚入门编程的人,对于一些概念像枚举,泛型等,不会再做解释,只写在Rust中怎么用。 文章目录 枚举枚举的定义与赋值枚举绑定方法和函数match匹配枚举…...
hivesql-dayofweek 函数
返回日期或时间戳的星期几。 此函数是 extract(DAYOFWEEK FROM expr) 的同义函数。 语法 dayofweek(expr) 参数 expr:一个 DATE 或 TIMESTAMP 表达式。 返回 一个 INTEGER,其中 1 Sunday 和 7 Saturday。 示例 > SELECT dayofweek(2009-07-30)…...

DIP:《Deep Image Prior》经典文献阅读总结与实现
文章目录 Deep Image Prior1. 方法原理1.1 研究动机1.2 方法 2. 实验验证2.1 去噪2.2 超分辨率2.3 图像修复2.4 消融实验 3. 总结 Deep Image Prior 1. 方法原理 1.1 研究动机 动机 深度神经网络在图像复原和生成领域有非常好的表现一般归功于神经网络学习到了图像的先验信息…...

LAXCUS如何通过技术创新管理数千台服务器
随着互联网技术的不断发展,服务器已经成为企业和个人获取信息、进行计算和存储的重要工具。然而,随着服务器数量的不断增加,传统的服务器管理和运维方式已经无法满足现代企业的需求。LAXCUS做为专注服务器集群的【数存算管】一体化平台&#…...

【Java】BF算法(串模式匹配算法)
☀️ 什么是BF算法 BF算法,即暴力算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个与模式串T的第一个字符串进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则…...
Vue:使用Promise.all()方法并行执行多个请求
在Vue中,可以使用Promise.all()方法来并行执行多个请求。当需要同时执行多个异步请求时,可以将这些请求封装为Promise对象并使用Promise.all()方法来执行它们。 示例1: 以下是一个示例代码,展示了如何通过Promise.all()方法并行…...

21.0 CSS 介绍
1. CSS层叠样式表 1.1 CSS简介 CSS(层叠样式表): 是一种用于描述网页上元素外观和布局的样式标记语言. 它可以与HTML结合使用, 通过为HTML元素添加样式来改变其外观. CSS使用选择器来选择需要应用样式的元素, 并使用属性-值对来定义这些样式.1.2 CSS版本 CSS有多个版本, 每个…...

下一代计算:嵌入AI的云/雾/边缘/量子计算
计算系统在过去几十年中推动了计算机科学的发展,现在已成为企业世界的核心,提供基于云计算、雾计算、边缘计算、无服务器计算和量子计算的服务。现代计算系统解决了现实世界中许多需要低延迟和低响应时间的问题。这有助于全球各地的青年才俊创办初创企业…...

Gitlab-第四天-CD到k8s集群的坑
一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…...

【Java基础】Java对象的生命周期
【Java基础】Java对象的生命周期 一、概述 一个类通过编译器将一个Java文件编译为Class字节码文件,然后通过JVM中的解释器编译成不同操作系统的机器码。虽然操作系统不同,但是基于解释器的虚拟机是相同的。java类的生命周期就是指一个class文件加载到类…...
【每日一题】88. 合并两个有序数组
【每日一题】88. 合并两个有序数组 88. 合并两个有序数组题目描述解题思路 88. 合并两个有序数组 题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 …...

Navicat Premium连接sqlserve数据库失败?你需要注意这几点看看配置对了么?
新建数据库连接的时候这么填的信息 报错 原因1:sqlserver数据库的端口和IP地址之间不是:连接而是用,连接 改成如下样式用逗号连接端口和IP地址就好了 原因2:在Navicat Premium中需要安装一个sqlserver的插件 找到安装路径的根目…...

207、仿真-51单片机脉搏心率与血氧报警Proteus仿真设计(程序+Proteus仿真+配套资料等)
毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…...

flutter 初识(开发体验,优缺点)
前言 最近有个跨平台桌面应用的需求,需要支持 windows/linux/mac 系统,要做个更新应用的小界面,主要功能就是下载更新文件并在本地进行替换,很简单的小功能。 花了几分钟构建没做 UI 优化的示例界面: 由于我们的客…...
校验vue prop的几种方式
校验vue prop的几种方式 vue 要求将传递给组件的任何数据显式声明为 props。此外,它还提供了强大的内置机制来验证该数据。这充当组件和父级组件之间的约定,并确保组件能按预期使用。 让我们看看怎么对props进行校验。它可以帮助我们在开发和调试过程中…...
vue+springboot 前后端分离 上传文件处理后再下载,并且传递参数
vue代码 <template><div><input type"file" ref"fileInput" accept".json"/><el-button click"upload">上传</el-button></div> </template><script> export default {name: "…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:
一、属性动画概述NETX 作用:实现组件通用属性的渐变过渡效果,提升用户体验。支持属性:width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项: 布局类属性(如宽高)变化时&#…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...

C++ 设计模式 《小明的奶茶加料风波》
👨🎓 模式名称:装饰器模式(Decorator Pattern) 👦 小明最近上线了校园奶茶配送功能,业务火爆,大家都在加料: 有的同学要加波霸 🟤,有的要加椰果…...
0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化
是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可,…...
Python 高级应用10:在python 大型项目中 FastAPI 和 Django 的相互配合
无论是python,或者java 的大型项目中,都会涉及到 自身平台微服务之间的相互调用,以及和第三发平台的 接口对接,那在python 中是怎么实现的呢? 在 Python Web 开发中,FastAPI 和 Django 是两个重要但定位不…...

Canal环境搭建并实现和ES数据同步
作者:田超凡 日期:2025年6月7日 Canal安装,启动端口11111、8082: 安装canal-deployer服务端: https://github.com/alibaba/canal/releases/1.1.7/canal.deployer-1.1.7.tar.gz cd /opt/homebrew/etc mkdir canal…...