当前位置: 首页 > news >正文

训练 CNN 对 CIFAR-10 数据中的图像进行分类

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inlinefig = plt.figure(figsize=(20,5))
for i in range(36):ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]# 打印训练集的形状
print('x_train shape:', x_train.shape)# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropoutmodel = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   # 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)hist = model.fit(x_train, y_train, batch_size=32, epochs=100,validation_data=(x_valid, y_valid), callbacks=[checkpointer], verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_test[idx]))pred_idx = np.argmax(y_hat[idx])true_idx = np.argmax(y_test[idx])ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),color=("green" if pred_idx == true_idx else "red"))

相关文章:

训练 CNN 对 CIFAR-10 数据中的图像进行分类

1. 加载 CIFAR-10 数据库 import keras from keras.datasets import cifar10# 加载预先处理的训练数据和测试数据 (x_train, y_train), (x_test, y_test) cifar10.load_data() 2. 可视化前 24 个训练图像 import numpy as np import matplotlib.pyplot as plt %matplotlib …...

香港科技大学广州|智能制造学域博士招生宣讲会—天津大学专场

时间:2023年12月07日(星期四)15:30 地点:天津大学卫津路校区26楼B112 报名链接:https://www.wjx.top/vm/mmukLPC.aspx# 宣讲嘉宾: 汤凯教授 学域主任 https://facultyprofiles.hkust-gz.edu.cn/faculty-p…...

滑动窗口练习(二)— 子数组中满足max -min <= sum的个数

题目 给定一个整型数组arr&#xff0c;和一个整数num 某个arr中的子数组sub&#xff0c;如果想达标&#xff0c;必须满足&#xff1a; sub中最大值 – sub中最小值 < num&#xff0c; 返回arr中达标子数组的数量 暴力对数器 暴力对数器方法主要是用来和另一个方法互相校验正…...

用xlwings新建一个excel并同时生成多个sheet

新建一个excel并同时生成多个sheet&#xff0c;要实现如下效果&#xff1a; 一般要使用数据透视表来快速实现。 今天记录用xlwings新建一个excel并同时生成多个sheet。 import xlwings as xw # 打开excel,参数visible表示处理过程是否可视,add_book表示是否打开新的Excel程序…...

诺威信,浪潮云,微众区块链

目录 诺威信B隐私计算平台 浪潮云=星火连-澳优码 HyperChain 产品介绍 CA认证即电子认证服务...

Redux在React中的使用

Redux在React中的使用 1.构建方式 采用reduxjs/toolkitreact-redux的方式 安装方式 npm install reduxjs/toolkit react-redux2.使用 ①创建目录 创建store文件夹&#xff0c;然后创建index和对应的模块&#xff0c;如上图所示 ②编写counterStore.js 文章以counterStore…...

Go 数字类型

一、数字类型 1、Golang 数据类型介绍 Go 语言中数据类型分为&#xff1a;基本数据类型和复合数据类型基本数据类型有&#xff1a; 整型、浮点型、布尔型、字符串复合数据类型有&#xff1a; 数组、切片、结构体、函数、map、通道&#xff08;channel&#xff09;、接口 2、…...

时间序列预测 — Informer实现多变量负荷预测(PyTorch)

目录 1 实验数据集 2 如何运行自己的数据集 3 报错分析 1 实验数据集 实验数据集采用数据集4&#xff1a;2016年电工数学建模竞赛负荷预测数据集&#xff08;下载链接&#xff09;&#xff0c;数据集包含日期、最高温度℃ 、最低温度℃、平均温度℃ 、相对湿度(平均) 、降雨…...

2023年金融信创行业研究报告

第一章 行业概况 1.1 定义 金融信创是指在金融行业中应用的信息技术&#xff0c;特别是那些涉及到金融IT基础设施、基础软件、应用软件和信息安全等方面的技术和产品。这一概念源于更广泛的“信创 (信息技术应用创新)”&#xff0c;即通过中国国产信息技术替换海外信息技术&a…...

51单片机按键控制LED灯亮灭的N个玩法

51单片机按键控制LED灯亮灭的N个玩法 1.概述 这篇文章介绍按键的使用&#xff0c;以及通过控制LED灯的小实验&#xff0c;发现按键中存在的问题&#xff0c;然后思考并解决这些问题。达到熟练使用按键控制元器件。 2.搭建硬件环境 1.硬件准备 名称型号数量单片机STC12C205…...

推荐6款本周 yyds 的开源项目

&#x1f525;&#x1f525;&#x1f525;本周GitHub项目圈选: 主要包含 链接管理、视频总结、有道音色情感合成、中文文本格式校正、GPT爬虫、深度学习推理 等热点项目。 1、Dub 一个开源的链接管理工具&#xff0c;可自定义域名将繁杂的长链接生成短链接&#xff0c;便于保…...

【Git】git 更换远程仓库地址三种方法总结分享

因为公司更改了 gitlab 的网段地址&#xff0c;发现全部项目都需要重新更改远程仓库的地址了&#xff0c;所以做了个记录&#xff0c;说不定以后还会用到呢。 一、不删除远程仓库修改&#xff08;最方便&#xff09; # 查看远端地址 git remote -v # 查看远端仓库名 git rem…...

springboot 返回problem+json

spring所有配置都在WebMvcAutoConfiguration中 其中有 ProblemDetailsExceptionHandler 容器中的一个组件 -ControllerAdvice用来集中处理异常的 -点进ResponseEntityExceptionHandler 包含这些异常&#xff0c;如果出现以下异常&#xff0c;会被springboot支持以RFC 7807规…...

AI动画制作 StableDiffusion

1.brew -v 2.安装爬虫项目包所必需的python和git等系列系统支持部件 brew install cmake protobuf rust python@3.10 git wget pod --version brew link --overwrite cocoapods 3.从github网站克隆stable-diffusion-webui爬虫项目包至本地 ssh-add /Users/haijunyan/.ssh/id_r…...

【开源】基于Vue和SpringBoot的木马文件检测系统

项目编号&#xff1a; S 041 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S041&#xff0c;文末获取源码。} 项目编号&#xff1a;S041&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 木马分类模块2.3 木…...

5 动态规划解分割等和子串

来源&#xff1a;LeetCode第416题 难度&#xff1a;中等 描述&#xff1a;给你一个只包含正整数的非空数组nums,请你判断是否可以将这个数组分割成两个子集&#xff0c;使得两个子集的元素和相等 分析&#xff1a;相当于从nums数组中选取一些元素&#xff0c;使得他们的和为…...

file_get_contents() 函数详解与使用

概述 在PHP中&#xff0c;file_get_contents() 函数是一个强大的工具&#xff0c;它既可以用于读取本地文件的内容&#xff0c;也可以用于发起 HTTP 请求获取远程资源。本文将详细介绍 file_get_contents() 函数的两种主要用途&#xff0c;并探讨如何充分利用这个函数。 1. 文…...

某医生用 ChatGPT 在 4 个月内狂写 16 篇论文,其中 5 篇已发表,揭密ChatGPT进行论文润色与改写的秘籍

如果写过学术论文&#xff0c;想必会有这样的感受&#xff1a; 绞尽脑汁、茶饭不思、夜不能寐、废寝忘食、夜以继日&#xff0c;赶出一篇论文&#xff0c;然后还被导师点评&#xff0c;“写得就是一坨&#xff01;” 可是&#xff0c;却有人4个月产出了16篇论文&#xff0c;成功…...

进程等待讲解

今日为大家分享有关进程等待的知识&#xff01;希望读完本文&#xff0c;大家能有一定的收获&#xff01; 正文开始&#xff01; 进程等待的引进 既然我们今天要讲进程等待这个概念&#xff01;那么只有我们把下面这三个方面搞明白&#xff0c;才能真正的了解进程等待&#x…...

MySQL Binlog深度解析:进阶应用与实战技巧【进阶应用】

&#x1f38f;&#xff1a;你只管努力&#xff0c;剩下的交给时间 &#x1f3e0; &#xff1a;小破站 MySQL Binlog深度解析&#xff1a;进阶应用与实战技巧 前言第一&#xff1a;Binlog事件详解第二&#xff1a;关于GTIDGTID的结构&#xff1a;GTID的作用&#xff1a;GTID的事…...

Qwen-Turbo-BF16实战案例:电商主图生成——白底产品图+场景化展示图双输出

Qwen-Turbo-BF16实战案例&#xff1a;电商主图生成——白底产品图场景化展示图双输出 1. 电商主图生成的新选择 电商卖家每天都要面对一个头疼的问题&#xff1a;商品主图怎么设计&#xff1f;白底图要干净专业&#xff0c;场景图要吸引眼球&#xff0c;找设计师成本高&#…...

如何构建高效可扩展的实时数据处理系统:抖音直播弹幕采集架构深度解析

如何构建高效可扩展的实时数据处理系统&#xff1a;抖音直播弹幕采集架构深度解析 【免费下载链接】DouyinLiveWebFetcher 抖音直播间网页版的弹幕数据抓取&#xff08;2025最新版本&#xff09; 项目地址: https://gitcode.com/gh_mirrors/do/DouyinLiveWebFetcher 抖音…...

如何3步实现Windows任务栏透明美化:TranslucentTB完整使用指南

如何3步实现Windows任务栏透明美化&#xff1a;TranslucentTB完整使用指南 【免费下载链接】TranslucentTB A lightweight utility that makes the Windows taskbar translucent/transparent. 项目地址: https://gitcode.com/gh_mirrors/tr/TranslucentTB TranslucentTB…...

实时多人姿态估计终极指南:多尺度特征提取技术深度解析

实时多人姿态估计终极指南&#xff1a;多尺度特征提取技术深度解析 【免费下载链接】Realtime_Multi-Person_Pose_Estimation Code repo for realtime multi-person pose estimation in CVPR17 (Oral) 项目地址: https://gitcode.com/gh_mirrors/re/Realtime_Multi-Person_Po…...

s2-pro开源TTS价值:填补中文专业级开源语音合成模型空白

s2-pro开源TTS价值&#xff1a;填补中文专业级开源语音合成模型空白 1. 为什么我们需要专业级中文TTS 在语音技术领域&#xff0c;中文语音合成(TTS)长期面临一个尴尬局面&#xff1a;虽然商业解决方案众多&#xff0c;但高质量的开源模型却寥寥无几。这种状况直到s2-pro的出…...

告别WebSecurityConfigurerAdapter:Spring Security 5.7+组件化配置实战指南

1. 从WebSecurityConfigurerAdapter到组件化配置的转变 如果你最近在升级Spring Boot应用&#xff0c;特别是从2.x版本迁移到3.x&#xff0c;肯定会遇到一个重大变化&#xff1a;Spring Security 5.7版本中&#xff0c;WebSecurityConfigurerAdapter这个老朋友已经被正式弃用了…...

测试、项目管理、软件度量和质量

欢迎来到我的软考中级——软件设计师备考合集。这里不只是一份简单的知识点堆砌&#xff0c;而是我在备考征途中&#xff0c;对庞杂知识体系进行深度梳理与内化的结晶。 面对浩瀚的考纲&#xff0c;从计算机组成原理的底层逻辑&#xff0c;到操作系统的进程调度&#xff1b;从数…...

Facebook广告细分定位新功能解析

Facebook广告细分定位新功能的本质&#xff0c;是广告受众定位正式进入了“自然语言”时代。简单来说&#xff0c;就是把过去从庞大的标签库里找词&#xff0c;变成了直接用日常语言描述你想要触达的目标人群。这背后&#xff0c;是Meta全新的 “Andromeda”&#xff08;仙女座…...

nli-distilroberta-base生产环境:低延迟NLI服务在搜索Query改写中应用

nli-distilroberta-base生产环境&#xff1a;低延迟NLI服务在搜索Query改写中应用 1. 项目概述 在搜索引擎优化和智能问答系统中&#xff0c;Query改写是一个关键环节。nli-distilroberta-base是一个基于DistilRoBERTa模型的轻量级自然语言推理(NLI)服务&#xff0c;专门为生…...

第一篇:KNX入门实战|从协议基础到开发环境搭建,新手也能轻松上手

在智能楼宇与工业自动化领域&#xff0c;KNX协议绝对是绕不开的核心标准——作为全球通用的开放式楼宇控制协议&#xff08;ISO/IEC 14543&#xff09;&#xff0c;它融合了欧洲三大总线协议的优势&#xff0c;能实现照明、空调、传感器等各类设备的无缝联动&#xff0c;广泛应…...