当前位置: 首页 > news >正文

CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

相关文章:

CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍 CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成&#xff0c;每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。 数据集分为5个训练批次和1个测试批次&#xff0c;每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像…...

《傲剑狂刀》中的人物性格——龙吟风

在《傲剑狂刀》这款经典武侠题材的格斗游戏中,龙吟风作为一位具有传奇色彩的角色,其性格特征复杂且引人入胜。以下是对龙吟风这一角色的性格特点进行深度剖析: 一、孤高独立的剑客气质 龙吟风的名字本身就流露出一种独特的江湖气息,"吟风"象征着他的飘逸与淡泊名…...

KVM和JVM的虚拟化技术有何区别?

随着虚拟化技术的不断发展&#xff0c;KVM和JVM已成为两种主流的虚拟化技术。尽管它们都提供了虚拟化的解决方案&#xff0c;但它们在实现方式、功能和性能方面存在一些重要的差异。本文将深入探讨KVM和JVM的虚拟化技术之间的区别。 KVM&#xff08;Kernel-based Virtual Mac…...

LeetCode力扣 面试经典150题 详细题解 (1~5) 持续更新中

目录 1.合并两个有序数组 2.移动元素 3.删除有序数组中的重复项 4.删除排序数组中的重复项 II 5.多数元素 暂时更新到这里&#xff0c;博主会持续更新的 1.合并两个有序数组 题目&#xff08;难度&#xff1a;简单&#xff09;&#xff1a; 给你两个按 非递减顺序 排列的…...

如何解决利用cron定时任务自动更新SSL证书后Nginx重启问题

利用cron定时任务自动更新SSL证书后&#xff0c;用浏览器访问网站&#xff0c;获取到的证书仍然是之前的。原因在于没有对Nginx进行重启。 据说certbot更新完成证书后会自动重启Nginx,但显然经我检测不是这回事儿。 所以我们需要创建一bash脚本&#xff0c;然后定时调用这个脚…...

第一个 Angular 项目 - 静态页面

第一个 Angular 项目 - 静态页面 之前的笔记&#xff1a; [Angular 基础] - Angular 渲染过程 & 组件的创建 [Angular 基础] - 数据绑定(databinding) [Angular 基础] - 指令(directives) 这是在学完了上面这三个内容后能够完成的项目&#xff0c;目前因为还没有学到数…...

网络协议与攻击模拟_17HTTPS 协议

HTTPShttpssl/tls 1、加密算法 2、PKI&#xff08;公钥基础设施&#xff09; 3、证书 4、部署HTTPS服务器 部署CA证书服务器 5、分析HTTPS流量 分析TLS的交互过程 一、HTTPS协议 在http的通道上增加了安全性&#xff0c;传输过程通过加密和身份认证来确保传输安全性 1、TLS …...

【linux系统体验】-ubuntu简易折腾

ubuntu 一、终端美化二、桌面美化2.1 插件安装2.2 主题和图标2.3 美化配置 三、常用命令 以后看不看不重要&#xff0c;咱就是想记点儿东西。一、终端美化 安装oh my posh&#xff0c;参考链接&#xff1a;Linux 终端美化 1、安装字体 oh my posh美化工具可以使用合适的字体&a…...

Android 判断通知是进度条通知

1.需求: 应用监听安卓系统中的通知,需要区分出带进度条的通知. 当使用NotificationCompat.Builder构建一个通知时&#xff0c;可以通过调用setProgress(max, progress, indeterminate)方法来添加一个进度条。这里的max参数表示最大进度值&#xff0c;progress表示当前进度值&a…...

学习数据结构和算法的第8天

顺序表的实现 顺序表 ​ 本质就是数组 概念及结构 ​ 顺序表是用一段物理地址连续的储存单元依次储存数据元素的线性结构&#xff0c;一般情况下采用数组储存&#xff0c;在数组上完成数据的增删。 顺序表就是数组&#xff0c;但是在数组的基础上&#xff0c;它还要求数据…...

JCIM | MD揭示PTP1B磷酸酶激活RtcB连接酶的机制

Background 内质网应激反应&#xff08;UPR&#xff09; 中的一个重要过程。UPR是由内质网中的三种跨膜传感器&#xff08;IRE1、PERK和ATF6&#xff09;控制的细胞应激反应&#xff0c;当内质网中的蛋白质折叠能力受到压力时&#xff0c;UPR通过减少蛋白质合成和增加未折叠或错…...

基于Java (spring-boot)的音乐管理系统

一、项目介绍 播放器的前端&#xff1a; 1.首页&#xff1a;点击歌单中的音乐播放列表中的歌曲进行播放&#xff0c;播放时跳转播放界面&#xff0c;并显示歌手信息&#xff0c;同时会匹配歌词&#xff0c;把相应的歌词显示在歌词面板中。 2.暂停&#xff1a;当歌曲正在播放时…...

在 MacOS M系列处理器上使用 Anaconda 开发 Oralce 的Python程序

在 MacOS M系列处理器上使用 Anaconda 开发 Oralce 的Python程序 因oracle官方驱动暂无 苹果 M 系列处理器版本&#xff0c;所以使用Arm的python解释器报驱动错误&#xff1a; cx_Oracle.DatabaseError: DPI-1047: Cannot locate a 64-bit Oracle Client library: "dlop…...

四、OpenAI之文本生成模型

文本生成模型 OpenAI的文本生成模型(也叫做生成预训练的转换器(Generative pre-trained transformers)或大语言模型)已经被训练成可以理解自然语言、代码和图片的模型。模型提供文本的输出作为输入的响应。对这些模型的输入内容也被称作“提示词”。设计提示词的本质是你如何对…...

CSS之flex布局

flex布局 CSS的Flex布局&#xff08;Flexible Box Layout&#xff09;是一种用于在页面上布置元素的高效方法&#xff0c;特别适合于响应式设计。Flex布局使得元素能够伸缩以适应可用空间&#xff0c;可以简化很多原本需要复杂CSS和HTML结构才能实现的布局设计。 flex布局包括…...

UnityShader——02三大主流编程语言

三大主流编程语言 Shader Language Shader language的发展方向是设计出在便携性方面可以与C/JAVA相比的高级语言&#xff0c;“赋予程序员灵活而方便的编程方式”&#xff0c;并“利用图形硬件的并行性&#xff0c;提高算法的效率” Shader language目前主要有 3 种语言&…...

Centos7安装nginx yum报错

Centos7安装nginx yum报错&#xff0c;yum源报错解决办法&#xff1a; 1、更新epel源后&#xff0c;出现yum报错 [roothacker117 ~]# yum install epel-release&#xff08;安装成功&#xff09; [roothacker117 ~]# yum install nginx&#xff08;安装失败&#xff0c;提示如…...

【机组】基于FPGA的32位算术逻辑运算单元的设计(EP2C5扩充选配类)

​&#x1f308;个人主页&#xff1a;Sarapines Programmer&#x1f525; 系列专栏&#xff1a;《机组 | 模块单元实验》⏰诗赋清音&#xff1a;云生高巅梦远游&#xff0c; 星光点缀碧海愁。 山川深邃情难晤&#xff0c; 剑气凌云志自修。 目录 一、实验目的 二、实验要求 …...

Asp .Net Core 系列:Asp .Net Core 集成 NLog

简介 NLog是一个基于.NET平台编写的日志记录类库&#xff0c;它可以在应用程序中添加跟踪调试代码&#xff0c;以便在开发、测试和生产环境中对程序进行监控和故障排除。NLog具有简单、灵活和易于配置的特点&#xff0c;支持在任何一种.NET语言中输出带有上下文的调试诊断信息…...

一个基于 .NET 7 + Vue.js 的前后端分离的通用后台管理系统框架 - DncZeus

前言 今天给大家推荐一个基于.NET 7 Vue.js(iview-admin) 的前后端分离的通用后台权限(页面访问、操作按钮控制)管理系统框架&#xff1a;DncZeus。 官方项目简介 DncZeus是一个基于 .NET 7 Vue.js 的前后端分离的通用后台管理系统框架。后端使用.NET 7 Entity Framework…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

高防服务器能够抵御哪些网络攻击呢?

高防服务器作为一种有着高度防御能力的服务器&#xff0c;可以帮助网站应对分布式拒绝服务攻击&#xff0c;有效识别和清理一些恶意的网络流量&#xff0c;为用户提供安全且稳定的网络环境&#xff0c;那么&#xff0c;高防服务器一般都可以抵御哪些网络攻击呢&#xff1f;下面…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)

题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...

[论文阅读]TrustRAG: Enhancing Robustness and Trustworthiness in RAG

TrustRAG: Enhancing Robustness and Trustworthiness in RAG [2501.00879] TrustRAG: Enhancing Robustness and Trustworthiness in Retrieval-Augmented Generation 代码&#xff1a;HuichiZhou/TrustRAG: Code for "TrustRAG: Enhancing Robustness and Trustworthin…...

React父子组件通信:Props怎么用?如何从父组件向子组件传递数据?

系列回顾&#xff1a; 在上一篇《React核心概念&#xff1a;State是什么&#xff1f;》中&#xff0c;我们学习了如何使用useState让一个组件拥有自己的内部数据&#xff08;State&#xff09;&#xff0c;并通过一个计数器案例&#xff0c;实现了组件的自我更新。这很棒&#…...