【TensorFlow】T1:实现mnist手写数字识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
1、设置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0],"GPU")print(gpus)
# 输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
固定模板,直接调用即可
2、导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images -- 训练集图片
# train_labels -- 训练集标签
# test_images -- 测试集图片
# test_labels -- 测试机标签
此处所需数据直接联网下载,故无需数据集
3、归一化
将图像数据从0-255缩放到0-1,主要是为了加快训练速度和提高模型的性能,使得模型训练过程更加稳定高效:
- 加速收敛:帮助优化算法更快找到最优解。
- 提升性能:确保不同规模的数据对模型的影响均衡,提高预测准确性。
- 简化调参:使超参数的选择更加简单有效。
- 数值稳定:避免因数据尺度问题导致的计算异常,如梯度爆炸。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
4、可视化
# figure函数--设置画布大小
# figsize--指定了画布的宽度(20英寸)和高度(10英寸)
plt.figure(figsize=(20, 10))# 遍历前20张图片
for i in range(20):# subplot函数--创建子图# 2--2行,10--10列,i+1--第i+1个位置plt.subplot(2, 10, i+1)# xticks函数--隐藏x轴plt.xticks([])# yticks函数--隐藏y轴plt.yticks([])# grid函数--关闭网格线plt.grid(False)# imshow函数--绘制图像# train_images[i]--需要显示的图片,cmap--设置颜色映射(binary)plt.imshow(train_images[i], cmap=plt.cm.binary)# xlabel函数--给图像添加标签plt.xlabel(train_labels[i])plt.show()
输出:
5、调整格式
# reshape函数--只改变数据形状,不改变数据内容
# 6000--6000个样本,28--28*28像素,1--颜色通道数(灰度)
train_images = train_images.reshape((60000, 28, 28, 1))
# 1000--1000个样本,28--28*28像素,1--颜色通道数(灰度)
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)
6、构建CNN网络模型
关键组件:
- 卷积层(Conv2D):主要用于提取图像的特征。它通过滑动窗口(即滤波器或核)在输入图像上移动,计算每个位置上的点积来生成特征图。每个滤波器可以识别特定类型的特征,如边缘或纹理。
- 激活函数(Activation function):如ReLU(Rectified Linear Unit),用于引入非线性因素,使模型能够学习更复杂的模式。
- 最大池化层(MaxPooling2D):用于减小特征图的空间维度(宽度和高度),同时保留最重要的信息。这有助于减少计算量并控制过拟合。
- 展平层(Flatten):将多维的卷积层输出转换为一维向量,以便将其输入到全连接层中。
- 全连接层(Dense):在这里,每个神经元与前一层的所有神经元相连,用于最终的分类任务。最后一层通常不使用激活函数(或者使用softmax函数),以直接输出每个类别的得分。
# Sequential函数--创建了一个Sequential模型,允许按顺序堆叠各层
model = models.Sequential([# Conv2D--第一个二维卷积层# 32--32个滤波器,(3, 3)--每个滤波器的大小为(3, 3),activation0--激活函数(relu),input_shape--输入数据的大小(28*28像素、单通道)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),# MaxPooling2D--第一个二维最大池化层# 2--池化窗口的大小为2*2,用于下采样时减少特征维度layers.MaxPooling2D(2, 2),# Conv2D--第二个二维卷积层layers.Conv2D(64, (3, 3), activation='relu'),# MaxPooling2D--第二个二维最大池化层layers.MaxPooling2D(2, 2),# Flatten--将多维的卷积层输出展平为一维向量,方便输入全连接层layers.Flatten(),# Dense--全连接层# 64--64个节点,activation0--激活函数(relu)layers.Dense(64, activation='relu'),# Dense--输出层# 10--10个节点(此处对应0-9),默认线性激活函数layers.Dense(10)
])# 打印模型结构和参数信息,包括每一层的输出形状和参数数量
print(model.summary())
输出:
7、编译模型
# compile函数--配置模型的编译设置
model.compile(# optimizer--优化器(adam)# Adam:一种基于一阶梯度的优化算法,能自适应地调整不同参数的学习率,适用于大规模数据和高维空间问题optimizer='adam',# loss--损失函数(Sparse Categorical Crossentropy)# Sparse Categorical Crossentropy(稀疏分类交叉熵损失):适用于多分类问题,特别是标签是整数形式时# from_logits=True:表示网络输出未经过softmax激活函数处理。这种情况下,损失函数会自动应用softmax来计算最终的分类概率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# metrics--性能指标# accuracy(准确率):表示预测正确的样本数占总样本数的比例metrics=['accuracy']
)
8、训练模型
- 主要参数:
- x:训练数据的输入(特征)。它可以是一个 Numpy 数组或者一个 TensorFlow 数据集等。对于图像数据,这通常是一个形状为 (样本数量, 高度, 宽度, 颜色通道) 的四维数组。
- y:目标数据(标签)。与 x 对应,表示每个输入样本的真实类别或值。对于分类任务,这通常是一个一维数组,长度等于训练样本的数量;对于多分类问题,可能是一个二维数组,采用 one-hot 编码。
- batch_size:每次梯度更新时使用的样本数。默认值是 32。较大的批量大小可以使计算更高效,但需要更多的内存。
- epochs:整个训练集迭代的次数。每次迭代称为一个 epoch。增加 epoch 数量可以让模型有更多机会学习数据中的模式,但也增加了过拟合的风险。
- verbose:日志显示模式。0 表示不输出日志到屏幕上,1 表示输出进度条记录,2 表示每个epoch输出一行记录。
- validation_data:在每个 epoch 结束后用于评估模型性能的数据集。这是一个元组 (x_val, y_val) 或者一个包含输入和目标数据的列表。这对于监控模型在未见过的数据上的表现非常重要。
- validation_split:从训练数据中划分出一定比例的数据作为验证集,取值范围是 (0, 1)。注意,这个参数只会在 x 是 Numpy 数组时有效。
- shuffle:在每个 epoch 开始前是否打乱训练数据。这对于确保模型不会因为数据顺序而产生偏差非常重要,默认值为 True。
- callbacks:一个列表,其中包含各种回调函数对象,在训练过程中这些回调函数会在特定时间点被调用(如每轮结束、训练开始或结束等),可用于实现提前停止、动态调整学习率等功能。
- 返回值:
- History对象:该对象包含了一个 history 属性,这是一个字典,包含了训练过程中损失值和评估指标的变化情况。可以访问 history.history[‘loss’] 来获取每个 epoch 的训练损失值,或 history.history[‘val_accuracy’] 获取每个 epoch 的验证准确率。
“”"
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels),
)
9、模型预测
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
# 输出:[-0.0146451 0.11037247 -0.01110678 0.03087252 -0.02923543 -0.10968889 -0.00841374 0.04551534 -0.02969249 -0.00869128]
输出:
10、完整代码
# 1.设置GPU
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0], "GPU")print(gpus)# 2.导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()# 3.归一化
train_images, test_images = train_images / 255.0, test_images / 255.0print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 4.可视化
plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(2,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()# 5.调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 6.构建CNN网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10)
])
print(model.summary())# 7.编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 8.训练模型
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels))# 9.模型预测
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
相关文章:
【TensorFlow】T1:实现mnist手写数字识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 1、设置GPU import tensorflow as tf gpus tf.config.list_physical_devices("GPU")if gpus:gpu0 gpus[0]tf.config.experimental.set_memory_g…...
Rapidjson 实战
Rapidjson 是一款 C 的 json 库. 支持处理 json 格式的文档. 其设计风格是头文件库, 包含头文件即可使用, 小巧轻便并且性能强悍. 本文结合样例来介绍 Rapidjson 一些常见的用法. 环境要求 有如何的几种方法可以将 Rapidjson 集成到您的项目中. Vcpkg安装: 使用 vcpkg instal…...
【React】受控组件和非受控组件
目录 受控组件非受控组件基于ref获取DOM元素1、在标签中使用2、在组件中使用 受控组件 表单元素的状态(值)由 React 组件的 state 完全控制。组件的 state 保存了表单元素的值,并且每次用户输入时,React 通过事件处理程序来更新 …...
Ollama+deepseek+Docker+Open WebUI实现与AI聊天
1、下载并安装Ollama 官方网址:Ollama 安装好后,在命令行输入, ollama --version 返回以下信息,则表明安装成功, 2、 下载AI大模型 这里以deepseek-r1:1.5b模型为例, 在命令行中,执行&…...
DEEPSEKK GPT等AI体的出现如何重构工厂数字化架构:从设备控制到ERP MES系统的全面优化
随着深度学习(DeepSeek)、GPT等先进AI技术的出现,工厂的数字化架构正在经历前所未有的变革。AI的强大处理能力、预测能力和自动化决策支持,将大幅度提升生产效率、设备管理、资源调度以及产品质量管理。本文将探讨AI体(…...
阿莱(arri)mxf文件变0字节的恢复方法
阿莱(arri)是专业级的影视产品软硬件供应商,很多影视作品都是使用阿莱(arri)的设备拍摄出来的。总体上来讲阿莱(arri)的文件格式有mov和mxf两种,这次恢复的是阿莱(arri)的mxf,机型是arri mini,素材保存在一个8t的硬盘上,使用的是e…...
初识 Node.js
在当今快速发展的互联网技术领域,Node.js 已经成为了一个非常流行且强大的平台。无论是构建高性能的网络应用、实时协作工具还是微服务架构,Node.js 都展示了其独特的优势。本文将带您走进 Node.js 的世界,了解它的基本概念、核心特性以及如何…...
debug-vscode调试方法
debug - vscode gdb调试指南 文章目录 debug - vscode gdb调试指南前言一、调试代码二、命令查看main反汇编查看寄存器打印某个变量打印寄存器,如pc打印当前函数栈信息(当前执行位置)打印程序栈局部变量x命令的语法如下所示:打印某…...
Cypher进阶(函数、索引)
文章目录 Cypher进阶Aggregationcount()函数统计函数collect()函数 unwindforeachmergeunionload csvcall 函数断言函数all()any()~~exists()~~is not nullnone()single() 标量函数coalesce()startNode()/endNode()id()length()size() 列表函数nodes()keys()range()reduce() 数…...
XML Schema 数值数据类型
XML Schema 数值数据类型 引言 XML Schema 是一种用于描述 XML 文档结构的语言。它定义了 XML 文档中数据的有效性和结构。在 XML Schema 中,数值数据类型是非常重要的一部分,它定义了 XML 文档中可以包含的数值类型。本文将详细介绍 XML Schema 中常用的数值数据类型,以及…...
Window获取界面空闲时间
GetLastInputInfo是一种Windows API函数,用于获取上次输入操作的时间。 该函数通过LASTINPUTINFO结构返回最后一次输入事件的时间。 原型如下 BOOL WINAPI GetLastInputInfo(PLASTINPUTINFO plii);那么可以利用GetLastInputInfo来得到界面没有操作的时长 uint…...
Java进阶(vue基础)
目录 1.vue简单入门 ?1.1.创建一个vue程序 1.2.使用Component模板(组件) 1.3.引入AXOIS ?1.4.vue的Methods(方法) 和?compoted(计算) 1.5.插槽slot 1.6.创建自定义事件? 2.Vue脚手架安装? 3.Element-UI的…...
Mac电脑上好用的压缩软件
在Mac电脑上,有许多优秀的压缩软件可供选择,这些软件不仅支持多种压缩格式,还提供了便捷的操作体验和强大的功能。以下是几款被广泛推荐的压缩软件: BetterZip 功能特点:BetterZip 是一款功能强大的压缩和解压缩工具&a…...
Ubuntn24.04安装
1.镜像下载 https://cn.ubuntu.com/download Ubuntu 24.04.1 (Noble Numbat) 进入下载即可 2.安装系统 打开虚拟机 选择语言 输入用户名和密码 安装ssh 安装完成重启即可。 3.可能出现的问题 关于Ubuntu系统虚拟机出现频繁闪屏,移动和屏幕适应大小问题_vmware安…...
基于ansible部署elk集群
ansible部署 ELK部署 ELK常见架构 (1)ElasticsearchLogstashKibana:这种架构是最常见的一种,也是最简单的一种架构,这种架构通过Logstash收集日志,运用Elasticsearch分析日志,最后通过Kibana中…...
解锁.NET Fiddle:在线编程的神奇之旅
在.NET 开发的广袤领域中,快速验证想法、测试代码片段以及便捷地分享代码是开发者们日常工作中不可或缺的环节。而.NET Fiddle 作为一款卓越的在线神器,正逐渐成为众多.NET 开发者的得力助手。它打破了传统开发模式中对本地开发环境的依赖,让…...
记录pve中使用libvirt创建虚拟机
pve中创建虚拟机 首先在pve网页中创建一个linux虚拟机,我用的是debian系统,过程省略 注意虚拟机cpu类型要设置为host 检查是否支持虚拟化 ssh分别进入pve和debian虚拟机 检查cpu是否支持虚拟化 egrep --color vmx|svm /proc/cpuinfo # 结果高亮显示…...
【HTML性能优化】提升网站加载速度:GZIP、懒加载与资源合并
系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...
三维空间全局光照 | 及各种扫盲
Lecture 6 SH for diffuse transport Lecture 7关于 SH for glossy transport 三维空间全局光照 diffuse case和glossy case的区别 在Lambertian模型中,BRDF是一个常数 diffuse case 跟outgoing point无关 glossy case 跟outgoing point有关 (Gloss…...
数据库开发常识(10.6)——SQL性能判断标准及索引误区(1)
10.6. 数据库开发常识 作为一名专业数据库开发人员,不但需要掌握数据库开发相关的语法和功能实现,还要掌握专业数据库开发的常识。这样,才能在保量完成工作任务的同时,也保质的完成工作任务,避免了为应用的日后维护埋…...
树莓派机械爪项目实战:从硬件连接到Python控制全解析
1. 项目概述:当树莓派遇上机械爪最近在折腾一个挺有意思的小项目,叫Demwunz/openclaw-pi-installation。光看这个名字,就能猜到个大概:这是一个为树莓派(Raspberry Pi)准备的机械爪(Claw&#x…...
基于RP2040与CircuitPython的键盘内嵌DOOM游戏启动器DIY指南
1. 项目概述与核心思路几年前,我还在用笨重的全尺寸键盘时,就总琢磨着怎么给这每天摸上八小时的家伙加点“私货”。直到后来玩起了RP2040和CircuitPython,一个念头就冒出来了:能不能把游戏直接“焊”进键盘里?不是那种…...
火灾动力学模拟实战:如何用FDS构建精准的火灾预测系统
火灾动力学模拟实战:如何用FDS构建精准的火灾预测系统 【免费下载链接】fds Fire Dynamics Simulator 项目地址: https://gitcode.com/gh_mirrors/fd/fds 你是否曾面临这样的困境:当设计一栋大型商业建筑时,如何科学评估火灾时的人员疏…...
CircuitPython硬件交互实战:引脚命名、模块管理与内存优化
1. 项目概述:CircuitPython硬件交互的基石 如果你刚开始接触CircuitPython,或者从Arduino转过来,可能会对如何控制板子上的某个引脚感到困惑。板子上明明印着“A0”、“D13”,但在代码里到底该怎么写? board.A0 和 …...
Redis分布式锁进阶第二十二篇拆解
一、本篇前置衔接 第九十二篇我们完成Redisson源码拆解、手写复刻、底层内核穿透,彻底明白分布式锁代码层、脚本层、线程层原理。到此为止,代码、源码、坑点、运维、监控、面试全部讲透。但很多开发最大的困惑依旧存在:不同体量公司为什么锁架…...
AI团队协作镜像:Docker容器化实现环境一致性与高效复现
1. 项目概述:从开源镜像到AI协作平台的深度解构最近在GitHub上看到一个名为“team9ai/team9”的仓库,这个看似简单的镜像名背后,其实隐藏着一个非常典型的现代AI项目协作范式。它不是某个单一的算法模型,也不是一个孤立的工具&…...
AI驱动Figma设计自动化:Claude插件实现自然语言到UI生成
1. 项目概述:当设计工具遇上AI助手最近在和一些资深UI/UX设计师朋友交流时,大家不约而同地提到了一个痛点:在Figma这类设计工具里,从概念到高保真原型的转化过程,依然充满了大量重复、机械的劳动。比如,我需…...
ARM虚拟化中VTCR寄存器详解与地址转换优化
1. VTCR寄存器概述与虚拟化地址转换背景在ARM架构的虚拟化环境中,内存管理单元(MMU)通过两阶段地址转换机制实现虚拟机内存隔离。VTCR(Virtualization Translation Control Register)作为第二阶段地址转换的核心控制寄…...
用1.44寸ST7735 TFT屏DIY一个桌面天气站(附STM32/Arduino完整项目代码)
用1.44寸ST7735 TFT屏打造智能桌面天气站(STM32/Arduino全流程实战) 在创客圈里,能够实时显示天气信息的桌面小设备一直备受青睐。本文将带你从零开始,利用常见的1.44寸ST7735 TFT屏幕,构建一个功能完善的智能天气站。…...
Linux系统操作痕迹清理:Shell脚本实现与安全运维实践
1. 项目概述与核心价值在Linux系统上进行日常运维、故障排查或者一些自动化任务时,我们执行的每一条命令、访问的每一个文件,甚至系统本身的运行状态,都会留下或多或少的“痕迹”。这些痕迹,对于系统审计和安全分析来说是宝贵的日…...


