当前位置: 首页 > news >正文

机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 构建神经网络模型
model = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_images, train_labels, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)# 保存模型
model.save('mnist_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像# 调整图片大小为 28x28 像素
image = image.resize((28, 28))# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)# 使用模型进行预测
predictions = loaded_model.predict(input_image)# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

相关文章:

机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成: 1. 输入层 输入层接收输入数据,这里是一个 28x…...

vscode wsl远程连接 权限问题

问题描述:执行命令时遇到Operation not permitted 和 Permission denied问题,是有关ip地址和创建文件的权限问题,参考网络上更改wsl.conf文件等方法均无法解决,只能加sudo来解决...

VED-eBPF:一款基于eBPF的内核利用和Rootkit检测工具

关于VED-eBPF VED-eBPF是一款功能强大的内核漏洞利用和Rootkit检测工具,该工具基于eBPF技术实现其功能,可以实现Linux操作系统运行时内核安全监控和漏洞利用检测。 eBPF是一个内核内虚拟机,它允许我们直接在内核中执行代码,而无…...

配置ARM交叉编译工具的通用步骤

ARM交叉编译工具是用于编译在ARM架构上运行的代码的工具。这些工具允许开发者在一种架构(通常是x86或x64)上编写和编译代码,然后将其移植到ARM架构上运行。 ARM交叉编译工具链通常包括编译器、链接器、调试器和其他必要的工具,用…...

相机图像质量研究(5)常见问题总结:光学结构对成像的影响--景深

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结:光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结:光学结构对成…...

使用django构建一个多级评论功能

,评论系统是交流和反馈的重要工具,尤其是多级评论系统,它允许用户回复特定评论,形成丰富的对话结构。这个文章是使用Django框架从零开始构建一个多级评论系统。Django是一个高级Python Web框架,它鼓励快速开发和干净、…...

测试管理_利用python连接禅道数据库并自动统计bug数据到钉钉群

测试管理_利用python连接禅道数据库并统计bug数据到钉钉 这篇不多赘述,直接上代码文件。 另文章基础参考博文:参考博文 加以我自己的需求优化而成。 统计的前提 以下代码统计的前提是禅道的提bug流程应规范化 bug未解决不删除bug未关闭不删除 db_…...

Python 小白的 Leetcode Daily Challenge 刷题计划 - 20240209(除夕)

368. Largest Divisible Subset 难度:Medium 动态规划 方案还原 Yesterdays Daily Challenge can be reduced to the problem of shortest path in an unweighted graph while todays daily challenge can be reduced to the problem of longest path in an unwe…...

BFS——双向广搜+A—star

有时候从一个点能扩展出来的情况很多,这样几层之后搜索空间就很大了,我们采用从两端同时进行搜索的策略,压缩搜索空间。 190. 字串变换(190. 字串变换 - AcWing题库) 思路:这题因为变化规则很多,所以我们一层一层往外…...

LLM之LangChain(七)| 使用LangChain,LangSmith实现Prompt工程ToT

如下图所示,LLM仍然是自治代理的backbone,可以通过给LLM增加以下模块来增强LLM功能: Prompter AgentChecker ModuleMemory moduleToT controller 当解决具体问题时,这些模块与LLM进行多轮对话。这是基于LLM的自治代理的典型情况,…...

新零售的升维体验,摸索华为云GaussDB如何实现数据赋能

新零售商业模式 商业模式通常是由客户价值、企业资源和能力、盈利方式三个方面构成。其最主要的用途是为实现客户价值最大化。 商业模式通过把能使企业运行的内外各要素整合起来,从而形成一个完整的、高效率的、具有独特核心竞争力的运行系统,并通过最…...

vscode +git +gitee 文件管理

文章目录 前言一、gitee是什么?2. Gitee与VScode连接大概步骤 二、在vscode中安装git1.安装git2.安装过程3.安装完后记得重启 三、使用1.新建文件夹first2.vscode 使用 四、连接git1.初始化仓库2.设置git 提交用户和邮箱3.登陆gitee账号新建仓库没有的自己注册一个4…...

【力扣】用栈判断有效的括号

有效的括号原题地址 方法一:栈 对于特殊情况,当字符串的长度为奇数时,一定不是有效的括号。 对于一般情况,考虑使用数据结构栈。 遍历字符串, 遇到左括号时,就入栈。遇到右括号时, 若栈顶元…...

【目录】CSAPP的实验简介与解法总结(已包含Attack/Link/Architecture/Cache)

文章目录 Attack Lab(缓冲区溢出实验)对应书上Chap3Link Lab(链接实验) 对应书上Chap7Architecture Lab(体系结构实验)对应书上Chap4-5Cache Lab(缓存实验)对应书上Chap6 Attack Lab…...

【机器学习】数据清洗之识别缺失点

🎈个人主页:甜美的江 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步…...

【Vue】Vue基础入门

📝个人主页:五敷有你 🔥系列专栏:Vue ⛺️稳重求进,晒太阳 Vue概念 是一个用于构建用户界面的渐进式框架优点:大大提高开发效率缺点:需要理解记忆规则 创建Vue实例 步骤: …...

正点原子-STM32通用定时器学习笔记(1)

目录 1. 通用定时器简介(F1为例) 2. 通用定时器框图 ①时钟源 ②控制器 ③时基单元 ④输入捕获 ⑤捕获/比较(公共) ⑥输出比较 3.时钟源配置 3.1 计数器时钟源寄存器设置方法 3.2 外部时钟模式1 3.3 外部时钟模式2 3…...

Redis篇之redis是单线程

一、redis是单线程 Redis是单线程的,但是为什么还那么快?主要原因有下面3点原因: 1. Redis是纯内存操作,执行速度非常快。 2. 采用单线程,避免不必要的上下文切换可竞争条件,多线程还要考虑线程安全问题。 …...

随机MM引流源码PHP开源版

引流源码最新随机MM开源版PHP源码,非常简洁好看的单页全解代码没任何加密 直接上传即可用无需数据库支持主机空间...

【C++修行之道】(引用、函数提高)

目录 一、引用 1.1引用的基本使用 1.2 引用注意事项 1.3 引用做函数参数 1.4 引用做函数返回值 1.5 引用的本质 1.6 常量引用 1.7引用和指针的区别 二、函数提高 2.1 函数默认参数 2.2函数占位参数 2.3 函数重载 2.4函数重载注意事项 一、引用 1.1引用的基本使用 …...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

Qt Widget类解析与代码注释

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

VM虚拟机网络配置(ubuntu24桥接模式):配置静态IP

编辑-虚拟网络编辑器-更改设置 选择桥接模式,然后找到相应的网卡(可以查看自己本机的网络连接) windows连接的网络点击查看属性 编辑虚拟机设置更改网络配置,选择刚才配置的桥接模式 静态ip设置: 我用的ubuntu24桌…...

宇树科技,改名了!

提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...

Android屏幕刷新率与FPS(Frames Per Second) 120hz

Android屏幕刷新率与FPS(Frames Per Second) 120hz 屏幕刷新率是屏幕每秒钟刷新显示内容的次数,单位是赫兹(Hz)。 60Hz 屏幕:每秒刷新 60 次,每次刷新间隔约 16.67ms 90Hz 屏幕:每秒刷新 90 次,…...

Java设计模式:责任链模式

一、什么是责任链模式? 责任链模式(Chain of Responsibility Pattern) 是一种 行为型设计模式,它通过将请求沿着一条处理链传递,直到某个对象处理它为止。这种模式的核心思想是 解耦请求的发送者和接收者,…...

[QMT量化交易小白入门]-六十二、ETF轮动中简单的评分算法如何获取历史年化收益32.7%

本专栏主要是介绍QMT的基础用法,常见函数,写策略的方法,也会分享一些量化交易的思路,大概会写100篇左右。 QMT的相关资料较少,在使用过程中不断的摸索,遇到了一些问题,记录下来和大家一起沟通,共同进步。 文章目录 相关阅读1. 策略概述2. 趋势评分模块3 代码解析4 木头…...