使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
文章目录
- 🍋引言
- 🍋什么是多分类问题?
- 🍋处理步骤
- 🍋多分类问题
- 🍋MNIST dataset的实现
- 🍋NLLLoss 和 CrossEntropyLoss
🍋引言
当处理多分类问题时,PyTorch是一种非常有用的深度学习框架。在这篇博客中,我们将讨论如何使用PyTorch来解决多分类问题。我们将介绍多分类问题的基本概念,构建一个简单的多分类神经网络模型,并演示如何准备数据、训练模型和评估结果。
🍋什么是多分类问题?
多分类问题是一种机器学习任务,其中目标是将输入数据分为多个不同的类别或标签。与二分类问题不同,多分类问题涉及到三个或更多类别的分类任务。例如,图像分类问题可以将图像分为不同的类别,如猫、狗、鸟等。
🍋处理步骤
-
准备数据:
收集和准备数据集,确保每个样本都有相应的标签,以指明其所属类别。
划分数据集为训练集、验证集和测试集,以便进行模型训练、调优和性能评估。 -
数据预处理:
对数据进行预处理,例如归一化、标准化、缺失值处理或数据增强,以确保模型训练的稳定性和性能。 -
选择模型架构:
选择适当的深度学习模型架构,通常包括卷积神经网络(CNN)、循环神经网络(RNN)、Transformer等,具体取决于问题的性质。 -
定义损失函数:
为多分类问题选择适当的损失函数,通常是交叉熵损失(Cross-Entropy Loss)。 -
选择优化器:
选择合适的优化算法,如随机梯度下降(SGD)、Adam、RMSprop等,以训练模型并调整权重。 -
训练模型:
使用训练数据集来训练模型。在每个训练迭代中,通过前向传播和反向传播来更新模型参数,以减小损失函数的值。 -
评估模型:
使用验证集来评估模型性能。常见的性能指标包括准确性、精确度、召回率、F1分数等。 -
调优模型:
根据验证集的性能,对模型进行调优,可以尝试不同的超参数设置、模型架构变化或数据增强策略。 -
测试模型:
最终,在独立的测试数据集上评估模型的性能,以获得最终性能评估。 -
部署模型:
将训练好的模型部署到实际应用中,用于实时或批处理多分类任务。
🍋多分类问题
之前我们讨论的问题都是二分类居多,对于二分类问题,我们若求得p(0),南无p(1)=1-p(0),还是比较容易的,但是本节我们将引入多分类,那么我们所求得就转化为p(i)(i=1,2,3,4…),同时我们需要满足以上概率中每一个都大于0;且总和为1。
处理多分类问题,这里我们新引入了一个称为Softmax Layer
接下来我们一起讨论一下Softmax Layer层
首先我们计算指数计算e的zi次幂,原因很简单e的指数函数恒大于0;分母就是e的z1次幂+e的z2次幂+e的z3次幂…求和,这样所有的概率和就为1了。
下图形象的展示了Softmax,Exponent这里指指数,和上面我们说的一样,先求指数,这样有了分子,再将所有指数求和,最后一一divide,得到了每一个概率。
接下来我们一起来看看损失函数
如果使用numpy进行实现,根据刘二大人的代码,可以进行如下的实现
import numpy as np
y = np.array([1,0,0])
z = np.array([0.2,0.1,-0.1])
y_pred = np.exp(z)/np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss)
运行结果如下
注意:神经网络的最后一层不需要激活
在pytorch中
import torch
y = torch.LongTensor([0]) # 长整型
z = torch.Tensor([[0.2, 0.1, -0.1]])
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z, y)
print(loss)
运行结果如下
下面根据一个例子进行演示
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])
Y_pred1 = torch.Tensor([[0.1, 0.2, 0.9], [1.1, 0.1, 0.2], [0.2, 2.1, 0.1]])
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3], [0.2, 0.3, 0.5], [0.2, 0.2, 0.5]])
l1 = criterion(Y_pred1, Y)
l2 = criterion(Y_pred2, Y)
print("Batch Loss1 = ", l1.data, "\nBatch Loss2=", l2.data)
运行结果如下
根据上面的代码可以看出第一个损失比第二个损失要小。原因很简单,想对于Y_pred1每一个预测的分类与Y是一致的,而Y_pred2则相差了一下,所以损失自然就大了些
🍋MNIST dataset的实现
首先第一步还是导包
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
之后是数据的准备
batch_size = 64
# transform可以将其转化为0-1,形状的转换从28×28转换为,1×28×28
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ), (0.3081, )) # 均值mean和标准差std
])
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)
接下来我们构建网络
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512) self.l2 = torch.nn.Linear(512, 256) self.l3 = torch.nn.Linear(256, 128) self.l4 = torch.nn.Linear(128, 64) self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.l1(x)) x = F.relu(self.l2(x)) x = F.relu(self.l3(x)) x = F.relu(self.l4(x)) return self.l5(x) # 注意最后一层不做激活
model = Net()
之后定义损失和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
接下来就进行训练了
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0): inputs, target = dataoptimizer.zero_grad()# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) running_loss = 0.0
def test():correct = 0total = 0with torch.no_grad(): # 这里可以防止内嵌代码不会执行梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))
最后调用执行
if __name__ == '__main__': for epoch in range(10): train(epoch)test()
🍋NLLLoss 和 CrossEntropyLoss
NLLLoss 和 CrossEntropyLoss(也称为交叉熵损失)是深度学习中常用的两种损失函数,用于测量模型的输出与真实标签之间的差距,通常用于分类任务。它们有一些相似之处,但也有一些不同之处。
相同点:
用途:两者都用于分类任务,评估模型的输出和真实标签之间的差异,以便进行模型的训练和优化。
数学基础:NLLLoss 和 CrossEntropyLoss 本质上都是交叉熵损失的不同变种,它们都以信息论的概念为基础,衡量两个概率分布之间的相似度。
输入格式:它们通常期望模型的输出是一个概率分布,表示各个类别的预测概率,以及真实的标签。
不同点:
输入格式:NLLLoss 通常期望输入是对数概率(log probabilities),而 CrossEntropyLoss 通常期望输入是未经对数化的概率。在实际应用中,CrossEntropyLoss 通常与softmax操作结合使用,将原始模型输出转化为概率分布,而NLLLoss可以直接使用对数概率。
对数化:NLLLoss 要求将模型输出的概率经过对数化(取对数)以获得对数概率,然后与真实标签的离散概率分布进行比较。CrossEntropyLoss 通常在 softmax 操作之后直接使用未对数化的概率值与真实标签比较。
输出维度:NLLLoss 更通用,可以用于多种情况,包括多类别分类和序列生成等任务,因此需要更多的灵活性。CrossEntropyLoss 通常用于多类别分类任务。
总之,NLLLoss 和 CrossEntropyLoss 都用于分类任务,但它们在输入格式和使用上存在一些差异。通常,选择哪个损失函数取决于你的模型输出的格式以及任务的性质。如果你的模型输出已经是对数概率形式,通常使用NLLLoss,否则通常使用CrossEntropyLoss。
挑战与创造都是很痛苦的,但是很充实。
相关文章:

使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…...

基于nodejs+vue网课学习平台
各功能简要描述如下: 1个人信息管理:包括对学生用户、老师和管理员的信息进行录入、修改,以及老师信息的审核等 2在库课程查询:用于学生用户查询相关课程的功能 3在库老师查询:用于学生用户查询相关老师教学的所有课程的功能。 4在库学校查询:用于学生用户查询相关学…...
读书笔记:Effective C++ 2.0 版,条款13(初始化顺序==声明顺序)、条款14(基类有虚析构)
条款13: 初始化列表中成员列出的顺序和它们在类中声明的顺序相同 类成员是按照它们在类里被声明的顺序进行初始化的,和它们在成员初始化列表中列出的顺序没一点关系。 根本原因可能是考虑到内存的分布,按照定义顺序进行排列。 另外,初始化列表…...

flutter开发实战-下拉刷新与上拉加载更多实现
flutter开发实战-下拉刷新与上拉加载更多实现 在开发中经常遇到列表需要下拉刷新与上拉加载更多,这里使用EasyRefresh,版本是3.3.21 一、什么是EasyRefresh EasyRefresh可以在Flutter应用程序上轻松实现下拉刷新和上拉加载。它几乎支持所有Flutter Sc…...

旧手机热点机改造成服务器方案
如果你也跟我一样有这种想法, 那真的太酷了!!! ok,前提是得有root,不然体验大打折扣 目录 目录 1.做一个能爬墙能走百度直连的热点机(做热点机用) 2.做emby视频服务器 3.做文件服务, 存取文件 4.装青龙面板,跑一些定时任务 5.做远程摄像头监控 6.做web服务器 7.内网穿…...

网工实验笔记:策略路由PBR的应用场景
一、概述 PBR(Policy-Based Routing,策略路由):PBR使得网络设备不仅能够基于报文的目的IP地址进行数据转发,更能基于其他元素进行数据转发,例如源IP地址、源MAC地址、目的MAC地址、源端口号、目的端口号、…...
webrtc快速入门——使用 WebRTC 拍摄静止的照片
文章目录 使用 getUserMedia() 拍摄静态照片HTML 标记JavaScript 代码初始化startup() 函数获取元素引用获取流媒体 监听视频开始播放处理按钮上的点击包装 startup() 方法 清理照片框从流中捕获帧 例子代码HTML代码CSS代码JavaScript代码 过滤器使用特定设备 使用 getUserMedi…...

预约按摩app软件开发定制足浴SPA上们服务小程序
同城按摩小程序是一种基于地理位置服务的小程序,它可以帮助用户快速找到附近的按摩师,并提供在线预约、评价、支付等功能。用户可以通过手机或者其他移动设备访问同城按摩小程序,实现足不出户就能预约到专业的按摩服务。 一、同城按摩小程序的…...

jenkins出错与恢复
如果你的jenkins出现了如下图所示问题(比如不能下载插件,无法保存任务等),这个时候就需要重新安装了。 一、卸载干净jenknis 要彻底卸载 Jenkins,您可以按照以下步骤进行操作: 1、停止 Jenkins 服务&…...
ssh免密登录的原理RSA非对称加密的理解
RSA非对称加密,是采用公钥加密私钥解密的原则。 举个例子SSH的免密登录 SSH免密登录是通过使用公钥加密技术实现的。以下是SSH免密登录的原理: 1. 生成密钥对:首先,在客户端上生成一对密钥,包括一个私钥和一个公钥。私…...

【监督学习】基于合取子句进化算法(CCEA)和析取范式进化算法(DNFEA)解决分类问题(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
力扣每日一题41:缺失的第一个正数
题目描述: 给你一个未排序的整数数组 nums ,请你找出其中没有出现的最小的正整数。 请你实现时间复杂度为 O(n) 并且只使用常数级别额外空间的解决方案。 示例 1: 输入:nums [1,2,0] 输出:3示例 2: 输…...
OpenCV与mediapipe实践
1. 安装前准备 开发环境:vscode venv 设置vscode, 建立项目,如: t1/src, 用vscode打开,新建终端Terminal,这时可能会有错误产生,解决办法: 运行命令:Set-ExecutionPolicy -ExecutionPolicy …...

【css拾遗】粘性布局实现有滚动条的情况下,按钮固定在页面底部展示
效果: 滚动条滚动过程中,按钮的位置位于手机的底部 滚动条滚到底部时,按钮的位置正常 这个position:sticky真的好用,我原先的想法是利用滚动条滚动事件去控制,没想到css就可以解决 <template><view class…...

git 创建并配置 GitHub 连接密钥
前记: git svn sourcetree gitee github gitlab gitblit gitbucket gitolite gogs 版本控制 | 仓库管理 ---- 系列工程笔记. Platform:Windows 10 Git version:git version 2.32.0.windows.1 Function: git 创建并配置 GitHub…...

使用Premiere、PhotoShop和Audition做视频特效
今天接到一个做视频的任务,给一个精忠报国的视频,要求: ①去掉人声,就是将唱歌的人声去掉,只留下伴奏; ②截图视频中的横幅,做一个展开的效果,类似卷纸慢慢展开;…...

vueday01——动态参数
我们现在知道了 v-bind:的语法糖是: v-on:的语法糖是 我们现在来尝试一下,定义一个动态参数模拟点击事件按钮 <div :id"idValue" ref"myDiv">我是待测div{{ resultId }}</div> <button v-on:[eventName]"doSomething&…...
双向链表C语言版本
1、声明链表节点操作函数 linklist.h #ifndef LINKLIST_H__ #define LINKLIST_H__ #include <stdio.h> #include <stdlib.h> #include <stdbool.h>//#define TAIL_ADD #define HEAD_ADD typedef int LinkDataType; // 构造节点 struct LinkNode {LinkDataTy…...

visual studio安装时候修改共享组件、工具和SDK路径方法
安装了VsStudio后,如果自己修改了Shared路径,当卸载旧版本,需要安装新版本时发现,之前的Shared路径无法进行修改,这就很坑爹了,因为我运行flutter程序的时候,报错找不到windows sdk的位置,所以我…...

Motorola IPMC761 使用边缘TPU加速神经网络
Motorola IPMC761 使用边缘TPU加速神经网络 人工智能(AI)和机器学习(ML)正在塑造和推进复杂的自动化技术解决方案。将这些功能集成到硬件中,解决方案可以识别图像中的对象,分析和检测模式中的异常或找到关键短语。这些功能对于包括但不限于自动驾驶汽车…...

【WiFi帧结构】
文章目录 帧结构MAC头部管理帧 帧结构 Wi-Fi的帧分为三部分组成:MAC头部frame bodyFCS,其中MAC是固定格式的,frame body是可变长度。 MAC头部有frame control,duration,address1,address2,addre…...

ElasticSearch搜索引擎之倒排索引及其底层算法
文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...

GitFlow 工作模式(详解)
今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...

什么是VR全景技术
VR全景技术,全称为虚拟现实全景技术,是通过计算机图像模拟生成三维空间中的虚拟世界,使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验,结合图文、3D、音视频等多媒体元素…...

消防一体化安全管控平台:构建消防“一张图”和APP统一管理
在城市的某个角落,一场突如其来的火灾打破了平静。熊熊烈火迅速蔓延,滚滚浓烟弥漫开来,周围群众的生命财产安全受到严重威胁。就在这千钧一发之际,消防救援队伍迅速行动,而豪越科技消防一体化安全管控平台构建的消防“…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案
引言 在分布式系统的事务处理中,如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议(2PC)通过准备阶段与提交阶段的协调机制,以同步决策模式确保事务原子性。其改进版本三阶段提交协议(3PC…...
Spring Boot + MyBatis 集成支付宝支付流程
Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例(电脑网站支付) 1. 添加依赖 <!…...
Java 与 MySQL 性能优化:MySQL 慢 SQL 诊断与分析方法详解
文章目录 一、开启慢查询日志,定位耗时SQL1.1 查看慢查询日志是否开启1.2 临时开启慢查询日志1.3 永久开启慢查询日志1.4 分析慢查询日志 二、使用EXPLAIN分析SQL执行计划2.1 EXPLAIN的基本使用2.2 EXPLAIN分析案例2.3 根据EXPLAIN结果优化SQL 三、使用SHOW PROFILE…...