当前位置: 首页 > news >正文

01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别(二分类问题)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0

1、识别目标

识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。

2、Tensorflow算法实现

STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。

from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。


STEP2:屏蔽无用警告并允许中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。

tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。

plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。


STEP3:导入并划分数据集

划分10%作为测试:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型构建与训练

# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)

STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

3、运行结果

按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。
在这里插入图片描述

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现二元手写数字识别(二分类问题)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_scorelogging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_data/X.npy")y = np.load("Handwritten_Digit_Recognition_data/y.npy")X = X[0:1000]y = y[0:1000]return X, y# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 将1*400的数据转换为20*20的图像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

相关文章:

01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别(二分类问题) 开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。 基于Tensorflow 2.10.0 1、…...

HCIA-Datacom跟官方路线学习第二部分

接着前面第六章,通过VLAN技术, 可以将物理的局域网划分成多个广播域, 实现同一VLAN内的网络设备可以直接进行二层通信, 不同VLAN内的设备不可以直接进行二层通信。 第七章 生成树 在以太网交换网络会使用冗余链路, 但…...

BIO、NIO和AIO的区别

一、基础知识: I/O 模型的简单理解: 1.BIO(Blocking I/O):同步阻塞,一个线程处理一个通道上的事件。 2.NIO(Non-blocking I/O):同步非阻塞,使用选择器&…...

makefile 学习(5)完整的makefile模板

参考自: (1)深度学习部署笔记(二): g, makefile语法,makefile自己的CUDA编程模板(2)https://zhuanlan.zhihu.com/p/396448133(3) 一个挺好的工程模板,(https://github.com/shouxieai/cpp-proj-template) 1. c 编译流…...

专业远程控制如何塑造安全体系?向日葵“全流程安全闭环”解析

安全是远程控制的重中之重,作为国民级远程控制品牌,向日葵远程控制就极为注重安全远控服务的塑造。近期向日葵发布了以安全和核心的新版“向日葵15”以及同步发布《贝锐向日葵远控安全标准白皮书》(下简称《白皮书》),…...

node.js解决输出中文乱码问题

个人简介 👨🏻‍💻个人主页:九黎aj 🏃🏻‍♂️幸福源自奋斗,平凡造就不凡 🌟如果文章对你有用,麻烦关注点赞收藏走一波,感谢支持! 🌱欢迎订阅我的…...

CentOS 7 使用异步网络框架Libevent

CentOS 7 安装Libevent库 libevent github地址:https://github.com/libevent/libevent 步骤1:首先,你需要下载libevent的源代码。你可以从github或者源代码官方网站下载。并上传至/usr/local/source_code/ 步骤2:下载完成后&…...

枚举 B. Lorry

Problem - B - Codeforces 题目大意:给物品数量 n n n,体积为 v ( 0 ≤ v ≤ 1 e 9 ) v_{(0 \le v \le 1e9)} v(0≤v≤1e9)​,第一行读入 n , v n, v n,v,之后 n n n行,读入 n n n个物品,之后每行依次是体…...

ON1 Photo RAW 2024 for Mac——专业照片编辑的终极利器

ON1 Photo RAW 2024 for Mac是一款专为Mac用户打造的照片编辑器,以其强大的功能和易用的操作,让你的照片编辑工作变得轻松愉快。 一、强大的RAW处理能力 ON1 Photo RAW 2024支持大量的RAW格式照片,能够让你在编辑过程中获得更多的自由度和更…...

从word复制内容到wangEditor富文本框的时候会把html标签也复制过来,如果只想实现直接复制纯文本,有什么好的实现方式

从word复制内容到wangEditor富文本框的时候会把html标签也复制过来,如果只想实现直接复制纯文本,有什么好的实现方式? 将 Word 中的内容复制到富文本编辑器时,常常会带有大量的 HTML 标签和样式,这可能导致不必要的格式…...

项目中如何配置数据可视化展现

在现今数据驱动的时代,可视化已逐渐成为数据分析的主要途径,可视化大屏的广泛使用便应运而生。很多公司及政务机构,常利用大屏的手段展现其实力或演示业务,可视化的效果能让观者更快速的理解结果并直观的看到数据展现。因此&#…...

ArkUI开发进阶—@Builder函数@BuilderParam装饰器的妙用与场景应用

ArkUI开发进阶—@Builder函数@BuilderParam装饰器的妙用与场景应用 HarmonyOS,作为一款全场景分布式操作系统,为了推动更广泛的应用开发,采用了一种先进而灵活的编程语言——ArkTS。ArkTS是在TypeScript(TS)的基础上发展而来,为HarmonyOS提供了丰富的应用开发工具,使开…...

大语言模型概述(三):基于亚马逊云科技的研究分析与实践

上期介绍了基于亚马逊云科技的大语言模型相关研究方向,以及大语言模型的训练和构建优化。本期将介绍大语言模型训练在亚马逊云科技上的最佳实践。 大语言模型训练在亚马逊云科技上的最佳实践 本章节内容,将重点关注大语言模型在亚马逊云科技上的最佳训…...

键入网址到网页显示,期间发生了什么?

文章目录 键入网址到网页显示,期间发生了什么?1. HTTP2. 真实地址查询 —— DNS3. 指南好帮手 —— 协议栈4. 可靠传输 —— TCP5. 远程定位 —— IP6. 两点传输 —— MAC7. 出口 —— 网卡8. 送别者 —— 交换机9. 出境大门 —— 路由器10. 互相扒皮 —…...

深度学习基于Python+TensorFlow+Django的水果识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介简介技术组合系统功能使用流程 二、功能三、系统四. 总结 一项目简介 # 深度学习基于PythonTensorFlowDjango的水果识别系统介绍 简介 该水果识别系统基于…...

vs动态库生成过程中还存在静态库

为什么VS生成动态库dll同时还会生成lib静态库 动态库与静态库(Windows环境下) ​ 动态库和静态库都是一种可执行代码的二进制形式,可以被操作系统载入内存执行。 ​ 静态库实际上是在链接时被链接到exe的,编译后,静态…...

P13 C++ 类 | 结构体内部的静态static

目录 01 前言 02 类内部创建静态变量的例子 03 在类的内部创建静态变量的作用 04 最后的话 01 前言 本期我们讨论 static 在一个类或一个结构体中的具体情况。 在几乎所有面向对象的语言中,静态在一个类中意味着特定的东西。这意味着在类的所有实例中&#xff…...

【腾讯云云上实验室-向量数据库】Tencent Cloud VectorDB在实战项目中替换Milvus测试

为什么尝试使用Tencent Cloud VectorDB替换Milvus向量库? 亮点:Tencent Cloud VectorDB支持Embedding,免去自己搭建模型的负担(搭建一个生产环境的模型实在耗费精力和体力)。 腾讯云向量数据库是什么? 腾…...

git clone -mirror 和 git clone 的区别

目录 前言两则区别git clone --mirrorgit clone 获取到的文件有什么不同瘦身仓库如何选择结语开源项目 前言 Git是一款强大的版本控制系统,通过Git可以方便地管理代码的版本和协作开发。在使用Git时,常见的操作之一就是通过git clone命令将远程仓库克隆…...

基于51单片机的公交自动报站系统

**单片机设计介绍, 基于51单片机的公交自动报站系统 文章目录 一 概要公交自动报站系统概述工作原理应用与优势 二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 很高兴为您介绍基于51单片机的公交自动报站系统: 公交自动报…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题: 下面创建一个简单的Flask RESTful API示例。首先,我们需要创建环境,安装必要的依赖,然后…...

爬虫基础学习day2

# 爬虫设计领域 工商:企查查、天眼查短视频:抖音、快手、西瓜 ---> 飞瓜电商:京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空:抓取所有航空公司价格 ---> 去哪儿自媒体:采集自媒体数据进…...

力扣热题100 k个一组反转链表题解

题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...

Python实现简单音频数据压缩与解压算法

Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中&#xff0c;压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言&#xff0c;提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...

【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统

Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...

SQL注入篇-sqlmap的配置和使用

在之前的皮卡丘靶场第五期SQL注入的内容中我们谈到了sqlmap&#xff0c;但是由于很多朋友看不了解命令行格式&#xff0c;所以是纯手动获取数据库信息的 接下来我们就用sqlmap来进行皮卡丘靶场的sql注入学习&#xff0c;链接&#xff1a;https://wwhc.lanzoue.com/ifJY32ybh6vc…...

用js实现常见排序算法

以下是几种常见排序算法的 JS实现&#xff0c;包括选择排序、冒泡排序、插入排序、快速排序和归并排序&#xff0c;以及每种算法的特点和复杂度分析 1. 选择排序&#xff08;Selection Sort&#xff09; 核心思想&#xff1a;每次从未排序部分选择最小元素&#xff0c;与未排…...

【Java】Ajax 技术详解

文章目录 1. Filter 过滤器1.1 Filter 概述1.2 Filter 快速入门开发步骤:1.3 Filter 执行流程1.4 Filter 拦截路径配置1.5 过滤器链2. Listener 监听器2.1 Listener 概述2.2 ServletContextListener3. Ajax 技术3.1 Ajax 概述3.2 Ajax 快速入门服务端实现:客户端实现:4. Axi…...

ffmpeg(三):处理原始数据命令

FFmpeg 可以直接处理原始音频和视频数据&#xff08;Raw PCM、YUV 等&#xff09;&#xff0c;常见场景包括&#xff1a; 将原始 YUV 图像编码为 H.264 视频将 PCM 音频编码为 AAC 或 MP3对原始音视频数据进行封装&#xff08;如封装为 MP4、TS&#xff09; 处理原始 YUV 视频…...

二维数组 行列混淆区分 js

二维数组定义 行 row&#xff1a;是“横着的一整行” 列 column&#xff1a;是“竖着的一整列” 在 JavaScript 里访问二维数组 grid[i][j] 表示 第i行第j列的元素 let grid [[1, 2, 3], // 第0行[4, 5, 6], // 第1行[7, 8, 9] // 第2行 ];// grid[i][j] 表示 第i行第j列的…...