当前位置: 首页 > article >正文

# 手写数字识别:使用PyTorch构建MNIST分类器

手写数字识别:使用PyTorch构建MNIST分类器

在这篇文章中,我将引导你通过使用PyTorch框架构建一个简单的神经网络模型,用于识别MNIST数据集中的手写数字。MNIST数据集是一个经典的机器学习数据集,包含了60,000张训练图像和10,000张测试图像,每张图像都是28x28像素的灰度手写数字。
在这里插入图片描述

在这里插入图片描述

环境准备

首先,确保你的环境中安装了PyTorch和torchvision。可以通过以下命令安装:

pip install torch torchvision

数据加载与预处理

我们首先加载MNIST数据集,并将图像转换为PyTorch张量格式,以便模型可以处理。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST( #跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击 training_data:Datasetroot="data",#表示下载的手写数字 到哪个路径。60000train=True, #读取下载后的数据 中的 训练集download=True,#如果你之前已经下载过了,就不用再下载transform=ToTensor(), #张量,图片是不能直接传入神经网络模型
)   #对于pytorch库能够识别的数据一般是tensor张量。'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
print(len(training_data))

数据可视化

为了更好地理解数据,我们可以展示一些手写数字图像。

''展示手写字图片,把训练数据集中的前59000张图片展示一下'''from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3, 3, i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off") # plt.show(I)#是示矢量,plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

创建DataLoader

为了高效地加载数据,我们使用DataLoader来批量加载数据。

# '"创建数据DataLoader(数据加载器)开'
#  'batch_size:将数据集分成多份,每一份为batch_size个数据'
#  '优点:可以减少内存的使用,提高训练速度。train_dataloader = DataLoader(training_data, batch_size=64) #64张图片为一个包,train_dataloader:<torch
test_dataloader = DataLoader(test_data, batch_size=64)

模型定义

接下来,我们定义一个简单的神经网络模型,包含两个隐藏层和一个输出层。

'''定义神经网络类的继承这种方式'''
class NeuralNetwork(nn.Module):  #通过调用类的形式来使用神经网络,神经网络的模型,nn.moduledef __init__(self): #python基础关于类,self类自已本身super().__init__() #继承的父类初始化self.flatten = nn.Flatten() #展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28*28, 128 ) #第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x) #图像进行展开x = self.hidden1(x)x = torch.relu(x) #激活函数,torch使用的relu函数 relu,tanhx = self.hidden2(x)x = torch.relu(x)x = self.out(x)return xmodel = NeuralNetwork().to(device) #把刚刚创建的模型传入到Gpu
print(model)

训练与测试

我们定义训练和测试函数,使用交叉熵损失函数和随机梯度下降优化器。

def train(dataloader, model, loss_fn, optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
# #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()batch_size_num = 1for X, y in dataloader: #其中batch为每一个数据的编号X, y = X.to(device), y.to(device) #把训练数据集和标签传入cpu或GPUpred = model.forward(X) #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化wloss= loss_fn(pred, y) #通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() #梯度值清零loss.backward() #反向传播计算得到每个参数的梯度值woptimizer.step() #根据梯度更新网络w参数loss_value = loss.item() #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset) #10000num_batches = len(dataloader) #打包的数量model.eval() #测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad(): #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  #test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能来衡量模型测试的好坏。correct /= size #平均的正确率print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

训练模型

最后,我们训练模型并测试其性能。

loss_fn = nn.CrossEntropyLoss() #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()# #lr:learning_rate学习率,也就是步长#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

运行结果

在这里插入图片描述

结论

通过这篇文章,我们成功构建了一个简单的神经网络模型来识别MNIST数据集中的手写数字。这个模型展示了如何使用PyTorch进行数据处理、模型定义、训练和测试。希望这能帮助你开始自己的深度学习项目!

相关文章:

# 手写数字识别:使用PyTorch构建MNIST分类器

手写数字识别&#xff1a;使用PyTorch构建MNIST分类器 在这篇文章中&#xff0c;我将引导你通过使用PyTorch框架构建一个简单的神经网络模型&#xff0c;用于识别MNIST数据集中的手写数字。MNIST数据集是一个经典的机器学习数据集&#xff0c;包含了60,000张训练图像和10,000张…...

扩展虚拟机磁盘空间并使其在Linux系统中可用的步骤总结

VMware在虚拟机扩展空间时&#xff0c;若想扩展到150G&#xff0c;那么所在盘的空闲空间须大于150G&#xff0c;否则VM将不允许扩展。 1&#xff1a;确认新磁盘空间是否被识别 使用 lsblk 或 fdisk -l 命令检查 /dev/sda 的大小是否已经更新到新的容量&#xff08;例如从原来的…...

A股周度复盘与下周策略 的deepseek提示词模板

以下是反向整理的股票大盘分析提示词模板&#xff0c;采用结构化框架数据占位符设计&#xff0c;可直接套用每周市场数据&#xff1a; 请根据一下markdown格式的模板&#xff0c;帮我检索整理并输出本周股市复盘和下周投资策略 【A股周度复盘与下周策略提示词模板】 一、市场…...

dev_set_drvdata、dev_get_drvdata使用详解

在Linux内核驱动开发中&#xff0c;dev_set_drvdata() 及相关函数用于管理设备驱动的私有数据&#xff0c;是模块化设计和数据隔离的核心工具。以下从函数定义、使用场景、示例及注意事项等方面进行详细解析&#xff1a; 一、函数定义与作用 核心函数 dev_set_drvdata() 和 dev…...

数据驱动未来:大数据在智能网联汽车中的深度应用

数据驱动未来:大数据在智能网联汽车中的深度应用 引言 随着智能网联汽车(Intelligent Connected Vehicles,ICV)的快速发展,数据已成为其核心驱动力。从实时交通数据到车辆传感器信息,大数据的深度应用正在让智能汽车更安全、更高效、更智能化。那么,大数据如何赋能智能…...

LeetCode:DFS综合练习

简单 1863. 找出所有子集的异或总和再求和 一个数组的 异或总和 定义为数组中所有元素按位 XOR 的结果&#xff1b;如果数组为 空 &#xff0c;则异或总和为 0 。 例如&#xff0c;数组 [2,5,6] 的 异或总和 为 2 XOR 5 XOR 6 1 。 给你一个数组 nums &#xff0c;请你求出 n…...

Perf学习

重要的能解决的问题是这些&#xff1a; perf_events is an event-oriented observability tool, which can help you solve advanced performance and troubleshooting functions. Questions that can be answered include: Why is the kernel on-CPU so much? What code-pa…...

齐次坐标变换+Unity矩阵变换

矩阵变换 变换&#xff08;transform)&#xff1a;指的是我们把一些数据&#xff0c;如点&#xff0c;方向向量甚至是颜色&#xff0c;通过某种方式&#xff08;矩阵运算&#xff09;&#xff0c;进行转换的过程。 变换类型 线性变换&#xff1a;保留矢量加和标量乘的计算 f(x)…...

Pandas取代Excel?

有人在知乎上提问&#xff1a;为什么大公司不用pandas取代excel&#xff1f; 而且列出了几个理由&#xff1a;Pandas功能比Excel强大&#xff0c;运行速度更快&#xff0c;Excel除了简单和可视化界面外&#xff0c;没有其他更多的优势。 有个可怕的现实是&#xff0c;对比Exce…...

启动vite项目报Unexpected “\x88“ in JSON

启动vite项目报Unexpected “\x88” in JSON 通常是文件被防火墙加密需要寻找运维解决 重启重装npm install...

HTTP测试智能化升级:动态变量管理实战与效能跃迁

在Web应用、API接口测试等领域&#xff0c;测试场景的动态性和复杂性对测试数据的灵活管理提出了极高要求。传统的静态测试数据难以满足多用户并发、参数化请求及响应内容验证等需求。例如&#xff0c;在电商系统性能测试中&#xff0c;若无法动态生成用户ID、订单号或实时提取…...

关于一对多关系(即E-R图中1:n)中的界面展示优化和数据库设计

前言 一对多&#xff0c;是常见的数据库关系。在界面设计时&#xff0c;有时为了方便&#xff0c;就展示成逗号分割的字符串。例如&#xff1a;学生和爱好的界面。 存储 如果是简单存储&#xff0c;建立数据库&#xff1a;爱好&#xff0c;课程&#xff0c;存在一张表中。 但…...

【gpt生成-总览】怎样才算开发了一门编程语言,需要通过什么测试

开发一门真正的编程语言需要经历完整的设计、实现和验证过程&#xff0c;并通过系统的测试体系验证其完备性。以下是分阶段开发标准及测试方法&#xff1a; 一、语言开发核心阶段 1. 语言规范设计&#xff08;ISO/IEC 标准级别&#xff09; ​​语法规范​​&#xff1a;BNF/…...

JVM笔记【一】java和Tomcat类加载机制

JVM笔记一java和Tomcat类加载机制 java和Tomcat类加载机制 Java类加载 * loadClass加载步骤类加载机制类加载器初始化过程双亲委派机制全盘负责委托机制类关系图自定义类加载器打破双亲委派机制 Tomcat类加载器 * 为了解决以上问题&#xff0c;tomcat是如何实现类加载机制的…...

React 组件类型详解:类组件 vs. 函数组件

React 是一个用于构建用户界面的 JavaScript 库&#xff0c;其核心思想是组件化开发。React 组件可以分为类组件&#xff08;Class Components&#xff09;和函数组件&#xff08;Function Components&#xff09;&#xff0c;它们在设计理念、使用方式和适用场景上有所不同。随…...

GPT-SoVITS 使用指南

一、简介 TTS&#xff08;Text-to-Speech&#xff0c;文本转语音&#xff09;&#xff1a;是一种将文字转换为自然语音的技术&#xff0c;通过算法生成人类可听的语音输出&#xff0c;广泛应用于语音助手、无障碍服务、导航系统等场景。类似的还有SVC&#xff08;歌声转换&…...

美信监控易:数据采集与整合的卓越之选

在当今复杂多变的运维环境中&#xff0c;一款具备强大数据采集与整合能力的运维管理软件对于企业的稳定运行和高效决策至关重要。美信监控易正是这样一款在数据采集与整合方面展现出显著优势的软件&#xff0c;以下是它的一些关键技术优势&#xff0c;值得每一个运维团队深入了…...

基于Redis的3种分布式ID生成策略

在分布式系统设计中&#xff0c;全局唯一ID是一个基础而关键的组件。随着业务规模扩大和系统架构向微服务演进&#xff0c;传统的单机自增ID已无法满足需求。高并发、高可用的分布式ID生成方案成为构建可靠分布式系统的必要条件。 Redis具备高性能、原子操作及简单易用的特性&…...

OCR技术与视觉模型技术的区别、应用及展望

在计算机视觉技术飞速发展的当下&#xff0c;OCR技术与视觉模型技术成为推动各行业智能化变革的重要力量。它们在原理、应用等方面存在诸多差异&#xff0c;在自动化测试领域也展现出不同的表现与潜力&#xff0c;下面将为你详细剖析。 一、技术区别 &#xff08;一&#xff…...

End-to-End从混沌到秩序:基于LLM的Pipeline将非结构化数据转化为知识图谱

摘要:本文介绍了一种将非结构化数据转换为知识图谱的端到端方法。通过使用大型语言模型(LLM)和一系列数据处理技术,我们能够从原始文本中自动提取结构化的知识。这一过程包括文本分块、LLM 提示设计、三元组提取、归一化与去重,最终利用 NetworkX 和 ipycytoscape 构建并可…...

比特币的跨输入签名聚合(Cross-Input Signature Aggregation,CISA)

1. 引言 2024 年&#xff0c;人权基金会&#xff08;Human Rights Foundation&#xff0c;简称 HRF&#xff09;启动了一项研究奖学金计划&#xff0c;旨在探讨“跨输入签名聚合”&#xff08;Cross-Input Signature Aggregation&#xff0c;简称 CISA&#xff09;的潜在影响。…...

洛谷P1177【模板】排序:十种排序算法全解(2)

我们接着上一篇继续讲【洛谷P1177【模板】排序&#xff1a;十种排序算法全解(1)】 三、计数排序&#xff08;Counting Sort&#xff09; ‌仅适用于数据范围较小的情况‌ // Java import java.io.*; public class Main {static final int OFFSET 100000;public static void…...

MySql 三大日志(redolog、undolog、binlog)详解

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/aa730ab3f84049638f6c9a785e6e51e9.png 1. redo log&#xff1a;“你他妈别丢数据啊&#xff01;” 干啥的&#xff1f; 这货是InnoDB的“紧急备忘录”。比如你改了一条数据&#xff0c;MySQL怕自己突然断电嗝屁了&am…...

Docker使用、容器迁移

Docker 简介 Docker 是一个开源的容器化平台&#xff0c;用于打包、部署和运行应用程序及其依赖环境。Docker 容器是轻量级的虚拟化单元&#xff0c;运行在宿主机操作系统上&#xff0c;通过隔离机制&#xff08;如命名空间和控制组&#xff09;确保应用运行环境的一致性和可移…...

HTTP:九.WEB机器人

概念 Web机器人是能够在无需人类干预的情况下自动进行一系列Web事务处理的软件程序。人们根据这些机器人探查web站点的方式,形象的给它们取了一个饱含特色的名字,比如“爬虫”、“蜘蛛”、“蠕虫”以及“机器人”等!爬虫概述 网络爬虫(英语:web crawler),也叫网络蜘蛛(…...

2025妈妈杯数学建模C题完整分析论文(共36页)(含模型建立、可运行代码、数据)

2025 年第十五届 MathorCup 数学建模C题完整分析论文 目录 摘 要 一、问题分析 二、问题重述 三、模型假设 四、 模型建立与求解 4.1问题1 4.1.1问题1思路分析 4.1.2问题1模型建立 4.1.3问题1代码&#xff08;仅供参考&#xff09; 4.1.4问题1求解结果&#xff08;仅…...

数据结构排序算法全解析:从基础原理到实战应用

在计算机科学领域&#xff0c;排序算法是数据处理的核心技术之一。无论是小规模数据的简单整理&#xff0c;还是大规模数据的高效处理&#xff0c;选择合适的排序算法直接影响着程序的性能。本文将深入解析常见排序算法的核心思想、实现细节、特性对比及适用场景&#xff0c;帮…...

UMG:ListView

1.创建WBP_ListView,添加Border和ListView。 2.创建Object,命名为Item(数据载体&#xff0c;可以是其他类型)。新增变量name。 3.创建User Widget&#xff0c;命名为Entry(循环使用的UI载体).添加Border和Text。 4.设置Entry继承UserObjectListEntry接口。 5.Entry中对象生成时…...

每天学一个 Linux 命令(18):mv

​​可访问网站查看&#xff0c;视觉品味拉满&#xff1a; http://www.616vip.cn/18/index.html 每天学一个 Linux 命令&#xff08;18&#xff09;&#xff1a;mv 命令功能 mv&#xff08;全称&#xff1a;move&#xff09;用于移动文件/目录或重命名文件/目录&#xff0c;是…...

ubuntu24.04上使用qemu和buildroot模拟vexpress-ca9开发板构建嵌入式arm linux环境

1 准备工作 1.1 安装qemu 在ubuntu系统中使用以下命令安装qemu。 sudo apt install qemu-system-arm 安装完毕后&#xff0c;在终端输入: qemu- 后按TAB键&#xff0c;弹出下列命令证明安装成功。 1.2 安装arm交叉编译工具链 sudo apt install gcc-arm-linux-gnueabihf 安装之…...