当前位置: 首页 > news >正文

机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 构建神经网络模型
model = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_images, train_labels, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)# 保存模型
model.save('mnist_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像# 调整图片大小为 28x28 像素
image = image.resize((28, 28))# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)# 使用模型进行预测
predictions = loaded_model.predict(input_image)# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

相关文章:

机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成: 1. 输入层 输入层接收输入数据,这里是一个 28x…...

vscode wsl远程连接 权限问题

问题描述:执行命令时遇到Operation not permitted 和 Permission denied问题,是有关ip地址和创建文件的权限问题,参考网络上更改wsl.conf文件等方法均无法解决,只能加sudo来解决...

VED-eBPF:一款基于eBPF的内核利用和Rootkit检测工具

关于VED-eBPF VED-eBPF是一款功能强大的内核漏洞利用和Rootkit检测工具,该工具基于eBPF技术实现其功能,可以实现Linux操作系统运行时内核安全监控和漏洞利用检测。 eBPF是一个内核内虚拟机,它允许我们直接在内核中执行代码,而无…...

配置ARM交叉编译工具的通用步骤

ARM交叉编译工具是用于编译在ARM架构上运行的代码的工具。这些工具允许开发者在一种架构(通常是x86或x64)上编写和编译代码,然后将其移植到ARM架构上运行。 ARM交叉编译工具链通常包括编译器、链接器、调试器和其他必要的工具,用…...

相机图像质量研究(5)常见问题总结:光学结构对成像的影响--景深

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结:光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结:光学结构对成…...

使用django构建一个多级评论功能

,评论系统是交流和反馈的重要工具,尤其是多级评论系统,它允许用户回复特定评论,形成丰富的对话结构。这个文章是使用Django框架从零开始构建一个多级评论系统。Django是一个高级Python Web框架,它鼓励快速开发和干净、…...

测试管理_利用python连接禅道数据库并自动统计bug数据到钉钉群

测试管理_利用python连接禅道数据库并统计bug数据到钉钉 这篇不多赘述,直接上代码文件。 另文章基础参考博文:参考博文 加以我自己的需求优化而成。 统计的前提 以下代码统计的前提是禅道的提bug流程应规范化 bug未解决不删除bug未关闭不删除 db_…...

Python 小白的 Leetcode Daily Challenge 刷题计划 - 20240209(除夕)

368. Largest Divisible Subset 难度:Medium 动态规划 方案还原 Yesterdays Daily Challenge can be reduced to the problem of shortest path in an unweighted graph while todays daily challenge can be reduced to the problem of longest path in an unwe…...

BFS——双向广搜+A—star

有时候从一个点能扩展出来的情况很多,这样几层之后搜索空间就很大了,我们采用从两端同时进行搜索的策略,压缩搜索空间。 190. 字串变换(190. 字串变换 - AcWing题库) 思路:这题因为变化规则很多,所以我们一层一层往外…...

LLM之LangChain(七)| 使用LangChain,LangSmith实现Prompt工程ToT

如下图所示,LLM仍然是自治代理的backbone,可以通过给LLM增加以下模块来增强LLM功能: Prompter AgentChecker ModuleMemory moduleToT controller 当解决具体问题时,这些模块与LLM进行多轮对话。这是基于LLM的自治代理的典型情况,…...

新零售的升维体验,摸索华为云GaussDB如何实现数据赋能

新零售商业模式 商业模式通常是由客户价值、企业资源和能力、盈利方式三个方面构成。其最主要的用途是为实现客户价值最大化。 商业模式通过把能使企业运行的内外各要素整合起来,从而形成一个完整的、高效率的、具有独特核心竞争力的运行系统,并通过最…...

vscode +git +gitee 文件管理

文章目录 前言一、gitee是什么?2. Gitee与VScode连接大概步骤 二、在vscode中安装git1.安装git2.安装过程3.安装完后记得重启 三、使用1.新建文件夹first2.vscode 使用 四、连接git1.初始化仓库2.设置git 提交用户和邮箱3.登陆gitee账号新建仓库没有的自己注册一个4…...

【力扣】用栈判断有效的括号

有效的括号原题地址 方法一:栈 对于特殊情况,当字符串的长度为奇数时,一定不是有效的括号。 对于一般情况,考虑使用数据结构栈。 遍历字符串, 遇到左括号时,就入栈。遇到右括号时, 若栈顶元…...

【目录】CSAPP的实验简介与解法总结(已包含Attack/Link/Architecture/Cache)

文章目录 Attack Lab(缓冲区溢出实验)对应书上Chap3Link Lab(链接实验) 对应书上Chap7Architecture Lab(体系结构实验)对应书上Chap4-5Cache Lab(缓存实验)对应书上Chap6 Attack Lab…...

【机器学习】数据清洗之识别缺失点

🎈个人主页:甜美的江 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步…...

【Vue】Vue基础入门

📝个人主页:五敷有你 🔥系列专栏:Vue ⛺️稳重求进,晒太阳 Vue概念 是一个用于构建用户界面的渐进式框架优点:大大提高开发效率缺点:需要理解记忆规则 创建Vue实例 步骤: …...

正点原子-STM32通用定时器学习笔记(1)

目录 1. 通用定时器简介(F1为例) 2. 通用定时器框图 ①时钟源 ②控制器 ③时基单元 ④输入捕获 ⑤捕获/比较(公共) ⑥输出比较 3.时钟源配置 3.1 计数器时钟源寄存器设置方法 3.2 外部时钟模式1 3.3 外部时钟模式2 3…...

Redis篇之redis是单线程

一、redis是单线程 Redis是单线程的,但是为什么还那么快?主要原因有下面3点原因: 1. Redis是纯内存操作,执行速度非常快。 2. 采用单线程,避免不必要的上下文切换可竞争条件,多线程还要考虑线程安全问题。 …...

随机MM引流源码PHP开源版

引流源码最新随机MM开源版PHP源码,非常简洁好看的单页全解代码没任何加密 直接上传即可用无需数据库支持主机空间...

【C++修行之道】(引用、函数提高)

目录 一、引用 1.1引用的基本使用 1.2 引用注意事项 1.3 引用做函数参数 1.4 引用做函数返回值 1.5 引用的本质 1.6 常量引用 1.7引用和指针的区别 二、函数提高 2.1 函数默认参数 2.2函数占位参数 2.3 函数重载 2.4函数重载注意事项 一、引用 1.1引用的基本使用 …...

网络六边形受到攻击

大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...

在rocky linux 9.5上在线安装 docker

前面是指南,后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制&#xff0…...

leetcodeSQL解题:3564. 季节性销售分析

leetcodeSQL解题:3564. 季节性销售分析 题目: 表:sales ---------------------- | Column Name | Type | ---------------------- | sale_id | int | | product_id | int | | sale_date | date | | quantity | int | | price | decimal | -…...

企业如何增强终端安全?

在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...

浪潮交换机配置track检测实现高速公路收费网络主备切换NQA

浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...