当前位置: 首页 > news >正文

神经网络-MNIST数据集训练

文章目录

  • 一、MNIST数据集
    • 1.数据集概述
    • 2.数据集组成
    • 3.文件结构
    • 4.数据特点
  • 二、代码实现
    • 1.数据加载与预处理
    • 2. 模型定义
    • 3. 训练和测试函数
    • 4.训练和测试结果
  • 三、总结

一、MNIST数据集

MNIST数据集是深度学习和计算机视觉领域非常经典且基础的数据集,它包含了大量的手写数字图片,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。

1.数据集概述

  • 来源:MNIST数据集由Yann LeCun等人于1994年创建,它是NIST(美国国家标准与技术研究所)数据集的一个子集。
  • 内容:数据集主要包含手写数字(0~9)的图片及其对应的标签。
  • 用途:作为深度学习和计算机视觉领域的入门级数据集,它适合初学者练习建立模型、训练和预测。

2.数据集组成

MNIST数据集总共包含两个子数据集:训练数据集和测试数据集。

训练数据集:

  • 包含了60,000张28x28像素的灰度图像。
  • 对应的标签文件包含了60,000个标签,每个标签对应一张图像中的手写数字。

测试数据集:

  • 包含了10,000张28x28像素的灰度图像。
  • 对应的标签文件包含了10,000个标签。

3.文件结构

MNIST数据集包含四个文件,分别是训练集图像、训练集标签、测试集图像和测试集标签。这些文件以gzip格式压缩,并且不是标准的图像格式,需要通过专门的编程方式读取。

  • 训练集图像:train-images-idx3-ubyte.gz
  • 训练集标签:train-labels-idx1-ubyte.gz)
  • 测试集图像:t10k-images-idx3-ubyte.gz
  • 测试集标签:t10k-labels-idx1-ubyte.gz

4.数据特点

  • 图像大小:每张图像的大小为28x28像素,是一个灰度图像,位深度为8(灰度值范围为0~255)。
  • 数据来源:手写数字来自250个不同的人。
  • 数据格式:图像数据以字节的形式存储在二进制文件中,标签文件则存储了每张图像对应的数字标签。

二、代码实现

1.数据加载与预处理

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor"""下载训练集数据(包含训练图片和标签)"""
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),  # 张量,图片是不能直接传入神经网络模型
)"""下载测试集数据(包括训练图片和标签)"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
  • 下载数据集:使用torchvision.datasets.MNIST下载并加载MNIST数据集。数据集分为训练集和测试集,train=True为训练集数据,train=False为测试集数据。
  • 数据转换:数据通过transform=ToTensor()进行预处理,将图片转换为PyTorch张量(Tensor),并自动将像素值归一化到[0,1]区间。
  • 数据封装:使用DataLoader将数据集封装成批次(batch)形式,便于后续的训练和测试过程。

2. 模型定义

class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型,nn.moduledef __init__(self):  # python基础关于类,self类自己本身super().__init__()  # 继承的父类初始化self.flatten = nn.Flatten()  # 展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28 * 28, 128)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去前一层神经元的个数,当前本层神经元个数self.hidden2 = nn.Linear(128, 256)self.hidden3 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)def forward(self, x):  # 前向传播,告诉它,数据的流向。x = self.flatten(x)  # 图像进行展开x = self.hidden1(x)x = torch.sigmoid(x) x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.out(x)return xmodel = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到gpu
print(model)

定义类:定义了一个名为NeuralNetwork的类,该类继承自nn.Module,用于构建神经网络模型。
模型结构:模型包含输入层,输出层,隐藏层,其中隐藏层使用了Sigmoid激活函数,最后输出10个类别的得分(对应0-9的数字)
打印模型结构:打印了模型的结构,有助于理解模型的架构。
在这里插入图片描述

3. 训练和测试函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:  # 其中batch为每一个数据的编号X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPUpred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值loss# Backpropaqation 进来-个bqtch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 100 == 0:print(f"loss:{loss_value:>7f}  [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)  # 把预测值Ture、False 转换为01test_loss /= num_batches  # 评判模型的好坏correct /= size  # 平均的准确率print(f"Test result:\n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}")
  • train函数负责训练模型。它遍历训练数据集的每个批次,计算模型的预测、损失,并执行反向传播和参数更新。
  • test函数用于评估模型在测试集上的性能。它遍历测试数据集的每个批次,计算模型的预测和损失,但不进行反向传播或参数更新。
  • 在训练和测试过程中,都使用了torch.no_grad()上下文管理器来关闭梯度计算,这可以节省内存和计算资源。

4.训练和测试结果

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 创建一个优化器,S6D为随机梯度下降算法epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
  • 使用torch.optim.Adam优化器来优化模型的参数,这里的学习率设置为0.01。
  • 定义了训练轮次(epochs),并在每个epoch中调用train函数来训练模型。
  • 最后,使用test函数来评估模型在测试集上的性能,并打印出准确率和平均损失。
    在这里插入图片描述

三、总结

本文为大家介绍了MNIST数据集的组成、文件结构与数据集特点,然后为大家提供了MNIST数据集训练的相关代码,通过对数据集进行处理,训练来得出准确率与损失率,为大家更好的展示。总之,MNIST数据集是深度学习和计算机视觉领域不可或缺的基础数据集之一,对于初学者来说是一个非常好的练手项目,同时也为相关领域的研究和实验提供了宝贵的数据资源。

相关文章:

神经网络-MNIST数据集训练

文章目录 一、MNIST数据集1.数据集概述2.数据集组成3.文件结构4.数据特点 二、代码实现1.数据加载与预处理2. 模型定义3. 训练和测试函数4.训练和测试结果 三、总结 一、MNIST数据集 MNIST数据集是深度学习和计算机视觉领域非常经典且基础的数据集,它包含了大量的手…...

数据结构二

求 sizeof(name1)?(晟安信息) struct name1{ char str; short x; int num; }; sizeof name1内存对齐 8个字节 char分配8个字节 然后 short节省空间在4个字节中 而这个int独自分配分配内存 4个字节所以共8个字节 (电工时代) typedef struct _a { char c1; long i…...

Python|基于Kimi大模型,删除已上传的“指定文档”或“全部文档”(6)

前言 本文是该专栏的第6篇,后面会持续分享AI大模型干货知识,记得关注。 在本专栏上一篇《Python|基于Kimi大模型,实现上传文档并进行对话(5)》中,笔者有详细介绍“基于kimi大模型,上传指定文档并结合prompt,获取目标文本数据”。对此感兴趣的同学,可以直接点击翻阅查…...

CenterPoint-KITTI:环境配置、模型训练、效果展示;KITTI 3D 目标检测数据集下载

目录 前言 Python虚拟环境创建以及使用 KITTI3D目标检测数据集 CenterPoint-KITTI编译遇到问题合集 ImportError: cannot import name VoxelGenerator from spconv.utils 失败案例 最终解决方案 对于可选参数,road plane的处理 E: Unable to locate packag…...

【Android】ViewPager

1.ViewPager的简介和作用 ViewPager是android扩展包v4包中的类,这个类可以让用户左右切换当前的view,用于允许用户在几个页面(或称为碎片)之间左右滑动切换。它通常用于创建像画廊或轮播图那样的用户体验。 ViewPager类直接继承了…...

[go] 命令模式

命令模式 将“请求”封装成对象,以便使用不同的请求、队列或者日志来参数化其他对象。命令模式也支持可撤销的操作。 模型说明 触发者类负责对请求进行初始化,其中必须包含一个成员变量来存储对于命令对象的引用。触发命令,而不同接受者直接…...

代码随想录冲冲冲 Day48 单调栈Part2

42. 接雨水 关键点有以下几个 首先是怎么去理解接雨水 其实就是找每一个段的左边第一个最大值和右边第一个最大值 既然是最大值 那么单调栈就是递增的 左边第一个最大值其实就是pop掉中间的之后st.top 由于是出现大于等于情况时候进行操作 所以右边最大值就是i 接下来就…...

企业内训|Nvidia智算中心深度技术研修-某智算厂商研发中心

课程概述 此企业内训课程“Nvidia智算中心的深度技术研修”专为某智算厂商研发中心设计,内容涵盖了从基础设施构建到高性能计算优化的全方位技术要点。课程为期七天,分模块详细讲解了NV算力资源的网络架构、存储优化、智算集群的建设与自动化管理、NCCL…...

《算法笔记》例题解析 第3章入门模拟--3图形输出(9题)2021-03-03

例题 旋转方阵 题目描述 Time Limit: 1000 ms Memory Limit: 256 mb 打印出一个旋转方阵,见样例输出。 输入描述: 输入一个整数n(1 < n < 20), n为方阵的行数。 输出描述: 输出一个大小为n*n的距阵 输入 5 输出 1 16 15 14 13 2 17 24 23 12 3 18 25 22 11 4 1…...

合宙Air201模组LuatOS:PWRKEY控制,一键解决解决关机难问题

不知不觉间&#xff0c;我们已经发布拉期课程&#xff1a;hello world初体验&#xff0c;点灯、远程控制、定位和扩展功能&#xff0c;你学的怎么样&#xff1f;很多伙伴表示已经有点上瘾啦&#xff01;合宙Air201&#xff0c;如同我们一路升级打怪的得力法器&#xff0c;让开发…...

Kafka 命令详解及使用示例

文章目录 Kafka 命令详解及使用示例Kafka 命令详解kafka-topics.sh&#xff1a;主题管理创建主题创建带副本的主题修改主题分区数了解分区分布列出主题查看主题详情删除主题 kafka-console-producer.sh&#xff1a;消息生产者发送消息到主题带键值对的消息消息生产性能优化带分…...

重生归来之挖掘stm32底层知识(1)——寄存器

概念理解 要使用stm32首先要知道什么是引脚和寄存器。 如下图所示&#xff0c;芯片通过这些金属丝与电路板连接&#xff0c;这些金属丝叫做引脚。一般做软件开发是不需要了解芯片是怎么焊的&#xff0c;只要会使用就行。我们平常通过编程来控制这些引脚的输入和输出&#xff0c…...

Qt构建JSON及解析JSON

目录 一.JSON简介 JSON对象 JSON数组 二.Qt中JSON介绍 QJsonvalue Qt中JSON对象 Qt中JSON数组 QJsonDocument 三.Qt构建JSON数组 四.解析JSON数组 一.JSON简介 一般来讲C类和对象在java中是无法直接直接使用的&#xff0c;因为压根就不是一个规则。但是他们在内存中…...

合宙Air201模组LuatOS扩展功能:温湿度传感器篇!

通过前面几期的学习&#xff0c;同学们的学习热情越来越高。 合宙Air201模组除了支持3种定位方式外&#xff0c;还具有丰富的扩展功能&#xff0c;比如&#xff1a;通过外扩BTB链接方案&#xff0c;最多可支持21个IO接口&#xff1a;SPI、I2C、UART等多种接口全部支持。 本期…...

主流敏捷工具scrum工具

在当今的快速变化和高需求的业务环境中&#xff0c;敏捷开发已经成为许多企业实现快速迭代和响应市场需求的重要方法。而在众多敏捷工具中&#xff0c;选择适合自己团队的工具尤为重要。 今天&#xff0c;我们将对比几款主流的敏捷工具&#xff0c;供参考 1. Leangoo领歌&…...

探索微服务架构:从理论到实践,深度剖析其优缺点

微服务架构&#xff08;Microservice Architecture&#xff09;是一种软件开发架构形式&#xff0c;它的核心 思想是将大型应用程序拆分成一组小的服务&#xff0c;每个服务都运行在其独立的进程中&#xff0c;并且 服务与服务之间通过轻量级的通信机制&#xff08;如HTTP REST…...

2024 年最佳 Chrome 验证码扩展,解决 reCAPTCHA 问题

验证码&#xff0c;特别是 reCAPTCHA&#xff0c;已成为在线安全的不可或缺的一部分。虽然它们在区分人类和机器人方面起着至关重要的作用&#xff0c;但它们也可能成为合法用户和从事网络自动化的企业的主要障碍。无论您是试图简化在线体验的个人&#xff0c;还是依赖自动化工…...

Go语言现代web开发defer 延迟执行

The defer statement will delay the execution of a function until the surrounding function is completed. Although execution is postponed, funciton arguments will be evaluated immediately. defer语句将延迟函数的执行&#xff0c;直到周围的函数完成。虽然执行被延…...

Vue路由二(嵌套多级路由、路由query传参、路由命名、路由params传参、props配置、<router-link>的replace属性)

目录 1. 嵌套(多级)路由2. 路由query传参3. 路由命名4. 路由params传参5. props配置6. <router-link>的replace属性 1. 嵌套(多级)路由 pages/Car.vue <template><ul><li>car1</li><li>car2</li><li>car3</li></ul…...

【RabbitMQ】可靠性传输

概述 作为消息中间件来说&#xff0c;最重要的任务就是收发消息。因此我们在收发消息的过程中&#xff0c;就要考虑消息是否会丢失的问题。结果是必然的&#xff0c;假设我们没有采取任何措施&#xff0c;那么消息一定会丢失。对于一些不那么重要的业务来说&#xff0c;消息丢失…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查

在对接支付宝API的时候&#xff0c;遇到了一些问题&#xff0c;记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖

在前面的练习中&#xff0c;每个页面需要使用ref&#xff0c;onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入&#xff0c;需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store&#xff1a; 我们在使用异步的时候理应是要使用中间件的&#xff0c;但是configureStore 已经自动集成了 redux-thunk&#xff0c;注意action里面要返回函数 import { configureS…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

基于TurtleBot3在Gazebo地图实现机器人远程控制

1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...