【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面
上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。
这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。
写在中间
学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂keras
和tf.keras
的区别,还是很有必要看看这篇文章!
- 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。
但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。
注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras
就会报错。
1. 引包
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行
2. 数据集的读取与处理
这一步就老生常谈了,直接复制粘贴过来
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [-1, 28*28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y(x, y), (x_test, y_test) = datasets.mnist.load_data()# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)
3. 网络层的构建
封装创建
对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras
提供的网络容器 Sequential
将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build
方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()
函数可以方便打印出网络结构和参数量
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28)) # None代表batch不定
network.summary()
就如我们会打印出以下信息:
-
Layer (type)
:层名称、层类型 -
Output Shape
:输出形状 -
Param #
:层的参数个数 -
Trainable params、Non-trainable params
:可优化的参数、不可优化的参数
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 128) 32896 dense_2 (Dense) (None, 64) 8256 dense_3 (Dense) (None, 32) 2080 dense_4 (Dense) (None, 10) 330 =================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
4. 模型装配、训练、测试
在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。
在 tf.keras
中提供了 compile()
和 fit()
函数方便实现上述逻辑。首先通过compile
函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型装配完成后,即可通过 fit()
函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)
如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)
循环测试完 test_db数据集上所有样本,并打印出性能指标
network.evaluate(test_db)
5. 模型保存与加载
在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。
仅保存网络参数
这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。
print('模型参数自动保存...')
network.save_weights('weights.ckpt')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型的参数...')# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
保存模型及参数
这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)
函数可以将模型的结构以及模型的参数保存到path
文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)
即可恢复网络结构和网络参数。
network.save('model.h5')
print('模型已自动保存...')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)
学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!
写在最后
👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!
相关文章:
【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面 上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。 这篇文章我们将了解keras高层…...
柔性数组(C语言)
也许你从来没有听说过柔性数组( flexible array )这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做柔性数组成员,但结 构中的柔性数组成员前面必须至少一个其他…...

判断推理 -- 图形推理 -- 属性规律
中心对称:取一个点,穿过中心能找到另一个对称点。把轴对称 中心对称标出来。五角星不是中心对称。 BD对称轴方向相同,但135自带对称轴,24没带,所以6应该不带对称轴。 百分号不是轴对称。 白色对称轴 平行 或者 夹角…...

【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)
问题发生场景: 在使用 IDEA 开发 SpringBoot 项目时,在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示,查看其他使用该注解的地方同样出现了警告提示。这是怎么回事?由于先去使用了SpringBoot并没有对Spring进行…...
Rust语法: 枚举,泛型,trait
这是我学习Rust的笔记,本文适合于有一定高级语言基础的开发者看不适合刚入门编程的人,对于一些概念像枚举,泛型等,不会再做解释,只写在Rust中怎么用。 文章目录 枚举枚举的定义与赋值枚举绑定方法和函数match匹配枚举…...
hivesql-dayofweek 函数
返回日期或时间戳的星期几。 此函数是 extract(DAYOFWEEK FROM expr) 的同义函数。 语法 dayofweek(expr) 参数 expr:一个 DATE 或 TIMESTAMP 表达式。 返回 一个 INTEGER,其中 1 Sunday 和 7 Saturday。 示例 > SELECT dayofweek(2009-07-30)…...

DIP:《Deep Image Prior》经典文献阅读总结与实现
文章目录 Deep Image Prior1. 方法原理1.1 研究动机1.2 方法 2. 实验验证2.1 去噪2.2 超分辨率2.3 图像修复2.4 消融实验 3. 总结 Deep Image Prior 1. 方法原理 1.1 研究动机 动机 深度神经网络在图像复原和生成领域有非常好的表现一般归功于神经网络学习到了图像的先验信息…...

LAXCUS如何通过技术创新管理数千台服务器
随着互联网技术的不断发展,服务器已经成为企业和个人获取信息、进行计算和存储的重要工具。然而,随着服务器数量的不断增加,传统的服务器管理和运维方式已经无法满足现代企业的需求。LAXCUS做为专注服务器集群的【数存算管】一体化平台&#…...

【Java】BF算法(串模式匹配算法)
☀️ 什么是BF算法 BF算法,即暴力算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个与模式串T的第一个字符串进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则…...
Vue:使用Promise.all()方法并行执行多个请求
在Vue中,可以使用Promise.all()方法来并行执行多个请求。当需要同时执行多个异步请求时,可以将这些请求封装为Promise对象并使用Promise.all()方法来执行它们。 示例1: 以下是一个示例代码,展示了如何通过Promise.all()方法并行…...

21.0 CSS 介绍
1. CSS层叠样式表 1.1 CSS简介 CSS(层叠样式表): 是一种用于描述网页上元素外观和布局的样式标记语言. 它可以与HTML结合使用, 通过为HTML元素添加样式来改变其外观. CSS使用选择器来选择需要应用样式的元素, 并使用属性-值对来定义这些样式.1.2 CSS版本 CSS有多个版本, 每个…...

下一代计算:嵌入AI的云/雾/边缘/量子计算
计算系统在过去几十年中推动了计算机科学的发展,现在已成为企业世界的核心,提供基于云计算、雾计算、边缘计算、无服务器计算和量子计算的服务。现代计算系统解决了现实世界中许多需要低延迟和低响应时间的问题。这有助于全球各地的青年才俊创办初创企业…...

Gitlab-第四天-CD到k8s集群的坑
一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…...

【Java基础】Java对象的生命周期
【Java基础】Java对象的生命周期 一、概述 一个类通过编译器将一个Java文件编译为Class字节码文件,然后通过JVM中的解释器编译成不同操作系统的机器码。虽然操作系统不同,但是基于解释器的虚拟机是相同的。java类的生命周期就是指一个class文件加载到类…...
【每日一题】88. 合并两个有序数组
【每日一题】88. 合并两个有序数组 88. 合并两个有序数组题目描述解题思路 88. 合并两个有序数组 题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 …...

Navicat Premium连接sqlserve数据库失败?你需要注意这几点看看配置对了么?
新建数据库连接的时候这么填的信息 报错 原因1:sqlserver数据库的端口和IP地址之间不是:连接而是用,连接 改成如下样式用逗号连接端口和IP地址就好了 原因2:在Navicat Premium中需要安装一个sqlserver的插件 找到安装路径的根目…...

207、仿真-51单片机脉搏心率与血氧报警Proteus仿真设计(程序+Proteus仿真+配套资料等)
毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…...

flutter 初识(开发体验,优缺点)
前言 最近有个跨平台桌面应用的需求,需要支持 windows/linux/mac 系统,要做个更新应用的小界面,主要功能就是下载更新文件并在本地进行替换,很简单的小功能。 花了几分钟构建没做 UI 优化的示例界面: 由于我们的客…...
校验vue prop的几种方式
校验vue prop的几种方式 vue 要求将传递给组件的任何数据显式声明为 props。此外,它还提供了强大的内置机制来验证该数据。这充当组件和父级组件之间的约定,并确保组件能按预期使用。 让我们看看怎么对props进行校验。它可以帮助我们在开发和调试过程中…...
vue+springboot 前后端分离 上传文件处理后再下载,并且传递参数
vue代码 <template><div><input type"file" ref"fileInput" accept".json"/><el-button click"upload">上传</el-button></div> </template><script> export default {name: "…...

【Linux操作系统】举例解释Linux系统编程中文件io常用的函数
在Linux系统编程中,文件IO操作是非常常见和重要的操作之一。通过文件IO操作,我们可以打开、读取、写入和关闭文件,对文件进行定位、复制、删除和重命名等操作。本篇博客将介绍一些常用的文件IO操作函数。 文章目录 1. open()1.1 原型、参数及…...
Ubuntu和centos版本有哪些区别
Ubuntu和CentOS是两个非常流行的Linux发行版,它们在一些方面有一些区别,如下所示: CentOS的版本发布周期相对较长,主要是因为它是基于RedHatEnterpriseLinux(RHEL)的。这意味着在RHEL发布后才能推出对应的CentOS版本。而Ubuntu则在…...

Netty:ChannelHandler抛出异常,对应的channel被关闭
说明 使用Netty框架构建的socket服务端在处理客户端请求时,每接到一个客户端的连接请求,服务端会分配一个channel处理跟该客户端的交互。如果处理该channel数据的ChannelHandler抛出异常没有捕获,那么该channel会关闭。但服务端和其它客户端…...

pytest结合 allure 打标记之的详细使用
前言 前面我们提到使用allure 可以生成漂亮的测试报告,下面就Allure 标记我们做详细介绍。 allure 标记 包含:epic,feature, story, title, testcase, issue, description, step, serverity, link, attachment 常用的标记 allure.feature…...

【linux】2 软件管理器yum和编辑器vim
目录 1. linux软件包管理器yum 1.1 什么是软件包 1.2 关于rzsz 1.3 注意事项 1.4 查看软件包 1.5 如何安装、卸载软件 1.6 centos 7设置成国内yum源 2. linux开发工具-Linux编辑器-vim使用 2.1 vim的基本概念 2.2 vim的基本操作 2.3 vim正常模式命令集 2.4 vim末行…...
Angular中的ActivatedRoute和Router
Angular中的ActivatedRoute和Router解释 在Angular中,ActivatedRoute和Router是两个核心的路由服务。他们都提供可以用来检查和操作当前页面路由信息的方法和属性。 ActivatedRoute ActivatedRoute是一个保存关于当前路由状态(如路由参数、查询参数以…...

Layui精简版,快速入门
目录 LayUI之入门 1.什么是layui 2.layui入门 3.自定义模块 4.用户登录 5.主页搭建 LayUI之动态树 main.jsp main.js LayUI之动态选项卡 1.选项卡 main.jsp main.js 2.用户登录 User.java UserDao.java UserAction.java R.java LayUI之用户管理 1.用户查询…...
SSH远程Ubuntu教程
SSH远程Ubuntu教程 目录 什么是SSH?SSH的优点在Ubuntu上启用SSH服务连接到远程Ubuntu服务器SSH的常用命令 1. 什么是SSH? SSH(Secure Shell)是一种网络协议,用于在不安全的网络中安全地远程登录和执行命令。它使用…...

NPM与外部服务的集成(下)
目录 1、撤消访问令牌 2、在CI/CD工作流中使用私有包 2.1 创建新的访问令牌 持续整合 持续部署 交互式工作流 CIDR白名单 2.2 将令牌设置为CI/CD服务器上的环境变量 2.3 创建并签入特定于项目的.npmrc文件 2.4 令牌安全 3、Docker和私有模块 3.1 背景:运…...

Flask Web开发实战(狼书)| 笔记第1、2章
前言 2023-8-11 以前对网站开发萌生了想法,又有些急于求成,在B站照着视频敲了一个基于flask的博客系统。但对于程序的代码难免有些囫囵吞枣,存在许多模糊或不太理解的地方,只会照葫芦画瓢。 而当自己想开发一个什么网站的时&…...