Keras实战之图像分类识别
文章目录
- 整体流程
- 数据加载与预处理
- 搭建网络模型
- 优化网络模型
- 学习率
- Drop-out操作
- 权重初始化方法对比
- 正则化
- 加载模型进行测试
实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。
问:为什么选择Keras?
答:使用Keras便捷快速。用起来简单,入门容易,上手快。没有tensorflow那么复杂的规范。
整体流程
- 读取数据
- 数据预处理
- 切分数据集(分为训练集和测试集)
- 搭建网络模型(初始化参数)
- 训练网络模型
- 评估测试模型(通过对比不同参数下损失函数不断优化模型)
- 保存模型到本地
(1)手动配置参数,设置数据存储路径、模型保存路径、图片保存路径
# 输入参数,手动设置数据存储路径、模型保存路径、图片保存路径等
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

数据加载与预处理
# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
# 数据洗牌前设置随机种子确保后面调参过程中训练数据集一样# 遍历读取数据
for imagePath in imagePaths:# 读取图像数据,由于使用神经网络,需要输入数据给定成一维image = cv2.imread(imagePath)# 而最初获取的图像数据是三维的,则需要将三维数据进行拉长image = cv2.resize(image, (32, 32)).flatten()data.append(image)# 读取标签,通过读取数据存储位置文件夹来判断图片标签label = imagePath.split(os.path.sep)[-2]labels.append(label)# scale图像数据,归一化
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)# 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
数据预处理:①通过数据除以255进行数据归一化;②对数据标签进行格式转换。
搭建网络模型
- 创建序列结构
model = Sequential()
- 添加全连接层
- 第一层全连接层Dense设计512个神经元,当前输入特征个数(输入神经元个数)为3072,设置激活函数为"relu";
- 第二层设计256个神经元;
- 第三层设计类别数个神经元(即3个),并作softmax操作得到最终分类类别。
# 第一层
model.add(Dense(512, input_shape=(3072,),activation="relu"))
# 第二层
model.add(Dense(256, activation="relu",))
# 第三层
model.add(Dense(len(lb.classes_), activation="softmax",))
- 初始化参数
# 学习率
INIT_LR = 0.01
# 迭代次数
EPOCHS = 200
- 训练网络模型
# 给定损失函数和评估方法
opt = SGD(lr=INIT_LR) # 指定优化器为梯度下降的优化器
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])# 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)
- 测试网络模型
使用上面训练所得网络模型对测试集进行预测,并对比预测解国和数据集真实结果打印结果报告(包括准确率、recall、f1-score),并将损失函数以折线图的效果直观展示出来
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=lb.classes_))
- 评估结果


从损失函数图像中可看出,模型出现明显过拟合现象,故而该初始参数所构建的模型效果较差,需要通过调参优化模型。
优化网络模型
学习率
对比学习率为0.01和0.001的损失函数图像。

train_loss与val_loss之间差异仍然存在,但是可看出学习率越大,过拟合现象越明显。
Drop-out操作
Dropout操作:在搭建网络模型中,通过设置一0到1范围内的参数从而防止过拟合。

权重初始化方法对比
(1)RandomNormal随机高斯初始化
kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)
model.add(Dense(512, input_shape=(3072,),activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))

图中可看出,添加RandomNormal初始化后,过拟合现象减弱了一丢丢。
(2)TruncatedNormal截断
kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)
相比于正常高斯分布截断了两边,只取小于2倍stddev的值
model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))

对比stddev取不同值时的loss函数图可得,TruncatedNormal中stddev值越小,过拟合风险越低,模型效果越好。TruncatedNormal消除过拟合的效果RandomNormal好。
正则化
kernel_regularizer=regularizers.l2(0.01)
正则化后,损失函数loss = 初始loss + aR(W)。正则化惩罚W,让稳定的W减少过拟合。
model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
对比正则化前后取迭代150到200的loss波动图,可发现正则化后虽然开始时loss值较大,但后期过拟合现象有明显减弱

再对比正则化参数l2 = 0.01和0.05的结果可得,l2越大,W的惩罚力度越大,过拟合风险越小

加载模型进行测试
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2# 设置输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,help="path to input image we are going to classify")
ap.add_argument("-m", "--model", required=True,help="path to trained Keras model")
ap.add_argument("-l", "--label-bin", required=True,help="path to label binarizer")
ap.add_argument("-w", "--width", type=int, default=28,help="target spatial dimension width")
ap.add_argument("-e", "--height", type=int, default=28,help="target spatial dimension height")
ap.add_argument("-f", "--flatten", type=int, default=-1,help="whether or not we should flatten the image")
args = vars(ap.parse_args())# 加载测试数据并进行相同预处理操作
image = cv2.imread(args["image"])
output = image.copy()
image = cv2.resize(image, (args["width"], args["height"]))# scale the pixel values to [0, 1]
image = image.astype("float") / 255.0# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))# 读取模型和标签
print("[INFO] loading network and label binarizer...")
model = load_model(args["model"])
lb = pickle.loads(open(args["label_bin"], "rb").read())# 预测
preds = model.predict(image)# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)
分类结果:



通过预测结果可得:该模型在预测猫上存在较大误差,在预测熊猫上较为准确。或许改进增加迭代次数可进一步优化模型。
相关文章:
Keras实战之图像分类识别
文章目录 整体流程数据加载与预处理搭建网络模型优化网络模型学习率Drop-out操作权重初始化方法对比正则化加载模型进行测试 实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。 问:为什么选择Keras&am…...
Celery,一个实时处理的 Python 分布式系统
大家好!我是爱摸鱼的小鸿,关注我,收看每期的编程干货。 一个简单的库,也许能够开启我们的智慧之门, 一个普通的方法,也许能在危急时刻挽救我们于水深火热, 一个新颖的思维方式,也许能…...
源码编译安装 LAMP
源码编译安装 LAMP Apache 网站服务基础Apache 简介安装 httpd 服务器 httpd 服务器的基本配置Web 站点的部署过程httpd.conf 配置文件 构建虚拟 Web 主机基于域名的虚拟主机基于IP 地址、基于端口的虚拟主机 MySQL 的编译安装构建 PHP 运行环境安装PHP软件包设置 LAMP 组件环境…...
PostgreSQL的pg_filedump工具
PostgreSQL的pg_filedump工具 基础信息 OS版本:Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本:16.2 pg软件目录:/home/pg16/soft pg数据目录:/home/pg16/data 端口:5777pg_filedump 是一个工具&#x…...
Java语言+后端+前端Vue,ElementUI 数字化产科管理平台 产科电子病历系统源码
Java语言后端前端Vue,ElementUI 数字化产科管理平台 产科电子病历系统源码 Java开发的数字化产科管理系统,已在多家医院实施,支持直接部署。系统涵盖孕产全程,包括门诊、住院、统计和移动服务,整合高危管理、智能提醒、档案追踪等…...
Linux 服务器环境搭建
一、安装 JDK 官网下载地址:https://www.oracle.com/java/technologies/downloads # 创建目录 mkdir /usr/local/java/# 解压 tar -zxvf jdk-8u333-linux-x64.tar.gz -C /usr/local/java/# 配置环境变量 vim /etc/profileexport export JAVA_HOME/usr/local/java/…...
RabbitMQ 更改服务端口号
需求 windows环境下,将RabbitMQ默认的端口号 5672 改为 11001 实现 本机RabbitMQ版本为3.8.16,找到配置文件位置,路径为:C:\Users\%USERNAME%\AppData\Roaming\RabbitMQ\advanced.config 配置文件默认内容为空 填写修改端口号…...
16:9横屏短视频素材库有哪些?横屏短视频素材网站分享
在这个视觉内容至关重要的时代,16:9横屏视频因其宽广的画面和优越的观赏体验,已经成为无数创作者和营销专家的首选格式。但要创造出吸引人的横屏视频,高质量的视频素材库是不可或缺的。不管你是资深视频制作人还是刚入行的新手,下…...
在Java中,创建一个实现了Callable接口的类可以提供强大的灵活性,特别是当你需要在多线程环境中执行任务并获取返回结果时。
在Java中,创建一个实现了Callable接口的类可以提供强大的灵活性,特别是当你需要在多线程环境中执行任务并获取返回结果时。以下是一个简单的案例,演示了如何创建一个实现了Callable接口的类,并在线程池中执行它。 首先࿰…...
Vuforia AR篇(八)— AR塔防上篇
目录 前言一、设置Vuforia AR环境1. 添加AR Camera2. 设置目标图像 二、创建塔防游戏基础1. 导入素材2. 搭建场景3. 创建敌人4. 创建脚本 前言 在增强现实(AR)技术快速发展的今天,Vuforia作为一个强大的AR开发平台,为开发者提供了…...
Spring AOP源码篇四之 数据库事务
了解了Spring AOP执行过程,再看Spring事务源码其实非常简单。 首先从简单使用开始, 演示Spring事务使用过程 Xml配置: <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema…...
小波与傅里叶变换的对比(Python)
直接上代码,理论可以去知乎看。 #Import necessary libraries %matplotlib inline import numpy as np import matplotlib.pyplot as plt import seaborn as snsimport pywt from scipy.ndimage import gaussian_filter1d from scipy.signal import chirp import m…...
Linux-sqlplus安装
1.下载安装包 下载入口:安装包 下载对应版本: oracle-instantclient-sqlplus-21.14.0.0.0-1.x86_64.rpm oracle-instantclient-basic-21.14.0.0.0-1.x86_64.rpm oracle-instantclient-devel-21.14.0.0.0-1.x86_64.rpm 2.安装 [rootpromethues-01 tmp…...
LeetCode 算法:课程表 c++
原题链接🔗:课程表 难度:中等⭐️⭐️ 题目 你这个学期必须选修 numCourses 门课程,记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出,其中 prerequisites[i]…...
前端面试题30(闭包和作用域链的关系)
闭包和作用域链在JavaScript中是紧密相关的两个概念,理解它们之间的关系对于深入掌握JavaScript的执行机制至关重要。 作用域链 作用域链是一个链接列表,它包含了当前执行上下文的所有父级执行上下文的变量对象。每当函数被调用时,JavaScri…...
A股本周在3000点以下继续筑底,本周依然继续探底?
夜已深,市场传来了3个浓烈的消息,炸锅了,恐有大事发生,马上告诉所有人: 消息面: 1、中国经济周刊首席评论员钮文新称:不要等中小投资者都彻底希望,销户离场了,才发现该…...
Javadoc介绍
Javadoc 是用于生成 Java 代码文档的工具。它利用特定的注释格式,将 Java 源代码中的注释提取出来,并生成 HTML 文档。Javadoc 注释通常位于类、接口、构造函数、方法和字段的声明之前,以 /** 开始,以 */ 结束。以下是 Javadoc 注释的一些主要元素和使用方法: 基本语法 …...
C# Application.DoEvents()的作用
文章目录 1、详解 Application.DoEvents()2、示例处理用户事件响应系统事件控制台输出游戏和多媒体应用与操作系统的交互 3、注意事项总结 Application.DoEvents() 是 .NET 框架中的一个方法,它主要用于处理消息队列中的事件。在 Windows 应用程序中,当一…...
IDEA如何创建原生maven子模块
文件 -> 新建 -> 新模块 -> Maven ArcheTypeMaven ArcheType界面中的输入框介绍 名称:子模块的名称位置:子模块存放的路径名创建Git仓库:子模块不单独作为一个git仓库,无需勾选JDK:JDK版本号父项:…...
LCD EMC 辐射 测试随想
最近做几个产品过认证。 有带2.8寸 MCU8080接口的小屏(320 X 240),也有RGB接口的10.1寸的大屏(800*600). 以下为个人随想,不知道是否正确,仅作记录。 测试发现辐射的核心问题还是在于时钟及其倍频所产生的尖峰。 记得读…...
华为云AI开发平台ModelArts
华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...
eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...
多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...
Cursor实现用excel数据填充word模版的方法
cursor主页:https://www.cursor.com/ 任务目标:把excel格式的数据里的单元格,按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例,…...
【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...
逻辑回归:给不确定性划界的分类大师
想象你是一名医生。面对患者的检查报告(肿瘤大小、血液指标),你需要做出一个**决定性判断**:恶性还是良性?这种“非黑即白”的抉择,正是**逻辑回归(Logistic Regression)** 的战场&a…...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了: 这一篇我们开始讲: 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下: 一、场景操作步骤 操作步…...
全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...
