【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面
上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。
这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。
写在中间
学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂keras
和tf.keras
的区别,还是很有必要看看这篇文章!
- 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。
但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。
注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras
就会报错。
1. 引包
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行
2. 数据集的读取与处理
这一步就老生常谈了,直接复制粘贴过来
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [-1, 28*28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y(x, y), (x_test, y_test) = datasets.mnist.load_data()# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)
3. 网络层的构建
封装创建
对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras
提供的网络容器 Sequential
将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build
方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()
函数可以方便打印出网络结构和参数量
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28)) # None代表batch不定
network.summary()
就如我们会打印出以下信息:
-
Layer (type)
:层名称、层类型 -
Output Shape
:输出形状 -
Param #
:层的参数个数 -
Trainable params、Non-trainable params
:可优化的参数、不可优化的参数
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 128) 32896 dense_2 (Dense) (None, 64) 8256 dense_3 (Dense) (None, 32) 2080 dense_4 (Dense) (None, 10) 330 =================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
4. 模型装配、训练、测试
在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。
在 tf.keras
中提供了 compile()
和 fit()
函数方便实现上述逻辑。首先通过compile
函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型装配完成后,即可通过 fit()
函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)
如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)
循环测试完 test_db数据集上所有样本,并打印出性能指标
network.evaluate(test_db)
5. 模型保存与加载
在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。
仅保存网络参数
这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。
print('模型参数自动保存...')
network.save_weights('weights.ckpt')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型的参数...')# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
保存模型及参数
这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)
函数可以将模型的结构以及模型的参数保存到path
文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)
即可恢复网络结构和网络参数。
network.save('model.h5')
print('模型已自动保存...')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)
学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!
写在最后
👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!
相关文章:
【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面 上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。 这篇文章我们将了解keras高层…...
柔性数组(C语言)
也许你从来没有听说过柔性数组( flexible array )这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做柔性数组成员,但结 构中的柔性数组成员前面必须至少一个其他…...

判断推理 -- 图形推理 -- 属性规律
中心对称:取一个点,穿过中心能找到另一个对称点。把轴对称 中心对称标出来。五角星不是中心对称。 BD对称轴方向相同,但135自带对称轴,24没带,所以6应该不带对称轴。 百分号不是轴对称。 白色对称轴 平行 或者 夹角…...

【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)
问题发生场景: 在使用 IDEA 开发 SpringBoot 项目时,在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示,查看其他使用该注解的地方同样出现了警告提示。这是怎么回事?由于先去使用了SpringBoot并没有对Spring进行…...
Rust语法: 枚举,泛型,trait
这是我学习Rust的笔记,本文适合于有一定高级语言基础的开发者看不适合刚入门编程的人,对于一些概念像枚举,泛型等,不会再做解释,只写在Rust中怎么用。 文章目录 枚举枚举的定义与赋值枚举绑定方法和函数match匹配枚举…...
hivesql-dayofweek 函数
返回日期或时间戳的星期几。 此函数是 extract(DAYOFWEEK FROM expr) 的同义函数。 语法 dayofweek(expr) 参数 expr:一个 DATE 或 TIMESTAMP 表达式。 返回 一个 INTEGER,其中 1 Sunday 和 7 Saturday。 示例 > SELECT dayofweek(2009-07-30)…...

DIP:《Deep Image Prior》经典文献阅读总结与实现
文章目录 Deep Image Prior1. 方法原理1.1 研究动机1.2 方法 2. 实验验证2.1 去噪2.2 超分辨率2.3 图像修复2.4 消融实验 3. 总结 Deep Image Prior 1. 方法原理 1.1 研究动机 动机 深度神经网络在图像复原和生成领域有非常好的表现一般归功于神经网络学习到了图像的先验信息…...

LAXCUS如何通过技术创新管理数千台服务器
随着互联网技术的不断发展,服务器已经成为企业和个人获取信息、进行计算和存储的重要工具。然而,随着服务器数量的不断增加,传统的服务器管理和运维方式已经无法满足现代企业的需求。LAXCUS做为专注服务器集群的【数存算管】一体化平台&#…...

【Java】BF算法(串模式匹配算法)
☀️ 什么是BF算法 BF算法,即暴力算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个与模式串T的第一个字符串进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则…...
Vue:使用Promise.all()方法并行执行多个请求
在Vue中,可以使用Promise.all()方法来并行执行多个请求。当需要同时执行多个异步请求时,可以将这些请求封装为Promise对象并使用Promise.all()方法来执行它们。 示例1: 以下是一个示例代码,展示了如何通过Promise.all()方法并行…...

21.0 CSS 介绍
1. CSS层叠样式表 1.1 CSS简介 CSS(层叠样式表): 是一种用于描述网页上元素外观和布局的样式标记语言. 它可以与HTML结合使用, 通过为HTML元素添加样式来改变其外观. CSS使用选择器来选择需要应用样式的元素, 并使用属性-值对来定义这些样式.1.2 CSS版本 CSS有多个版本, 每个…...

下一代计算:嵌入AI的云/雾/边缘/量子计算
计算系统在过去几十年中推动了计算机科学的发展,现在已成为企业世界的核心,提供基于云计算、雾计算、边缘计算、无服务器计算和量子计算的服务。现代计算系统解决了现实世界中许多需要低延迟和低响应时间的问题。这有助于全球各地的青年才俊创办初创企业…...

Gitlab-第四天-CD到k8s集群的坑
一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…...

【Java基础】Java对象的生命周期
【Java基础】Java对象的生命周期 一、概述 一个类通过编译器将一个Java文件编译为Class字节码文件,然后通过JVM中的解释器编译成不同操作系统的机器码。虽然操作系统不同,但是基于解释器的虚拟机是相同的。java类的生命周期就是指一个class文件加载到类…...
【每日一题】88. 合并两个有序数组
【每日一题】88. 合并两个有序数组 88. 合并两个有序数组题目描述解题思路 88. 合并两个有序数组 题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 …...

Navicat Premium连接sqlserve数据库失败?你需要注意这几点看看配置对了么?
新建数据库连接的时候这么填的信息 报错 原因1:sqlserver数据库的端口和IP地址之间不是:连接而是用,连接 改成如下样式用逗号连接端口和IP地址就好了 原因2:在Navicat Premium中需要安装一个sqlserver的插件 找到安装路径的根目…...

207、仿真-51单片机脉搏心率与血氧报警Proteus仿真设计(程序+Proteus仿真+配套资料等)
毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…...

flutter 初识(开发体验,优缺点)
前言 最近有个跨平台桌面应用的需求,需要支持 windows/linux/mac 系统,要做个更新应用的小界面,主要功能就是下载更新文件并在本地进行替换,很简单的小功能。 花了几分钟构建没做 UI 优化的示例界面: 由于我们的客…...
校验vue prop的几种方式
校验vue prop的几种方式 vue 要求将传递给组件的任何数据显式声明为 props。此外,它还提供了强大的内置机制来验证该数据。这充当组件和父级组件之间的约定,并确保组件能按预期使用。 让我们看看怎么对props进行校验。它可以帮助我们在开发和调试过程中…...
vue+springboot 前后端分离 上传文件处理后再下载,并且传递参数
vue代码 <template><div><input type"file" ref"fileInput" accept".json"/><el-button click"upload">上传</el-button></div> </template><script> export default {name: "…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
django filter 统计数量 按属性去重
在Django中,如果你想要根据某个属性对查询集进行去重并统计数量,你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求: 方法1:使用annotate()和Count 假设你有一个模型Item,并且你想…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

ServerTrust 并非唯一
NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

网络编程(UDP编程)
思维导图 UDP基础编程(单播) 1.流程图 服务器:短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...