当前位置: 首页 > news >正文

keras深度学习框架通过卷积神经网络cnn实现手写数字识别

    昨天通过keras构建简单神经网络实现手写数字识别,结果在最后进行我们自己的手写数字识别的时候,准确率堪忧,只有60%。今天通过卷积神经网络来实现手写数字识别。

    构建卷积神经网络和简单神经网络思路类似,只不过这里加入了卷积、池化等概念,网络结构复杂了一些,但是整体的思路没有变化,加载数据集,数据集修改,搭建网络模型,编译模型,训练模型,保存模型,利用模型预测。

    这里还是给出两个例子,一个是构建网络,最后保存训练好的网络模型,一个是通过加载保存的网络模型预测我们自己的手写数字图片。

import keras
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Conv2D, Flatten, MaxPool2D
from tensorflow.keras import datasets, utils
# 数据处理
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[1], 1)
x_train = x_train.astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[1], 1)
x_test = x_test.astype('float32') / 255
y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)
# 构建模型
model = Sequential()
model.add(Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation="relu", input_shape=(28, 28, 1)))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Conv2D(filters=36, kernel_size=(3, 3), padding='same', activation="relu"))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation="relu"))
model.add(Dropout(0.25))
model.add(Dense(10, activation="softmax"))
# 编译
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()
# 训练
model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))
# 保存模型
model.save("mnist.h5")

     训练模型,打印信息如下:

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 28, 28, 16)        160       max_pooling2d (MaxPooling2D  (None, 14, 14, 16)       0         )                                                               conv2d_1 (Conv2D)           (None, 14, 14, 36)        5220      max_pooling2d_1 (MaxPooling  (None, 7, 7, 36)         0         2D)                                                             dropout (Dropout)           (None, 7, 7, 36)          0         flatten (Flatten)           (None, 1764)              0         dense (Dense)               (None, 128)               225920    dropout_1 (Dropout)         (None, 128)               0         dense_1 (Dense)             (None, 10)                1290      =================================================================
Total params: 232,590
Trainable params: 232,590
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
2023-08-28 16:03:54.677314: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8800
469/469 [==============================] - 10s 17ms/step - loss: 0.2842 - accuracy: 0.9123 - val_loss: 0.0628 - val_accuracy: 0.9798
Epoch 2/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0836 - accuracy: 0.9743 - val_loss: 0.0473 - val_accuracy: 0.9841
Epoch 3/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0627 - accuracy: 0.9801 - val_loss: 0.0325 - val_accuracy: 0.9886
Epoch 4/5
469/469 [==============================] - 7s 15ms/step - loss: 0.0497 - accuracy: 0.9844 - val_loss: 0.0346 - val_accuracy: 0.9882
Epoch 5/5
469/469 [==============================] - 7s 15ms/step - loss: 0.0422 - accuracy: 0.9867 - val_loss: 0.0298 - val_accuracy: 0.9898

    准确率最后,到达了98.5%以上。

    用模型预测

import keras
import numpy as np
import cv2
from keras.models import load_modelmodel = load_model("mnist.h5")def predict(img_path):img = cv2.imread(img_path, 0)img = img.reshape(28, 28).astype("float32") / 255  # 0 1img = img.reshape(1, 28, 28, 1)  # 28 * 28 -> (1,28,28,1)label = model.predict(img)label = np.argmax(label, axis=1)print('{} -> {}'.format(img_path, label[0]))if __name__ == '__main__':for _ in range(10):predict("number_images/b_{}.png".format(_))

    数字图片如下: 

    图片放在项目目录number_images中。

    预测结果打印:

 

    感觉就是不一样,准确率从60%提升到了90%。虽然没有达到100%,但是已经很好了。 

    对比之前的代码,改动很小,主要是网络输入的时候,数据形状发生了改变,简单神经网络需要的是(784,*)结构,卷积神经网络需要的是(1,28,28,1)的结构, 在数据处理上做了调整,另一个不一样的地方就是网络模型在添加的时候,之前就是简单的两层网络,卷积神经网络复杂了很多。

相关文章:

keras深度学习框架通过卷积神经网络cnn实现手写数字识别

昨天通过keras构建简单神经网络实现手写数字识别,结果在最后进行我们自己的手写数字识别的时候,准确率堪忧,只有60%。今天通过卷积神经网络来实现手写数字识别。 构建卷积神经网络和简单神经网络思路类似,只不过这里加入了卷积、池…...

Springboot启动异常 Command line is too long

Springboot启动异常 Command line is too long Springboot启动时直接报异常 Command line is too long. Shorten command line for xxxxxApplication or also for Spring Boot default解决方案: 修改 SystemApplication 的 Shorten command line,选择 JAR manife…...

PXE 装机(五十)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 一、PXE是什么 二、PXE的组件 三、配置vsftpd 四、配置tftp 五、准备pxelinx.0文件、引导文件、内核文件 ​六、配置dhcp 七、创建default文件 八、配置pxe无人值守…...

C++ 虚函数与纯虚函数

目录 1. 虚函数 2. 纯虚函数 C 中的虚函数和纯虚函数都是实现多态的重要机制。多态可以让不同的对象以相同的方式进行操作,从而简化代码的编写和维护。 1. 虚函数 虚函数是一种在基类中声明的函数,可以在派生类中进行重写。在运行时,根据对…...

警告:Provides transitive vulnerable dependency maven:org.yaml:snakeyaml:1.30

1. 警告 SpringBoot 的 validation 依赖包含有易受攻击的依赖 snakeyaml。 警告信息如下: Provides transitive vulnerable dependency maven:org.yaml:snakeyaml:1.30 意思是:提供了可传递的易受攻击依赖 maven:org.yaml:snakeyaml:1.30 2. 警告示例 …...

中文命名实体识别

本文通过people_daily_ner数据集,介绍两段式训练过程,第一阶段是训练下游任务模型,第二阶段是联合训练下游任务模型和预训练模型,来实现中文命名实体识别任务。 一.任务和数据集介绍 1.命名实体识别任务 NER(Named En…...

WPF CommunityToolkit.Mvvm Messenger通讯

文章目录 环境WeakReferenceMessenger方法介绍无回调订阅发送Token区分有回调订阅发送 环境 CommunityToolkit.Mvvm Messenger 十月的寒流: 如何使用 CommunityToolkit.Mvvm 中的 Messenger 来进行 ViewModel 之间的通信 WeakReferenceMessenger 我这里只讲简单的弱Messenger…...

【杂言】写在研究生开学季

这两天搬进了深研院的宿舍,比中南的本科宿舍好很多,所以个人还算满意。受台风 “苏拉” 的影响,原本的迎新计划全部打乱,导致我现在都还没报道。刚开学的半个月将被各类讲座、体检以及入学教育等活动占满,之后又是比较…...

渗透测试漏洞原理之---【任意文件读取漏洞】

文章目录 1、概述1.1、漏洞成因1.2、漏洞危害1.3、漏洞分类1.4、任意文件读取1.4.1、文件读取函数1.4.2、任意文件读取 1.5、任意文件下载1.5.1、一般情况1.5.2、PHP实现1.5.3、任意文件下载 2、任意文件读取攻防2.1、路径过滤2.1.1、过滤../ 2.2、简单绕过2.2.1、双写绕过2.2.…...

合宙Air724UG LuatOS-Air LVGL API控件-图片 (Image)

图片 (Image) 图片IMG是用于显示图像的基本对象类型,图像来源可以是文件,或者定义的符号。 示例代码 -- 创建图片控件 img lvgl.img_create(lvgl.scr_act(), nil) -- 设置图片显示的图像 lvgl.img_set_src(img, "/lua/luatos.png") -- 图片…...

仿京东 项目笔记2(注册登录)

这里写目录标题 1. 注册页面1.1 注册/登录页面——接口请求1.2 Vue开发中Element UI的样式穿透1.2.1 ::v-deep的使用1.2.2 elementUI Dialog内容区域显示滚动条 1.3 注册页面——步骤条和表单联动 stepsform1.4 注册页面——滑动拼图验证1.5 注册页面——element-ui组件Popover…...

Spark与Flink的区别

分析&回答 (1)设计理念 1、Spark的技术理念是使用微批来模拟流的计算,基于Micro-batch,数据流以时间为单位被切分为一个个批次,通过分布式数据集RDD进行批量处理,是一种伪实时。 2、Flink是基于事件驱动的,是面向流的处理框架, Flink基于…...

未来智造:珠三角引领人工智能产业集群

原创 | 文 BFT机器人 产业集群是指产业或产业群体在地理位置上集聚的现象,产业集群的研究对拉动区域经济发展,提高区域产业竞争力具有重要意义。 从我国人工智能产业集群形成及区域布局来看,我国人工智能产业发展主要集聚在京津冀、长三角、…...

【Unity db】sqlite

背景 最近使用unity,需要用到sqlite,记录下使用过程 需要的动态库 Mono.Data.Sqlite.dll,这个文件下载参考下面链接 SqliteConnection的Close和Open 连接的概念: 在数据库编程中,连接是一个重要的概念&#xff0c…...

Linux 指令心法(四)`touch` 创建一个新的空文件

文章目录 命令的概述和用途命令的用法命令行选项和参数的详细说明命令的示例命令的注意事项或提示 命令的概述和用途 touch 是一个用于在 Linux 和 Unix 系统中创建空文件或更改现有文件的访问和修改时间的命令。如果指定的文件不存在,touch会创建一个新的空文件&a…...

分类算法系列②:KNN算法

目录 KNN算法 1、简介 2、原理分析 数学原理 相关公式及其过程分析 距离度量 k值选择 分类决策规则 3、API 4、⭐案例实践 4.1、分析 4.2、代码 5、K-近邻算法总结 🍃作者介绍:准大三网络工程专业在读,努力学习Java,涉…...

12. 微积分 - 梯度积分

Hi,大家好。我是茶桁。 上一节课,我们讲了方向导数,并且在最后留了个小尾巴,是什么呢?就是梯度。 我们再来回看一下但是的这个式子: [ f x f y...

Large Language Models and Knowledge Graphs: Opportunities and Challenges

本文是LLM系列的文章,针对《Large Language Models and Knowledge Graphs: Opportunities and Challenges》的翻译。 大语言模型和知识图谱:机会与挑战 摘要1 引言2 社区内的共同辩论点3 机会和愿景4 关键研究主题和相关挑战5 前景 摘要 大型语言模型&…...

Python操作Excel教程(图文教程,超详细)Python xlwings模块详解,

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 xlwings模块详解 1、快速入门1、打开Excel2、创建工作簿2.1、使用工作簿2.2、操作…...

Java入门

Java导入包 import 主要用于导入在使用类前准备好了Import 关键字可以多次使用,导入多个包Import 关键字可以多次使用,导入多个类如果同一个包中需要大量的类,那么可以使用通配符进行导入如果Import了不同包,相同名称的类&#x…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...

RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill

视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

📌 LRU 缓存机制详解与实现(Java版) 一、📖 问题背景 在日常开发中,我们经常会使用 缓存(Cache) 来提升性能。但由于内存有限,缓存不可能无限增长,于是需要策略决定&am…...

上位机开发过程中的设计模式体会(1):工厂方法模式、单例模式和生成器模式

简介 在我的 QT/C 开发工作中,合理运用设计模式极大地提高了代码的可维护性和可扩展性。本文将分享我在实际项目中应用的三种创造型模式:工厂方法模式、单例模式和生成器模式。 1. 工厂模式 (Factory Pattern) 应用场景 在我的 QT 项目中曾经有一个需…...

基于鸿蒙(HarmonyOS5)的打车小程序

1. 开发环境准备 安装DevEco Studio (鸿蒙官方IDE)配置HarmonyOS SDK申请开发者账号和必要的API密钥 2. 项目结构设计 ├── entry │ ├── src │ │ ├── main │ │ │ ├── ets │ │ │ │ ├── pages │ │ │ │ │ ├── H…...

前端高频面试题2:浏览器/计算机网络

本专栏相关链接 前端高频面试题1:HTML/CSS 前端高频面试题2:浏览器/计算机网络 前端高频面试题3:JavaScript 1.什么是强缓存、协商缓存? 强缓存: 当浏览器请求资源时,首先检查本地缓存是否命中。如果命…...

《信号与系统》第 6 章 信号与系统的时域和频域特性

目录 6.0 引言 6.1 傅里叶变换的模和相位表示 6.2 线性时不变系统频率响应的模和相位表示 6.2.1 线性与非线性相位 6.2.2 群时延 6.2.3 对数模和相位图 6.3 理想频率选择性滤波器的时域特性 6.4 非理想滤波器的时域和频域特性讨论 6.5 一阶与二阶连续时间系统 6.5.1 …...