当前位置: 首页 > news >正文

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>

目录

一、回顾神经网络框架

1、单层神经网络

2、多层神经网络

二、手写数字识别

1、续接上节课代码,如下所示

2、建立神经网络模型

输出结果:

3、设置训练集

4、设置测试集

5、创建损失函数、优化器

参数解析:

1)params

2)lr

3)loss

6、开始训练

运行结果:

三、总结

1、关键步骤

2、训练过程中

3、训练完成后


一、回顾神经网络框架

1、单层神经网络

        图示为单层神经完了的基本构造,首先,有信号想传入输入层,然后这些信号会在传播途中发生衰减或者增强,假如想传入的信号为x,那么经过衰减或增强后的值则为wx,这个w叫权重,此时便有了传入信号,有很多传入信号,那么神经元就会对这些信号进行处理,把这些信号的加权值求和,再将这个求和的值映射到非线性激活函数上来引入非线性特征,最后得到输出结果。

2、多层神经网络

        相比于上述的单层神经网络,单层神经网络就相当于上图中的绿色的第一列,即多个信号传入后只进行加权映射处理一次即输出结果,而多层神经网络即有多列神经元,传入信号经过第一列神经元处理后得到的值将其再次当做信号进行衰减或增强传入下一层神经元对其进行处理,至于多少层需要经过多次训练去决定,没有固定的值。

二、手写数字识别

1、续接上节课代码,如下所示

import torch
print(torch.__version__)"""MNIST包含70,000张手写数字图像:60,000张用于训练,10,000张用于测试。
图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。"""from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,
from torchvision import datasets   # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor   # 数据转换,张量,将其他类型的数据转换为tensor张量,numpy arrgy,"""下载训练数据集,图片+标签"""
training_data = datasets.MNIST(   # 跳转到函数的内部源代码,pycharm 按下ctrl +鼠标点击root='data',   # 表述下载的数据存放的根目录train=True,   # 表示下载的是训练数据集,如果要下载测试集,更改为False即可download=True,   # 表示如果根目录有该数据,则不再下载,如果没有则下载transform=ToTensor()   # 张量,图片是不能直接传入神经网络模型# 表示制定一个数据转换操作,将下载的图片转换为pytorch张量,因为pytorch只能处理张量tensor类型的数据
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()  # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如 PyTorch、TensorFlo
)  # NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。print(len(training_data))
print(len(test_data))"""展示手写数字图片,把训练数据集中的前59000张图片展示一下"""# # tensor -》numpy  矩阵类型的数据,矩阵是特殊的张量,张量可以包含任意维度的数据
# from matplotlib import pyplot as plt   # 导入绘图库
# figure = plt.figure()   # 设置一个空白画布
# for i in range(9):
#     img,label = training_data[i+59000]   # 提取第59000张图片开始,共9张,返回图片及其对应的标签值
#
#     figure.add_subplot(3,3,i+1)   # 在画布创建3行3列的小窗口,通过遍历的值i来确定每个画布展示的图片
#     plt.title(label)   # 设置每个窗口的标题,设置标签为上述返回的标签值
#     plt.axis('off')   # 取消画布中的坐标轴的图像
#     plt.imshow(img.squeeze(),cmap='gray')   # plt.imshow()将NumPy数组data中的数据显示为图像,并在图形窗口中,
#     a = img.squeeze()   # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1,则张量不会改变。
# plt.show()train_dataloader = DataLoader(training_data,batch_size=64)  # 调用上述定义的DataLoader打包库,将训练集的图片和标签,64张图片为一个包,
test_dataloader = DataLoader(test_data,batch_size=64)   # 将测试集的图片和标签,每64张打包成一份
for x,y in test_dataloader:# x是表示打包好的每一个数据包,其形状为[64,1,28,28],64表示批次大小,1表示通道数为1,即灰度图,28表示图像的宽高像素值# y表示每个图片标签print(f"shape of x[N,C,H,W]:{x.shape}")   # 打印图片形状print(f"shape of y:{y.shape}{y.dtype}")   # 打印标签的形状和数据类型break  # 跳出并终止循环,表示只遍历一个包的数据情况"""判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""  # 返回cuda,mps,cpu, m1,m2集显CPU+GPU RTX3060
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化。CUDA驱动软件的功能:pytorch能够去执行cuda的命令,cuda通过GPU指令集
# 神经网络的模型也需要传入到GPU,1个batchsize的数据集也需要传入到GPU,才可以进行训练。

2、建立神经网络模型

class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型def __init__(self):super().__init__()   # 继承的父类初始化self.flatten = nn.Flatten()   # 将输入的多维数据展开成一维数据,创建一个展开对象flattenself.hidden1 = nn.Linear(28*28,128)  # 创建一个全连接层,其上包含权重的值,第1个参数表示有多少个神经元传入进来,第2个参数表示有多少个数据传出去,这里表示上述多层神经网络的第一列神经元self.hidden2 = nn.Linear(128, 256)  # 再次创建一个全连接层,将第一层输出个数转变成这一层的输入信号数,然后在设定输出层个数为256self.hidden3 = nn.Linear(256,128)  # 继承上一层的输出信号个数,将其当做这一层的输入为256层,设定输出为128层self.hidden4 = nn.Linear(128,64)   # 输入128层,输出64层self.out = nn.Linear(64, 10)   # 设定输出层,输出必需和标签的类别相同,输入必须是上一层的神经元个数def forward(self,x):   # 设定前向传播函数,你得告诉它,数据的流向。是神经网络层连接起来,函数名称不能改。当你调用forward函数的时候,传入进来的图像数据x = self.flatten(x)  # 传入信号x图像,首先进行展开x = self.hidden1(x)  # 将其当做输入层x = torch.sigmoid(x)   # 使用sigmoid激活函数得到结果,torch使用的relu函数 relu,tanhx = self.hidden2(x)    # 再次对上一层的结果进行加权x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)  # 输出结果return x  # 将输出结果返回出去model = NeuralNetwork().to(device)   # 把刚刚创建的模型传入到gpu
print(model)   # 打印模型的构造

        此处的神经网络层数可以自己设置,如果想多设置几层或者少设置基层,只需改变self.hidden的个数以及前向传播函数中激活函数的个数。

        输出结果:

3、设置训练集

def train(dataloader,model,loss_fn,optimizer):   # 导入参数,dataloader表示打包,数据加载器,model导入上述定义的神经网络模型,loss_fn表示损失值,optimizer表示优化器model.train()   # 模型设置为训练模式# 告诉模型,我要开始训练,模型中权重w进行随机化操作,已经更新w。在训练过程中,w会被修改的# #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()。batch_size_num = 1for x,y in dataloader:    # 遍历打包的图片的每一个包中的每一张图片及其对应的标签,其中batch为每一个数据的编号x,y = x.to(device),y.to(device)   # 把训练数据集和标签传入cpu或GPUpred = model.forward(x)    # 模型进行前向传播,输入图片信息后得到预测结果,forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred,y)     # 调用交叉熵损失函数计算损失值loss,输入参数为预测结果和真实结果,# Backpropaqation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()    # 梯度值清零,在反向传播之前先清除之前的梯度loss.backward()     # 反向传播,计算得到每个参数的梯度值woptimizer.step()    # 根据梯度更新权重w参数loss_value = loss.item()   # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 200 == 0:  # 判断遍历包的个数是否整除于200,用于将训练到的包的个数打印出来,整除200目的是节省资源print(f"loss:{loss_value:>7f}   [number: {batch_size_num}]")  # 打印损失值及其对应的值,损失值最大宽度为7,右对齐batch_size_num += 1    # 每遍历一个包增加一次,以达到显示出来遍历的包的个数

4、设置测试集

def test(dataloader,model,loss_fn):  # 输入参数打包的图片、训练好的模型、以及损失值size = len(dataloader.dataset)   # 返回测试数据集的样本总数num_batches = len(dataloader)   # 返回当前dataloader配置下的批次数model.eval()    # 表示此为模型测试,w就不能再更新。test_loss,correct = 0, 0   # 设置总损失值初始化为0,正确预测结果初始化为0with torch.no_grad():    # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算for x,y in dataloader:   # 遍历测试集中的每个包的每个图片及其对应的标签x,y = x.to(device),y.to(device)   # 将其传入gpupred = model.forward(x)   # 图片数据进行前向传播test_loss += loss_fn(pred,y).item()    # test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # pred.argmax(1) == y用于判断预测结果最大值对用的标签是否与真实值相同,然后将判断结果的bool值转变为浮点数并求和a = (pred.argmax(1) == y)   # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches    # 总损失值除以打包的批次数,返回测试的每一个包的损失值的均值,能来衡量模型测试的好坏。correct /= size   # 平均的正确率print(f"Test result: \n Accuracy:{(100 * correct)}%, Avg loss:{test_loss}")

5、创建损失函数、优化器

loss_fn = nn.CrossEntropyLoss()  # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer = torch.optim.Adam(model.parameters(),lr=0.0045)  # 创建一个优化器,SGD为随机梯度下降算法,学习率或者叫步长为0.0045
参数解析:
        1)params

                表示要训练的参数,一般我们传入的都是model.parameters()

        2)lr

                learning_rate学习率,也就是步长

        3)loss

                表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近于真实的型。

6、开始训练

epochs = 25  # 设置训练的轮数为25轮,因为模型中设置了权重值的更新,所以重复训练会更新模型的权值
for i in range(epochs):print(f"Epoch {i+1}\n--------------------")train(train_dataloader,model,loss_fn,optimizer)
print('Done!!')
test(test_dataloader,model,loss_fn)   # 导入测试集进行测试

        将上述所有代码连贯起来运行。

        运行结果:

        至此训练结束,用户可手动更改模型激活函数、学习率、模型训练轮数等等,以找到最优结果。

        祝你成功!!

三、总结

1、关键步骤

        使用PyTorch进行手写数字识别可以分为几个关键步骤。首先,需要准备手写数字数据集,通常使用MNIST数据集。然后,需要定义神经网络模型,可以使用PyTorch提供的各种层和激活函数来构建模型架构。接下来,需要定义损失函数,通常使用交叉熵损失函数来计算预测结果与真实标签之间的差异。

2、训练过程中

        在训练过程中,使用优化算法(如随机梯度下降)来更新模型的权重和偏置,使其逐渐接近最优解。训练过程中会使用训练集的数据进行反向传播,并根据损失函数的结果来调整模型参数。可以设置训练的轮数和批次大小,以更好地优化模型。

3、训练完成后

        在训练完成后,可以使用测试集或新的手写数字图像来评估模型的性能。通过将图像输入已训练好的模型,可以获得预测结果,并与真实标签进行比较。可以计算准确率等指标来评估模型的性能。

相关文章:

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>

目录 一、回顾神经网络框架 1、单层神经网络 2、多层神经网络 二、手写数字识别 1、续接上节课代码,如下所示 2、建立神经网络模型 输出结果: 3、设置训练集 4、设置测试集 5、创建损失函数、优化器 参数解析: 1)para…...

【笔记】材料分析测试:晶体学

晶体与晶体结构Crystal and Crystal Structure 1.晶体主要特征 固态物质可以分为晶态和非晶态两大类,分别称为晶体和非晶体。 晶体和非晶体在微观结构上的区别在于是否具有长程有序。 晶体(长程有序)非晶(短程有序&#xff09…...

飞塔Fortigate7.4.4的DNS劫持功能

基础网络配置、上网策略、与Server的VIP配置(略)。 在FortiGate上配置DNS Translation,将DNS请求结果为202.103.12.2的DNS响应报文中的IP地址修改为Server的内网IP 10.10.2.100。 config firewall dnstranslationedit 1set src 2.13.12.2set…...

Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】

Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 目录 Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 一、简单介绍 二、状态模式(State Pattern) 1、什么时候使用状态模式 2、使用状态模式的…...

【RabbitMQ】RabbitMQ 的概念以及使用RabbitMQ编写生产者消费者代码

目录 1. RabbitMQ 核心概念 1.1生产者和消费者 1.2 Connection和Channel 1.3 Virtual host 1.4 Queue 1.5 Exchange 1.6 RabbitMO工作流程 2. AMQP 3.RabbitMO快速入门 3.1.引入依赖 3.2.编写生产者代码 ​3.3.编写消费者代码 4.源码 1. RabbitMQ 核心概念 在安装…...

openmv与stm32通信

控制小车视觉循迹使用 OpenMV 往往是不够的。一般使用 OpenMV 对图像进行处理,将处理过后的数据使用串口发送给STM32,使用STM32控制小车行驶。本文主要讲解 OpenMV 模块与 STM32 间的串口通信以及两种循迹方案,分别是划分检测区域和线性回归。…...

C++ STL全面解析:六大核心组件之一----序列式容器(vector和List)(STL进阶学习)

目录 序列式容器 Vector vector概述 vector的迭代器 vector的数据结构 vector的构造和内存管理 vector的元素操作 List List概述 List的设计结构 List的迭代器 List的数据结构 List的内存构造 List的元素操作 C标准模板库(STL)是一组高效的…...

【c数据结构】OJ练习篇 帮你更深层次理解链表!(相交链表、相交链表、环形链表、环形链表之寻找环形入口点、判断链表是否是回文结构、 随机链表的复制)

目录 一. 相交链表 二. 环形链表 三. 环形链表之寻找环形入口点 四. 判断链表是否是回文结构 五. 随机链表的复制 一. 相交链表 最简单粗暴的思路,遍历两个链表,分别寻找是否有相同的对应的结点。 我们对两个链表的每个对应的节点进行判断比较&…...

微软开源GraphRAG的使用教程(最全,非常详细)

GraphRAG的介绍 目前微软已经开源了GraphRAG的完整项目代码。对于某一些LLM的下游任务则可以使用GraphRAG去增强自己业务的RAG的表现。项目给出了两种使用方式: 在打包好的项目状态下运行,可进行尝试使用。在源码基础上运行,适合为了下游任…...

使用Refine构建项目(1)初始化项目

要初始化一个空的Refine项目,你可以使用Refine提供的CLI工具create-refine-app。以下是初始化步骤: 使用npx命令: 在命令行中运行以下命令来创建一个新的Refine项目: npx create-refine-applatest my-refine-project这将引导你通过…...

【Docker】安装及使用

1. 安装Docker Desktop Docker Desktop是官方提供的桌面版Docker客户端,在Mac上使用Docker需要安装这个工具。 访问 Docker官方页面 并下载Docker Desktop for Mac。打开下载的.dmg文件,并拖动Docker图标到应用程序文件夹。安装完成后,打开…...

[大语言模型-论文精读] 以《黑神话:悟空》为研究案例探讨VLMs能否玩动作角色扮演游戏?

1. 论文简介 论文《Can VLMs Play Action Role-Playing Games? Take Black Myth Wukong as a Study Case》是阿里巴巴集团的Peng Chen、Pi Bu、Jun Song和Yuan Gao,在2024.09.19提交到arXiv上的研究论文。 论文: https://arxiv.org/abs/2409.12889代码和数据: h…...

提升动态数据查询效率:应对数据库成为性能瓶颈的优化方案

引言 在现代软件系统中,数据库性能是决定整个系统响应速度和处理能力的关键因素之一。然而,当系统负载增加,特别是在高并发、大数据量场景下,数据库性能往往会成为瓶颈,导致查询响应时间延长,影响用户体验…...

Prometheus+grafana+kafka_exporter监控kafka运行情况

使用Prometheus、Grafana和kafka_exporter来监控Kafka的运行情况是一种常见且有效的方案。以下是详细的步骤和说明: 1. 部署kafka_exporter 步骤: 从GitHub下载kafka_exporter的最新版本:kafka_exporter项目地址(注意&#xff…...

在vue中:style 的几种使用方式

在日常开发中:style的使用也是比较常见的&#xff1a; 亲测有效 1.最通用的写法 <p :style"{fontFamily:arr.conFontFamily,color:arr.conFontColor,backgroundColor:arr.conBgColor}">{{con.title}}</p> 2.三元表达式 <a :style"{height:…...

商城小程序后端开发实践中出现的问题及其解决方法

前言 商城小程序后端开发中&#xff0c;开发者可能会面临多种问题。以下是一些常见的问题及其解决方法&#xff1a; 一、性能优化 问题&#xff1a;随着用户量的增加和功能的扩展&#xff0c;商城小程序可能会出现响应速度慢、处理效率低的问题。 解决方法&#xff1a; 对数…...

阿里Arthas-Java诊断工具,基本操作和命令使用

Arthas 是阿里巴巴开源的一款Java诊断工具&#xff0c;深受开发者喜爱。它可以帮助开发者在不需要修改代码的情况下&#xff0c;对运行中的Java程序进行问题诊断和性能分析。 软件具体使用方法 1 启动 Arthas&#xff0c;此时可能会出现好几个jvm的进程号&#xff0c;输入序号…...

Go 1.19.4 路径和目录-Day 15

1. 路径介绍 存储设备保存着数据&#xff0c;但是得有一种方便的模式让用户可以定位资源位置&#xff0c;操作系统采用一种路径字符 串的表达方式&#xff0c;这是一棵倒置的层级目录树&#xff0c;从根开始。 相对路径&#xff1a;不是以根目录开始的路径&#xff0c;例如 a/b…...

jEasyUI 创建标签页

jEasyUI 创建标签页 jEasyUI&#xff08;jQuery EasyUI&#xff09;是一个基于jQuery的框架&#xff0c;它为Web应用程序提供了丰富的用户界面组件。标签页&#xff08;Tabs&#xff09;是jEasyUI中的一个常用组件&#xff0c;用于在一个页面内组织多个面板&#xff0c;用户可…...

鸿蒙HarmonyOS开发:一次开发,多端部署(界面级)天气应用案例

文章目录 一、布局简介二、典型布局场景三、侧边栏 SideBarContainer1、子组件2、属性3、事件 四、案例 天气应用1、UX设计2、实现分析3、主页整体实现4、具体代码 五、运行效果 一、布局简介 布局可以分为自适应布局和响应式布局&#xff0c;二者的介绍如下表所示。 名称简介…...

使用 Python 模拟光的折射,反射,和全反射

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

大厂太卷了!又一款国产AI视频工具上线了,免费无限使用!(附提示词宝典)

大家好&#xff0c;我是程序员X小鹿&#xff0c;前互联网大厂程序员&#xff0c;自由职业2年&#xff0c;也一名 AIGC 爱好者&#xff0c;持续分享更多前沿的「AI 工具」和「AI副业玩法」&#xff0c;欢迎一起交流~ 记得去年刚开始分享 AI 视频工具的时候&#xff0c;介绍的大多…...

vue3扩展echart封装为组件库-快速复用

ECharts ECharts&#xff0c;全称Enterprise Charts&#xff0c;是一款由百度团队开发并开源&#xff0c;后捐赠给Apache基金会的纯JavaScript图表库。它提供了直观、生动、可交互、可个性化定制的数据可视化图表&#xff0c;广泛应用于数据分析、商业智能、网页开发等领域。以…...

随机掉落的项目足迹:Vue3 + wangEditor5富文本编辑器——toolbar.getConfig() 查看工具栏的默认配置

问题引入 小提示&#xff1a;问题引入是一个讲故事的废话环节&#xff0c;各位小伙伴可以直接跳到第二大点&#xff1a;问题解决 我的项目不需要在富文本编辑器中引入添加代码块的功能&#xff0c;于是我寻思在工具栏上把操作代码的菜单删一删 于是我来到官网文档工具栏配置 …...

更新 Git 软件

更新 Git 软件本身是指将你当前安装的 Git 版本升级到最新版本。不同的操作系统有不同的更新方法。以下是针对 Windows、macOS 和 Linux 的 Git 更新步骤&#xff1a; Windows 检查当前版本&#xff1a; git --version访问官网下载最新版本&#xff1a; 访问 Git 官方网站 (ht…...

Keil根据map文件确定单片机代码存储占用flash情况

可以从map文件中查看得知&#xff0c;代码占用内存情况大概为35KB,而在在线仿真时&#xff0c;可以看到在flash的0x8008F64地址前均有数据&#xff0c;是代码数据&#xff0c;8F64(HEX)36708(DEC),36708/102335,刚好35。因此&#xff0c;要想操作读写flash&#xff0c;必须在不…...

ByteTrack多目标跟踪流程图

ByteTrack多目标跟踪流程图 点个赞吧&#xff0c;谢谢。...

什么是L2范数

定义&#xff1a; 在数学和计算中&#xff0c;L2 范数是一种用于测量向量长度或大小的方法&#xff0c;也被称为欧几里得范数。对于一个 n 维向量 x ( x 1 , x 2 , … , x n ) \mathbf{x} (x_1, x_2, \dots, x_n) x(x1​,x2​,…,xn​)&#xff0c;其 L2 范数定义为&#x…...

Scrapy爬虫IP代理池:提升爬取效率与稳定性

在互联网时代&#xff0c;数据就是新的黄金。无论是企业还是个人&#xff0c;数据的获取和分析能力都显得尤为重要。而在众多数据获取手段中&#xff0c;使用爬虫技术无疑是一种高效且广泛应用的方法。然而&#xff0c;爬虫在实际操作中常常会遇到IP被封禁的问题。为了解决这个…...

信息技术(IT)行业的发展

近年来&#xff0c;信息技术&#xff08;IT&#xff09;行业的发展呈现出前所未有的活力和潜力。随着全球数字化转型的加速&#xff0c;IT行业正逐步成为推动社会经济发展的重要引擎。无论是互联网、大数据、人工智能&#xff0c;还是云计算、物联网&#xff0c;这些新兴技术都…...