基于多层感知机(MLP)实现MNIST手写体识别
实现步骤
- 下载数据集
- 处理好数据集
- 确定好模型(初始化模型参数等等)
- 确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )
- 进行模型的训练
- 进行模型的评估
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 1. 下载数据集
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms.ToTensor(), download=True)# 2. 创建批量数据迭代器
train_iter = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_iter = DataLoader(mnist_test, batch_size=256)# 3. 可视化检查数据
var = next(iter(train_iter))
plt.title(str(var[1][0])) # 显示标签
plt.imshow(var[0][0].squeeze().numpy(), cmap='gray') # 显示图片
plt.show()# 4. 定义模型:多层感知机
net = nn.Sequential(nn.Flatten(),nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 10) # 注意这里是不需要加 Softmax 了的,因为后面定义了,nn.CrossEntropyLoss()这个会自动帮我们进行 Softmax 以及进行损失计算。其实就是目标函数
)# 初始化模型参数
def init_weights(m):if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)# 5. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss() # CrossEntropyLoss已经包含了softmax,所以不需要LogSoftmax
optimizer = optim.SGD(net.parameters(), lr=0.2)# 6. 训练模型
epoch_num = 20
for epoch in range(epoch_num):net.train() # 设置为训练模式total_loss = 0for X, y in train_iter:optimizer.zero_grad() # 清除梯度y_hat = net(X) # 前向传播loss = loss_fn(y_hat, y) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数total_loss += loss.item() * X.shape[0] # 累积损失avg_loss = total_loss / len(mnist_train) # 计算平均损失print(f'Epoch {epoch + 1}/{epoch_num}, Loss: {avg_loss:.4f}')# 7. 评估模型
def evaluate_model(net, test_iter):net.eval() # 设置为评估模式correct, total = 0, 0with torch.no_grad(): # 在评估时不需要计算梯度for X, y in test_iter:y_hat = net(X)_, predicted = torch.max(y_hat, 1) # 获取预测的标签correct += (predicted == y).sum().item() # 统计正确的个数total += y.size(0) # 统计总数accuracy = correct / totalprint(f'Accuracy on test set: {accuracy * 100:.2f}%')# 评估模型的表现
evaluate_model(net, test_iter)
代码实践的结果:
- 自己不会去计算损失
- 在模型进行训练的时候不知道如何把数据放进去:原来只需要创建好了 DataLoader 以后,通过枚举
就可以拿到数据了。- 最后进行模型评估的时候也是用 AI 进行完成了。所以多少还是差点意思。
- 后面的代码多去实践实践,并且思考吧!!!
关于代码中交叉熵计算的理解
理解损失函数(loss_fn
)是如何计算的,对于训练神经网络来说是非常重要的。具体到你提到的这行代码:
loss = loss_fn(y_hat, y) # 计算损失
损失函数的定义:
在你的代码中,损失函数是:
loss_fn = nn.CrossEntropyLoss()
nn.CrossEntropyLoss()
是一种常用于多分类问题的损失函数,它实际上包含了两个步骤:
- Softmax:将模型的输出转换为概率分布。
- 交叉熵损失:计算真实标签与预测概率分布之间的差距。
为什么要用交叉熵呢?因为交叉熵可以来衡量预测差距,这个我们只需要这个知识点,并且知道上面的公式就好了。
我们逐步分析这两个步骤。
1. Softmax(概率转换)
假设模型的输出 y_hat
是一个向量,其中每个元素代表对应类别的“分数”(或者说是原始的 logits)。例如,假设有 3 个类别,模型的输出可能是:
y_hat = [2.0, 1.0, -1.0] # 这三个数字是 logits,不是概率
通过 Softmax 函数,我们将这些 logits 转换成概率:
# 计算 softmax
softmax = torch.nn.functional.softmax(y_hat, dim=-1)
softmax
的输出会是一个概率分布,每个数值的范围在 [0, 1] 之间,且所有数值加起来为 1。例如,经过 Softmax 后可能得到:
softmax = [0.7, 0.2, 0.1] # 类别 0 的概率是 0.7,类别 1 的概率是 0.2,类别 2 的概率是 0.1
2. 交叉熵损失(Cross Entropy Loss)
交叉熵是衡量两个概率分布之间差异的一个标准方法。在分类任务中,我们希望预测的类别概率与真实标签分布尽可能接近。
对于一个单一的样本,交叉熵损失的计算公式为:
L = − ∑ i = 1 C y i log ( p i ) L = - \sum_{i=1}^{C} y_i \log(p_i) L=−i=1∑Cyilog(pi)
- ( C ) 是类别数。
- ( y_i ) 是真实标签(在 one-hot 编码下,真实类别的标签为 1,其他类别为 0)。
- ( p_i ) 是模型预测的概率。
对于多分类任务来说,交叉熵损失会选择对应真实标签的类别概率 ( p_{\text{true}} ) 来计算损失。例如,如果真实标签是类别 0,那么我们只关注模型在类别 0 上的预测概率。
假设真实标签 y
是类别 0,对应的 one-hot 编码是 [1, 0, 0]
,而模型的预测是:
softmax = [0.7, 0.2, 0.1]
那么交叉熵损失为:
L = − ( 1 ⋅ log ( 0.7 ) + 0 ⋅ log ( 0.2 ) + 0 ⋅ log ( 0.1 ) ) = − log ( 0.7 ) ≈ 0.3567 L = - (1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1)) = - \log(0.7) \approx 0.3567 L=−(1⋅log(0.7)+0⋅log(0.2)+0⋅log(0.1))=−log(0.7)≈0.3567
nn.CrossEntropyLoss()
如何工作
在 PyTorch 中,nn.CrossEntropyLoss
会自动处理上述两个步骤:
- 将
y_hat
(logits)转换为概率。 - 使用真实标签
y
计算交叉熵损失。
输入和输出:
y_hat
: 这是模型的原始输出(logits),形状为(batch_size, num_classes)
。每一行是一个样本的 logits。y
: 这是标签,通常是一个包含类别索引的向量,形状为(batch_size,)
。每个元素是该样本的真实类别索引。
例如:
假设我们有以下数据:
-
模型的输出(logits)为:
y_hat = torch.tensor([[2.0, 1.0, -1.0], # 第一个样本[0.5, 1.5, 0.3]]) # 第二个样本
-
真实标签
y
为:y = torch.tensor([0, 1]) # 第一个样本的标签是类别 0,第二个样本的标签是类别 1
使用 nn.CrossEntropyLoss()
计算损失:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(y_hat, y)
CrossEntropyLoss
会首先对 y_hat
进行 softmax 转换,然后计算每个样本的交叉熵损失。你可以通过打印出来的 loss
来查看模型的表现。
总结:
y_hat
是模型的原始输出(logits),表示每个类别的“分数”。nn.CrossEntropyLoss
会自动处理 softmax 和交叉熵损失的计算。- 损失函数的目的是衡量模型的输出与真实标签之间的差异,差异越小,损失值越小,说明模型的预测越准确。
使用`nn.CrossEntropyLoss 会自动进行独热编码
在计算交叉熵损失时,nn.CrossEntropyLoss
会自动处理标签,并且不需要你手动将标签转换为独热编码(one-hot encoding)。
具体来说:
y_hat
:是模型的原始输出(logits),形状为(batch_size, num_classes)
,每一行是一个样本的预测结果,包含每个类别的分数(logits)。y
:是标签,形状为(batch_size,)
,每个元素是该样本的真实类别的 索引,而不是独热编码。
nn.CrossEntropyLoss
会自动使用标签 y
中的类别索引(如类别 0, 1, 2)来计算损失,它会根据该类别索引选择对应的模型输出进行计算,而不需要你事先将标签转换为独热编码。
举个例子:
假设我们有一个批次的两个样本,模型的输出 y_hat
和真实标签 y
如下:
模型的输出 y_hat
(logits):
y_hat = torch.tensor([[2.0, 1.0, -1.0], # 第一个样本的 logits[0.5, 1.5, 0.3]]) # 第二个样本的 logits
真实标签 y
(类别索引):
y = torch.tensor([0, 1]) # 第一个样本的标签是类别 0,第二个样本的标签是类别 1
在这个例子中,y_hat
的形状是 (2, 3)
,表示有两个样本,每个样本有三个类别的 logits。
- 对于第一个样本,它的真实标签是类别
0
,y[0] = 0
。 - 对于第二个样本,它的真实标签是类别
1
,y[1] = 1
。
当使用 nn.CrossEntropyLoss
时,它会根据真实标签中的类别索引来选择对应的类别概率(通过 softmax 后的结果),然后计算交叉熵损失。PyTorch 会自动完成:
- Softmax 转换:将
y_hat
的 logits 转换为概率分布。 - 损失计算:根据真实标签的类别索引计算交叉熵损失。
损失计算过程:
假设 y_hat
的 softmax 输出(经过 softmax 转换后)为:
softmax = torch.nn.functional.softmax(y_hat, dim=1)# softmax 结果:
# [[0.7, 0.2, 0.1], # 第一个样本的概率
# [0.3, 0.6, 0.1]] # 第二个样本的概率
然后计算交叉熵损失时:
- 对于第一个样本(真实标签是
0
),损失将仅计算softmax[0][0]
(即类别0
的概率),即-log(0.7)
。 - 对于第二个样本(真实标签是
1
),损失将仅计算softmax[1][1]
(即类别1
的概率),即-log(0.6)
。
因此,nn.CrossEntropyLoss
会自动从 y
中提取类别索引,并计算相应的损失。
总结:
- 不需要手动进行独热编码,因为
nn.CrossEntropyLoss
会自动处理。 - 标签
y
只需要是一个包含类别索引的张量(每个元素是一个整数,表示类别)。 nn.CrossEntropyLoss
会根据标签中的类别索引,从y_hat
中提取对应的 logits,然后计算交叉熵损失。
通过这种方式,PyTorch 简化了损失计算的过程,让你不需要显式地进行独热编码的转换。
为什么在实验的时候我去掉了参数初始化对结果并没有很大影响?
在深度学习模型中,尤其是在多层感知机(MLP)中,初始化模型参数对于训练的收敛速度和结果有重要影响。虽然初始化通常有助于提升训练过程的稳定性和准确性,但在某些情况下,即使没有显式初始化,训练仍然可能得到较好的结果。这有以下几个原因:
1. PyTorch 默认的权重初始化
如果你没有手动初始化模型参数,PyTorch 会使用默认的初始化方式。对于 nn.Linear
层,PyTorch 默认使用的是 均匀分布初始化(U(-sqrt(k), sqrt(k))
),其中 k
是每一层的输入特征数(即该层的输入维度)。这种初始化方式通常可以在大部分场景下有效工作,尤其是在简单的任务或小型模型中。
2. 任务本身比较简单
MNIST 数据集是一个相对简单的任务,具有以下特点:
- 样本相对简单(28x28 的灰度图像)。
- 类别数量较少(10 个类别)。
- 数据集规模较小(60,000 个训练样本)。
由于这些原因,即使没有特别优化初始化方式,模型仍然能在训练过程中较好地拟合数据,因此准确率可能不会受到显著影响。
3. 优化器的鲁棒性
现代优化器(如 SGD、Adam 等)通常具有较强的鲁棒性,能够在一定范围内有效地调整模型的参数,避免了初始化差异带来的过度影响。即使没有进行显式初始化,优化器也能够逐步调整模型的参数,从而避免梯度消失或梯度爆炸等问题,保证训练的顺利进行。
4. 训练过程中参数的调整
在模型训练初期,即使初始化不完美,随着训练的进行,网络的权重会在反向传播过程中逐步调整到合适的值。因此,即使开始时的参数较为随机,优化过程仍然能够找到有效的解决方案。这就是深度学习的一个特性:即使参数初始不理想,优化过程通常能通过梯度更新找到合适的解。
5. 初始化不影响最终收敛结果
对于一些简单的任务,模型可能在多个初始化条件下都能够达到一个相对接近的局部最优解。在这种情况下,即使没有手动初始化权重,模型也能收敛到较好的解。
总结:
- 默认初始化(PyTorch 内部的初始化方式)通常已经能在很多简单的任务中有效工作,特别是像 MNIST 这样简单的图像分类任务。
- 优化器的鲁棒性帮助模型调整参数,避免了初始化不完美时对结果产生显著影响。
- 对于 MNIST 这种简单任务,初始化参数的不同可能不会导致显著差异,尤其是在训练的过程中,优化器能够找到较好的解。
然而,在一些更复杂的任务中,初始化的方式会直接影响模型的训练效率和性能。在这些任务中,精心设计的初始化(例如 Xavier、He 初始化等)能够帮助模型更快地收敛并避免训练过程中遇到的问题。
相关文章:
基于多层感知机(MLP)实现MNIST手写体识别
实现步骤 下载数据集处理好数据集确定好模型(初始化模型参数等等)确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )进行模型的训练进行模型的评估 import torch import torch…...
QT和有道词典有冲突,导致内存溢出,闪退。
提示:本文为学习记录,若有疑问,请联系作者。 前言 具体详细查看此博主:原文链接 在使用Qt Designer时,如果开启了有道词典,会导致Qt Designer崩溃。估计应该是把有道词典屏幕取词功能打开后,有…...
4. 示例:创建带约束的随机地址生成器(范围0x1000-0xFFFF)
文章目录 前言代码示例:运行方法:查看结果:关键功能说明:扩展功能建议: 前言 以下是一个完整的SystemVerilog测试平台示例,包含约束随机地址生成、日志输出和波形生成功能: 代码示例࿱…...

VSCode轻松调试运行C#控制台程序
1.背景 我一直都是用VS来开发C#项目的,用的比较顺手,也习惯了。看其他技术文章有介绍VS Code更轻量,更方便。所以我专门花时间来使用VS Code,看看它是如何调试代码、如何运行C#控制台。这篇文章是一个记录的过程。 2.操作 2.1 V…...

内容中台是什么?内容管理平台解析
内容中台的核心价值 现代企业数字化转型进程中,内容中台作为中枢系统,通过构建统一化的内容管理平台实现数据资产的高效整合与智能调度。其核心价值体现在打破传统信息孤岛,将分散于CRM、ERP等系统的文档、知识库、产品资料进行标准化归集&a…...

sqlmap:自动SQL注入和数据库接管工具
SQL 注入攻击是 Web 安全领域最常见的漏洞之一,今天给大家介绍一个自动化 SQL 注入和数据库接管工具:sqlmap。sqlmap 作为一款开源渗透测试工具,能帮助安全测试人员快速发现并利用 SQL 注入漏洞接管数据库服务器。 功能特性 sqlmap 使用 Pyt…...
Python设置阿里云镜像源教程:解决PIP安装依赖包下载速度慢的问题
在 Python 中,你可以通过修改 pip 的配置文件来设置阿里云镜像源,以加速包的安装。以下是具体步骤: 1. 临时使用阿里云镜像源 你可以在使用 pip 安装包时,通过 -i 参数临时指定阿里云镜像源: pip install <packa…...

基于专利合作地址匹配的数据构建区域协同矩阵
文章目录 地区地址提取完成的处理代码 在专利合作申请表中,有多家公司合作申请。在专利权人地址中, 有多个公司的地址信息。故想利用这里多个地址。想用这里的地址来代表区域之间的专利合作情况代表区域之间的协同、协作情况。 下图是专利合作表的一部分…...

Java集合List快速实现重复判断的10种方法深度解析
文章目录 引言:为什么需要关注List重复判断?一、基础实现方法1.1 暴力双循环法1.2 HashSet法 二、进阶实现方案2.1 Stream API实现2.2 TreeSet排序法 三、高性能优化方案3.1 并行流处理3.2 BitSet位图法(仅限整数) 四、第三方库实…...

List的模拟实现(2)
前言 上一节我们讲解了list的基本功能,那么本节我们就结合底层代码来分析list是怎么实现的,那么废话不多说,我们正式进入今天的学习:) List的底层结构 我们先来看一下list的底层基本结构: 这里比较奇怪的…...
如何使用SaltStack批量替换SSL证书方案
以下是借助 SaltStack 批量替换 SSL 证书的完整方案,该方案结合了自动化更新与回滚机制,以保障操作的高效性与安全性: 一、准备工作 目录结构搭建 在 Salt Master 的 /home/salt/ssl_update 目录下构建如下结构:ssl_update/ ├──…...
Golang快速上手01/Golang基础
最近有需求,需要使用go,这几天快速过一遍基础语法,这是今天的总结 项目结构 #mermaid-svg-qpF09pnIik9bqQ4E {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-qpF09pnIik9bqQ4E .e…...
[Web 安全] 反序列化漏洞 - 学习笔记
关注这个专栏的其他相关笔记:[Web 安全] Web 安全攻防 - 学习手册-CSDN博客 0x01:反序列化漏洞 — 漏洞介绍 反序列化漏洞是一种常见的安全漏洞,主要出现在应用程序将 序列化数据 重新转换为对象(即反序列化)的过程中…...

【学习笔记】Google的Lyra项目:基于神经网络的超低比特率语音编解码技术
一、引言:语音通信的带宽挑战与技术突破 在实时音视频通信占据全球数字化生活核心地位的今天,Google于2021年推出的Lyra编解码器标志着语音编码技术进入新的时代。这款基于机器学习的新型音频编解码器以3kbps的极低比特率实现接近原始音质的语音重建能力…...

Unity Dedicated Server 控制台 输出日志LOg 中文 乱码
现象: 中文乱码 原因: Unity打包出来的.exe文件,语言一栏是英文,VS控制台出来不一样 解决方案: 新建.bat文件 ,并使用命令chcp 65001,运行时启动.bat,而不是.exe, 改不了exe属性,虽然有点奇怪ÿ…...

【Excel】 Power Query抓取多页数据导入到Excel
抓取多页数据想必大多数人都会,只要会点编程技项的人都不会是难事儿。那么,如果只是单纯的利用Excel软件,我还真的没弄过。昨天,我就因为这个在网上找了好久发好久。 1、在数据-》新建查询-》从其他源-》自网站 ,如图 …...

去耦电容的作用详解
在霍尔元件的实际应用过程中,经常会用到去耦电容。去耦电容是电路中装设在元件的电源端的电容,其作用详解如下: 一、基本概念 去耦电容,也称退耦电容,是把输出信号的干扰作为滤除对象。它通常安装在集成电路…...

HTTPS 与 HTTP 的区别在哪?
HTTP与HTTPS作为互联网数据传输的核心协议,其通信机制与安全特性深刻影响着现代网络应用的可靠性与用户体验。本文将解析两者的通信流程、安全机制及核心差异。 一、HTTP的通信机制 先来看看HTTP是什么吧。 HTTP基于TCP/IP协议栈,采用经典客户端-服务…...

let、const【ES6】
“我唯一知道的就是我一无所知。” - 苏格拉底 目录 块级作用域:var、let、const的对比:Object.freeze(): 块级作用域: 块级作用域指由 {} 包围的代码块(如 if、for、while、单独代码块等)形成的独立作用…...

openharmony5.0中hdf框架中实现驱动程序的动态加载和管理的技术细节分析
在分析openharmony的hdf框架的设备驱动加载器(IDriverLoader)时发现在创建实例时会首先判断一下是否完成了驱动入口的构建(HdfDriverEntryConstruct),如果没有构建会重新构建,这与我开始以为的不一致(我一直以为是采用的linux内核方式,只是由…...
Vim 调用外部命令学习笔记
Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...

【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
在四层代理中还原真实客户端ngx_stream_realip_module
一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡(如 HAProxy、AWS NLB、阿里 SLB)发起上游连接时,将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后,ngx_stream_realip_module 从中提取原始信息…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
MySQL账号权限管理指南:安全创建账户与精细授权技巧
在MySQL数据库管理中,合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号? 最小权限原则…...
重启Eureka集群中的节点,对已经注册的服务有什么影响
先看答案,如果正确地操作,重启Eureka集群中的节点,对已经注册的服务影响非常小,甚至可以做到无感知。 但如果操作不当,可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...

算法岗面试经验分享-大模型篇
文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer (1)资源 论文&a…...

LLMs 系列实操科普(1)
写在前面: 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容,原视频时长 ~130 分钟,以实操演示主流的一些 LLMs 的使用,由于涉及到实操,实际上并不适合以文字整理,但还是决定尽量整理一份笔…...