当前位置: 首页 > news >正文

keras深度学习框架通过卷积神经网络cnn实现手写数字识别

    昨天通过keras构建简单神经网络实现手写数字识别,结果在最后进行我们自己的手写数字识别的时候,准确率堪忧,只有60%。今天通过卷积神经网络来实现手写数字识别。

    构建卷积神经网络和简单神经网络思路类似,只不过这里加入了卷积、池化等概念,网络结构复杂了一些,但是整体的思路没有变化,加载数据集,数据集修改,搭建网络模型,编译模型,训练模型,保存模型,利用模型预测。

    这里还是给出两个例子,一个是构建网络,最后保存训练好的网络模型,一个是通过加载保存的网络模型预测我们自己的手写数字图片。

import keras
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Conv2D, Flatten, MaxPool2D
from tensorflow.keras import datasets, utils
# 数据处理
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[1], 1)
x_train = x_train.astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[1], 1)
x_test = x_test.astype('float32') / 255
y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)
# 构建模型
model = Sequential()
model.add(Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation="relu", input_shape=(28, 28, 1)))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Conv2D(filters=36, kernel_size=(3, 3), padding='same', activation="relu"))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation="relu"))
model.add(Dropout(0.25))
model.add(Dense(10, activation="softmax"))
# 编译
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()
# 训练
model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))
# 保存模型
model.save("mnist.h5")

     训练模型,打印信息如下:

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 28, 28, 16)        160       max_pooling2d (MaxPooling2D  (None, 14, 14, 16)       0         )                                                               conv2d_1 (Conv2D)           (None, 14, 14, 36)        5220      max_pooling2d_1 (MaxPooling  (None, 7, 7, 36)         0         2D)                                                             dropout (Dropout)           (None, 7, 7, 36)          0         flatten (Flatten)           (None, 1764)              0         dense (Dense)               (None, 128)               225920    dropout_1 (Dropout)         (None, 128)               0         dense_1 (Dense)             (None, 10)                1290      =================================================================
Total params: 232,590
Trainable params: 232,590
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
2023-08-28 16:03:54.677314: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8800
469/469 [==============================] - 10s 17ms/step - loss: 0.2842 - accuracy: 0.9123 - val_loss: 0.0628 - val_accuracy: 0.9798
Epoch 2/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0836 - accuracy: 0.9743 - val_loss: 0.0473 - val_accuracy: 0.9841
Epoch 3/5
469/469 [==============================] - 7s 16ms/step - loss: 0.0627 - accuracy: 0.9801 - val_loss: 0.0325 - val_accuracy: 0.9886
Epoch 4/5
469/469 [==============================] - 7s 15ms/step - loss: 0.0497 - accuracy: 0.9844 - val_loss: 0.0346 - val_accuracy: 0.9882
Epoch 5/5
469/469 [==============================] - 7s 15ms/step - loss: 0.0422 - accuracy: 0.9867 - val_loss: 0.0298 - val_accuracy: 0.9898

    准确率最后,到达了98.5%以上。

    用模型预测

import keras
import numpy as np
import cv2
from keras.models import load_modelmodel = load_model("mnist.h5")def predict(img_path):img = cv2.imread(img_path, 0)img = img.reshape(28, 28).astype("float32") / 255  # 0 1img = img.reshape(1, 28, 28, 1)  # 28 * 28 -> (1,28,28,1)label = model.predict(img)label = np.argmax(label, axis=1)print('{} -> {}'.format(img_path, label[0]))if __name__ == '__main__':for _ in range(10):predict("number_images/b_{}.png".format(_))

    数字图片如下: 

    图片放在项目目录number_images中。

    预测结果打印:

 

    感觉就是不一样,准确率从60%提升到了90%。虽然没有达到100%,但是已经很好了。 

    对比之前的代码,改动很小,主要是网络输入的时候,数据形状发生了改变,简单神经网络需要的是(784,*)结构,卷积神经网络需要的是(1,28,28,1)的结构, 在数据处理上做了调整,另一个不一样的地方就是网络模型在添加的时候,之前就是简单的两层网络,卷积神经网络复杂了很多。

相关文章:

keras深度学习框架通过卷积神经网络cnn实现手写数字识别

昨天通过keras构建简单神经网络实现手写数字识别,结果在最后进行我们自己的手写数字识别的时候,准确率堪忧,只有60%。今天通过卷积神经网络来实现手写数字识别。 构建卷积神经网络和简单神经网络思路类似,只不过这里加入了卷积、池…...

Springboot启动异常 Command line is too long

Springboot启动异常 Command line is too long Springboot启动时直接报异常 Command line is too long. Shorten command line for xxxxxApplication or also for Spring Boot default解决方案: 修改 SystemApplication 的 Shorten command line,选择 JAR manife…...

PXE 装机(五十)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 一、PXE是什么 二、PXE的组件 三、配置vsftpd 四、配置tftp 五、准备pxelinx.0文件、引导文件、内核文件 ​六、配置dhcp 七、创建default文件 八、配置pxe无人值守…...

C++ 虚函数与纯虚函数

目录 1. 虚函数 2. 纯虚函数 C 中的虚函数和纯虚函数都是实现多态的重要机制。多态可以让不同的对象以相同的方式进行操作,从而简化代码的编写和维护。 1. 虚函数 虚函数是一种在基类中声明的函数,可以在派生类中进行重写。在运行时,根据对…...

警告:Provides transitive vulnerable dependency maven:org.yaml:snakeyaml:1.30

1. 警告 SpringBoot 的 validation 依赖包含有易受攻击的依赖 snakeyaml。 警告信息如下: Provides transitive vulnerable dependency maven:org.yaml:snakeyaml:1.30 意思是:提供了可传递的易受攻击依赖 maven:org.yaml:snakeyaml:1.30 2. 警告示例 …...

中文命名实体识别

本文通过people_daily_ner数据集,介绍两段式训练过程,第一阶段是训练下游任务模型,第二阶段是联合训练下游任务模型和预训练模型,来实现中文命名实体识别任务。 一.任务和数据集介绍 1.命名实体识别任务 NER(Named En…...

WPF CommunityToolkit.Mvvm Messenger通讯

文章目录 环境WeakReferenceMessenger方法介绍无回调订阅发送Token区分有回调订阅发送 环境 CommunityToolkit.Mvvm Messenger 十月的寒流: 如何使用 CommunityToolkit.Mvvm 中的 Messenger 来进行 ViewModel 之间的通信 WeakReferenceMessenger 我这里只讲简单的弱Messenger…...

【杂言】写在研究生开学季

这两天搬进了深研院的宿舍,比中南的本科宿舍好很多,所以个人还算满意。受台风 “苏拉” 的影响,原本的迎新计划全部打乱,导致我现在都还没报道。刚开学的半个月将被各类讲座、体检以及入学教育等活动占满,之后又是比较…...

渗透测试漏洞原理之---【任意文件读取漏洞】

文章目录 1、概述1.1、漏洞成因1.2、漏洞危害1.3、漏洞分类1.4、任意文件读取1.4.1、文件读取函数1.4.2、任意文件读取 1.5、任意文件下载1.5.1、一般情况1.5.2、PHP实现1.5.3、任意文件下载 2、任意文件读取攻防2.1、路径过滤2.1.1、过滤../ 2.2、简单绕过2.2.1、双写绕过2.2.…...

合宙Air724UG LuatOS-Air LVGL API控件-图片 (Image)

图片 (Image) 图片IMG是用于显示图像的基本对象类型,图像来源可以是文件,或者定义的符号。 示例代码 -- 创建图片控件 img lvgl.img_create(lvgl.scr_act(), nil) -- 设置图片显示的图像 lvgl.img_set_src(img, "/lua/luatos.png") -- 图片…...

仿京东 项目笔记2(注册登录)

这里写目录标题 1. 注册页面1.1 注册/登录页面——接口请求1.2 Vue开发中Element UI的样式穿透1.2.1 ::v-deep的使用1.2.2 elementUI Dialog内容区域显示滚动条 1.3 注册页面——步骤条和表单联动 stepsform1.4 注册页面——滑动拼图验证1.5 注册页面——element-ui组件Popover…...

Spark与Flink的区别

分析&回答 (1)设计理念 1、Spark的技术理念是使用微批来模拟流的计算,基于Micro-batch,数据流以时间为单位被切分为一个个批次,通过分布式数据集RDD进行批量处理,是一种伪实时。 2、Flink是基于事件驱动的,是面向流的处理框架, Flink基于…...

未来智造:珠三角引领人工智能产业集群

原创 | 文 BFT机器人 产业集群是指产业或产业群体在地理位置上集聚的现象,产业集群的研究对拉动区域经济发展,提高区域产业竞争力具有重要意义。 从我国人工智能产业集群形成及区域布局来看,我国人工智能产业发展主要集聚在京津冀、长三角、…...

【Unity db】sqlite

背景 最近使用unity,需要用到sqlite,记录下使用过程 需要的动态库 Mono.Data.Sqlite.dll,这个文件下载参考下面链接 SqliteConnection的Close和Open 连接的概念: 在数据库编程中,连接是一个重要的概念&#xff0c…...

Linux 指令心法(四)`touch` 创建一个新的空文件

文章目录 命令的概述和用途命令的用法命令行选项和参数的详细说明命令的示例命令的注意事项或提示 命令的概述和用途 touch 是一个用于在 Linux 和 Unix 系统中创建空文件或更改现有文件的访问和修改时间的命令。如果指定的文件不存在,touch会创建一个新的空文件&a…...

分类算法系列②:KNN算法

目录 KNN算法 1、简介 2、原理分析 数学原理 相关公式及其过程分析 距离度量 k值选择 分类决策规则 3、API 4、⭐案例实践 4.1、分析 4.2、代码 5、K-近邻算法总结 🍃作者介绍:准大三网络工程专业在读,努力学习Java,涉…...

12. 微积分 - 梯度积分

Hi,大家好。我是茶桁。 上一节课,我们讲了方向导数,并且在最后留了个小尾巴,是什么呢?就是梯度。 我们再来回看一下但是的这个式子: [ f x f y...

Large Language Models and Knowledge Graphs: Opportunities and Challenges

本文是LLM系列的文章,针对《Large Language Models and Knowledge Graphs: Opportunities and Challenges》的翻译。 大语言模型和知识图谱:机会与挑战 摘要1 引言2 社区内的共同辩论点3 机会和愿景4 关键研究主题和相关挑战5 前景 摘要 大型语言模型&…...

Python操作Excel教程(图文教程,超详细)Python xlwings模块详解,

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 xlwings模块详解 1、快速入门1、打开Excel2、创建工作簿2.1、使用工作簿2.2、操作…...

Java入门

Java导入包 import 主要用于导入在使用类前准备好了Import 关键字可以多次使用,导入多个包Import 关键字可以多次使用,导入多个类如果同一个包中需要大量的类,那么可以使用通配符进行导入如果Import了不同包,相同名称的类&#x…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表&#xff1f;1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

Mac软件卸载指南,简单易懂!

刚和Adobe分手&#xff0c;它却总在Library里给你写"回忆录"&#xff1f;卸载的Final Cut Pro像电子幽灵般阴魂不散&#xff1f;总是会有残留文件&#xff0c;别慌&#xff01;这份Mac软件卸载指南&#xff0c;将用最硬核的方式教你"数字分手术"&#xff0…...

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

微信小程序云开发平台MySQL的连接方式

注&#xff1a;微信小程序云开发平台指的是腾讯云开发 先给结论&#xff1a;微信小程序云开发平台的MySQL&#xff0c;无法通过获取数据库连接信息的方式进行连接&#xff0c;连接只能通过云开发的SDK连接&#xff0c;具体要参考官方文档&#xff1a; 为什么&#xff1f; 因为…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索&#xff08;基于物理空间 广播范围&#xff09;2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

【Linux】Linux 系统默认的目录及作用说明

博主介绍&#xff1a;✌全网粉丝23W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...