当前位置: 首页 > news >正文

【TensorFlow】T1:实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

1、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0],"GPU")print(gpus)
# 输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

固定模板,直接调用即可

2、导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images -- 训练集图片
# train_labels -- 训练集标签
# test_images  -- 测试集图片
# test_labels  -- 测试机标签

此处所需数据直接联网下载,故无需数据集

3、归一化

将图像数据从0-255缩放到0-1,主要是为了加快训练速度和提高模型的性能,使得模型训练过程更加稳定高效:

  1. 加速收敛:帮助优化算法更快找到最优解。
  2. 提升性能:确保不同规模的数据对模型的影响均衡,提高预测准确性。
  3. 简化调参:使超参数的选择更加简单有效。
  4. 数值稳定:避免因数据尺度问题导致的计算异常,如梯度爆炸。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

4、可视化

# figure函数--设置画布大小
# figsize--指定了画布的宽度(20英寸)和高度(10英寸)
plt.figure(figsize=(20, 10))# 遍历前20张图片
for i in range(20):# subplot函数--创建子图# 2--2行,10--10列,i+1--第i+1个位置plt.subplot(2, 10, i+1)# xticks函数--隐藏x轴plt.xticks([])# yticks函数--隐藏y轴plt.yticks([])# grid函数--关闭网格线plt.grid(False)# imshow函数--绘制图像# train_images[i]--需要显示的图片,cmap--设置颜色映射(binary)plt.imshow(train_images[i], cmap=plt.cm.binary)# xlabel函数--给图像添加标签plt.xlabel(train_labels[i])plt.show()

输出:
在这里插入图片描述

5、调整格式

# reshape函数--只改变数据形状,不改变数据内容
# 6000--6000个样本,28--28*28像素,1--颜色通道数(灰度)
train_images = train_images.reshape((60000, 28, 28, 1))
# 1000--1000个样本,28--28*28像素,1--颜色通道数(灰度)
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)

6、构建CNN网络模型

关键组件:

  1. 卷积层(Conv2D):主要用于提取图像的特征。它通过滑动窗口(即滤波器或核)在输入图像上移动,计算每个位置上的点积来生成特征图。每个滤波器可以识别特定类型的特征,如边缘或纹理。
  2. 激活函数(Activation function):如ReLU(Rectified Linear Unit),用于引入非线性因素,使模型能够学习更复杂的模式。
  3. 最大池化层(MaxPooling2D):用于减小特征图的空间维度(宽度和高度),同时保留最重要的信息。这有助于减少计算量并控制过拟合。
  4. 展平层(Flatten)将多维的卷积层输出转换为一维向量,以便将其输入到全连接层中。
  5. 全连接层(Dense):在这里,每个神经元与前一层的所有神经元相连,用于最终的分类任务。最后一层通常不使用激活函数(或者使用softmax函数),以直接输出每个类别的得分。
# Sequential函数--创建了一个Sequential模型,允许按顺序堆叠各层
model = models.Sequential([# Conv2D--第一个二维卷积层# 32--32个滤波器,(3, 3)--每个滤波器的大小为(3, 3),activation0--激活函数(relu),input_shape--输入数据的大小(28*28像素、单通道)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),# MaxPooling2D--第一个二维最大池化层# 2--池化窗口的大小为2*2,用于下采样时减少特征维度layers.MaxPooling2D(2, 2),# Conv2D--第二个二维卷积层layers.Conv2D(64, (3, 3), activation='relu'),# MaxPooling2D--第二个二维最大池化层layers.MaxPooling2D(2, 2),# Flatten--将多维的卷积层输出展平为一维向量,方便输入全连接层layers.Flatten(),# Dense--全连接层# 64--64个节点,activation0--激活函数(relu)layers.Dense(64, activation='relu'),# Dense--输出层# 10--10个节点(此处对应0-9),默认线性激活函数layers.Dense(10)
])# 打印模型结构和参数信息,包括每一层的输出形状和参数数量
print(model.summary())

输出:
在这里插入图片描述

7、编译模型

# compile函数--配置模型的编译设置
model.compile(# optimizer--优化器(adam)# Adam:一种基于一阶梯度的优化算法,能自适应地调整不同参数的学习率,适用于大规模数据和高维空间问题optimizer='adam',# loss--损失函数(Sparse Categorical Crossentropy)# Sparse Categorical Crossentropy(稀疏分类交叉熵损失):适用于多分类问题,特别是标签是整数形式时# from_logits=True:表示网络输出未经过softmax激活函数处理。这种情况下,损失函数会自动应用softmax来计算最终的分类概率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# metrics--性能指标# accuracy(准确率):表示预测正确的样本数占总样本数的比例metrics=['accuracy']
)

8、训练模型

  1. 主要参数:
  • x训练数据的输入(特征)。它可以是一个 Numpy 数组或者一个 TensorFlow 数据集等。对于图像数据,这通常是一个形状为 (样本数量, 高度, 宽度, 颜色通道) 的四维数组。
  • y目标数据(标签)。与 x 对应,表示每个输入样本的真实类别或值。对于分类任务,这通常是一个一维数组,长度等于训练样本的数量;对于多分类问题,可能是一个二维数组,采用 one-hot 编码。
  • batch_size每次梯度更新时使用的样本数。默认值是 32。较大的批量大小可以使计算更高效,但需要更多的内存。
  • epochs整个训练集迭代的次数。每次迭代称为一个 epoch。增加 epoch 数量可以让模型有更多机会学习数据中的模式,但也增加了过拟合的风险。
  • verbose:日志显示模式。0 表示不输出日志到屏幕上,1 表示输出进度条记录,2 表示每个epoch输出一行记录。
  • validation_data:在每个 epoch 结束后用于评估模型性能的数据集。这是一个元组 (x_val, y_val) 或者一个包含输入和目标数据的列表。这对于监控模型在未见过的数据上的表现非常重要。
  • validation_split:从训练数据中划分出一定比例的数据作为验证集,取值范围是 (0, 1)。注意,这个参数只会在 x 是 Numpy 数组时有效。
  • shuffle:在每个 epoch 开始前是否打乱训练数据。这对于确保模型不会因为数据顺序而产生偏差非常重要,默认值为 True。
  • callbacks:一个列表,其中包含各种回调函数对象,在训练过程中这些回调函数会在特定时间点被调用(如每轮结束、训练开始或结束等),可用于实现提前停止、动态调整学习率等功能。
  1. 返回值:
  • History对象:该对象包含了一个 history 属性,这是一个字典,包含了训练过程中损失值评估指标的变化情况。可以访问 history.history[‘loss’] 来获取每个 epoch 的训练损失值,或 history.history[‘val_accuracy’] 获取每个 epoch 的验证准确率。
    “”"
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels),
)

9、模型预测

plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
# 输出:[-0.0146451 0.11037247 -0.01110678 0.03087252 -0.02923543 -0.10968889 -0.00841374 0.04551534  -0.02969249 -0.00869128]

输出:
在这里插入图片描述

10、完整代码

# 1.设置GPU
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0], "GPU")print(gpus)# 2.导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()# 3.归一化
train_images, test_images = train_images / 255.0, test_images / 255.0print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 4.可视化
plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(2,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()# 5.调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 6.构建CNN网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10)
])
print(model.summary())# 7.编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 8.训练模型
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels))# 9.模型预测
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])

相关文章:

【TensorFlow】T1:实现mnist手写数字识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 1、设置GPU import tensorflow as tf gpus tf.config.list_physical_devices("GPU")if gpus:gpu0 gpus[0]tf.config.experimental.set_memory_g…...

Rapidjson 实战

Rapidjson 是一款 C 的 json 库. 支持处理 json 格式的文档. 其设计风格是头文件库, 包含头文件即可使用, 小巧轻便并且性能强悍. 本文结合样例来介绍 Rapidjson 一些常见的用法. 环境要求 有如何的几种方法可以将 Rapidjson 集成到您的项目中. Vcpkg安装: 使用 vcpkg instal…...

【React】受控组件和非受控组件

目录 受控组件非受控组件基于ref获取DOM元素1、在标签中使用2、在组件中使用 受控组件 表单元素的状态(值)由 React 组件的 state 完全控制。组件的 state 保存了表单元素的值,并且每次用户输入时,React 通过事件处理程序来更新 …...

Ollama+deepseek+Docker+Open WebUI实现与AI聊天

1、下载并安装Ollama 官方网址:Ollama 安装好后,在命令行输入, ollama --version 返回以下信息,则表明安装成功, 2、 下载AI大模型 这里以deepseek-r1:1.5b模型为例, 在命令行中,执行&…...

DEEPSEKK GPT等AI体的出现如何重构工厂数字化架构:从设备控制到ERP MES系统的全面优化

随着深度学习(DeepSeek)、GPT等先进AI技术的出现,工厂的数字化架构正在经历前所未有的变革。AI的强大处理能力、预测能力和自动化决策支持,将大幅度提升生产效率、设备管理、资源调度以及产品质量管理。本文将探讨AI体&#xff08…...

阿莱(arri)mxf文件变0字节的恢复方法

阿莱(arri)是专业级的影视产品软硬件供应商,很多影视作品都是使用阿莱(arri)的设备拍摄出来的。总体上来讲阿莱(arri)的文件格式有mov和mxf两种,这次恢复的是阿莱(arri)的mxf,机型是arri mini,素材保存在一个8t的硬盘上,使用的是e…...

初识 Node.js

在当今快速发展的互联网技术领域,Node.js 已经成为了一个非常流行且强大的平台。无论是构建高性能的网络应用、实时协作工具还是微服务架构,Node.js 都展示了其独特的优势。本文将带您走进 Node.js 的世界,了解它的基本概念、核心特性以及如何…...

debug-vscode调试方法

debug - vscode gdb调试指南 文章目录 debug - vscode gdb调试指南前言一、调试代码二、命令查看main反汇编查看寄存器打印某个变量打印寄存器,如pc打印当前函数栈信息(当前执行位置)打印程序栈局部变量x命令的语法如下所示:打印某…...

Cypher进阶(函数、索引)

文章目录 Cypher进阶Aggregationcount()函数统计函数collect()函数 unwindforeachmergeunionload csvcall 函数断言函数all()any()~~exists()~~is not nullnone()single() 标量函数coalesce()startNode()/endNode()id()length()size() 列表函数nodes()keys()range()reduce() 数…...

XML Schema 数值数据类型

XML Schema 数值数据类型 引言 XML Schema 是一种用于描述 XML 文档结构的语言。它定义了 XML 文档中数据的有效性和结构。在 XML Schema 中,数值数据类型是非常重要的一部分,它定义了 XML 文档中可以包含的数值类型。本文将详细介绍 XML Schema 中常用的数值数据类型,以及…...

Window获取界面空闲时间

‌GetLastInputInfo‌是一种Windows API函数,用于获取上次输入操作的时间。 该函数通过LASTINPUTINFO结构返回最后一次输入事件的时间。 原型如下 BOOL WINAPI GetLastInputInfo(PLASTINPUTINFO plii);那么可以利用GetLastInputInfo来得到界面没有操作的时长 uint…...

Java进阶(vue基础)

目录 1.vue简单入门 ?1.1.创建一个vue程序 1.2.使用Component模板(组件) 1.3.引入AXOIS ?1.4.vue的Methods(方法) 和?compoted(计算) 1.5.插槽slot 1.6.创建自定义事件? 2.Vue脚手架安装? 3.Element-UI的…...

Mac电脑上好用的压缩软件

在Mac电脑上,有许多优秀的压缩软件可供选择,这些软件不仅支持多种压缩格式,还提供了便捷的操作体验和强大的功能。以下是几款被广泛推荐的压缩软件: BetterZip 功能特点:BetterZip 是一款功能强大的压缩和解压缩工具&a…...

Ubuntn24.04安装

1.镜像下载 https://cn.ubuntu.com/download Ubuntu 24.04.1 (Noble Numbat) 进入下载即可 2.安装系统 打开虚拟机 选择语言 输入用户名和密码 安装ssh 安装完成重启即可。 3.可能出现的问题 关于Ubuntu系统虚拟机出现频繁闪屏,移动和屏幕适应大小问题_vmware安…...

基于ansible部署elk集群

ansible部署 ELK部署 ELK常见架构 (1)ElasticsearchLogstashKibana:这种架构是最常见的一种,也是最简单的一种架构,这种架构通过Logstash收集日志,运用Elasticsearch分析日志,最后通过Kibana中…...

解锁.NET Fiddle:在线编程的神奇之旅

在.NET 开发的广袤领域中,快速验证想法、测试代码片段以及便捷地分享代码是开发者们日常工作中不可或缺的环节。而.NET Fiddle 作为一款卓越的在线神器,正逐渐成为众多.NET 开发者的得力助手。它打破了传统开发模式中对本地开发环境的依赖,让…...

记录pve中使用libvirt创建虚拟机

pve中创建虚拟机 首先在pve网页中创建一个linux虚拟机,我用的是debian系统,过程省略 注意虚拟机cpu类型要设置为host 检查是否支持虚拟化 ssh分别进入pve和debian虚拟机 检查cpu是否支持虚拟化 egrep --color vmx|svm /proc/cpuinfo # 结果高亮显示…...

【HTML性能优化】提升网站加载速度:GZIP、懒加载与资源合并

系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...

三维空间全局光照 | 及各种扫盲

Lecture 6 SH for diffuse transport Lecture 7关于 SH for glossy transport 三维空间全局光照 diffuse case和glossy case的区别 在Lambertian模型中,BRDF是一个常数 diffuse case 跟outgoing point无关 glossy case 跟outgoing point有关 (Gloss…...

数据库开发常识(10.6)——SQL性能判断标准及索引误区(1)

10.6. 数据库开发常识 作为一名专业数据库开发人员,不但需要掌握数据库开发相关的语法和功能实现,还要掌握专业数据库开发的常识。这样,才能在保量完成工作任务的同时,也保质的完成工作任务,避免了为应用的日后维护埋…...

Phi-4-mini-reasoning科研协作:Jupyter Notebook嵌入式推理插件

Phi-4-mini-reasoning科研协作:Jupyter Notebook嵌入式推理插件 1. 模型简介 Phi-4-mini-reasoning是一个基于合成数据构建的轻量级开源模型,专注于高质量、密集推理的数据处理能力。作为Phi-4模型家族的一员,它经过专门微调以提升数学推理…...

肿瘤免疫微环境解析:8大免疫浸润工具实战指南

1. 肿瘤免疫微环境分析的核心价值 当你拿到一份肿瘤样本的转录组数据时,最令人兴奋的莫过于揭开它的免疫面纱——那些隐藏在肿瘤组织中的免疫细胞究竟在做什么?这就是免疫浸润分析的价值所在。想象一下,肿瘤组织就像一座复杂的城市&#xff0…...

区块链+AI的致命组合:深扒某DeFi项目的测试黑幕

在数字经济浪潮中,区块链与人工智能(AI)的融合被视为金融创新的“致命组合”,尤其在去中心化金融(DeFi)领域,它承诺了前所未有的效率和智能决策能力。然而,这一组合也带来了隐蔽的测…...

解锁AI编程新范式:Continue插件的颠覆性开发体验

解锁AI编程新范式:Continue插件的颠覆性开发体验 【免费下载链接】continue ⏩ Source-controlled AI checks, enforceable in CI. Powered by the open-source Continue CLI 项目地址: https://gitcode.com/GitHub_Trending/co/continue 你是否曾在深夜调试…...

ALM扩展开发教程:如何为TypeScript IDE创建自定义插件

ALM扩展开发教程:如何为TypeScript IDE创建自定义插件 【免费下载链接】alm :rose: A :cloud: ready IDE just for TypeScript :heart: 项目地址: https://gitcode.com/gh_mirrors/al/alm ALM是一款专为TypeScript和JavaScript设计的云端IDE,为开…...

智慧园区能源管理系统解决方案

某园区集成生产、办公、生活三大功能,建设有生产厂房、化学品库、辅助用房、气罐站、研发楼、综合楼及其他配套设施,涉及到多种用能,包含电能、天然气、压缩空气、冷热能等,带来日益高昂的能耗成本与能源浪费隐患。 1、制冷空调监…...

Tetrazine-amine HCl salt,CAS:1416711-59-5,四嗪-氨基盐酸盐的描述

Tetrazine-amine HCl salt(四嗪-氨基盐酸盐)是一种结合了四嗪基团和氨基盐酸盐结构的化合物,在化学、生物医药和材料科学等领域具有广泛应用。一、基本信息中文名称:四嗪-氨基盐酸盐英文名称:Tetrazine-amine HCl salt…...

Phi-3-mini-4k-instruct-gguf步骤详解:supervisor服务管理与错误日志定位方法

Phi-3-mini-4k-instruct-gguf步骤详解:supervisor服务管理与错误日志定位方法 1. 模型概述 Phi-3-mini-4k-instruct-gguf是微软Phi-3系列中的轻量级文本生成模型GGUF版本,特别适合问答、文本改写、摘要整理和简短创作等场景。这个开箱即用的解决方案已…...

SketchUp STL开源工具:让3D设计无缝转化为可打印模型的完整方案

SketchUp STL开源工具:让3D设计无缝转化为可打印模型的完整方案 【免费下载链接】sketchup-stl A SketchUp Ruby Extension that adds STL (STereoLithography) file format import and export. 项目地址: https://gitcode.com/gh_mirrors/sk/sketchup-stl 在…...

基于计算机视觉的AI头像质量评估系统

基于计算机视觉的AI头像质量评估系统 1. 引言 在数字社交时代,头像已经成为个人形象的重要代表。无论是社交平台、专业网站还是在线会议,一个高质量的头像都能显著提升个人形象和可信度。然而,如何快速评估头像的质量一直是个难题——什么样…...