当前位置: 首页 > news >正文

Keras实战之图像分类识别

文章目录

    • 整体流程
      • 数据加载与预处理
      • 搭建网络模型
      • 优化网络模型
        • 学习率
        • Drop-out操作
        • 权重初始化方法对比
        • 正则化
        • 加载模型进行测试

实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。

问:为什么选择Keras?
答:使用Keras便捷快速。用起来简单,入门容易,上手快。没有tensorflow那么复杂的规范。

整体流程

  1. 读取数据
  2. 数据预处理
  3. 切分数据集(分为训练集和测试集)
  4. 搭建网络模型(初始化参数)
  5. 训练网络模型
  6. 评估测试模型(通过对比不同参数下损失函数不断优化模型)
  7. 保存模型到本地

(1)手动配置参数,设置数据存储路径、模型保存路径、图片保存路径

# 输入参数,手动设置数据存储路径、模型保存路径、图片保存路径等
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

在这里插入图片描述

数据加载与预处理

# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
# 数据洗牌前设置随机种子确保后面调参过程中训练数据集一样# 遍历读取数据
for imagePath in imagePaths:# 读取图像数据,由于使用神经网络,需要输入数据给定成一维image = cv2.imread(imagePath)# 而最初获取的图像数据是三维的,则需要将三维数据进行拉长image = cv2.resize(image, (32, 32)).flatten()data.append(image)# 读取标签,通过读取数据存储位置文件夹来判断图片标签label = imagePath.split(os.path.sep)[-2]labels.append(label)# scale图像数据,归一化
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)# 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

数据预处理:①通过数据除以255进行数据归一化;②对数据标签进行格式转换。

搭建网络模型

  1. 创建序列结构
model = Sequential()
  1. 添加全连接层
  • 第一层全连接层Dense设计512个神经元,当前输入特征个数(输入神经元个数)为3072,设置激活函数为"relu";
  • 第二层设计256个神经元;
  • 第三层设计类别数个神经元(即3个),并作softmax操作得到最终分类类别。
# 第一层
model.add(Dense(512, input_shape=(3072,),activation="relu"))
# 第二层
model.add(Dense(256, activation="relu",))
# 第三层
model.add(Dense(len(lb.classes_), activation="softmax",))
  1. 初始化参数
# 学习率
INIT_LR = 0.01
# 迭代次数
EPOCHS = 200
  1. 训练网络模型
# 给定损失函数和评估方法
opt = SGD(lr=INIT_LR) # 指定优化器为梯度下降的优化器
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])# 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)
  1. 测试网络模型

使用上面训练所得网络模型对测试集进行预测,并对比预测解国和数据集真实结果打印结果报告(包括准确率、recall、f1-score),并将损失函数以折线图的效果直观展示出来

predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=lb.classes_))
  1. 评估结果
    在这里插入图片描述
    在这里插入图片描述
    从损失函数图像中可看出,模型出现明显过拟合现象,故而该初始参数所构建的模型效果较差,需要通过调参优化模型。

优化网络模型

学习率

对比学习率为0.01和0.001的损失函数图像。

在这里插入图片描述

train_loss与val_loss之间差异仍然存在,但是可看出学习率越大,过拟合现象越明显。

Drop-out操作

Dropout操作:在搭建网络模型中,通过设置一0到1范围内的参数从而防止过拟合。
在这里插入图片描述

权重初始化方法对比

(1)RandomNormal随机高斯初始化

kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)
model.add(Dense(512, input_shape=(3072,),activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))

在这里插入图片描述
图中可看出,添加RandomNormal初始化后,过拟合现象减弱了一丢丢。

(2)TruncatedNormal截断

kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)

相比于正常高斯分布截断了两边,只取小于2倍stddev的值

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))

在这里插入图片描述

对比stddev取不同值时的loss函数图可得,TruncatedNormal中stddev值越小,过拟合风险越低,模型效果越好。TruncatedNormal消除过拟合的效果RandomNormal好。

正则化
kernel_regularizer=regularizers.l2(0.01)

正则化后,损失函数loss = 初始loss + aR(W)。正则化惩罚W,让稳定的W减少过拟合。

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))

对比正则化前后取迭代150到200的loss波动图,可发现正则化后虽然开始时loss值较大,但后期过拟合现象有明显减弱
在这里插入图片描述
再对比正则化参数l2 = 0.01和0.05的结果可得,l2越大,W的惩罚力度越大,过拟合风险越小
在这里插入图片描述

加载模型进行测试
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2# 设置输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,help="path to input image we are going to classify")
ap.add_argument("-m", "--model", required=True,help="path to trained Keras model")
ap.add_argument("-l", "--label-bin", required=True,help="path to label binarizer")
ap.add_argument("-w", "--width", type=int, default=28,help="target spatial dimension width")
ap.add_argument("-e", "--height", type=int, default=28,help="target spatial dimension height")
ap.add_argument("-f", "--flatten", type=int, default=-1,help="whether or not we should flatten the image")
args = vars(ap.parse_args())# 加载测试数据并进行相同预处理操作
image = cv2.imread(args["image"])
output = image.copy()
image = cv2.resize(image, (args["width"], args["height"]))# scale the pixel values to [0, 1]
image = image.astype("float") / 255.0# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))# 读取模型和标签
print("[INFO] loading network and label binarizer...")
model = load_model(args["model"])
lb = pickle.loads(open(args["label_bin"], "rb").read())# 预测
preds = model.predict(image)# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)

分类结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
通过预测结果可得:该模型在预测猫上存在较大误差,在预测熊猫上较为准确。或许改进增加迭代次数可进一步优化模型。

相关文章:

Keras实战之图像分类识别

文章目录 整体流程数据加载与预处理搭建网络模型优化网络模型学习率Drop-out操作权重初始化方法对比正则化加载模型进行测试 实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。 问:为什么选择Keras&am…...

Celery,一个实时处理的 Python 分布式系统

大家好!我是爱摸鱼的小鸿,关注我,收看每期的编程干货。 一个简单的库,也许能够开启我们的智慧之门, 一个普通的方法,也许能在危急时刻挽救我们于水深火热, 一个新颖的思维方式,也许能…...

源码编译安装 LAMP

源码编译安装 LAMP Apache 网站服务基础Apache 简介安装 httpd 服务器 httpd 服务器的基本配置Web 站点的部署过程httpd.conf 配置文件 构建虚拟 Web 主机基于域名的虚拟主机基于IP 地址、基于端口的虚拟主机 MySQL 的编译安装构建 PHP 运行环境安装PHP软件包设置 LAMP 组件环境…...

PostgreSQL的pg_filedump工具

PostgreSQL的pg_filedump工具 基础信息 OS版本:Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本:16.2 pg软件目录:/home/pg16/soft pg数据目录:/home/pg16/data 端口:5777pg_filedump 是一个工具&#x…...

Java语言+后端+前端Vue,ElementUI 数字化产科管理平台 产科电子病历系统源码

Java语言后端前端Vue,ElementUI 数字化产科管理平台 产科电子病历系统源码 Java开发的数字化产科管理系统,已在多家医院实施,支持直接部署。系统涵盖孕产全程,包括门诊、住院、统计和移动服务,整合高危管理、智能提醒、档案追踪等…...

Linux 服务器环境搭建

一、安装 JDK 官网下载地址:https://www.oracle.com/java/technologies/downloads # 创建目录 mkdir /usr/local/java/# 解压 tar -zxvf jdk-8u333-linux-x64.tar.gz -C /usr/local/java/# 配置环境变量 vim /etc/profileexport export JAVA_HOME/usr/local/java/…...

RabbitMQ 更改服务端口号

需求 windows环境下,将RabbitMQ默认的端口号 5672 改为 11001 实现 本机RabbitMQ版本为3.8.16,找到配置文件位置,路径为:C:\Users\%USERNAME%\AppData\Roaming\RabbitMQ\advanced.config 配置文件默认内容为空 填写修改端口号…...

16:9横屏短视频素材库有哪些?横屏短视频素材网站分享

在这个视觉内容至关重要的时代,16:9横屏视频因其宽广的画面和优越的观赏体验,已经成为无数创作者和营销专家的首选格式。但要创造出吸引人的横屏视频,高质量的视频素材库是不可或缺的。不管你是资深视频制作人还是刚入行的新手,下…...

在Java中,创建一个实现了Callable接口的类可以提供强大的灵活性,特别是当你需要在多线程环境中执行任务并获取返回结果时。

在Java中,创建一个实现了Callable接口的类可以提供强大的灵活性,特别是当你需要在多线程环境中执行任务并获取返回结果时。以下是一个简单的案例,演示了如何创建一个实现了Callable接口的类,并在线程池中执行它。 首先&#xff0…...

Vuforia AR篇(八)— AR塔防上篇

目录 前言一、设置Vuforia AR环境1. 添加AR Camera2. 设置目标图像 二、创建塔防游戏基础1. 导入素材2. 搭建场景3. 创建敌人4. 创建脚本 前言 在增强现实(AR)技术快速发展的今天,Vuforia作为一个强大的AR开发平台,为开发者提供了…...

Spring AOP源码篇四之 数据库事务

了解了Spring AOP执行过程&#xff0c;再看Spring事务源码其实非常简单。 首先从简单使用开始, 演示Spring事务使用过程 Xml配置&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema…...

小波与傅里叶变换的对比(Python)

直接上代码&#xff0c;理论可以去知乎看。 #Import necessary libraries %matplotlib inline import numpy as np import matplotlib.pyplot as plt import seaborn as snsimport pywt from scipy.ndimage import gaussian_filter1d from scipy.signal import chirp import m…...

Linux-sqlplus安装

1.下载安装包 下载入口&#xff1a;安装包 下载对应版本&#xff1a; oracle-instantclient-sqlplus-21.14.0.0.0-1.x86_64.rpm oracle-instantclient-basic-21.14.0.0.0-1.x86_64.rpm oracle-instantclient-devel-21.14.0.0.0-1.x86_64.rpm 2.安装 [rootpromethues-01 tmp…...

LeetCode 算法:课程表 c++

原题链接&#x1f517;&#xff1a;课程表 难度&#xff1a;中等⭐️⭐️ 题目 你这个学期必须选修 numCourses 门课程&#xff0c;记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出&#xff0c;其中 prerequisites[i]…...

前端面试题30(闭包和作用域链的关系)

闭包和作用域链在JavaScript中是紧密相关的两个概念&#xff0c;理解它们之间的关系对于深入掌握JavaScript的执行机制至关重要。 作用域链 作用域链是一个链接列表&#xff0c;它包含了当前执行上下文的所有父级执行上下文的变量对象。每当函数被调用时&#xff0c;JavaScri…...

A股本周在3000点以下继续筑底,本周依然继续探底?

夜已深&#xff0c;市场传来了3个浓烈的消息&#xff0c;炸锅了&#xff0c;恐有大事发生&#xff0c;马上告诉所有人&#xff1a; 消息面&#xff1a; 1、中国经济周刊首席评论员钮文新称&#xff1a;不要等中小投资者都彻底希望&#xff0c;销户离场了&#xff0c;才发现该…...

Javadoc介绍

Javadoc 是用于生成 Java 代码文档的工具。它利用特定的注释格式,将 Java 源代码中的注释提取出来,并生成 HTML 文档。Javadoc 注释通常位于类、接口、构造函数、方法和字段的声明之前,以 /** 开始,以 */ 结束。以下是 Javadoc 注释的一些主要元素和使用方法: 基本语法 …...

C# Application.DoEvents()的作用

文章目录 1、详解 Application.DoEvents()2、示例处理用户事件响应系统事件控制台输出游戏和多媒体应用与操作系统的交互 3、注意事项总结 Application.DoEvents() 是 .NET 框架中的一个方法&#xff0c;它主要用于处理消息队列中的事件。在 Windows 应用程序中&#xff0c;当一…...

IDEA如何创建原生maven子模块

文件 -> 新建 -> 新模块 -> Maven ArcheTypeMaven ArcheType界面中的输入框介绍 名称&#xff1a;子模块的名称位置&#xff1a;子模块存放的路径名创建Git仓库&#xff1a;子模块不单独作为一个git仓库&#xff0c;无需勾选JDK&#xff1a;JDK版本号父项&#xff1a;…...

LCD EMC 辐射 测试随想

最近做几个产品过认证。 有带2.8寸 MCU8080接口的小屏&#xff08;320 X 240&#xff09;&#xff0c;也有RGB接口的10.1寸的大屏(800*600). 以下为个人随想&#xff0c;不知道是否正确&#xff0c;仅作记录。 测试发现辐射的核心问题还是在于时钟及其倍频所产生的尖峰。 记得读…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

【Oracle APEX开发小技巧12】

有如下需求&#xff1a; 有一个问题反馈页面&#xff0c;要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据&#xff0c;方便管理员及时处理反馈。 我的方法&#xff1a;直接将逻辑写在SQL中&#xff0c;这样可以直接在页面展示 完整代码&#xff1a; SELECTSF.FE…...

模型参数、模型存储精度、参数与显存

模型参数量衡量单位 M&#xff1a;百万&#xff08;Million&#xff09; B&#xff1a;十亿&#xff08;Billion&#xff09; 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的&#xff0c;但是一个参数所表示多少字节不一定&#xff0c;需要看这个参数以什么…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

《通信之道——从微积分到 5G》读书总结

第1章 绪 论 1.1 这是一本什么样的书 通信技术&#xff0c;说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号&#xff08;调制&#xff09; 把信息从信号中抽取出来&am…...

镜像里切换为普通用户

如果你登录远程虚拟机默认就是 root 用户&#xff0c;但你不希望用 root 权限运行 ns-3&#xff08;这是对的&#xff0c;ns3 工具会拒绝 root&#xff09;&#xff0c;你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案&#xff1a;创建非 roo…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...