《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>
目录
一、回顾神经网络框架
1、单层神经网络
2、多层神经网络
二、手写数字识别
1、续接上节课代码,如下所示
2、建立神经网络模型
输出结果:
3、设置训练集
4、设置测试集
5、创建损失函数、优化器
参数解析:
1)params
2)lr
3)loss
6、开始训练
运行结果:
三、总结
1、关键步骤
2、训练过程中
3、训练完成后
一、回顾神经网络框架
1、单层神经网络
图示为单层神经完了的基本构造,首先,有信号想传入输入层,然后这些信号会在传播途中发生衰减或者增强,假如想传入的信号为x,那么经过衰减或增强后的值则为wx,这个w叫权重,此时便有了传入信号,有很多传入信号,那么神经元就会对这些信号进行处理,把这些信号的加权值求和,再将这个求和的值映射到非线性激活函数上来引入非线性特征,最后得到输出结果。
2、多层神经网络
相比于上述的单层神经网络,单层神经网络就相当于上图中的绿色的第一列,即多个信号传入后只进行加权映射处理一次即输出结果,而多层神经网络即有多列神经元,传入信号经过第一列神经元处理后得到的值将其再次当做信号进行衰减或增强传入下一层神经元对其进行处理,至于多少层需要经过多次训练去决定,没有固定的值。
二、手写数字识别
1、续接上节课代码,如下所示
import torch
print(torch.__version__)"""MNIST包含70,000张手写数字图像:60,000张用于训练,10,000张用于测试。
图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。"""from torch import nn # 导入神经网络模块
from torch.utils.data import DataLoader # 数据包管理工具,打包数据,
from torchvision import datasets # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型的数据转换为tensor张量,numpy arrgy,"""下载训练数据集,图片+标签"""
training_data = datasets.MNIST( # 跳转到函数的内部源代码,pycharm 按下ctrl +鼠标点击root='data', # 表述下载的数据存放的根目录train=True, # 表示下载的是训练数据集,如果要下载测试集,更改为False即可download=True, # 表示如果根目录有该数据,则不再下载,如果没有则下载transform=ToTensor() # 张量,图片是不能直接传入神经网络模型# 表示制定一个数据转换操作,将下载的图片转换为pytorch张量,因为pytorch只能处理张量tensor类型的数据
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor() # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如 PyTorch、TensorFlo
) # NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。print(len(training_data))
print(len(test_data))"""展示手写数字图片,把训练数据集中的前59000张图片展示一下"""# # tensor -》numpy 矩阵类型的数据,矩阵是特殊的张量,张量可以包含任意维度的数据
# from matplotlib import pyplot as plt # 导入绘图库
# figure = plt.figure() # 设置一个空白画布
# for i in range(9):
# img,label = training_data[i+59000] # 提取第59000张图片开始,共9张,返回图片及其对应的标签值
#
# figure.add_subplot(3,3,i+1) # 在画布创建3行3列的小窗口,通过遍历的值i来确定每个画布展示的图片
# plt.title(label) # 设置每个窗口的标题,设置标签为上述返回的标签值
# plt.axis('off') # 取消画布中的坐标轴的图像
# plt.imshow(img.squeeze(),cmap='gray') # plt.imshow()将NumPy数组data中的数据显示为图像,并在图形窗口中,
# a = img.squeeze() # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1,则张量不会改变。
# plt.show()train_dataloader = DataLoader(training_data,batch_size=64) # 调用上述定义的DataLoader打包库,将训练集的图片和标签,64张图片为一个包,
test_dataloader = DataLoader(test_data,batch_size=64) # 将测试集的图片和标签,每64张打包成一份
for x,y in test_dataloader:# x是表示打包好的每一个数据包,其形状为[64,1,28,28],64表示批次大小,1表示通道数为1,即灰度图,28表示图像的宽高像素值# y表示每个图片标签print(f"shape of x[N,C,H,W]:{x.shape}") # 打印图片形状print(f"shape of y:{y.shape}{y.dtype}") # 打印标签的形状和数据类型break # 跳出并终止循环,表示只遍历一个包的数据情况"""判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU""" # 返回cuda,mps,cpu, m1,m2集显CPU+GPU RTX3060
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device") # 字符串的格式化。CUDA驱动软件的功能:pytorch能够去执行cuda的命令,cuda通过GPU指令集
# 神经网络的模型也需要传入到GPU,1个batchsize的数据集也需要传入到GPU,才可以进行训练。
2、建立神经网络模型
class NeuralNetwork(nn.Module): # 通过调用类的形式来使用神经网络,神经网络的模型def __init__(self):super().__init__() # 继承的父类初始化self.flatten = nn.Flatten() # 将输入的多维数据展开成一维数据,创建一个展开对象flattenself.hidden1 = nn.Linear(28*28,128) # 创建一个全连接层,其上包含权重的值,第1个参数表示有多少个神经元传入进来,第2个参数表示有多少个数据传出去,这里表示上述多层神经网络的第一列神经元self.hidden2 = nn.Linear(128, 256) # 再次创建一个全连接层,将第一层输出个数转变成这一层的输入信号数,然后在设定输出层个数为256self.hidden3 = nn.Linear(256,128) # 继承上一层的输出信号个数,将其当做这一层的输入为256层,设定输出为128层self.hidden4 = nn.Linear(128,64) # 输入128层,输出64层self.out = nn.Linear(64, 10) # 设定输出层,输出必需和标签的类别相同,输入必须是上一层的神经元个数def forward(self,x): # 设定前向传播函数,你得告诉它,数据的流向。是神经网络层连接起来,函数名称不能改。当你调用forward函数的时候,传入进来的图像数据x = self.flatten(x) # 传入信号x图像,首先进行展开x = self.hidden1(x) # 将其当做输入层x = torch.sigmoid(x) # 使用sigmoid激活函数得到结果,torch使用的relu函数 relu,tanhx = self.hidden2(x) # 再次对上一层的结果进行加权x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x) # 输出结果return x # 将输出结果返回出去model = NeuralNetwork().to(device) # 把刚刚创建的模型传入到gpu
print(model) # 打印模型的构造
此处的神经网络层数可以自己设置,如果想多设置几层或者少设置基层,只需改变self.hidden的个数以及前向传播函数中激活函数的个数。
输出结果:
3、设置训练集
def train(dataloader,model,loss_fn,optimizer): # 导入参数,dataloader表示打包,数据加载器,model导入上述定义的神经网络模型,loss_fn表示损失值,optimizer表示优化器model.train() # 模型设置为训练模式# 告诉模型,我要开始训练,模型中权重w进行随机化操作,已经更新w。在训练过程中,w会被修改的# #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()。batch_size_num = 1for x,y in dataloader: # 遍历打包的图片的每一个包中的每一张图片及其对应的标签,其中batch为每一个数据的编号x,y = x.to(device),y.to(device) # 把训练数据集和标签传入cpu或GPUpred = model.forward(x) # 模型进行前向传播,输入图片信息后得到预测结果,forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred,y) # 调用交叉熵损失函数计算损失值loss,输入参数为预测结果和真实结果,# Backpropaqation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() # 梯度值清零,在反向传播之前先清除之前的梯度loss.backward() # 反向传播,计算得到每个参数的梯度值woptimizer.step() # 根据梯度更新权重w参数loss_value = loss.item() # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 200 == 0: # 判断遍历包的个数是否整除于200,用于将训练到的包的个数打印出来,整除200目的是节省资源print(f"loss:{loss_value:>7f} [number: {batch_size_num}]") # 打印损失值及其对应的值,损失值最大宽度为7,右对齐batch_size_num += 1 # 每遍历一个包增加一次,以达到显示出来遍历的包的个数
4、设置测试集
def test(dataloader,model,loss_fn): # 输入参数打包的图片、训练好的模型、以及损失值size = len(dataloader.dataset) # 返回测试数据集的样本总数num_batches = len(dataloader) # 返回当前dataloader配置下的批次数model.eval() # 表示此为模型测试,w就不能再更新。test_loss,correct = 0, 0 # 设置总损失值初始化为0,正确预测结果初始化为0with torch.no_grad(): # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算for x,y in dataloader: # 遍历测试集中的每个包的每个图片及其对应的标签x,y = x.to(device),y.to(device) # 将其传入gpupred = model.forward(x) # 图片数据进行前向传播test_loss += loss_fn(pred,y).item() # test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item() # pred.argmax(1) == y用于判断预测结果最大值对用的标签是否与真实值相同,然后将判断结果的bool值转变为浮点数并求和a = (pred.argmax(1) == y) # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches # 总损失值除以打包的批次数,返回测试的每一个包的损失值的均值,能来衡量模型测试的好坏。correct /= size # 平均的正确率print(f"Test result: \n Accuracy:{(100 * correct)}%, Avg loss:{test_loss}")
5、创建损失函数、优化器
loss_fn = nn.CrossEntropyLoss() # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer = torch.optim.Adam(model.parameters(),lr=0.0045) # 创建一个优化器,SGD为随机梯度下降算法,学习率或者叫步长为0.0045
参数解析:
1)params
表示要训练的参数,一般我们传入的都是model.parameters()
2)lr
learning_rate学习率,也就是步长
3)loss
表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近于真实的型。
6、开始训练
epochs = 25 # 设置训练的轮数为25轮,因为模型中设置了权重值的更新,所以重复训练会更新模型的权值
for i in range(epochs):print(f"Epoch {i+1}\n--------------------")train(train_dataloader,model,loss_fn,optimizer)
print('Done!!')
test(test_dataloader,model,loss_fn) # 导入测试集进行测试
将上述所有代码连贯起来运行。
运行结果:
至此训练结束,用户可手动更改模型激活函数、学习率、模型训练轮数等等,以找到最优结果。
祝你成功!!
三、总结
1、关键步骤
使用PyTorch进行手写数字识别可以分为几个关键步骤。首先,需要准备手写数字数据集,通常使用MNIST数据集。然后,需要定义神经网络模型,可以使用PyTorch提供的各种层和激活函数来构建模型架构。接下来,需要定义损失函数,通常使用交叉熵损失函数来计算预测结果与真实标签之间的差异。
2、训练过程中
在训练过程中,使用优化算法(如随机梯度下降)来更新模型的权重和偏置,使其逐渐接近最优解。训练过程中会使用训练集的数据进行反向传播,并根据损失函数的结果来调整模型参数。可以设置训练的轮数和批次大小,以更好地优化模型。
3、训练完成后
在训练完成后,可以使用测试集或新的手写数字图像来评估模型的性能。通过将图像输入已训练好的模型,可以获得预测结果,并与真实标签进行比较。可以计算准确率等指标来评估模型的性能。
相关文章:

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>
目录 一、回顾神经网络框架 1、单层神经网络 2、多层神经网络 二、手写数字识别 1、续接上节课代码,如下所示 2、建立神经网络模型 输出结果: 3、设置训练集 4、设置测试集 5、创建损失函数、优化器 参数解析: 1)para…...

【笔记】材料分析测试:晶体学
晶体与晶体结构Crystal and Crystal Structure 1.晶体主要特征 固态物质可以分为晶态和非晶态两大类,分别称为晶体和非晶体。 晶体和非晶体在微观结构上的区别在于是否具有长程有序。 晶体(长程有序)非晶(短程有序)…...
飞塔Fortigate7.4.4的DNS劫持功能
基础网络配置、上网策略、与Server的VIP配置(略)。 在FortiGate上配置DNS Translation,将DNS请求结果为202.103.12.2的DNS响应报文中的IP地址修改为Server的内网IP 10.10.2.100。 config firewall dnstranslationedit 1set src 2.13.12.2set…...

Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】
Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 目录 Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 一、简单介绍 二、状态模式(State Pattern) 1、什么时候使用状态模式 2、使用状态模式的…...

【RabbitMQ】RabbitMQ 的概念以及使用RabbitMQ编写生产者消费者代码
目录 1. RabbitMQ 核心概念 1.1生产者和消费者 1.2 Connection和Channel 1.3 Virtual host 1.4 Queue 1.5 Exchange 1.6 RabbitMO工作流程 2. AMQP 3.RabbitMO快速入门 3.1.引入依赖 3.2.编写生产者代码 3.3.编写消费者代码 4.源码 1. RabbitMQ 核心概念 在安装…...

openmv与stm32通信
控制小车视觉循迹使用 OpenMV 往往是不够的。一般使用 OpenMV 对图像进行处理,将处理过后的数据使用串口发送给STM32,使用STM32控制小车行驶。本文主要讲解 OpenMV 模块与 STM32 间的串口通信以及两种循迹方案,分别是划分检测区域和线性回归。…...

C++ STL全面解析:六大核心组件之一----序列式容器(vector和List)(STL进阶学习)
目录 序列式容器 Vector vector概述 vector的迭代器 vector的数据结构 vector的构造和内存管理 vector的元素操作 List List概述 List的设计结构 List的迭代器 List的数据结构 List的内存构造 List的元素操作 C标准模板库(STL)是一组高效的…...

【c数据结构】OJ练习篇 帮你更深层次理解链表!(相交链表、相交链表、环形链表、环形链表之寻找环形入口点、判断链表是否是回文结构、 随机链表的复制)
目录 一. 相交链表 二. 环形链表 三. 环形链表之寻找环形入口点 四. 判断链表是否是回文结构 五. 随机链表的复制 一. 相交链表 最简单粗暴的思路,遍历两个链表,分别寻找是否有相同的对应的结点。 我们对两个链表的每个对应的节点进行判断比较&…...
微软开源GraphRAG的使用教程(最全,非常详细)
GraphRAG的介绍 目前微软已经开源了GraphRAG的完整项目代码。对于某一些LLM的下游任务则可以使用GraphRAG去增强自己业务的RAG的表现。项目给出了两种使用方式: 在打包好的项目状态下运行,可进行尝试使用。在源码基础上运行,适合为了下游任…...
使用Refine构建项目(1)初始化项目
要初始化一个空的Refine项目,你可以使用Refine提供的CLI工具create-refine-app。以下是初始化步骤: 使用npx命令: 在命令行中运行以下命令来创建一个新的Refine项目: npx create-refine-applatest my-refine-project这将引导你通过…...
【Docker】安装及使用
1. 安装Docker Desktop Docker Desktop是官方提供的桌面版Docker客户端,在Mac上使用Docker需要安装这个工具。 访问 Docker官方页面 并下载Docker Desktop for Mac。打开下载的.dmg文件,并拖动Docker图标到应用程序文件夹。安装完成后,打开…...

[大语言模型-论文精读] 以《黑神话:悟空》为研究案例探讨VLMs能否玩动作角色扮演游戏?
1. 论文简介 论文《Can VLMs Play Action Role-Playing Games? Take Black Myth Wukong as a Study Case》是阿里巴巴集团的Peng Chen、Pi Bu、Jun Song和Yuan Gao,在2024.09.19提交到arXiv上的研究论文。 论文: https://arxiv.org/abs/2409.12889代码和数据: h…...
提升动态数据查询效率:应对数据库成为性能瓶颈的优化方案
引言 在现代软件系统中,数据库性能是决定整个系统响应速度和处理能力的关键因素之一。然而,当系统负载增加,特别是在高并发、大数据量场景下,数据库性能往往会成为瓶颈,导致查询响应时间延长,影响用户体验…...
Prometheus+grafana+kafka_exporter监控kafka运行情况
使用Prometheus、Grafana和kafka_exporter来监控Kafka的运行情况是一种常见且有效的方案。以下是详细的步骤和说明: 1. 部署kafka_exporter 步骤: 从GitHub下载kafka_exporter的最新版本:kafka_exporter项目地址(注意ÿ…...

在vue中:style 的几种使用方式
在日常开发中:style的使用也是比较常见的: 亲测有效 1.最通用的写法 <p :style"{fontFamily:arr.conFontFamily,color:arr.conFontColor,backgroundColor:arr.conBgColor}">{{con.title}}</p> 2.三元表达式 <a :style"{height:…...
商城小程序后端开发实践中出现的问题及其解决方法
前言 商城小程序后端开发中,开发者可能会面临多种问题。以下是一些常见的问题及其解决方法: 一、性能优化 问题:随着用户量的增加和功能的扩展,商城小程序可能会出现响应速度慢、处理效率低的问题。 解决方法: 对数…...
阿里Arthas-Java诊断工具,基本操作和命令使用
Arthas 是阿里巴巴开源的一款Java诊断工具,深受开发者喜爱。它可以帮助开发者在不需要修改代码的情况下,对运行中的Java程序进行问题诊断和性能分析。 软件具体使用方法 1 启动 Arthas,此时可能会出现好几个jvm的进程号,输入序号…...

Go 1.19.4 路径和目录-Day 15
1. 路径介绍 存储设备保存着数据,但是得有一种方便的模式让用户可以定位资源位置,操作系统采用一种路径字符 串的表达方式,这是一棵倒置的层级目录树,从根开始。 相对路径:不是以根目录开始的路径,例如 a/b…...
jEasyUI 创建标签页
jEasyUI 创建标签页 jEasyUI(jQuery EasyUI)是一个基于jQuery的框架,它为Web应用程序提供了丰富的用户界面组件。标签页(Tabs)是jEasyUI中的一个常用组件,用于在一个页面内组织多个面板,用户可…...

鸿蒙HarmonyOS开发:一次开发,多端部署(界面级)天气应用案例
文章目录 一、布局简介二、典型布局场景三、侧边栏 SideBarContainer1、子组件2、属性3、事件 四、案例 天气应用1、UX设计2、实现分析3、主页整体实现4、具体代码 五、运行效果 一、布局简介 布局可以分为自适应布局和响应式布局,二者的介绍如下表所示。 名称简介…...

超短脉冲激光自聚焦效应
前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...

黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
AtCoder 第409场初级竞赛 A~E题解
A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

如何在看板中有效管理突发紧急任务
在看板中有效管理突发紧急任务需要:设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP(Work-in-Progress)弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中,设立专门的紧急任务通道尤为重要,这能…...
如何为服务器生成TLS证书
TLS(Transport Layer Security)证书是确保网络通信安全的重要手段,它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书,可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...

CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...