当前位置: 首页 > news >正文

TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras"""
加载数据集工具类
"""class DatasetLoader:def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):self.path_url = path_urlself.image_size = image_sizeself.batch_size = batch_sizeself.class_mode = class_mode# 不使用图像增强def load_data(self):# 加载训练数据集train_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/train',  # 训练数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签)# 加载验证数据集val_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/validation',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)# 加载测试数据集test_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/test',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)class_names = train_data.class_namesreturn train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layersfrom utils.dataset_loader import DatasetLoader"""
使用MobileNetV2,实现图像多分类
"""# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16# 训练模型
def train():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()# 构建 MobileNet 模型base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)# 将模型的主干参数进行冻结base_model.trainable = Falsemodel = keras.Sequential([layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 设置主干模型base_model,# 对主干模型的输出进行全局平均池化layers.GlobalAveragePooling2D(),# 通过全连接层映射到最后的分类数目上layers.Dense(len(class_total), activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 模型结构model.summary()# 指明训练的轮数epoch,开始训练model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)if __name__ == '__main__':train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()inputs = keras.Input(shape=IMG_SHAPE)# 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)# 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新for layer in base_model.layers:layer.trainable = False# 在 MobileNetV2 之后添加自定义的顶层分类器x = layers.GlobalAveragePooling2D()(base_model.output)predictions = layers.Dense(len(class_total), activation='softmax')(x)# 构建最终模型model = keras.Model(inputs=base_model.input, outputs=predictions)# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 查看模型结构model.summary()model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 保存曲线图Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatchfrom utils.dataset_loader import DatasetLoader"""
模型工具类
"""class ModelUtil:def __init__(self, saved_model_dir, path_url):self.save_model_dir = saved_model_dir  # savedModel 模型保存地址self.path_url = path_url  # 模型训练数据地址# 批量识别 进行可视化显示def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()# 加载savedModel模型tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)# 创建一个新的 Keras 模型,包含 TFSMLayermodel = keras.Sequential([keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状tfs_layer])plt.figure(figsize=(10, 10))for images, labels in test_ds.take(1):# 使用模型进行预测outputs = model.predict(images)for i in range(num_images):plt.subplot(5, 5, i + 1)image = np.array(images[i]).astype("uint8")plt.imshow(image)index = int(np.argmax(outputs[i]))prediction = outputs[i][index]percentage_str = "{:.2f}%".format(prediction * 100)plt.title(f"{class_names[index]}: {percentage_str}")plt.axis("off")plt.subplots_adjust(hspace=0.5, wspace=0.5)plt.show()

使用工具类:

if __name__ == '__main__':# train()model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)model_util.batch_evaluation()

相关文章:

TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集 编写工具类,实现数据集的加载 import keras""" 加载数据集工具类 """class DatasetLoader:def __init__(self, path_url, image_size(224, 224), batch_size32, class_modecategorical):self.path_url path_urlself…...

Redis+Caffeine 实现两级缓存实战

RedisCaffeine 实现两级缓存 背景 ​ 事情的开始是这样的,前段时间接了个需求,给公司的商城官网提供一个查询预计送达时间的接口。接口很简单,根据请求传的城市仓库发货时间查询快递的预计送达时间。因为商城下单就会调用这个接口&#xff…...

SpringBoot:SpringBoot中如何实现对Http接口进行监控

一、前言 Spring Boot Actuator是Spring Boot提供的一个模块,用于监控和管理Spring Boot应用程序的运行时信息。它提供了一组监控端点(endpoints),用于获取应用程序的健康状态、性能指标、配置信息等,并支持通过 HTTP …...

STM32-I2C硬件外设

本博文建议与我上一篇I2C 通信协议​​​​​​共同理解 合成一套关于I2C软硬件体系 STM32内部集成了硬件I2C收发电路,可以由硬件自动执行时钟生成、起始终止条件生成、应答位收发、数据收发等功能,减轻CPU的负担 特点: 多主机功能&#x…...

暑假第一次作业

第一步:给R1,R2,R3,R4配IP [R1-GigabitEthernet0/0/0]ip address 192.168.1.1 24 [R1-Serial4/0/0]ip address 15.0.0.1 24 [R2-GigabitEthernet0/0/0]ip address 192.168.2.1 24 [R2-Serial4/0/0]ip address 25.0.0.1 24 [R3-GigabitEthernet0/0/0]ip address 192.…...

【算法专题】快速排序

1. 颜色分类 75. 颜色分类 - 力扣(LeetCode) 依据题意,我们需要把只包含0、1、2的数组划分为三个部分,事实上,在我们前面学习过的【算法专题】双指针算法-CSDN博客中,有一道题叫做移动零,题目要…...

debian 12 PXE Server 批量部署系统

pxe server 前言 PXE(Preboot eXecution Environment,预启动执行环境)是一种网络启动协议,允许计算机通过网络启动而不是使用本地硬盘。PXE服务器是实现这一功能的服务器,它提供了启动镜像和引导加载程序,…...

【Pytorch】RNN for Image Classification

文章目录 1 RNN 的定义2 RNN 输入 input, h_03 RNN 输出 output, h_n4 多层5 小试牛刀 学习参考来自 pytorch中nn.RNN()总结RNN for Image Classification(RNN图片分类–MNIST数据集)pytorch使用-nn.RNNBuilding RNNs is Fun with PyTorch and Google Colab 1 RNN 的定义 nn.…...

基于Java的飞机大战游戏的设计与实现论文

点击下载源码 基于Java的飞机大战游戏的设计与实现 摘 要 现如今,随着智能手机的兴起与普及,加上4G(the 4th Generation mobile communication ,第四代移动通信技术)网络的深入,越来越多的IT行业开始向手机…...

初识影刀:EXCEL根据部门筛选低值易耗品

第一次知道这个办公自动化的软件还是在招聘网站上,了解之后发现对于办公中重复性的工作还是挺有帮助的,特别是那些操作非EXCEL的重复性工作,当然用在EXCEL上更加方便,有些操作比写VBA便捷。 下面就是一个了解基本操作后&#xff…...

nginx的四层负载均衡实战

目录 1 环境准备 1.1 mysql 部署 1.2 nginx 部署 1.3 关闭防火墙和selinux 2 nginx配置 2.1 修改nginx主配置文件 2.2 创建stream配置文件 2.3 重启nginx 3 测试四层代理是否轮循成功 3.1 远程链接通过代理服务器访问 3.2 动图演示 4 四层反向代理算法介绍 4.1 轮询&#xff0…...

中职网络安全B模块Cenots6.8数据库

任务环境说明: ✓ 服务器场景:CentOS6.8(开放链接) ✓ 用户名:root;密码:123456 进入虚拟机操作系统:CentOS 6.8,登陆数据库(用户名:root&#x…...

BGP笔记的基本概要

技术背景: 在只有IGP(诸如OSPF、IS-IS、RIP等协议,因为最初是被设计在一个单域中进行一个路由操纵,因此被统一称为Interior Gateway Protocol,内部网关协议)的时代,域间路由无法实现一个全局路由…...

【Redis】复制(Replica)

文章目录 一、复制是什么?二、 基本命令三、 配置(分为配置文件和命令配置)3.1 配置文件3.2 命令配置3.3 嵌套连接3.4 关闭从属关系 四、 复制原理五、 缺点 以下是本篇文章正文内容 一、复制是什么? 主从复制 master&#xff…...

封装了一个仿照抖音效果的iOS评论弹窗

需求背景 开发一个类似抖音评论弹窗交互效果的弹窗,支持滑动消失, 滑动查看评论 效果如下图 思路 创建一个视图,该视图上面放置一个tableView, 该视图上添加一个滑动手势,同时设置代理,实现代理方法 (BOOL)gestur…...

【JavaWeb程序设计】Servlet(二)

目录 一、改进上一篇博客Servlet(一)的第一题 1. 运行截图 2. 建表 3. 实体类 4. JSP页面 4.1 login.jsp 4.2 loginSuccess.jsp 4.3 loginFail.jsp 5. mybatis-config.xml 6. 工具类:创建SqlSessionFactory实例,进行 My…...

php探针

php探针是用来探测空间、服务器运行状况和PHP信息用的,探针可以实时查看服务器硬盘资源、内存占用、网卡流量、系统负载、服务器时间等信息。 下面就分享下我是怎样利用php探针来探测服务器网站空间速度、性能、安全功能等。 具体步骤如下: 1.从网上下…...

泰勒级数 (Taylor Series) 动画展示 包括源码

泰勒级数 (Taylor Series) 动画展示 包括源码 flyfish 泰勒级数(英语:Taylor series)用无限项连加式 - 级数来表示一个函数,这些相加的项由函数在某一点的导数求得。 定义了一个函数f(x)表示要近似的函数 sin ⁡ ( x ) \sin(x) …...

蔚来汽车:拥抱TiDB,实现数据库性能与稳定性的飞跃

作者: Billdi表弟 原文来源: https://tidb.net/blog/449c3f5b 演讲嘉宾:吴记 蔚来汽车Tidb爱好者 整理编辑:黄漫绅(表妹)、李仲舒、吴记 本文来自 TiDB 社区合肥站走进蔚来汽车——来自吴记老师的演讲…...

【Django+Vue3 线上教育平台项目实战】构建高效线上教育平台之首页模块

文章目录 前言一、导航功能实现a.效果图:b.后端代码c.前端代码 二、轮播图功能实现a.效果图b.后端代码c.前端代码 三、标签栏功能实现a.效果图b.后端代码c.前端代码 四、侧边栏功能实现1.整体效果图2.侧边栏功能实现a.效果图b.后端代码c.前端代码 3.侧边栏展示分类及…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...

JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案

JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停​​ 1. ​​安全点(Safepoint)阻塞​​ ​​现象​​:JVM暂停但无GC日志,日志显示No GCs detected。​​原因​​:JVM等待所有线程进入安全点(如…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: ​onCreate()​​ ​调用时机​:Activity 首次创建时调用。​…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

Linux 下 DMA 内存映射浅析

序 系统 I/O 设备驱动程序通常调用其特定子系统的接口为 DMA 分配内存,但最终会调到 DMA 子系统的dma_alloc_coherent()/dma_alloc_attrs() 等接口。 关于 dma_alloc_coherent 接口详细的代码讲解、调用流程,可以参考这篇文章,我觉得写的非常…...

写一个shell脚本,把局域网内,把能ping通的IP和不能ping通的IP分类,并保存到两个文本文件里

写一个shell脚本&#xff0c;把局域网内&#xff0c;把能ping通的IP和不能ping通的IP分类&#xff0c;并保存到两个文本文件里 脚本1 #!/bin/bash #定义变量 ip10.1.1 #循环去ping主机的IP for ((i1;i<10;i)) doping -c1 $ip.$i &>/dev/null[ $? -eq 0 ] &&am…...

AT模式下的全局锁冲突如何解决?

一、全局锁冲突解决方案 1. 业务层重试机制&#xff08;推荐方案&#xff09; Service public class OrderService {GlobalTransactionalRetryable(maxAttempts 3, backoff Backoff(delay 100))public void createOrder(OrderDTO order) {// 库存扣减&#xff08;自动加全…...

性能优化中,多面体模型基本原理

1&#xff09;多面体编译技术是一种基于多面体模型的程序分析和优化技术&#xff0c;它将程序 中的语句实例、访问关系、依赖关系和调度等信息映射到多维空间中的几何对 象&#xff0c;通过对这些几何对象进行几何操作和线性代数计算来进行程序的分析和优 化。 其中&#xff0…...