18.多分类问题代码实现
在机器学习中,多分类问题是一类常见的问题,它涉及到将输入数据划分为多个类别中的一个。例如,在图像识别中,我们可能需要将图像分为不同的类别,如手写数字识别(MNIST数据集)就是将手写数字图像分类为0-9的十个数字。本文将介绍如何使用PyTorch框架来构建一个简单的神经网络模型来解决多分类问题,并以MNIST数据集为例进行说明。
数据集
MNIST是一个包含手写数字图像的大型数据集,由NIST(美国国家标准与技术研究院)发起整理,包含了60,000个训练样本和10,000个测试样本。每个样本都是一张28x28像素的灰度图像,表示一个0-9之间的手写数字。
构建神经网络模型
首先,我们需要导入必要的库,并定义神经网络模型。这里我们将使用一个简单的全连接神经网络,包含两个隐藏层和一个输出层。
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # 定义神经网络模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 500) self.fc2 = nn.Linear(500, 100) self.fc3 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return torch.log_softmax(x, dim=1) # 实例化模型 model = Net()
数据加载和预处理
接下来,我们需要加载MNIST数据集,并进行必要的预处理。这里我们使用torchvision.datasets.MNIST来加载数据集,并使用torch.utils.data.DataLoader来加载数据。
# 数据预处理:转换为Tensor并归一化
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))
]) # 加载训练集和测试集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
定义损失函数和优化器
对于多分类问题,我们通常使用交叉熵损失函数(CrossEntropyLoss)。在PyTorch中,nn.CrossEntropyLoss结合了LogSoftmax和NLLLoss,所以我们不需要在模型输出时显式使用LogSoftmax。
对于优化器,我们选择随机梯度下降(SGD)。
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
训练模型
现在我们可以开始训练模型了。在每个训练周期(epoch)中,我们将遍历整个训练集,计算损失,反向传播梯度,并更新模型参数。
# 训练模型
num_epochs = 10
for epoch in range(num_epochs): for i, (images, labels) in enumerate(trainloader, 0): # 清零梯度缓存 optimizer.zero_grad() # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播和优化 loss.backward() optimizer.step() if (i+1) % 1000 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], Loss: {loss.item()}') print('Finished Training')
评估模型
训练完成后,我们可以使用测试集来评估模型的性能。这里我们计算了模型在测试集上的准确率。
# 评估模型
correct = 0
total = 0
with torch.no_grad(): # 不需要计算梯度,节省内存和计算资源 for images, labels in testloader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) # 获取预测结果中概率最大的类别索引 total += labels.size(0) # 总样本数 correct += (predicted == labels).sum().item() # 正确预测的样本数 print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')
总结
本文介绍了如何使用PyTorch框架来构建和训练一个用于多分类问题的神经网络模型。我们以MNIST手写数字数据集为例,展示了从数据加载和预处理、模型定义、损失函数和优化器选择,到模型训练和评估的整个流程。
在实际应用中,我们可以根据具体的问题和数据集来调整模型的结构和参数,以获得更好的性能。此外,还可以使用更高级的技术和策略来优化模型的训练和评估过程,例如数据增强、正则化、学习率调整等。
通过本文的介绍,读者应该能够掌握使用PyTorch进行多分类问题建模的基本流程和关键技术,为后续的深度学习项目打下坚实的基础。
相关文章:
18.多分类问题代码实现
在机器学习中,多分类问题是一类常见的问题,它涉及到将输入数据划分为多个类别中的一个。例如,在图像识别中,我们可能需要将图像分为不同的类别,如手写数字识别(MNIST数据集)就是将手写数字图像分…...
实时通信的方式——WebRTC
文章目录 基于WebRTC实现音视频通话P2P通信原理如何发现对方? 不同的音视频编解码能力如何沟通?(媒体协商SDP)如何联系上对方?(网络协商) 常用的API音视频采集getUserMedia核心对象RTCPeerConne…...
Android 使用 ActivityResultLauncher 申请权限
前面介绍了 Android 运行时权限。 其中,申请权限的步骤有些繁琐,需要用到:ActivityCompat.requestPermissions 函数和 onRequestPermissionsResult 回调函数,今天就借助 ActivityResultLauncher 来简化书写。 步骤1:创…...
如何将前端项目打包并部署到不同服务器环境
学习源码可以看我的个人前端学习笔记 (github.com):qdxzw/frontlearningNotes 觉得有帮助的同学,可以点心心支持一下哈(笔记是根据b站尚硅谷的前端讲师【张天禹老师】整理的,用于自己复盘,有需要学习的可以去b站学习原版视频&…...
什么样的展馆场馆才是科技满满?就差一张智慧场馆大屏
随着科技的飞速发展,传统的场馆展示方式已经无法满足现代人对信息获取和体验的需求。智慧场馆大屏作为一种新型的展示方式,应运而生。它将高清大屏显示技术、智能交互技术、数据分析技术等融为一体,为观众带来更加丰富、生动的展示体验。 一…...
python核心编程(二)
python面向对象 一、基本理论二、 面向对象在python中实践2.1 如何去定义类2.2 通过类创建对象2.3 属性相关2.4 方法相关 三、python对象的生命周期,以及周期方法3.1 概念3.2 监听对象的生命周期 四、面向对象的三大特性4.1 封装4.2 继承4.2.1 概念4.2.1 目的4.2.2 分类4.2.3 t…...
【wiki知识库】02.wiki知识库SpringBoot后端的准备
📝个人主页:哈__ 期待您的关注 目录 一、🔥今日目标 二、📂打开SpringBoot项目 2.1 导入所需依赖 2.2修改application.yml配置文件 2.3导入MybatisPlus逆向工程工具 2.4创建一个公用的返回值 2.5创建CopyUtil工具类 2.6创建…...
python tuple(元组)
python list(列表)、创建、访问、内置index、判断in、not in、添加元素、insert、append、extend、列表排序、颠倒顺序、删除元素、remove、pop、clear-CSDN博客 目录 tuple: 元组的主要特点包括: tuple的创建 单个元组需要注…...
opencv调用摄像头保存视频
opencv调用摄像头保存视频 文章目录 opencv调用摄像头保存视频保存视频(采用默认分辨率640 x 480)保存视频(指定分辨率,例1280720) 保存视频(采用默认分辨率640 x 480) import cv2 import time # 定义视频捕捉对象 cap cv2.Vide…...
STM32定时器四大功能之定时器编码接口
1什么是编码器接口? 编码器接口接受编码器的正交信号,根据编码器产生的正交信号脉冲控制CNT的自增和自减,从而指示编码器的旋转方向和旋转速度。 每个高级定时器和通用定时器都有一个编码器接口,同时正交编码器产生的正交信号分…...
全国各城市间驾车耗时和距离矩阵数据集(更新至2022年)
数据简介:城市之间距离越远,耗时越长。经济发达地区的交通状况较好。各城市之间的驾车耗时和距离存在差异。有些城市之间的交通非常便捷,而有些城市之间的交通则较为不便。这表明中国的交通网络发展尚不平衡,需进一步优化。特别是…...
推荐二轮电动车仪表盘蓝牙主芯片方案-HS6621CGC
随着国内二轮电动车的火热开启,电动车的智能化程度越来越高;电动车的智能操控需求也越来越高,现在介绍蓝牙控制面板的一些功能;例如:定位(GNSS),设防,实时上报数据&#…...
『香橙派』基于Orange Pi AIpro打造高效个人云存储解决方案
读完这篇文章里你能收获到 了解Orange Pi AIpro硬件优势,为构建高效云存储基础设施的理想平台。学会使用Orange Pi AIpro硬件平台,搭载Ubuntu Server系统,打造云存储环境。掌握利用Kodbox软件,享受文件管理、多格式预览及编辑的全…...
Sylvester矩阵、子结式、辗转相除法的三者关系(第二部分)
【三者的关系】 首先,辗转相除法可以通过Sylvester矩阵进行,过程如下(以 m 8 、 l 7 m 8、l 7 m8、l7为例子)。 首先调整矩阵中 a a a系数到最后面几行,如下所示: S ( a 8 a 7 a 6 a 5 a 4 a 3 a 2 …...
PyTorch的数据处理
💥今天看一下 PyTorch数据通常的处理方法~ 一般我们会将dataset用来封装自己的数据集,dataloader用于读取数据 Dataset格式说明 💬dataset定义了这个数据集的总长度,以及会返回哪些参数,模板: from tor…...
第14章-蓝牙遥控小车 手把手做蓝牙APP遥控小车 蓝牙串口通讯讲解
本文讲解手机蓝牙如何遥控小车,如何编写串口通信指令 第14章-手机遥控功能 我们要实现蓝牙遥控功能,蓝牙遥控功能要使用:1.单片机的串口、2.蓝牙通信模块 所以我们先调试好:单片机的串口->蓝牙模块->接到一起联调 14.1-电脑控制小车 完成功能…...
【补充1】字节对齐
文章目录 1.字节对齐的基本概念2.字节对齐规则3.实践出真知(加大难度)4 位域 1.字节对齐的基本概念 (1)现代计算机中内存空间都是按照byte划分的, 从理论上讲似乎对任何类型的变量的访问可以从任何地址开始࿰…...
Java数据库连接(JDBC)
一、引言 在Java应用程序中,经常需要与数据库进行交互以存储、检索和处理数据。Java数据库连接(JDBC)是Java平台中用于执行这一任务的标准API。JDBC允许Java程序连接到关系数据库,并使用SQL语句来执行查询和更新操作。本教程将详…...
记录一次cas单点登录的集成
主要思路:浏览器访问CAS服务器登录,拿到凭证给后端,后端用此凭证到CAS服务器验证登录并拿到用户信息,之后基于该凭证维持用户的登录状态。 主要流程: 1.浏览器访问后端需认证登录地址(不带ticket…...
【吊打面试官系列】Java高并发篇 - 什么是乐观锁和悲观锁?
大家好,我是锋哥。今天分享关于 【什么是乐观锁和悲观锁?】面试题,希望对大家有帮助; 什么是乐观锁和悲观锁? 1、乐观锁: 就像它的名字一样,对于并发间操作产生的线程安全问题持乐观状态, 乐观锁认为竞争…...
Win10下ENSP USG6000镜像加载卡在###?别慌,VirtualBox网卡桥接这个设置是关键
Win10下ENSP USG6000镜像加载卡在###的终极解决方案 当你满怀期待地在Windows 10上启动ENSP模拟器,拖入USG6000防火墙设备,却只看到一串无情的 ### 符号时,那种挫败感我深有体会。作为一名曾经被这个问题折磨数小时的网络工程师,…...
量子计算模拟Hubbard模型:算法实现与噪声分析
1. Hubbard模型与量子计算模拟概述在凝聚态物理研究中,Hubbard模型堪称是研究强关联电子系统的"果蝇模型"。这个看似简单的理论框架却能展现出从金属-绝缘体相变到高温超导等丰富物理现象。模型的核心哈密顿量包含两项关键竞争:H -t∑⟨i,j⟩…...
别急着重启!深入理解Ubuntu 22.04的needrestart:守护进程、库文件与系统更新背后的原理
别急着重启!深入理解Ubuntu 22.04的needrestart:守护进程、库文件与系统更新背后的原理在Ubuntu 22.04 LTS的系统维护中,许多管理员都曾遇到过这样的场景:执行apt upgrade后,终端突然弹出"Daemons using outdated…...
告别黑窗口!保姆级教程:在Win11上用Xming给WSL2装个轻量级桌面(XFCE4)
告别黑窗口!Win11 WSL2轻量级桌面配置全指南 对于习惯Windows图形界面的开发者来说,初次接触WSL的黑窗口命令行界面总有些不适。本文将手把手教你如何用Xming和XFCE4为WSL2打造一个轻量级Linux桌面环境,无需虚拟机就能运行GIMP、VSCode等图形…...
【AI Agent旅游行业落地实战指南】:2024年已验证的7大高ROI应用场景与避坑清单
更多请点击: https://kaifayun.com 第一章:AI Agent旅游行业应用全景图 AI Agent正以前所未有的深度与广度重塑旅游产业的服务范式。它不再局限于单点智能响应,而是以目标驱动、多工具协同、自主规划与持续反思为特征,构建起覆盖…...
反事实推理:用因果视角评估与缓解AI模型偏见
1. 项目概述:当模型决策需要“如果当初”在机器学习的世界里,我们常常面临一个困境:模型预测准确率很高,但我们却不知道它为什么做出这样的决策。更棘手的是,我们越来越频繁地发现,这些“黑箱”决策背后&am…...
从微服务到 Agent 服务:架构思维的迁移
从微服务到 Agent 服务:架构思维的迁移与落地全指南 第一部分:引言与基础 (Introduction & Foundation) 1. 引人注目的标题 (Compelling Title) 副标题:深入解析微服务痛点、Agent服务原理、架构设计迁移路径与企业级生产实践 2. 摘要/引言 (Abstract / Introduction)…...
手把手教你学 Simulink-- 开关磁阻电机(SRM)的转矩分配函数(TSF)控制仿真
目录 手把手教你学 Simulink-- 开关磁阻电机(SRM)的转矩分配函数(TSF)控制仿真 🔥 前言:为什么选 SRM+TSF? 一、SRM 基础:12/8 极结构与数学模型 1.1 电压方程(第 k 相) 1.2 转矩方程(强非线性) 二、TSF 核心原理:一句话讲透 2.1 四种常用 TSF 公式(含参数…...
3分钟解决网易云音乐格式限制:免费NCM转换工具完全指南
3分钟解决网易云音乐格式限制:免费NCM转换工具完全指南 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾因网易云音乐下载的NCM格式文件无法在车载音响或普通播放器中播放而烦恼?今天,我将…...
JDK常用类与工具(速览版)
JDK常用类与工具(速览版)JDK(Java Development Kit)提供了丰富的标准库和实用工具,它们构成了Java开发者日常工作的基石。掌握这些核心类、集合框架、并发工具、IO/NIO库、日期时间API、正则表达式、异常处理机制、日志…...
