当前位置: 首页 > article >正文

【TensorFlow】T1:实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

1、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0],"GPU")print(gpus)
# 输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

固定模板,直接调用即可

2、导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images -- 训练集图片
# train_labels -- 训练集标签
# test_images  -- 测试集图片
# test_labels  -- 测试机标签

此处所需数据直接联网下载,故无需数据集

3、归一化

将图像数据从0-255缩放到0-1,主要是为了加快训练速度和提高模型的性能,使得模型训练过程更加稳定高效:

  1. 加速收敛:帮助优化算法更快找到最优解。
  2. 提升性能:确保不同规模的数据对模型的影响均衡,提高预测准确性。
  3. 简化调参:使超参数的选择更加简单有效。
  4. 数值稳定:避免因数据尺度问题导致的计算异常,如梯度爆炸。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

4、可视化

# figure函数--设置画布大小
# figsize--指定了画布的宽度(20英寸)和高度(10英寸)
plt.figure(figsize=(20, 10))# 遍历前20张图片
for i in range(20):# subplot函数--创建子图# 2--2行,10--10列,i+1--第i+1个位置plt.subplot(2, 10, i+1)# xticks函数--隐藏x轴plt.xticks([])# yticks函数--隐藏y轴plt.yticks([])# grid函数--关闭网格线plt.grid(False)# imshow函数--绘制图像# train_images[i]--需要显示的图片,cmap--设置颜色映射(binary)plt.imshow(train_images[i], cmap=plt.cm.binary)# xlabel函数--给图像添加标签plt.xlabel(train_labels[i])plt.show()

输出:
在这里插入图片描述

5、调整格式

# reshape函数--只改变数据形状,不改变数据内容
# 6000--6000个样本,28--28*28像素,1--颜色通道数(灰度)
train_images = train_images.reshape((60000, 28, 28, 1))
# 1000--1000个样本,28--28*28像素,1--颜色通道数(灰度)
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)

6、构建CNN网络模型

关键组件:

  1. 卷积层(Conv2D):主要用于提取图像的特征。它通过滑动窗口(即滤波器或核)在输入图像上移动,计算每个位置上的点积来生成特征图。每个滤波器可以识别特定类型的特征,如边缘或纹理。
  2. 激活函数(Activation function):如ReLU(Rectified Linear Unit),用于引入非线性因素,使模型能够学习更复杂的模式。
  3. 最大池化层(MaxPooling2D):用于减小特征图的空间维度(宽度和高度),同时保留最重要的信息。这有助于减少计算量并控制过拟合。
  4. 展平层(Flatten)将多维的卷积层输出转换为一维向量,以便将其输入到全连接层中。
  5. 全连接层(Dense):在这里,每个神经元与前一层的所有神经元相连,用于最终的分类任务。最后一层通常不使用激活函数(或者使用softmax函数),以直接输出每个类别的得分。
# Sequential函数--创建了一个Sequential模型,允许按顺序堆叠各层
model = models.Sequential([# Conv2D--第一个二维卷积层# 32--32个滤波器,(3, 3)--每个滤波器的大小为(3, 3),activation0--激活函数(relu),input_shape--输入数据的大小(28*28像素、单通道)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),# MaxPooling2D--第一个二维最大池化层# 2--池化窗口的大小为2*2,用于下采样时减少特征维度layers.MaxPooling2D(2, 2),# Conv2D--第二个二维卷积层layers.Conv2D(64, (3, 3), activation='relu'),# MaxPooling2D--第二个二维最大池化层layers.MaxPooling2D(2, 2),# Flatten--将多维的卷积层输出展平为一维向量,方便输入全连接层layers.Flatten(),# Dense--全连接层# 64--64个节点,activation0--激活函数(relu)layers.Dense(64, activation='relu'),# Dense--输出层# 10--10个节点(此处对应0-9),默认线性激活函数layers.Dense(10)
])# 打印模型结构和参数信息,包括每一层的输出形状和参数数量
print(model.summary())

输出:
在这里插入图片描述

7、编译模型

# compile函数--配置模型的编译设置
model.compile(# optimizer--优化器(adam)# Adam:一种基于一阶梯度的优化算法,能自适应地调整不同参数的学习率,适用于大规模数据和高维空间问题optimizer='adam',# loss--损失函数(Sparse Categorical Crossentropy)# Sparse Categorical Crossentropy(稀疏分类交叉熵损失):适用于多分类问题,特别是标签是整数形式时# from_logits=True:表示网络输出未经过softmax激活函数处理。这种情况下,损失函数会自动应用softmax来计算最终的分类概率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# metrics--性能指标# accuracy(准确率):表示预测正确的样本数占总样本数的比例metrics=['accuracy']
)

8、训练模型

  1. 主要参数:
  • x训练数据的输入(特征)。它可以是一个 Numpy 数组或者一个 TensorFlow 数据集等。对于图像数据,这通常是一个形状为 (样本数量, 高度, 宽度, 颜色通道) 的四维数组。
  • y目标数据(标签)。与 x 对应,表示每个输入样本的真实类别或值。对于分类任务,这通常是一个一维数组,长度等于训练样本的数量;对于多分类问题,可能是一个二维数组,采用 one-hot 编码。
  • batch_size每次梯度更新时使用的样本数。默认值是 32。较大的批量大小可以使计算更高效,但需要更多的内存。
  • epochs整个训练集迭代的次数。每次迭代称为一个 epoch。增加 epoch 数量可以让模型有更多机会学习数据中的模式,但也增加了过拟合的风险。
  • verbose:日志显示模式。0 表示不输出日志到屏幕上,1 表示输出进度条记录,2 表示每个epoch输出一行记录。
  • validation_data:在每个 epoch 结束后用于评估模型性能的数据集。这是一个元组 (x_val, y_val) 或者一个包含输入和目标数据的列表。这对于监控模型在未见过的数据上的表现非常重要。
  • validation_split:从训练数据中划分出一定比例的数据作为验证集,取值范围是 (0, 1)。注意,这个参数只会在 x 是 Numpy 数组时有效。
  • shuffle:在每个 epoch 开始前是否打乱训练数据。这对于确保模型不会因为数据顺序而产生偏差非常重要,默认值为 True。
  • callbacks:一个列表,其中包含各种回调函数对象,在训练过程中这些回调函数会在特定时间点被调用(如每轮结束、训练开始或结束等),可用于实现提前停止、动态调整学习率等功能。
  1. 返回值:
  • History对象:该对象包含了一个 history 属性,这是一个字典,包含了训练过程中损失值评估指标的变化情况。可以访问 history.history[‘loss’] 来获取每个 epoch 的训练损失值,或 history.history[‘val_accuracy’] 获取每个 epoch 的验证准确率。
    “”"
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels),
)

9、模型预测

plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
# 输出:[-0.0146451 0.11037247 -0.01110678 0.03087252 -0.02923543 -0.10968889 -0.00841374 0.04551534  -0.02969249 -0.00869128]

输出:
在这里插入图片描述

10、完整代码

# 1.设置GPU
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0], "GPU")print(gpus)# 2.导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()# 3.归一化
train_images, test_images = train_images / 255.0, test_images / 255.0print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 4.可视化
plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(2,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()# 5.调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 6.构建CNN网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10)
])
print(model.summary())# 7.编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 8.训练模型
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels))# 9.模型预测
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])

相关文章:

【TensorFlow】T1:实现mnist手写数字识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 1、设置GPU import tensorflow as tf gpus tf.config.list_physical_devices("GPU")if gpus:gpu0 gpus[0]tf.config.experimental.set_memory_g…...

【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据

说明: 该专栏之前的文章中python脚本使用的是ArcMap10.6自带的arcpy(好几年前的文章),从本篇开始使用的是ArcGIS Pro 3.3.2版本自带的arcpy,需要注意不同版本对应的arcpy函数是存在差异的 数据准备:准备一…...

动静态库的学习

动静态库中,不需要包含main函数 文件分为内存级(被打开的)文件和磁盘级文件 库 每个程序都要依赖很多基础的底层库,本质上来说库是一种可执行代码的二进制形式,可以被载入内存执行 静态库 linux .a windows .lib 动态库 linux .…...

Rapidjson 实战

Rapidjson 是一款 C 的 json 库. 支持处理 json 格式的文档. 其设计风格是头文件库, 包含头文件即可使用, 小巧轻便并且性能强悍. 本文结合样例来介绍 Rapidjson 一些常见的用法. 环境要求 有如何的几种方法可以将 Rapidjson 集成到您的项目中. Vcpkg安装: 使用 vcpkg instal…...

DeepSeek的多模态AI模型-Janus-pro,可生图,可读图

简介 Janus-Pro 是由 DeepSeek 开发的一款多模态理解与生成模型,是 Janus 模型的升级版。它能够同时处理文本和图像,既能理解图像内容,又能根据文本描述生成高质量图像。Janus-Pro 的核心目标是通过解耦视觉编码路径,解决多模态理…...

Python爬虫实战:一键采集电商数据,掌握市场动态!

电商数据分析是个香饽饽,可市面上的数据采集工具要不贵得吓人,要不就是各种广告弹窗。干脆自己动手写个爬虫,想抓啥抓啥,还能学点技术。今天咱聊聊怎么用Python写个简单的电商数据爬虫。 打好基础:搞定请求头 别看爬虫…...

最短木板长度

最短木板长度 真题目录: 点击去查看 E 卷 100分题型 题目描述 小明有 n 块木板,第 i ( 1 ≤ i ≤ n ) 块木板长度为 ai。 小明买了一块长度为 m 的木料,这块木料可以切割成任意块,拼接到已有的木板上,用来加长木板。 小明想让最…...

【人工智能】掌握图像风格迁移:使用Python实现艺术风格的自动化迁移

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 图像风格迁移(Image Style Transfer)是一种基于深度学习的计算机视觉技术,通过将一张图像的内容与另一张图像的艺术风格结合,生成一幅具…...

git submodules

当代码仓库中包含 .gitmodules 文件时,这意味着该仓库使用了 Git 子模块(Git Submodules)。.gitmodules 文件记录了子模块的相关信息,如子模块的仓库地址、路径等。若要在下载代码时一并同步子模块,可以按照以下几种常…...

7 与mint库对象互转宏(macros.rs)

macros.rs代码定义了一个Rust宏mint_vec,它用于在启用mint特性时,为特定的向量类型实现与mint库中对应类型的相互转换。mint库是一个提供基本数学类型(如点、向量、矩阵等)的Rust库,旨在与多个图形和数学库兼容。这个宏…...

游戏引擎 Unity - Unity 下载与安装

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…...

[openwrt]openwrt slaac only模式下部分终端无法获取到IPv6 DNS

问题描述 OpenWrt 中,如果启用了 RA 单播(ra_unicast),但部分终端无法获取到 DNS 信息 问题分析 RA 单播的局限性 并非所有终端都完全支持通过单播接收 RA 消息。部分终端可能无法正确解析单播 RA 中的 RDNSS(Recursive DNS Server)选项,从而导致无法获取 DNS 信息。终…...

Java 面试真题

本题适合一到三年 Java 开发 ,以下问题都是按照原面试官提问记录 文章目录 我要进大厂系列面试题二面 我要进大厂系列面试题 全部真题,欢迎投稿你的面试经验。 本篇涉及基础较多,但要耐性看完。 JVM内存模型垃圾回收器用的哪个gc各个算法…...

验证工具:GVIM和VIM

一、定义与关系 gVim:gVim是Vim的图形界面版本,提供了更多的图形化功能,如菜单栏、工具栏和鼠标支持。它使得Vim的使用更加直观和方便,尤其对于不习惯命令行界面的用户来说。Vim:Vim是一个在命令行界面下运行的文本编…...

理解 C 与 C++ 中的 const 常量与数组大小的关系

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C语言 文章目录 💯前言💯数组大小的常量要求💯C 语言中的数组大小要求💯C 中的数组大小要求💯为什么 C 中 const 变量可以作为数组大小💯进一步的…...

孟加拉国_行政边界省市边界arcgis数据shp格式wgs84坐标

这篇内容将深入探讨孟加拉国的行政边界省市边界数据,该数据是以arcgis的shp格式提供的,并采用WGS84坐标系统。ArcGIS是一款广泛应用于地理信息系统(GIS)的专业软件,它允许用户处理、分析和展示地理空间数据。在GIS领域…...

安心即美的生活方式

如果你的心是安定的,那么,外界也就安静了。就像陶渊明说的:心远地自偏。不是走到偏远无人的边荒才能得到片刻清净,不需要使用洪荒之力去挣脱生活的枷锁,这是陶渊明式的中国知识分子的雅量。如果你自己是好的男人或女人…...

APT (Advanced Package Tool) 安装与使用-linux014

APT (Advanced Package Tool) APT (Advanced Package Tool) 是一个用于管理 Debian 和 Ubuntu 系列 Linux 发行版上的软件包的工具。它简化了软件的安装、升级、配置和删除过程。APT 为用户提供了一个统一的命令行接口,使得管理和安装软件变得更加简单。 APT 主要…...

深度学习篇---深度学习中的超参数张量转换模型训练

文章目录 前言第一部分:深度学习中的超参数1. 学习率(Learning Rate)定义重要性常见设置 2. 批处理大小(Batch Size)定义重要性常见设置 3. 迭代次数(Number of Epochs)定义重要性常见设置 4. 优…...

Java设计模式:行为型模式→状态模式

Java 状态模式详解 1. 定义 状态模式(State Pattern)是一种行为型设计模式,它允许对象在内部状态改变时改变其行为。状态模式通过将状态需要的行为封装在不同的状态类中,实现对象行为的动态改变。该模式的核心思想是分离不同状态…...

快速幂,错位排序笔记

​ 记一下刚学明白的快速幂和错位排序的原理和代码 快速幂 原理: a^b (a^(b/2)) ^ 2(b为偶数) a^b a*(a^( (b-1)/2))^2(b为奇数) 指数为偶数时…...

机器人基础深度学习基础

参考: (1)【具身抓取课程-1】机器人基础 (2)【具身抓取课程-2】深度学习基础 1 机器人基础 从平面二连杆理解机器人学 正运动学:从关节角度到末端执行器位置的一个映射 逆运动学:已知末端位置…...

Java语法进阶

目录: Object类、常用APICollection、泛型List、Set、数据结构、CollectionsMap与斗地主案例异常、线程线程、同步等待与唤醒案例、线程池、Lambda表达式File类、递归字节流、字符流缓冲流、转换流、序列化流、Files网络编程 十二、函数式接口Stream流、方法引用 一…...

探索 paraphrase-MiniLM-L6-v2 模型在自然语言处理中的应用

在自然语言处理(NLP)领域,将文本数据转换为机器学习模型可以处理的格式是至关重要的。近年来,sentence-transformers 库因其在文本嵌入方面的卓越表现而受到广泛关注。本文将深入探讨 paraphrase-MiniLM-L6-v2 模型,这…...

《chatwise:DeepSeek的界面部署》

ChatWise:DeepSeek的界面部署 摘要 本文详细描述了DeepSeek公司针对其核心业务系统进行的界面部署工作。从需求分析到技术实现,再到测试与优化,全面阐述了整个部署过程中的关键步骤和解决方案。通过本文,读者可以深入了解DeepSee…...

论计算机网络技术专业如何?创新

计算机网络技术专业是顺应数字化时代发展的朝阳专业,前景十分广阔。它聚焦于计算机网络的规划、建设、维护与管理,从基础的网络布线、设备配置,到复杂的网络安全防护、云计算架构搭建,都在专业学习范畴内。该专业毕业生就业面广,可在互联网企业从事网络工程师岗位,负责搭…...

2. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--什么是微服务--微服务概述与演变

在软件架构不断演进的今天,微服务架构已成为许多企业构建现代化应用的首选方案。本文将深入探讨微服务的基本概念、演变历程及其出现的背景和推动因素,同时分析当前微服务在业界的应用现状和未来趋势,为读者提供一个全面的视角,理…...

单节锂电池外部供电自动切换的电路学习

文章目录 前言一、原理分析:①当VBUS处有外部电源输入时②当VBUS处无外部电源输入时 二、器件选择1、二极管2、MOS管3、其他 总结 前言 学习一种广泛应用的锂电池供电自动切换电路 电路存在外部电源时,优先使用外部电源供电,并为电池供电&…...

数据结构-堆和PriorityQueue

1.堆&#xff08;Heap&#xff09; 1.1堆的概念 堆是一种非常重要的数据结构&#xff0c;通常被实现为一种特殊的完全二叉树 如果有一个关键码的集合K{k0,k1,k2,...,kn-1}&#xff0c;把它所有的元素按照完全二叉树的顺序存储在一个一维数组中&#xff0c;如果满足ki<k2i…...

如何打造一个更友好的网站结构?

在SEO优化中&#xff0c;网站的结构往往被忽略&#xff0c;但它其实是决定谷歌爬虫抓取效率的关键因素之一。一个清晰、逻辑合理的网站结构&#xff0c;不仅能让用户更方便地找到他们需要的信息&#xff0c;还能提升搜索引擎的抓取效率 理想的网站结构应该像一棵树&#xff0c;…...