TensorFlow系列:第四讲:MobileNetV2实战
一. 加载数据集
编写工具类,实现数据集的加载

import keras"""
加载数据集工具类
"""class DatasetLoader:def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):self.path_url = path_urlself.image_size = image_sizeself.batch_size = batch_sizeself.class_mode = class_mode# 不使用图像增强def load_data(self):# 加载训练数据集train_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/train', # 训练数据集的目录路径image_size=self.image_size, # 调整图像大小batch_size=self.batch_size, # 每批次的样本数量label_mode=self.class_mode, # 类别模式:返回one-hot编码的标签)# 加载验证数据集val_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/validation', # 验证数据集的目录路径image_size=self.image_size, # 调整图像大小batch_size=self.batch_size, # 每批次的样本数量label_mode=self.class_mode # 类别模式:返回one-hot编码的标签)# 加载测试数据集test_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/test', # 验证数据集的目录路径image_size=self.image_size, # 调整图像大小batch_size=self.batch_size, # 每批次的样本数量label_mode=self.class_mode # 类别模式:返回one-hot编码的标签)class_names = train_data.class_namesreturn train_data, val_data, test_data, class_names
二. 训练模型完整代码
import keras
from keras import layersfrom utils.dataset_loader import DatasetLoader"""
使用MobileNetV2,实现图像多分类
"""# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'# 图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16# 训练模型
def train():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()# 构建 MobileNet 模型base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)# 将模型的主干参数进行冻结base_model.trainable = Falsemodel = keras.Sequential([layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 设置主干模型base_model,# 对主干模型的输出进行全局平均池化layers.GlobalAveragePooling2D(),# 通过全连接层映射到最后的分类数目上layers.Dense(len(class_total), activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 模型结构model.summary()# 指明训练的轮数epoch,开始训练model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)if __name__ == '__main__':train()
训练模型输出如下:
模型结构:

训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。

验证正确率:

保存的模型:

三. 函数式调用方式
以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。
# 函数式调用方式
def train1():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()inputs = keras.Input(shape=IMG_SHAPE)# 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)# 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新for layer in base_model.layers:layer.trainable = False# 在 MobileNetV2 之后添加自定义的顶层分类器x = layers.GlobalAveragePooling2D()(base_model.output)predictions = layers.Dense(len(class_total), activation='softmax')(x)# 构建最终模型model = keras.Model(inputs=base_model.input, outputs=predictions)# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 查看模型结构model.summary()model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)
四. 保存训练过程曲线图
在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。
在项目中编写一个工具类如下:

上边代码简单改造:
# 训练模型history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 保存曲线图Utils.trainResult(history, RESULT_URL)
曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

五. 模型可视化批量测试

编写可视化批量测试工具类:
import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatchfrom utils.dataset_loader import DatasetLoader"""
模型工具类
"""class ModelUtil:def __init__(self, saved_model_dir, path_url):self.save_model_dir = saved_model_dir # savedModel 模型保存地址self.path_url = path_url # 模型训练数据地址# 批量识别 进行可视化显示def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()# 加载savedModel模型tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)# 创建一个新的 Keras 模型,包含 TFSMLayermodel = keras.Sequential([keras.Input(shape=image_size + (3,)), # 根据你的模型的输入形状tfs_layer])plt.figure(figsize=(10, 10))for images, labels in test_ds.take(1):# 使用模型进行预测outputs = model.predict(images)for i in range(num_images):plt.subplot(5, 5, i + 1)image = np.array(images[i]).astype("uint8")plt.imshow(image)index = int(np.argmax(outputs[i]))prediction = outputs[i][index]percentage_str = "{:.2f}%".format(prediction * 100)plt.title(f"{class_names[index]}: {percentage_str}")plt.axis("off")plt.subplots_adjust(hspace=0.5, wspace=0.5)plt.show()
使用工具类:
if __name__ == '__main__':# train()model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)model_util.batch_evaluation()
相关文章:
TensorFlow系列:第四讲:MobileNetV2实战
一. 加载数据集 编写工具类,实现数据集的加载 import keras""" 加载数据集工具类 """class DatasetLoader:def __init__(self, path_url, image_size(224, 224), batch_size32, class_modecategorical):self.path_url path_urlself…...
Redis+Caffeine 实现两级缓存实战
RedisCaffeine 实现两级缓存 背景 事情的开始是这样的,前段时间接了个需求,给公司的商城官网提供一个查询预计送达时间的接口。接口很简单,根据请求传的城市仓库发货时间查询快递的预计送达时间。因为商城下单就会调用这个接口ÿ…...
SpringBoot:SpringBoot中如何实现对Http接口进行监控
一、前言 Spring Boot Actuator是Spring Boot提供的一个模块,用于监控和管理Spring Boot应用程序的运行时信息。它提供了一组监控端点(endpoints),用于获取应用程序的健康状态、性能指标、配置信息等,并支持通过 HTTP …...
STM32-I2C硬件外设
本博文建议与我上一篇I2C 通信协议共同理解 合成一套关于I2C软硬件体系 STM32内部集成了硬件I2C收发电路,可以由硬件自动执行时钟生成、起始终止条件生成、应答位收发、数据收发等功能,减轻CPU的负担 特点: 多主机功能&#x…...
暑假第一次作业
第一步:给R1,R2,R3,R4配IP [R1-GigabitEthernet0/0/0]ip address 192.168.1.1 24 [R1-Serial4/0/0]ip address 15.0.0.1 24 [R2-GigabitEthernet0/0/0]ip address 192.168.2.1 24 [R2-Serial4/0/0]ip address 25.0.0.1 24 [R3-GigabitEthernet0/0/0]ip address 192.…...
【算法专题】快速排序
1. 颜色分类 75. 颜色分类 - 力扣(LeetCode) 依据题意,我们需要把只包含0、1、2的数组划分为三个部分,事实上,在我们前面学习过的【算法专题】双指针算法-CSDN博客中,有一道题叫做移动零,题目要…...
debian 12 PXE Server 批量部署系统
pxe server 前言 PXE(Preboot eXecution Environment,预启动执行环境)是一种网络启动协议,允许计算机通过网络启动而不是使用本地硬盘。PXE服务器是实现这一功能的服务器,它提供了启动镜像和引导加载程序,…...
【Pytorch】RNN for Image Classification
文章目录 1 RNN 的定义2 RNN 输入 input, h_03 RNN 输出 output, h_n4 多层5 小试牛刀 学习参考来自 pytorch中nn.RNN()总结RNN for Image Classification(RNN图片分类–MNIST数据集)pytorch使用-nn.RNNBuilding RNNs is Fun with PyTorch and Google Colab 1 RNN 的定义 nn.…...
基于Java的飞机大战游戏的设计与实现论文
点击下载源码 基于Java的飞机大战游戏的设计与实现 摘 要 现如今,随着智能手机的兴起与普及,加上4G(the 4th Generation mobile communication ,第四代移动通信技术)网络的深入,越来越多的IT行业开始向手机…...
初识影刀:EXCEL根据部门筛选低值易耗品
第一次知道这个办公自动化的软件还是在招聘网站上,了解之后发现对于办公中重复性的工作还是挺有帮助的,特别是那些操作非EXCEL的重复性工作,当然用在EXCEL上更加方便,有些操作比写VBA便捷。 下面就是一个了解基本操作后ÿ…...
nginx的四层负载均衡实战
目录 1 环境准备 1.1 mysql 部署 1.2 nginx 部署 1.3 关闭防火墙和selinux 2 nginx配置 2.1 修改nginx主配置文件 2.2 创建stream配置文件 2.3 重启nginx 3 测试四层代理是否轮循成功 3.1 远程链接通过代理服务器访问 3.2 动图演示 4 四层反向代理算法介绍 4.1 轮询࿰…...
中职网络安全B模块Cenots6.8数据库
任务环境说明: ✓ 服务器场景:CentOS6.8(开放链接) ✓ 用户名:root;密码:123456 进入虚拟机操作系统:CentOS 6.8,登陆数据库(用户名:root&#x…...
BGP笔记的基本概要
技术背景: 在只有IGP(诸如OSPF、IS-IS、RIP等协议,因为最初是被设计在一个单域中进行一个路由操纵,因此被统一称为Interior Gateway Protocol,内部网关协议)的时代,域间路由无法实现一个全局路由…...
【Redis】复制(Replica)
文章目录 一、复制是什么?二、 基本命令三、 配置(分为配置文件和命令配置)3.1 配置文件3.2 命令配置3.3 嵌套连接3.4 关闭从属关系 四、 复制原理五、 缺点 以下是本篇文章正文内容 一、复制是什么? 主从复制 masterÿ…...
封装了一个仿照抖音效果的iOS评论弹窗
需求背景 开发一个类似抖音评论弹窗交互效果的弹窗,支持滑动消失, 滑动查看评论 效果如下图 思路 创建一个视图,该视图上面放置一个tableView, 该视图上添加一个滑动手势,同时设置代理,实现代理方法 (BOOL)gestur…...
【JavaWeb程序设计】Servlet(二)
目录 一、改进上一篇博客Servlet(一)的第一题 1. 运行截图 2. 建表 3. 实体类 4. JSP页面 4.1 login.jsp 4.2 loginSuccess.jsp 4.3 loginFail.jsp 5. mybatis-config.xml 6. 工具类:创建SqlSessionFactory实例,进行 My…...
php探针
php探针是用来探测空间、服务器运行状况和PHP信息用的,探针可以实时查看服务器硬盘资源、内存占用、网卡流量、系统负载、服务器时间等信息。 下面就分享下我是怎样利用php探针来探测服务器网站空间速度、性能、安全功能等。 具体步骤如下: 1.从网上下…...
泰勒级数 (Taylor Series) 动画展示 包括源码
泰勒级数 (Taylor Series) 动画展示 包括源码 flyfish 泰勒级数(英语:Taylor series)用无限项连加式 - 级数来表示一个函数,这些相加的项由函数在某一点的导数求得。 定义了一个函数f(x)表示要近似的函数 sin ( x ) \sin(x) …...
蔚来汽车:拥抱TiDB,实现数据库性能与稳定性的飞跃
作者: Billdi表弟 原文来源: https://tidb.net/blog/449c3f5b 演讲嘉宾:吴记 蔚来汽车Tidb爱好者 整理编辑:黄漫绅(表妹)、李仲舒、吴记 本文来自 TiDB 社区合肥站走进蔚来汽车——来自吴记老师的演讲…...
【Django+Vue3 线上教育平台项目实战】构建高效线上教育平台之首页模块
文章目录 前言一、导航功能实现a.效果图:b.后端代码c.前端代码 二、轮播图功能实现a.效果图b.后端代码c.前端代码 三、标签栏功能实现a.效果图b.后端代码c.前端代码 四、侧边栏功能实现1.整体效果图2.侧边栏功能实现a.效果图b.后端代码c.前端代码 3.侧边栏展示分类及…...
网络编程(Modbus进阶)
思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...
深度学习在微纳光子学中的应用
深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...
多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...
【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力
引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...
爬虫基础学习day2
# 爬虫设计领域 工商:企查查、天眼查短视频:抖音、快手、西瓜 ---> 飞瓜电商:京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空:抓取所有航空公司价格 ---> 去哪儿自媒体:采集自媒体数据进…...
【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)
前言: 双亲委派机制对于面试这块来说非常重要,在实际开发中也是经常遇见需要打破双亲委派的需求,今天我们一起来探索一下什么是双亲委派机制,在此之前我们先介绍一下类的加载器。 目录 编辑 前言: 类加载器 1. …...
HubSpot推出与ChatGPT的深度集成引发兴奋与担忧
上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...
前端中slice和splic的区别
1. slice slice 用于从数组中提取一部分元素,返回一个新的数组。 特点: 不修改原数组:slice 不会改变原数组,而是返回一个新的数组。提取数组的部分:slice 会根据指定的开始索引和结束索引提取数组的一部分。不包含…...
MyBatis中关于缓存的理解
MyBatis缓存 MyBatis系统当中默认定义两级缓存:一级缓存、二级缓存 默认情况下,只有一级缓存开启(sqlSession级别的缓存)二级缓存需要手动开启配置,需要局域namespace级别的缓存 一级缓存(本地缓存&#…...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...
