当前位置: 首页 > news >正文

项目实例_FashionMNIST_CNN

前言

提醒:
文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。
其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展及意见建议,欢迎评论区讨论交流。

文章目录

  • 前言
  • CNN介绍
  • 数据集介绍
      • FashionMNIST 数据集概述
        • 主要特点:
        • 类别标签:
        • 目标:
        • 为什么使用 FashionMNIST:
        • 图像示例:
        • 数据加载与预处理:
        • 下载和加载 FashionMNIST(PyTorch 代码示例):
        • 数据集使用场景:
      • 数据集可视化
      • 总结:
  • 项目实例
    • 代码


CNN介绍

可见于:
MNIST数据集_CNN
在这里插入图片描述

数据集介绍

FashionMNIST 数据集概述

FashionMNIST 是一个包含 10 个类别的图像数据集,用于训练和测试机器学习模型,特别是在图像分类任务中的应用。它是由 Zalando 提供的,目的是为了给机器学习社区提供一个标准化、简洁但具有挑战性的视觉分类数据集。FashionMNIST 是 MNIST 数据集的一个变种,后者包含手写数字图像。

主要特点:
  • 图像大小:每张图像的分辨率为 28x28 像素,每个像素为灰度值(单通道,值在 0 到 255 之间)。
  • 图像类型:这些图像展示了 10 种不同类型的时尚商品,例如鞋子、T 恤、外套等。
  • 数据集结构
    • 训练集:60,000 张图像
    • 测试集:10,000 张图像
  • 类别:数据集包含 10 个类别,分别对应不同的服装商品(每个类别有对应的标签)。
类别标签:
  • 0: T 恤/上衣
  • 1: 裤子
  • 2: 套头衫
  • 3: 连衣裙
  • 4: 外套
  • 5: 凉鞋
  • 6: 衬衫
  • 7: 运动鞋
  • 8: 包
  • 9: 靴子
目标:

FashionMNIST 的目标是对每张 28x28 的灰度图像进行分类,判定该图像属于哪个类别(例如是 “T 恤”、“裤子” 还是 “运动鞋” 等)。因此,它是一个 多类分类 问题,通常被用来评估各种机器学习模型,尤其是在图像分类任务中的表现。

为什么使用 FashionMNIST:

FashionMNIST 与原始的 MNIST 数据集相似,但相比 MNIST(手写数字),FashionMNIST 的图像内容更加复杂且多样。这使得它在很多机器学习和深度学习领域成为了一个较为简单但富有挑战性的测试集。它被广泛应用于:

  • 深度学习模型的评估:特别是用于测试卷积神经网络(CNN)等模型的性能。
  • 学习和研究:它提供了一个简单且标准化的图像分类数据集,适用于机器学习入门或新模型的验证。
图像示例:

每张图像都是 28x28 像素的灰度图,显示的是一件衣物的图片。图像尺寸较小,适合在初学者的机器学习项目中进行训练,因为它不需要大量的计算资源。

数据加载与预处理:

FashionMNIST 数据集一般会在加载时进行一些标准的预处理步骤,如:

  • 归一化:将像素值从 [0, 255] 范围映射到 [0, 1] 范围,或者进行标准化,常常帮助提升模型的收敛速度。
  • 转换:通常使用 transforms.ToTensor() 将图像转换为 PyTorch 中的 Tensor 格式。
下载和加载 FashionMNIST(PyTorch 代码示例):
import torchvision
from torchvision import transforms# 加载训练集数据
train_data = torchvision.datasets.FashionMNIST(root='data',     # 存储路径train=True,      # 训练集download=True,   # 下载数据集transform=transforms.ToTensor(),  # 转换为 Tensor 格式
)# 加载测试集数据
test_data = torchvision.datasets.FashionMNIST(root='data',train=False,     # 测试集download=True,transform=transforms.ToTensor(),
)

这段代码使用 torchvision.datasets.FashionMNIST 来下载并加载训练和测试数据集,图像会被转换为 PyTorch 的 Tensor 格式。

数据集使用场景:
  • 入门项目:对于初学者,FashionMNIST 是一个非常适合入门的图像分类数据集,因为它相对简单,且具有一定的挑战性。
  • 模型对比与验证:在机器学习和深度学习领域,FashionMNIST 常常作为测试模型性能的标准数据集之一,用来对比不同算法(如支持向量机、KNN、神经网络等)的表现。
  • 神经网络训练:尤其是卷积神经网络(CNN)在图像分类任务中的应用,FashionMNIST 为其提供了一个理想的训练平台。

数据集可视化

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# 下载 FashionMNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])  # 只需要转换为 Tensor
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)# 获取前 10 张图像以及对应的标签
images, labels = zip(*[(train_data[i][0], train_data[i][1]) for i in range(10)])# 类别名称映射
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]# 设置绘图
fig, axes = plt.subplots(1, 10, figsize=(15, 15))  # 创建 1 行 10 列的子图for i in range(10):ax = axes[i]ax.imshow(images[i].squeeze(), cmap='gray')  # squeeze 去掉多余的维度,cmap 为灰度色ax.set_title(class_names[labels[i]])  # 标注类别名称ax.axis('off')  # 不显示坐标轴plt.show()  # 显示图像

运行结果:
在这里插入图片描述

总结:

FashionMNIST 是一个标准化的图像分类数据集,由 Zalando 提供。它由 10 类不同的时尚商品构成,训练集包含 60,000 张图像,测试集包含 10,000 张图像。它适用于深度学习、机器学习模型的训练和评估,尤其适合初学者学习和实验。

项目实例

代码

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 下载训练数据集FashionMNIST
training_data = torchvision.datasets.FashionMNIST(root="data",  # 数据集存储位置train=True,    # 使用训练集download=True, # 如果数据集不存在,则下载transform=transforms.ToTensor(),  # 转换为Tensor
)# 下载测试数据集FashionMNIST
test_data = torchvision.datasets.FashionMNIST(root="data",train=False,   # 使用测试集download=True,transform=transforms.ToTensor(),  # 转换为Tensor
)# 标签的映射字典,数字标签对应的衣物类别名称
labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}# 可视化FashionMNIST数据集中前9张图片
figure = plt.figure(figsize=(8, 8))  # 创建一个8x8的图像画布
cols, rows = 3, 3  # 设置行列数
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()  # 随机选取一张图片img, label = training_data[sample_idx]  # 获取图片和标签figure.add_subplot(rows, cols, i)  # 在画布上添加子图plt.title(labels_map[label])  # 设置图片的标题为标签对应的衣物类别plt.axis("off")  # 关闭坐标轴显示plt.imshow(img.squeeze(), cmap="gray")  # 显示图像,squeeze去掉多余维度,cmap设置为灰度图
plt.show()  # 展示图像# 设置训练过程中的超参数
num_epochs = 10       # 训练的轮数
batch_size = 32       # 批大小
weight_decay = 1e-4    # 权重衰减(L2正则化)
learning_rate = 0.001 # 学习率# 创建数据加载器,训练集和测试集
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)# 打印测试集的一个batch的尺寸和标签类型
for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")  # X是输入图像,N是批大小,C是通道数,H和W是图像的高和宽print(f"Shape of y: {y.shape} {y.dtype}")      # y是标签,显示其维度和数据类型break  # 只打印一个batch的信息# 检测是否使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")  # 输出当前使用的设备(CUDA or CPU)# 定义一个简单的神经网络模型
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)   # 第一层全连接层,输入784个特征,输出hidden_size个特征self.relu = nn.ReLU()                           # ReLU激活函数self.fc2 = nn.Linear(hidden_size, num_classes)  # 第二层全连接层,输出num_classes个类别的预测值def forward(self, x):out = self.fc1(x)    # 输入通过第一层out = self.relu(out)  # ReLU激活out = self.fc2(out)   # 输入通过第二层,输出结果return out# 假设输入是28x28的图像,展开为784维,隐藏层大小为500,分类数为10
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数,适用于多分类问题
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  # 使用Adam优化器# 开始训练模型
total_step = len(train_dataloader)  # 获取训练集的总批次数
for epoch in range(num_epochs):  # 遍历每一个epochfor i, (images, labels) in enumerate(train_dataloader):  # 遍历每一个batchimages = images.reshape(-1, 28*28).to(device)  # 将28x28的图像展开成784维向量,转移到device(GPU/CPU)labels = labels.to(device)  # 标签转移到设备上# 前向传播outputs = model(images)  # 将输入传入模型,得到预测输出loss = criterion(outputs, labels)  # 计算损失# 反向传播和优化optimizer.zero_grad()  # 清零之前的梯度loss.backward()  # 计算当前梯度optimizer.step()  # 更新模型参数# 每100个batch输出一次训练状态if (i + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')# 训练完成后,在测试集上评估模型的准确率
model.eval()  # 设置模型为评估模式(此时BatchNorm等层使用移动平均值而不是批量值)
with torch.no_grad():  # 不需要计算梯度correct = 0total = 0for images, labels in test_dataloader:images = images.reshape(-1, 28*28).to(device)  # 将图像展开为784维labels = labels.to(device)  # 标签转移到设备上outputs = model(images)  # 获取模型的输出_, predicted = torch.max(outputs.data, 1)  # 获取预测类别,outputs.data返回模型的预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计预测正确的样本数# 输出模型在测试集上的准确率print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total} %')

在这里插入图片描述

相关文章:

项目实例_FashionMNIST_CNN

前言 提醒: 文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。 其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展…...

Ubuntu 安装 web 服务器

安装 apach sudo apt install apache2 -y 查看 apach2 版本号 apache2 -v 检查是否启动服务器 sudo service apache2 status 检查可用的 ufw 防火墙应用程序配置 sudo ufw app list 关闭防火墙 sudo ufw disable 更改允许通过端口流量 sudo ufw allow Apache Full 开启…...

burp的编解码,日志,比较器

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…...

2.1、模版语法

2.1.1、插值语法 1、代码示例 <body><!-- 准备容器 --><div id"app"><!-- 在data中声明的 --><!--1、 data中声明的变量 --><h1>{{msg}}</h1><h1>{{sayHello()}}</h1><!-- 不在data中的变量不可以 -->…...

最小二乘法拟合出二阶响应面近似模型

背景&#xff1a;根据样本试验数据拟合出二阶响应面近似模型&#xff08;正交二次型&#xff09;&#xff0c;并使用决定系数R和调整的决定系数R_adj来判断二阶响应面模型的拟合精度。 1、样本数据&#xff08;来源&#xff1a;硕士论文《航空发动机用W形金属密封环密封性能分析…...

【汽车】-- 常见的汽车悬挂系统

汽车悬挂系统是车辆的重要组成部分&#xff0c;其主要功能是连接车轮和车身&#xff0c;减缓路面颠簸对车身的影响&#xff0c;提高行驶的平顺性、舒适性和操控性。以下是常见的汽车悬挂系统类型及其特点&#xff1a; 1. 独立悬挂系统 每个车轮可以独立上下运动&#xff0c;不…...

VMware Workstation Pro 17 下载 以及 安装 Ubuntu 20.04.6 Ubuntu 启用 root 登录

1、个人免费版本 VMware Workstation Pro 17 下载链接怎么找&#xff1f;直接咕咕 VMware 找到如下链接。链接如下&#xff1a;Workstation 和 Fusion 对个人使用完全免费&#xff0c;企业许可转向订阅 - VMware 中文博客 点进去链接之后你会看到如下&#xff0c;注意安装之后仍…...

记录ubuntu22.04重启以后无法获取IP地址的问题处理方案

现象描述&#xff1a;我的虚拟机网络设置为桥接模式&#xff0c;输入ifconfig只显示127.0.0.1&#xff0c;不能连上外网。&#xff0c;且无法上网&#xff0c;用ifconfig只有如下显示&#xff1a; 1、sudo -i切换为root用户 2、输入dhclient -v 再输入ifconfig就可以看到多了…...

linux 删除系统特殊的的用户帐号

禁止所有默认的被操作系统本身启动的且不需要的帐号&#xff0c;当你第一次装上系统时就应该做此检查&#xff0c;Linux提供了各种帐号,你可能不需要&#xff0c;如果你不需要这个帐号,就移走它&#xff0c;你有的帐号越多,就越容易受到攻击。 1.为删除你系统上的用户,用下面的…...

core Webapi jwt 认证

core cookie 验证 Web API Jwt 》》》》用户信息 namespace WebAPI001.Coms {public class Account{public string UserName { get; set; }public string UserPassword { get; set; }public string UserRole { get; set; }} }》》》获取jwt类 using Microsoft.AspNetCore.Mvc…...

【Redis】Redis基础——Redis的安装及启动

一、初识Redis 1. 认识NoSQL 数据结构&#xff1a;对于SQL来说&#xff0c;表是有结构的&#xff0c;如字段约束、字段存储大小等。 关联性&#xff1a;SQL 的关联性体现在两张表之间可以通过外键&#xff0c;将两张表的数据关联查询出完整的数据。 查询方式&#xff1a; 2.…...

Oracle Recovery Tools工具一键解决ORA-00376 ORA-01110故障(文件offline)---惜分飞

客户在win上面迁移数据文件,由于原库非归档,结果导致有两个文件scn不一致,无法打开库,结果他们选择offline文件,然后打开数据库 Wed Dec 04 14:06:04 2024 alter database open Errors in file d:\app\administrator\diag\rdbms\orcl\orcl\trace\orcl_ora_6056.trc: ORA-01113:…...

常用环境部署(二十四)——Docker部署开源物联网平台Thingsboard

1、Docker和Docker-compose安装 参考网址如下&#xff1a; CENTOS8.0安装DOCKER&DOCKER-COMPOSE以及常见报错解决_centos8安装docker-compose-CSDN博客 2、 Thingsboard安装 &#xff08;1&#xff09;在/home目录下创建docker-compose.yml文件 vim /home/docker-com…...

SqlServer Doris Flink SQL 类型映射关系

SqlServer 对应 Flink SQL 数据类型映射关系 SQL Server TypeFlink SQL Typechar(n)CHAR(n)varchar(n)VARCHAR(n)nvarchar(n)VARCHAR(n)nchar(n)VARCHAR(n)textSTRINGntextSTRINGxmlSTRINGdecimal(p, s)DECIMAL(p, s)moneyDECIMAL(p, s)smallmoneyDECIMAL(p, s)numericNUMERIC…...

Java 中的方法重写

在 Java 中&#xff0c;方法重写&#xff08;Method Overriding&#xff09;是面向对象编程的一个重要概念&#xff0c;它指的是子类中存在一个与父类中相同名称、相同参数列表和相同返回类型的方法。方法重写使得子类可以提供特定的实现&#xff0c;从而覆盖&#xff08;或改变…...

v-for遍历多个el-popover;el-popover通过visible控制显隐;点击其他隐藏el-popover

场景:el-popover通过visible控制显隐;同时el-popover是遍历生成的多个。 原文档的使用visible后就不能点击其他地方使其隐藏;同时解决实现点击其他区域隐藏 <template><div><template v-for="(item,index) in arr" :key="index"><…...

从 Excel 文件中读取数据生成 SQL 语句[快捷main方法]

从 Excel 文件中读取数据生成 SQL 语句的实现 在日常工作中&#xff0c;我们经常需要从 Excel 文件中提取数据&#xff0c;并将其转换为 SQL 插入语句&#xff0c;以便于将数据导入到数据库中。在这篇文章中&#xff0c;我将展示如何使用 Java 来实现这一需求。 项目需求 我…...

从0到1实现项目Docker编排部署

在深入讨论 Docker 编排之前&#xff0c;首先让我们了解一下 Docker 技术本身。Docker 是一个开源平台&#xff0c;旨在帮助开发者自动化应用程序的部署、扩展和管理。自 2013 年推出以来&#xff0c;Docker 迅速发展成为现代软件开发和运维领域不可或缺的重要工具。 Docker 采…...

Vue框架入门

Author&#xff1a;Dawn_T17?? 目录 什么是框架 一.Vue 的使用方向 二.Vue 框架的使用场景 &#xff08;TIP&#xff09;MVVM思想 三.Vue入门案例 TIP&#xff1a;插值表达式 四.Vue-指令? &#xff08;1&#xff09;v-bind 和 v-model? ? &#xff08;2&#x…...

vue入门实战(二)父子组件显示,参数传递

经过上次的写法&#xff0c;我们已经写出每个list项&#xff0c;现在要在每个父组件下面加入自己的子项 一、新建子组件&#xff1a; smallItem.vue&#xff1a; <script> export default{props:[text,id,status] //父组件传来的参数 } </script> <template>…...

内存分配函数malloc kmalloc vmalloc

内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...

嵌入式学习之系统编程(九)OSI模型、TCP/IP模型、UDP协议网络相关编程(6.3)

目录 一、网络编程--OSI模型 二、网络编程--TCP/IP模型 三、网络接口 四、UDP网络相关编程及主要函数 ​编辑​编辑 UDP的特征 socke函数 bind函数 recvfrom函数&#xff08;接收函数&#xff09; sendto函数&#xff08;发送函数&#xff09; 五、网络编程之 UDP 用…...

【Linux】Linux安装并配置RabbitMQ

目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的&#xff0c;需要先安…...

Python环境安装与虚拟环境配置详解

本文档旨在为Python开发者提供一站式的环境安装与虚拟环境配置指南&#xff0c;适用于Windows、macOS和Linux系统。无论你是初学者还是有经验的开发者&#xff0c;都能在此找到适合自己的环境搭建方法和常见问题的解决方案。 快速开始 一分钟快速安装与虚拟环境配置 # macOS/…...

js 设置3秒后执行

如何在JavaScript中延迟3秒执行操作 在JavaScript中&#xff0c;要设置一个操作在指定延迟后&#xff08;例如3秒&#xff09;执行&#xff0c;可以使用 setTimeout 函数。setTimeout 是JavaScript的核心计时器方法&#xff0c;它接受两个参数&#xff1a; 要执行的函数&…...

表单设计器拖拽对象时添加属性

背景&#xff1a;因为项目需要。自写设计器。遇到的坑在此记录 使用的拖拽组件时vuedraggable。下面放上局部示例截图。 坑1。draggable标签在拖拽时可以获取到被拖拽的对象属性定义 要使用 :clone, 而不是clone。我想应该是因为draggable标签比较特。另外在使用**:clone时要将…...

Spring是如何实现无代理对象的循环依赖

无代理对象的循环依赖 什么是循环依赖解决方案实现方式测试验证 引入代理对象的影响创建代理对象问题分析 源码见&#xff1a;mini-spring 什么是循环依赖 循环依赖是指在对象创建过程中&#xff0c;两个或多个对象相互依赖&#xff0c;导致创建过程陷入死循环。以下通过一个简…...

使用python进行图像处理—图像变换(6)

图像变换是指改变图像的几何形状或空间位置的操作。常见的几何变换包括平移、旋转、缩放、剪切&#xff08;shear&#xff09;以及更复杂的仿射变换和透视变换。这些变换在图像配准、图像校正、创建特效等场景中非常有用。 6.1仿射变换(Affine Transformation) 仿射变换是一种…...