01、Tensorflow实现二元手写数字识别
01、Tensorflow实现二元手写数字识别(二分类问题)
开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。
基于Tensorflow 2.10.0
1、识别目标
识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。
2、Tensorflow算法实现
STEP1:导入相关包
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score
import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。
import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。
from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。
from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。
from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。
import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。
import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。
import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。
from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。
STEP2:屏蔽无用警告并允许中文
logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。
tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。
warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。
plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。
plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。
STEP3:导入并划分数据集
划分10%作为测试:
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
STEP4:模型构建与训练
# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential([tf.keras.Input(shape=(400,)), #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)
STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。
# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92]) # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))
3、运行结果
按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。


4、工程下载与全部代码
工程链接:Tensorflow实现二元手写数字识别(二分类问题)
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_scorelogging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_data/X.npy")y = np.load("Handwritten_Digit_Recognition_data/y.npy")X = X[0:1000]y = y[0:1000]return X, y# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
# # Select random indices
# random_index = np.random.randint(m)
# # Select rows corresponding to the random indices and
# # 将1*400的数据转换为20*20的图像格式
# X_random_reshaped = X[random_index].reshape((20, 20)).T
# # Display the image
# ax.imshow(X_random_reshaped, cmap='gray')
# # Display the label above the image
# ax.set_title(y[random_index, 0])
# ax.set_axis_off()
# plt.show()# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential([tf.keras.Input(shape=(400,)), #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92]) # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))
相关文章:
01、Tensorflow实现二元手写数字识别
01、Tensorflow实现二元手写数字识别(二分类问题) 开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。 基于Tensorflow 2.10.0 1、…...
HCIA-Datacom跟官方路线学习第二部分
接着前面第六章,通过VLAN技术, 可以将物理的局域网划分成多个广播域, 实现同一VLAN内的网络设备可以直接进行二层通信, 不同VLAN内的设备不可以直接进行二层通信。 第七章 生成树 在以太网交换网络会使用冗余链路, 但…...
BIO、NIO和AIO的区别
一、基础知识: I/O 模型的简单理解: 1.BIO(Blocking I/O):同步阻塞,一个线程处理一个通道上的事件。 2.NIO(Non-blocking I/O):同步非阻塞,使用选择器&…...
makefile 学习(5)完整的makefile模板
参考自: (1)深度学习部署笔记(二): g, makefile语法,makefile自己的CUDA编程模板(2)https://zhuanlan.zhihu.com/p/396448133(3) 一个挺好的工程模板,(https://github.com/shouxieai/cpp-proj-template) 1. c 编译流…...
专业远程控制如何塑造安全体系?向日葵“全流程安全闭环”解析
安全是远程控制的重中之重,作为国民级远程控制品牌,向日葵远程控制就极为注重安全远控服务的塑造。近期向日葵发布了以安全和核心的新版“向日葵15”以及同步发布《贝锐向日葵远控安全标准白皮书》(下简称《白皮书》),…...
node.js解决输出中文乱码问题
个人简介 👨🏻💻个人主页:九黎aj 🏃🏻♂️幸福源自奋斗,平凡造就不凡 🌟如果文章对你有用,麻烦关注点赞收藏走一波,感谢支持! 🌱欢迎订阅我的…...
CentOS 7 使用异步网络框架Libevent
CentOS 7 安装Libevent库 libevent github地址:https://github.com/libevent/libevent 步骤1:首先,你需要下载libevent的源代码。你可以从github或者源代码官方网站下载。并上传至/usr/local/source_code/ 步骤2:下载完成后&…...
枚举 B. Lorry
Problem - B - Codeforces 题目大意:给物品数量 n n n,体积为 v ( 0 ≤ v ≤ 1 e 9 ) v_{(0 \le v \le 1e9)} v(0≤v≤1e9),第一行读入 n , v n, v n,v,之后 n n n行,读入 n n n个物品,之后每行依次是体…...
ON1 Photo RAW 2024 for Mac——专业照片编辑的终极利器
ON1 Photo RAW 2024 for Mac是一款专为Mac用户打造的照片编辑器,以其强大的功能和易用的操作,让你的照片编辑工作变得轻松愉快。 一、强大的RAW处理能力 ON1 Photo RAW 2024支持大量的RAW格式照片,能够让你在编辑过程中获得更多的自由度和更…...
从word复制内容到wangEditor富文本框的时候会把html标签也复制过来,如果只想实现直接复制纯文本,有什么好的实现方式
从word复制内容到wangEditor富文本框的时候会把html标签也复制过来,如果只想实现直接复制纯文本,有什么好的实现方式? 将 Word 中的内容复制到富文本编辑器时,常常会带有大量的 HTML 标签和样式,这可能导致不必要的格式…...
项目中如何配置数据可视化展现
在现今数据驱动的时代,可视化已逐渐成为数据分析的主要途径,可视化大屏的广泛使用便应运而生。很多公司及政务机构,常利用大屏的手段展现其实力或演示业务,可视化的效果能让观者更快速的理解结果并直观的看到数据展现。因此&#…...
ArkUI开发进阶—@Builder函数@BuilderParam装饰器的妙用与场景应用
ArkUI开发进阶—@Builder函数@BuilderParam装饰器的妙用与场景应用 HarmonyOS,作为一款全场景分布式操作系统,为了推动更广泛的应用开发,采用了一种先进而灵活的编程语言——ArkTS。ArkTS是在TypeScript(TS)的基础上发展而来,为HarmonyOS提供了丰富的应用开发工具,使开…...
大语言模型概述(三):基于亚马逊云科技的研究分析与实践
上期介绍了基于亚马逊云科技的大语言模型相关研究方向,以及大语言模型的训练和构建优化。本期将介绍大语言模型训练在亚马逊云科技上的最佳实践。 大语言模型训练在亚马逊云科技上的最佳实践 本章节内容,将重点关注大语言模型在亚马逊云科技上的最佳训…...
键入网址到网页显示,期间发生了什么?
文章目录 键入网址到网页显示,期间发生了什么?1. HTTP2. 真实地址查询 —— DNS3. 指南好帮手 —— 协议栈4. 可靠传输 —— TCP5. 远程定位 —— IP6. 两点传输 —— MAC7. 出口 —— 网卡8. 送别者 —— 交换机9. 出境大门 —— 路由器10. 互相扒皮 —…...
深度学习基于Python+TensorFlow+Django的水果识别系统
欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介简介技术组合系统功能使用流程 二、功能三、系统四. 总结 一项目简介 # 深度学习基于PythonTensorFlowDjango的水果识别系统介绍 简介 该水果识别系统基于…...
vs动态库生成过程中还存在静态库
为什么VS生成动态库dll同时还会生成lib静态库 动态库与静态库(Windows环境下) 动态库和静态库都是一种可执行代码的二进制形式,可以被操作系统载入内存执行。 静态库实际上是在链接时被链接到exe的,编译后,静态…...
P13 C++ 类 | 结构体内部的静态static
目录 01 前言 02 类内部创建静态变量的例子 03 在类的内部创建静态变量的作用 04 最后的话 01 前言 本期我们讨论 static 在一个类或一个结构体中的具体情况。 在几乎所有面向对象的语言中,静态在一个类中意味着特定的东西。这意味着在类的所有实例中ÿ…...
【腾讯云云上实验室-向量数据库】Tencent Cloud VectorDB在实战项目中替换Milvus测试
为什么尝试使用Tencent Cloud VectorDB替换Milvus向量库? 亮点:Tencent Cloud VectorDB支持Embedding,免去自己搭建模型的负担(搭建一个生产环境的模型实在耗费精力和体力)。 腾讯云向量数据库是什么? 腾…...
git clone -mirror 和 git clone 的区别
目录 前言两则区别git clone --mirrorgit clone 获取到的文件有什么不同瘦身仓库如何选择结语开源项目 前言 Git是一款强大的版本控制系统,通过Git可以方便地管理代码的版本和协作开发。在使用Git时,常见的操作之一就是通过git clone命令将远程仓库克隆…...
基于51单片机的公交自动报站系统
**单片机设计介绍, 基于51单片机的公交自动报站系统 文章目录 一 概要公交自动报站系统概述工作原理应用与优势 二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 很高兴为您介绍基于51单片机的公交自动报站系统: 公交自动报…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...
GraphQL 实战篇:Apollo Client 配置与缓存
GraphQL 实战篇:Apollo Client 配置与缓存 上一篇:GraphQL 入门篇:基础查询语法 依旧和上一篇的笔记一样,主实操,没啥过多的细节讲解,代码具体在: https://github.com/GoldenaArcher/graphql…...
Windows电脑能装鸿蒙吗_Windows电脑体验鸿蒙电脑操作系统教程
鸿蒙电脑版操作系统来了,很多小伙伴想体验鸿蒙电脑版操作系统,可惜,鸿蒙系统并不支持你正在使用的传统的电脑来安装。不过可以通过可以使用华为官方提供的虚拟机,来体验大家心心念念的鸿蒙系统啦!注意:虚拟…...
vxe-table vue 表格复选框多选数据,实现快捷键 Shift 批量选择功能
vxe-table vue 表格复选框多选数据,实现快捷键 Shift 批量选择功能 查看官网:https://vxetable.cn 效果 代码 通过 checkbox-config.isShift 启用批量选中,启用后按住快捷键和鼠标批量选取 <template><div><vxe-grid v-bind"gri…...
HTML版英语学习系统
HTML版英语学习系统 这是一个完全免费、无需安装、功能完整的英语学习工具,使用HTML CSS JavaScript实现。 功能 文本朗读练习 - 输入英文文章,系统朗读帮助练习听力和发音,适合跟读练习,模仿学习;实时词典查询 - 双…...
开源 vGPU 方案:HAMi,实现细粒度 GPU 切分
本文主要分享一个开源的 GPU 虚拟化方案:HAMi,包括如何安装、配置以及使用。 相比于上一篇分享的 TimeSlicing 方案,HAMi 除了 GPU 共享之外还可以实现 GPU core、memory 得限制,保证共享同一 GPU 的各个 Pod 都能拿到足够的资源。…...
rk3506上移植lvgl应用
本文档介绍如何在开发板上运行以及移植LVGL。 1. 移植准备 硬件环境:开发板及其配套屏幕 开发板镜像 主机环境:Ubuntu 22.04.5 2. LVGL启动 出厂系统默认配置了 LVGL,并且上电之后默认会启动 一个LVGL应用 。 LVGL 的启动脚本为/etc/init.d/pre_init/S00-lv_demo,…...
