【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面
上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。
这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。
写在中间
学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂keras和tf.keras的区别,还是很有必要看看这篇文章!
- 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。
但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。
注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras就会报错。
1. 引包
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行
2. 数据集的读取与处理
这一步就老生常谈了,直接复制粘贴过来
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [-1, 28*28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y(x, y), (x_test, y_test) = datasets.mnist.load_data()# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)
3. 网络层的构建
封装创建
对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras 提供的网络容器 Sequential 将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。
在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build 方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()函数可以方便打印出网络结构和参数量
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28)) # None代表batch不定
network.summary()
就如我们会打印出以下信息:
-
Layer (type):层名称、层类型 -
Output Shape:输出形状 -
Param #:层的参数个数 -
Trainable params、Non-trainable params:可优化的参数、不可优化的参数
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense (Dense) (None, 256) 200960 dense_1 (Dense) (None, 128) 32896 dense_2 (Dense) (None, 64) 8256 dense_3 (Dense) (None, 32) 2080 dense_4 (Dense) (None, 10) 330 =================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
4. 模型装配、训练、测试
在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。
在 tf.keras 中提供了 compile()和 fit()函数方便实现上述逻辑。首先通过compile 函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型装配完成后,即可通过 fit()函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)
如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)循环测试完 test_db数据集上所有样本,并打印出性能指标
network.evaluate(test_db)
5. 模型保存与加载
在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。
仅保存网络参数
这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。
print('模型参数自动保存...')
network.save_weights('weights.ckpt')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型的参数...')# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
保存模型及参数
这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)函数可以将模型的结构以及模型的参数保存到path文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)即可恢复网络结构和网络参数。
network.save('model.h5')
print('模型已自动保存...')print('模拟意外情况,网络删除...')
del networkprint('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)
学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!
写在最后
👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!
相关文章:
【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目
写在前面 上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。 这篇文章我们将了解keras高层…...
柔性数组(C语言)
也许你从来没有听说过柔性数组( flexible array )这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做柔性数组成员,但结 构中的柔性数组成员前面必须至少一个其他…...
判断推理 -- 图形推理 -- 属性规律
中心对称:取一个点,穿过中心能找到另一个对称点。把轴对称 中心对称标出来。五角星不是中心对称。 BD对称轴方向相同,但135自带对称轴,24没带,所以6应该不带对称轴。 百分号不是轴对称。 白色对称轴 平行 或者 夹角…...
【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)
问题发生场景: 在使用 IDEA 开发 SpringBoot 项目时,在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示,查看其他使用该注解的地方同样出现了警告提示。这是怎么回事?由于先去使用了SpringBoot并没有对Spring进行…...
Rust语法: 枚举,泛型,trait
这是我学习Rust的笔记,本文适合于有一定高级语言基础的开发者看不适合刚入门编程的人,对于一些概念像枚举,泛型等,不会再做解释,只写在Rust中怎么用。 文章目录 枚举枚举的定义与赋值枚举绑定方法和函数match匹配枚举…...
hivesql-dayofweek 函数
返回日期或时间戳的星期几。 此函数是 extract(DAYOFWEEK FROM expr) 的同义函数。 语法 dayofweek(expr) 参数 expr:一个 DATE 或 TIMESTAMP 表达式。 返回 一个 INTEGER,其中 1 Sunday 和 7 Saturday。 示例 > SELECT dayofweek(2009-07-30)…...
DIP:《Deep Image Prior》经典文献阅读总结与实现
文章目录 Deep Image Prior1. 方法原理1.1 研究动机1.2 方法 2. 实验验证2.1 去噪2.2 超分辨率2.3 图像修复2.4 消融实验 3. 总结 Deep Image Prior 1. 方法原理 1.1 研究动机 动机 深度神经网络在图像复原和生成领域有非常好的表现一般归功于神经网络学习到了图像的先验信息…...
LAXCUS如何通过技术创新管理数千台服务器
随着互联网技术的不断发展,服务器已经成为企业和个人获取信息、进行计算和存储的重要工具。然而,随着服务器数量的不断增加,传统的服务器管理和运维方式已经无法满足现代企业的需求。LAXCUS做为专注服务器集群的【数存算管】一体化平台&#…...
【Java】BF算法(串模式匹配算法)
☀️ 什么是BF算法 BF算法,即暴力算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个与模式串T的第一个字符串进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则…...
Vue:使用Promise.all()方法并行执行多个请求
在Vue中,可以使用Promise.all()方法来并行执行多个请求。当需要同时执行多个异步请求时,可以将这些请求封装为Promise对象并使用Promise.all()方法来执行它们。 示例1: 以下是一个示例代码,展示了如何通过Promise.all()方法并行…...
21.0 CSS 介绍
1. CSS层叠样式表 1.1 CSS简介 CSS(层叠样式表): 是一种用于描述网页上元素外观和布局的样式标记语言. 它可以与HTML结合使用, 通过为HTML元素添加样式来改变其外观. CSS使用选择器来选择需要应用样式的元素, 并使用属性-值对来定义这些样式.1.2 CSS版本 CSS有多个版本, 每个…...
下一代计算:嵌入AI的云/雾/边缘/量子计算
计算系统在过去几十年中推动了计算机科学的发展,现在已成为企业世界的核心,提供基于云计算、雾计算、边缘计算、无服务器计算和量子计算的服务。现代计算系统解决了现实世界中许多需要低延迟和低响应时间的问题。这有助于全球各地的青年才俊创办初创企业…...
Gitlab-第四天-CD到k8s集群的坑
一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…...
【Java基础】Java对象的生命周期
【Java基础】Java对象的生命周期 一、概述 一个类通过编译器将一个Java文件编译为Class字节码文件,然后通过JVM中的解释器编译成不同操作系统的机器码。虽然操作系统不同,但是基于解释器的虚拟机是相同的。java类的生命周期就是指一个class文件加载到类…...
【每日一题】88. 合并两个有序数组
【每日一题】88. 合并两个有序数组 88. 合并两个有序数组题目描述解题思路 88. 合并两个有序数组 题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2,另有两个整数 m 和 n ,分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 …...
Navicat Premium连接sqlserve数据库失败?你需要注意这几点看看配置对了么?
新建数据库连接的时候这么填的信息 报错 原因1:sqlserver数据库的端口和IP地址之间不是:连接而是用,连接 改成如下样式用逗号连接端口和IP地址就好了 原因2:在Navicat Premium中需要安装一个sqlserver的插件 找到安装路径的根目…...
207、仿真-51单片机脉搏心率与血氧报警Proteus仿真设计(程序+Proteus仿真+配套资料等)
毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…...
flutter 初识(开发体验,优缺点)
前言 最近有个跨平台桌面应用的需求,需要支持 windows/linux/mac 系统,要做个更新应用的小界面,主要功能就是下载更新文件并在本地进行替换,很简单的小功能。 花了几分钟构建没做 UI 优化的示例界面: 由于我们的客…...
校验vue prop的几种方式
校验vue prop的几种方式 vue 要求将传递给组件的任何数据显式声明为 props。此外,它还提供了强大的内置机制来验证该数据。这充当组件和父级组件之间的约定,并确保组件能按预期使用。 让我们看看怎么对props进行校验。它可以帮助我们在开发和调试过程中…...
vue+springboot 前后端分离 上传文件处理后再下载,并且传递参数
vue代码 <template><div><input type"file" ref"fileInput" accept".json"/><el-button click"upload">上传</el-button></div> </template><script> export default {name: "…...
DAY 47
三、通道注意力 3.1 通道注意力的定义 # 新增:通道注意力模块(SE模块) class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...
Python实现prophet 理论及参数优化
文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...
Python爬虫(一):爬虫伪装
一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...
C++ 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
AI编程--插件对比分析:CodeRider、GitHub Copilot及其他
AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...
浪潮交换机配置track检测实现高速公路收费网络主备切换NQA
浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...
JVM 内存结构 详解
内存结构 运行时数据区: Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器: 线程私有,程序控制流的指示器,分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 每个线程都有一个程序计数…...
Mysql8 忘记密码重置,以及问题解决
1.使用免密登录 找到配置MySQL文件,我的文件路径是/etc/mysql/my.cnf,有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...
华为OD机考-机房布局
import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...
