当前位置: 首页 > news >正文

基于PyTorch的MNIST手写体分类实战

第2章对MNIST数据做了介绍,描述了其构成方式及其数据的特征和标签的含义等。了解这些有助于编写合适的程序来对MNIST数据集进行分析和识别。本节将使用同样的数据集完成对其进行分类的任务。

3.1.1  数据图像的获取与标签的说明

MNIST数据集的详细介绍在第2章中已经完成,读者可以使用相同的代码对数据进行获取,代码如下:

import numpy as np
x_train = np.load("./dataset/mnist/x_train.npy")
y_train_label = np.load("./dataset/mnist/y_train_label.npy")

基本数据的获取与第2章类似,这里就不过多阐述了,不过需要注意的是,在第2章介绍数据集时只使用了图像数据,没有对标签进行说明,在这里重点对数据标签,也就是y_train_label进行介绍。

我们可以使用下面语句打印出数据集的前10个标签:

print(y_train_label[:10])

结果如下:

import numpy as np
import torch
x_train = np.load("./dataset/mnist/x_train.npy")
y_train_label = np.load("./dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
ptint(y) 

结果如下:

tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])

可以看到,one_hot的作用是将一个序列转换成以one_hot形式表示的数据集。所有的行或者列都被设置成0,而每个特定的位置都对应一个1来表示,如图3-1所示。

图3-1  one_hot形式表示的数据集

对于MNIST数据集的标签来说,这实际上就是一个60 000幅图片的60 000×10大小的矩阵张量[60 000,10]。前面的数指的是数据集中图片的个数为60 000个,后面的10指的是10个列向量。

下面使用PyTorch 2.0框架完成手写体的识别。

3.1.2  模型的准备(多层感知机)

在第2章已经讲过了,PyTorch最重要的一项内容是模型的准备与设计,而模型的设计最关键的一点就是了解输出和输入的数据结构类型。

通过第2章有关图像去噪的演示,读者已经了解了我们的输入数据格式是一个[28,28]大小的二维图像。而通过对数据结构的分析,我们可以知道,对于每个图形都有一个确定的分类结果,也就是0~10的一个确定数字。

下面将按这个想法来设计模型。从前面对图像的分析来看,对整体图形进行判别的一个基本想法就是将图像作为一个整体直观地进行判别,因此基于这种解决问题的思路,简单的模型设计就是同时对图像所有参数进行计算,即使用一个多层感知机(Multi-Layer Perceptron,MLP)对图像进行分类。整体的模型设计结构如图3-2所示。

图3-2  整体的模型设计结构

从图3-2可以看到,一个多层感知机模型就是将数据输入后,分散到每个模型的节点(隐藏层),进行数据计算后,再将计算结果输出到对应的输出层中。多层感知机的模型结构如下:

class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28,312),nn.ReLU(),nn.Linear(312, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logits

3.1.3  损失函数的表示与计算

第2章使用了MSELoss作为目标图形与预测图形的损失值,而在本例中,我们需要预测的目标是图形的“分类”,而不是图形表示本身,因此我们需要寻找并使用一种新的能够对类别归属进行“计算”的函数。

本例所使用的交叉熵损失函数为torch.nn.CrossEntropyLoss。PyTorch官方网站对其介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction='mean', label_smoothing=0.0)

该损失函数计算输入值(Input)和目标值(Target)之间的交叉熵损失。交叉熵损失函数CrossEntropyLoss可用于训练单类别或者多类别的分类问题。给定参数weight时,会为传递进来的每个类别的计算数值重新加载一个修正权重。当数据集分布不均衡时,这是很有用的。

同样需要注意的是,因为torch.nn.CrossEntropyLoss内置了Softmax运算,而Softmax的作用是计算分类结果中最大的那个类。从图3-3所示的对PyTorch 2.0中CrossEntropyLoss的实现可以看到,此时CrossEntropyLoss已经在计算的同时实现了Softmax计算,因此在使用torch.nn.CrossEntropyLoss作为损失函数时,不需要在网络的最后添加Softmax层。此外,label应为一个整数,而不是One-Hot编码形式。

图3-3  使用torch.nn.CrossEntropyLoss()作为损失函数

CrossEntropyLoss示例代码如下:

import torch
y = torch.LongTensor([0])
z = torch.Tensor([[0.2,0.1,-0.1]])
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z,y)
print(loss)	

CrossEntropyLoss的数学公式较为复杂,建议学有余力的读者查阅相关内容进行学习,目前只需要掌握这方面内容即可。

3.1.4  基于PyTorch的手写体识别的实现

下面介绍基于PyTorch的手写体识别的实现。通过前文的介绍,我们还需要定义深度学习的优化器部分,在这里采用Adam优化器,相关代码如下:

model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

在这个实战案例中首先需要定义模型,之后将模型参数传入优化器中,lr是对学习率的设定,根据设定的学习率进行模型计算。完整的手写体识别模型如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
import torch
import numpy as np
from tqdm import tqdm
batch_size =320#设定每次训练的批次数
epochs=1024   	#设定训练次数
#device="cpu"	#PyTorch的特性,需要指定计算的硬件,如果没有GPU,就使用CPU进行计算
device="cuda"	#在这里默认使用GPU,如果读者运行出现问题,可以将其改成CPU模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                	#将计算模型传入GPU硬件等待计算
model = torch.compile(model)            	#PyTorch 2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)  	#设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")
train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

此时模型的训练结果如图3-4所示。

图3-4  模型的训练结果

可以看到随着模型循环次数的增加,模型的损失值在降低,而准确率在增高,具体请读者自行验证测试。

本文节选自《PyTorch 2.0深度学习从零开始学》。

相关文章:

基于PyTorch的MNIST手写体分类实战

第2章对MNIST数据做了介绍,描述了其构成方式及其数据的特征和标签的含义等。了解这些有助于编写合适的程序来对MNIST数据集进行分析和识别。本节将使用同样的数据集完成对其进行分类的任务。 3.1.1 数据图像的获取与标签的说明 MNIST数据集的详细介绍在第2章中已…...

conda 复制系统环境

直接复制 想要通过 conda 直接复制一个已存在的环境,你可以使用 conda create 命令并配合 --clone 参数。以下是具体步骤: 查看现有的环境: 首先,你可以使用以下命令来查看所有的 conda 环境: conda env list这会给你一个环境列表…...

如何在Microsoft Visual Studio 中使用Cpp代码调用python代码

Microsoft Visual Studio中Cpp调用Python代码 本文介绍如何在Microsoft Visual Studio中,开发cpp项目时,调用python代码。 文章目录 Microsoft Visual Studio中Cpp调用Python代码前言一、Cpp生成exe文件1.1 安装python环境1.2 配置Microsoft Visual Stu…...

DAY35 435. 无重叠区间 + 763.划分字母区间 + 56. 合并区间

435. 无重叠区间 题目要求:给定一个区间的集合,找到需要移除区间的最小数量,使剩余区间互不重叠。 注意: 可以认为区间的终点总是大于它的起点。 区间 [1,2] 和 [2,3] 的边界相互“接触”,但没有相互重叠。 示例 1: 输入: [ […...

代码随想录算法训练营第2天| 977有序数组的平方、209长度最小的子数组。

JAVA代码编写 977. 有序数组的平方 给你一个按 非递减顺序 排序的整数数组 nums,返回 每个数字的平方 组成的新数组,要求也按 非递减顺序 排序。 示例 1: 输入:nums [-4,-1,0,3,10] 输出:[0,1,9,16,100] 解释&…...

微信小程序通过startLocationUpdate,onLocationChange获取当前地理位置信息,配合腾讯地图解析获取到地址

先创建个getLocation.js文件 //获取用户当前所在的位置 const getLocation () > {return new Promise((resolve, reject) > {let _locationChangeFn (res) > {resolve(res) // 回传地里位置信息wx.offLocationChange(_locationChangeFn) // 关闭实时定位wx.stopLoc…...

C/C++字符三角形 2020年12月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C字符三角形 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C字符三角形 2020年12月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 给定一个字符,用它构造一个底边长5个字…...

Python数据挖掘:入门、进阶与实用案例分析——基于非侵入式负荷检测与分解的电力数据挖掘

文章目录 摘要01 案例背景02 分析目标03 分析过程04 数据准备05 属性构造06 模型训练07 性能度量08 推荐阅读赠书活动 摘要 本案例将根据已收集到的电力数据,深度挖掘各电力设备的电流、电压和功率等情况,分析各电力设备的实际用电量,进而为电…...

基于 Qt控制开发板 LED和C语言控制LED渐变亮度效果

## 资源简介 在STM32开发板,板载资源上有两个可自由控制的 LED。如下图原理 图其中我们以操作 LED1 为示例,LED1 为出厂系统的心跳指示灯。 ## 应用实例 想要控制这个 LED,首先出厂内核已经默认将这个 LED 注册成了 gpio-leds类型设备。所以我们可以直接在应用层接口直接…...

Android 11.0 禁用插入耳机时弹出的保护听力对话框

1.前言 在11.0的系统开发中,在某些产品中会对耳机音量调节过高限制,在调高到最大音量的70%的时候,会弹出音量过高弹出警告,所以产品 开发的需要要求去掉这个音量弹窗警告功能 2.禁用插入耳机时弹出的保护听力对话框的核心类 frameworks\base\packages\SystemUI\src\com\and…...

微信小程序案例2-3:婚礼邀请函

文章目录 一、运行效果二、知识储备(一)导航栏配置(二)标签栏配置(三)vw、vh单位(四)video组件(五)表单组件(六)Node.js概述 三、实现…...

K8S部署Dashboard

获取recommended.yaml文件 Dashboard是官方提供的一个UI,可用于基本管理K8s资源。 YAML下载地址: wget https://raw.githubusercontent.com/kubernetes/dashboard/v2.4.0/aio/deploy/recommended.yaml如果网络错误无法直接下载,可以直接访问…...

【OJ比赛日历】快周末了,不来一场比赛吗? #10.29-11.04 #7场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号会推送最新的比赛消息,欢迎关注! 以下信息仅供参考,以比赛官网为准 目录 2023-10-29(周日) #3场比赛2023-10-30…...

常用应用安装教程---在centos7系统上安装Docker

在centos7系统上安装Docker 1:切换镜像源 wget https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo -O /etc/yum.repos.d/docker-ce.repo2:查看当前镜像源中支持的docker版本 yum list docker-ce --showduplicates | sort -r3&#x…...

CTFHub-SSRF-读取伪协议

WEB攻防-SSRF服务端请求&Gopher伪协议&无回显利用&黑白盒挖掘&业务功能点-CSDN博客 伪协议有: file:/// — 访问本地文件系统 http:/// — 访问 HTTP(s) 网址 ftp:/// — 访问 FTP(s) URLs php:/// — 访问各个输入/输出流(I/O streams) dic…...

推荐一款适合科技行业的CRM系统

推荐您一款科技行业好用的CRM系统——Zoho CRM客户管理系统,旨在帮助企业管理客户数据、销售过程、营销活动以及服务支持,助力业务增长及数字化转型,实现“以客户为中心”的企业管理和运营模式。 近些年,随着政府鼓励政策的出台、…...

ChatGPT 与 Python Echarts 完成热力图实例

热力图是一种数据可视化方式,它通过颜色的变化来表示数据的差异和分布。以下是使用热力图的一些作用和好处: 数据可视化:热力图可以将复杂的数据集转化为更直观、更易理解的形式。这对于很多人来说,尤其是那些没有深入统计学或数…...

vue3项目报错The template root requires exactly one element.eslint-plugin-vue

解决方案: 1.禁用 Vetur 并改用Volar》它现在是 Vue 3 项目的官方推荐。【必须重启vsCode】 从官方迁移指南: 建议使用带有我们官方扩展 Volar (opens new window) 的 VSCode,它为 Vue 3 提供了全面的 IDE 支持。 2.package.json文件中 &…...

【C++系列】STL容器——vector类的例题应用(12)

前言 大家好吖,欢迎来到 YY 滴C系列 ,热烈欢迎!本章主要内容面向接触过C的老铁,下面是收纳的一些例题与解析~ 主要内容含: 目录 【例1] 只出现一次的数字i(范围for与模等(^))【例2]…...

常用应用安装教程---在centos7系统上安装JDK8

在centos7系统上安装JDK8 1:进入oracle官网下载jdk8的tar.gz包: 2:将下载好的包上传到每个服务器上: 3:查看是否上传成功: [rootkafka01 ~]# ls anaconda-ks.cfg jdk-8u333-linux-x64.tar.gz4&#xf…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

【Oracle APEX开发小技巧12】

有如下需求&#xff1a; 有一个问题反馈页面&#xff0c;要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据&#xff0c;方便管理员及时处理反馈。 我的方法&#xff1a;直接将逻辑写在SQL中&#xff0c;这样可以直接在页面展示 完整代码&#xff1a; SELECTSF.FE…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

测试markdown--肇兴

day1&#xff1a; 1、去程&#xff1a;7:04 --11:32高铁 高铁右转上售票大厅2楼&#xff0c;穿过候车厅下一楼&#xff0c;上大巴车 &#xffe5;10/人 **2、到达&#xff1a;**12点多到达寨子&#xff0c;买门票&#xff0c;美团/抖音&#xff1a;&#xffe5;78人 3、中饭&a…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

【配置 YOLOX 用于按目录分类的图片数据集】

现在的图标点选越来越多&#xff0c;如何一步解决&#xff0c;采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集&#xff08;每个目录代表一个类别&#xff0c;目录下是该类别的所有图片&#xff09;&#xff0c;你需要进行以下配置步骤&#x…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问&#xff1a;说说对 IP 数据报中 TTL 的理解&#xff1f;我们都知道&#xff0c;IP 数据报由首部和数据两部分组成&#xff0c;首部又分为两部分&#xff1a;固定部分和可变部分&#xff0c;共占 20 字节&#xff0c;而即将讨论的 TTL 就位于首…...