当前位置: 首页 > article >正文

基于多层感知机(MLP)实现MNIST手写体识别

实现步骤

  1. 下载数据集
  2. 处理好数据集
  3. 确定好模型(初始化模型参数等等)
  4. 确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )
  5. 进行模型的训练
  6. 进行模型的评估
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 1. 下载数据集
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms.ToTensor(), download=True)# 2. 创建批量数据迭代器
train_iter = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_iter = DataLoader(mnist_test, batch_size=256)# 3. 可视化检查数据
var = next(iter(train_iter))
plt.title(str(var[1][0]))  # 显示标签
plt.imshow(var[0][0].squeeze().numpy(), cmap='gray')  # 显示图片
plt.show()# 4. 定义模型:多层感知机
net = nn.Sequential(nn.Flatten(),nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 10) # 注意这里是不需要加 Softmax 了的,因为后面定义了,nn.CrossEntropyLoss()这个会自动帮我们进行 Softmax 以及进行损失计算。其实就是目标函数
)# 初始化模型参数
def init_weights(m):if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)# 5. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # CrossEntropyLoss已经包含了softmax,所以不需要LogSoftmax
optimizer = optim.SGD(net.parameters(), lr=0.2)# 6. 训练模型
epoch_num = 20
for epoch in range(epoch_num):net.train()  # 设置为训练模式total_loss = 0for X, y in train_iter:optimizer.zero_grad()  # 清除梯度y_hat = net(X)  # 前向传播loss = loss_fn(y_hat, y)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数total_loss += loss.item() * X.shape[0]  # 累积损失avg_loss = total_loss / len(mnist_train)  # 计算平均损失print(f'Epoch {epoch + 1}/{epoch_num}, Loss: {avg_loss:.4f}')# 7. 评估模型
def evaluate_model(net, test_iter):net.eval()  # 设置为评估模式correct, total = 0, 0with torch.no_grad():  # 在评估时不需要计算梯度for X, y in test_iter:y_hat = net(X)_, predicted = torch.max(y_hat, 1)  # 获取预测的标签correct += (predicted == y).sum().item()  # 统计正确的个数total += y.size(0)  # 统计总数accuracy = correct / totalprint(f'Accuracy on test set: {accuracy * 100:.2f}%')# 评估模型的表现
evaluate_model(net, test_iter)

代码实践的结果:

  1. 自己不会去计算损失
  2. 在模型进行训练的时候不知道如何把数据放进去:原来只需要创建好了 DataLoader 以后,通过枚举
    就可以拿到数据了。
  3. 最后进行模型评估的时候也是用 AI 进行完成了。所以多少还是差点意思。
  4. 后面的代码多去实践实践,并且思考吧!!!

关于代码中交叉熵计算的理解

理解损失函数(loss_fn)是如何计算的,对于训练神经网络来说是非常重要的。具体到你提到的这行代码:

loss = loss_fn(y_hat, y)  # 计算损失

损失函数的定义:

在你的代码中,损失函数是:

loss_fn = nn.CrossEntropyLoss()

nn.CrossEntropyLoss() 是一种常用于多分类问题的损失函数,它实际上包含了两个步骤:

  1. Softmax:将模型的输出转换为概率分布。
  2. 交叉熵损失:计算真实标签与预测概率分布之间的差距。

为什么要用交叉熵呢?因为交叉熵可以来衡量预测差距,这个我们只需要这个知识点,并且知道上面的公式就好了。

我们逐步分析这两个步骤。

1. Softmax(概率转换)

假设模型的输出 y_hat 是一个向量,其中每个元素代表对应类别的“分数”(或者说是原始的 logits)。例如,假设有 3 个类别,模型的输出可能是:

y_hat = [2.0, 1.0, -1.0]  # 这三个数字是 logits,不是概率

通过 Softmax 函数,我们将这些 logits 转换成概率:

# 计算 softmax
softmax = torch.nn.functional.softmax(y_hat, dim=-1)

softmax 的输出会是一个概率分布,每个数值的范围在 [0, 1] 之间,且所有数值加起来为 1。例如,经过 Softmax 后可能得到:

softmax = [0.7, 0.2, 0.1]  # 类别 0 的概率是 0.7,类别 1 的概率是 0.2,类别 2 的概率是 0.1

2. 交叉熵损失(Cross Entropy Loss)

交叉熵是衡量两个概率分布之间差异的一个标准方法。在分类任务中,我们希望预测的类别概率与真实标签分布尽可能接近。

对于一个单一的样本,交叉熵损失的计算公式为:

L = − ∑ i = 1 C y i log ⁡ ( p i ) L = - \sum_{i=1}^{C} y_i \log(p_i) L=i=1Cyilog(pi)

  • ( C ) 是类别数。
  • ( y_i ) 是真实标签(在 one-hot 编码下,真实类别的标签为 1,其他类别为 0)。
  • ( p_i ) 是模型预测的概率。

对于多分类任务来说,交叉熵损失会选择对应真实标签的类别概率 ( p_{\text{true}} ) 来计算损失。例如,如果真实标签是类别 0,那么我们只关注模型在类别 0 上的预测概率。

假设真实标签 y 是类别 0,对应的 one-hot 编码是 [1, 0, 0],而模型的预测是:

softmax = [0.7, 0.2, 0.1]

那么交叉熵损失为:

L = − ( 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) + 0 ⋅ log ⁡ ( 0.1 ) ) = − log ⁡ ( 0.7 ) ≈ 0.3567 L = - (1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1)) = - \log(0.7) \approx 0.3567 L=(1log(0.7)+0log(0.2)+0log(0.1))=log(0.7)0.3567

nn.CrossEntropyLoss() 如何工作

在 PyTorch 中,nn.CrossEntropyLoss 会自动处理上述两个步骤:

  1. y_hat(logits)转换为概率。
  2. 使用真实标签 y 计算交叉熵损失。
输入和输出:
  • y_hat: 这是模型的原始输出(logits),形状为 (batch_size, num_classes)。每一行是一个样本的 logits。
  • y: 这是标签,通常是一个包含类别索引的向量,形状为 (batch_size,)。每个元素是该样本的真实类别索引。

例如:

假设我们有以下数据:

  • 模型的输出(logits)为:

    y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本[0.5, 1.5, 0.3]]) # 第二个样本
    
  • 真实标签 y 为:

    y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1
    

使用 nn.CrossEntropyLoss() 计算损失:

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(y_hat, y)

CrossEntropyLoss 会首先对 y_hat 进行 softmax 转换,然后计算每个样本的交叉熵损失。你可以通过打印出来的 loss 来查看模型的表现。

总结:

  • y_hat 是模型的原始输出(logits),表示每个类别的“分数”。
  • nn.CrossEntropyLoss 会自动处理 softmax 和交叉熵损失的计算。
  • 损失函数的目的是衡量模型的输出与真实标签之间的差异,差异越小,损失值越小,说明模型的预测越准确。

使用`nn.CrossEntropyLoss 会自动进行独热编码

在计算交叉熵损失时,nn.CrossEntropyLoss 会自动处理标签,并且不需要你手动将标签转换为独热编码(one-hot encoding)。

具体来说:

  • y_hat:是模型的原始输出(logits),形状为 (batch_size, num_classes),每一行是一个样本的预测结果,包含每个类别的分数(logits)。
  • y:是标签,形状为 (batch_size,),每个元素是该样本的真实类别的 索引,而不是独热编码。

nn.CrossEntropyLoss 会自动使用标签 y 中的类别索引(如类别 0, 1, 2)来计算损失,它会根据该类别索引选择对应的模型输出进行计算,而不需要你事先将标签转换为独热编码。

举个例子:

假设我们有一个批次的两个样本,模型的输出 y_hat 和真实标签 y 如下:

模型的输出 y_hat(logits):
y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本的 logits[0.5, 1.5, 0.3]]) # 第二个样本的 logits
真实标签 y(类别索引):
y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1

在这个例子中,y_hat 的形状是 (2, 3),表示有两个样本,每个样本有三个类别的 logits。

  • 对于第一个样本,它的真实标签是类别 0y[0] = 0
  • 对于第二个样本,它的真实标签是类别 1y[1] = 1

当使用 nn.CrossEntropyLoss 时,它会根据真实标签中的类别索引来选择对应的类别概率(通过 softmax 后的结果),然后计算交叉熵损失。PyTorch 会自动完成:

  1. Softmax 转换:将 y_hat 的 logits 转换为概率分布。
  2. 损失计算:根据真实标签的类别索引计算交叉熵损失。

损失计算过程:

假设 y_hat 的 softmax 输出(经过 softmax 转换后)为:

softmax = torch.nn.functional.softmax(y_hat, dim=1)# softmax 结果:
# [[0.7, 0.2, 0.1],  # 第一个样本的概率
#  [0.3, 0.6, 0.1]]  # 第二个样本的概率

然后计算交叉熵损失时:

  • 对于第一个样本(真实标签是 0),损失将仅计算 softmax[0][0](即类别 0 的概率),即 -log(0.7)
  • 对于第二个样本(真实标签是 1),损失将仅计算 softmax[1][1](即类别 1 的概率),即 -log(0.6)

因此,nn.CrossEntropyLoss 会自动从 y 中提取类别索引,并计算相应的损失。

总结:

  • 不需要手动进行独热编码,因为 nn.CrossEntropyLoss 会自动处理。
  • 标签 y 只需要是一个包含类别索引的张量(每个元素是一个整数,表示类别)。
  • nn.CrossEntropyLoss 会根据标签中的类别索引,从 y_hat 中提取对应的 logits,然后计算交叉熵损失。

通过这种方式,PyTorch 简化了损失计算的过程,让你不需要显式地进行独热编码的转换。

为什么在实验的时候我去掉了参数初始化对结果并没有很大影响?

在深度学习模型中,尤其是在多层感知机(MLP)中,初始化模型参数对于训练的收敛速度和结果有重要影响。虽然初始化通常有助于提升训练过程的稳定性和准确性,但在某些情况下,即使没有显式初始化,训练仍然可能得到较好的结果。这有以下几个原因:

1. PyTorch 默认的权重初始化

如果你没有手动初始化模型参数,PyTorch 会使用默认的初始化方式。对于 nn.Linear 层,PyTorch 默认使用的是 均匀分布初始化U(-sqrt(k), sqrt(k))),其中 k 是每一层的输入特征数(即该层的输入维度)。这种初始化方式通常可以在大部分场景下有效工作,尤其是在简单的任务或小型模型中。

2. 任务本身比较简单

MNIST 数据集是一个相对简单的任务,具有以下特点:

  • 样本相对简单(28x28 的灰度图像)。
  • 类别数量较少(10 个类别)。
  • 数据集规模较小(60,000 个训练样本)。

由于这些原因,即使没有特别优化初始化方式,模型仍然能在训练过程中较好地拟合数据,因此准确率可能不会受到显著影响。

3. 优化器的鲁棒性

现代优化器(如 SGD、Adam 等)通常具有较强的鲁棒性,能够在一定范围内有效地调整模型的参数,避免了初始化差异带来的过度影响。即使没有进行显式初始化,优化器也能够逐步调整模型的参数,从而避免梯度消失或梯度爆炸等问题,保证训练的顺利进行。

4. 训练过程中参数的调整

在模型训练初期,即使初始化不完美,随着训练的进行,网络的权重会在反向传播过程中逐步调整到合适的值。因此,即使开始时的参数较为随机,优化过程仍然能够找到有效的解决方案。这就是深度学习的一个特性:即使参数初始不理想,优化过程通常能通过梯度更新找到合适的解。

5. 初始化不影响最终收敛结果

对于一些简单的任务,模型可能在多个初始化条件下都能够达到一个相对接近的局部最优解。在这种情况下,即使没有手动初始化权重,模型也能收敛到较好的解。

总结:

  • 默认初始化(PyTorch 内部的初始化方式)通常已经能在很多简单的任务中有效工作,特别是像 MNIST 这样简单的图像分类任务。
  • 优化器的鲁棒性帮助模型调整参数,避免了初始化不完美时对结果产生显著影响。
  • 对于 MNIST 这种简单任务,初始化参数的不同可能不会导致显著差异,尤其是在训练的过程中,优化器能够找到较好的解。

然而,在一些更复杂的任务中,初始化的方式会直接影响模型的训练效率和性能。在这些任务中,精心设计的初始化(例如 Xavier、He 初始化等)能够帮助模型更快地收敛并避免训练过程中遇到的问题。

相关文章:

基于多层感知机(MLP)实现MNIST手写体识别

实现步骤 下载数据集处理好数据集确定好模型(初始化模型参数等等)确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )进行模型的训练进行模型的评估 import torch import torch…...

如何使用useContext进行全局状态管理?

在 React 中,使用 useContext 进行全局状态管理是一种有效的方法,尤其在需要在多个组件之间共享状态时。useContext 允许你在组件树中传递数据,而无需通过每个组件的 props 逐层传递。以下是关于如何使用 useContext 进行全局状态管理的详细指…...

【机器学习】Logistic回归#1基于Scikit-Learn的简单Logistic回归

主要参考学习资料: 《机器学习算法的数学解析与Python实现》莫凡 著 前置知识:线性代数-Python 目录 问题背景数学模型类别表示Logistic函数假设函数损失函数训练步骤 代码实现特点 问题背景 分类问题是一类预测非连续(离散)值的…...

8.Dashboard的导入导出

分享自己的Dashboard 1. 在Dashboard settings中选择 JSON Model 2. 导入 后续请参考第三篇导入光放Dashboard,相近...

next.js-学习2

next.js-学习2 1. https://nextjs.org/learn/dashboard-app/getting-started2. 模拟的数据3. 添加样式4. 字体,图片5. 创建布局和页面页面导航 1. https://nextjs.org/learn/dashboard-app/getting-started /app: Contains all the routes, components, and logic …...

视频推拉流EasyDSS直播点播平台授权激活码无效,报错400的原因是什么?

在当今数字化浪潮中,视频推拉流 EasyDSS 视频直播点播平台宛如一颗璀璨的明珠,汇聚了视频直播、点播、转码、精细管理、录像、高效检索以及时移回看等一系列强大功能于一身,全方位构建起音视频服务生态。它既能助力音视频采集,精准…...

【论文详解】Transformer 论文《Attention Is All You Need》能够并行计算的原因

文章目录 前言一、传统 RNN/CNN 存在的串行计算问题二、Transformer 如何实现并行计算?三、Transformer 的 Encoder 和 Decoder 如何并行四、结论 前言 亲爱的家人们,创作很不容易,若对您有帮助的话,请点赞收藏加关注哦&#xff…...

Fisher信息矩阵(Fisher Information Matrix,简称FIM)

Fisher信息矩阵简介 Fisher信息矩阵(Fisher Information Matrix,简称FIM)是统计学和信息理论中的一个重要概念,广泛应用于参数估计、统计推断和机器学习领域。它以统计学家罗纳德费希尔(Ronald Fisher)的名…...

基础设施安全(Infrastructure Security)是什么?

基础设施安全(Infrastructure Security)指的是保护IT基础设施(包括物理和云端的服务器、网络设备、存储、数据库等)免受网络攻击、数据泄露、未授权访问、系统故障等威胁的各种安全措施和技术。 1. 基础设施安全的主要组成部分 &…...

[杂学笔记]OSI七层模型作用、HTTP协议中的各种方法、HTTP的头部字段、TLS握手、指针与引用的使用场景、零拷贝技术

1.OSI七层模型作用 物理层:负责光电信号的传输,以及将光电信号转化为二进制数据数据链路层:主要负责将收到的二进制数据进一步的封装为数据帧报文。同时因为数据在网络中传递的时候,每一个主机都能够收到报文数据,该层…...

Framework层JNI侧Binder

目录 一,Binder JNI在整个系统的位置 1.1 小结 二,代码分析 2.1 BBinder创建 2.2 Bpinder是在查找服务时候创建的 2.3 JNI实现 2.4 JNI层android_os_BinderProxy_transact 2.5 BPProxy实现 2)调用IPCThreadState发送数据到Binder驱动…...

Windows 图形显示驱动开发-WDDM 3.2-自动显示切换(九)

面板驱动程序 显示器驱动程序是根据从 EDID 生成的即插即用 (PnP) 硬件 ID 加载的。 由于 EDID 保持不变,当任何一个 GPU 控制内部面板时,都会加载面板驱动程序。 这两个驱动程序将显示相同的亮度功能。 因此,加载应该不会造成任何问题&…...

Excel大文件拆分

import pandas as pddef split_excel_file(input_file, output_prefix, num_parts10):# 读取Excel文件df pd.read_excel(input_file)# 计算每部分的行数total_rows len(df)rows_per_part total_rows // num_partsremaining_rows total_rows % num_partsstart_row 0for i i…...

OpenCV计算摄影学(7)HDR成像之多帧图像对齐的类cv::AlignMTB

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 该算法将图像转换为‌中值阈值位图‌(Median Threshold Bitmap,MTB): 1.位图生成‌:…...

JWT+redis实现三大令牌管理方案深度解析

三种令牌管理方案对比与评估 1. 仅续期Redis(不生成新令牌) 实现原理 通过延长Redis中的令牌有效期维持会话,JWT本身不包含动态过期时间。 优点 ✅ 低开销:无需生成新令牌,减少JWT签名计算成本。 ✅ 简单实现&#x…...

北京大学DeepSeek提示词工程与落地场景(PDF无套路免费下载)

近年来,大模型技术飞速发展,但许多用户发现:即使使用同一款 AI 工具,效果也可能天差地别——有人能用 AI 快速生成精准方案,有人却只能得到笼统回答。这背后的关键差异,在于提示词工程的应用能力。 北京大…...

Axure PR 9 中继器 03 翻页控制

大家好,我是大明同学。 接着上期的内容,这期内容,我们来了解一下Axure中继器图表翻页控制。 预览地址:https://pvie5g.axshare.com 翻页控制 1.打开上期RP 文件,在元件库中拖入一个矩形,宽值根据业务实际…...

IO流(师从韩顺平)

文章目录 文件什么是文件文件流 常用的文件操作创建文件对象相关构造器和方法应用案例 获取文件的相关信息应用案例 目录的操作和文件删除应用案例 IO 流原理及流的分类Java IO 流原理IO流的分类 IO 流体系图-常用的类IO 流体系图(重要!!&…...

基于Javase的停车场收费管理系统

基于Javase的停车场收费管理系统 停车场管理系统开发文档 项目概述 1.1 项目背景 随着现代化城市的不断发展,车辆数量不断增加,停车难问题也日益突出。为了更好地管理停车场资 源,提升停车效率,需要一个基于Java SE的停车场管理…...

Exoplayer(MediaX)实现音频变调和变速播放

在K歌或录音类应用中变调是个常见需求,比如需要播出萝莉音/大叔音等。变速播放在影视播放类应用中普遍存在,在传统播放器Mediaplayer中这两个功能都比较难以实现,特别在低版本SDK中,而Exoplayer作为google官方推出的Mediaplayer替…...

Spring Boot集成Jetty、Tomcat或Undertow及支持HTTP/2协议

目录 一、常用Web服务器 1、Tomcat 2、Jetty 3、Undertow 二、什么是HTTP/2协议 1、定义 2、特性 3、优点 4、与HTTP/1.1的区别 三、集成Web服务器并开启HTTP/2协议 1、生成证书 2、新建springboot项目 3、集成Web服务器 3.1 集成Tomcat 3.2 集成Jetty 3.3 集成…...

《Python实战进阶》专栏 No 5:GraphQL vs RESTful API 对比与实现

《Python实战进阶》专栏包括68集,每一集聚焦一个中高级技术知识点,涵盖Python在Web开发、数据处理、自动化、机器学习、并发编程等领域的应用,系统梳理Python开发者的知识集。本集的主题为: No4 : GraphQL vs RESTful API 对比与实…...

类和对象——static修饰类的成员

static修饰类的成员 static成员1 static成员的概念2 特性 static成员 有时会有这样的需求:计算程序中创建出了多少个类的对象,以及多少个正在使用的对象。 因为构造函数和析构函数都只会调用一次,所以可以通过设置生命周期和main函数一致的…...

RabbitMQ系列(七)基本概念之Channel

RabbitMQ 中的 Channel(信道) 是客户端与 RabbitMQ 服务器通信的虚拟会话通道,其核心作用在于优化资源利用并提升消息处理效率。以下是其核心机制与功能的详细解析: 一、Channel 的核心定义 虚拟通信链路 Channel 是建立在 TCP 连…...

你对 Spring Cloud 的理解

Spring Cloud 是一个基于 Spring Boot 的微服务架构开发工具集,为开发者提供了快速构建分布式系统的一系列解决方案,涵盖了服务发现、配置管理、熔断器、智能路由、微代理、控制总线等多个方面。 从核心组件来看: 服务发现:以 Eu…...

MYSQL 5.7数据库,关于1067报错 invalid default value for,解决方法!

???作者: 米罗学长 ???个人简介:混迹java圈十余年,精通Java、小程序、数据库等。 ???各类成品java毕设 。javaweb,ssm,springboot,mysql等项目,源码丰富,欢迎咨询。 ???…...

C# Enumerable类 之 数据筛选

总目录 前言 在 C# 中,System.Linq.Enumerable 类是 LINQ(Language Integrated Query)的核心组成部分,它提供了一系列静态方法,用于操作实现了 IEnumerable 接口的集合。通过这些方法,我们可以轻松地对集合…...

运维基础知识(一)

一:SSH端口 首先SSH是什么? SSH(Secure Shell)是Linux、Unix、Mac及其他网络设备最常用的远程CLI管理协议,SSH使用秘钥对数据进行加密,保证了远程管理数据的安全性。 Secure Shell (SSH) 是一种网络协议,允许用户通过加密的通道安全地访问另一台计算机。SSH广泛用于远程…...

权重生成图像

简介 前面提到的许多生成模型都有保存了生成器的权重,本章主要介绍如何使用训练好的权重文件通过生成器生成图像。 但是如何使用权重生成图像呢? 一、参数配置 ima_size 为图像尺寸,这个需要跟你模型训练的时候resize的时候一样。 latent_dim为噪声维度,一般的设置都是…...

【Linux基础】Linux下的C编程指南

目录 一、前言 二、Vim的使用 2.1 普通模式 2.2 插入模式 2.3 命令行模式 2.4 可视模式 三、GCC编译器 3.1 预处理阶段 3.2 编译阶段 3.3 汇编阶段 3.4 链接阶段 3.5 静态库和动态库 四、Gdb调试器 五、总结 一、前言 在Linux环境下使用C语言进行编程是一项基础且…...