神经网络实战--使用迁移学习完成猫狗分类

前言:Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~
本文目录:
- 一、加载数据集
- 1.调用库函数
- 2.加载数据集
- 3.数据集管理
- 二、猫狗数据集介绍
- 1.猫狗数据集介绍:
- 2.图片展示
- 三、MobileNetV2网络介绍
- 1.加载tensorflow提供的预训练模型
- 2.轻量级网络——MobileNetV2
- 3.MobileNetV2的网络模块
- 四、搭建迁移学习
- 1.训练
- 2.训练结果可视化
- 3.输出训练的准确率
- 4.用cnn工具可视化一批数据的预测结果
- 5.数据输出
- 6.用cnn工具可视化一个数据样本的各层输出
- 7.输出结果图像
- 五、源码获取
说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:
-
实现基于tensorflow和keras的迁移学习
-
加载tensorflow提供的数据集(不得使用cifar10)
-
需要使用markdown单元格对数据集进行说明
-
加载tensorflow提供的预训练模型(不得使用vgg16)
-
需要使用markdown单元格对原始模型进行说明
-
网络末端连接任意结构的输出端网络
-
用图表显示准确率和损失函数
-
用cnn工具可视化一批数据的预测结果
-
用cnn工具可视化一个数据样本的各层输出
一、加载数据集
1.调用库函数
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout
2.加载数据集
数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。
path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)
3.数据集管理
使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。
train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)
二、猫狗数据集介绍
1.猫狗数据集介绍:
猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。

在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。

2.图片展示
下面是数据集中的图片展示:
class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
🌟🌟🌟 这里是输出的结果:✨✨✨

三、MobileNetV2网络介绍
1.加载tensorflow提供的预训练模型
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
2.轻量级网络——MobileNetV2
使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩。

3.MobileNetV2的网络模块
MobileNetV2的网络模块样子是这样的:

MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride:

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):

四、搭建迁移学习
1.训练
inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
base_model.trainable = False
base_model.summary()
🌟🌟🌟 这里是输出的结果:✨✨✨

2.训练结果可视化
用图表显示准确率和损失函数
# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()
🌟🌟🌟 这里是输出的结果:✨✨✨

3.输出训练的准确率
# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))
🌟🌟🌟 这里是输出的结果:✨✨✨

4.用cnn工具可视化一批数据的预测结果
label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)
🌟🌟🌟 这里是输出的结果:✨✨✨

5.数据输出
# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])
🌟🌟🌟 这里是输出的结果:✨✨✨

6.用cnn工具可视化一个数据样本的各层输出
cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))
🌟🌟🌟 这里是输出的结果:✨✨✨

7.输出结果图像
🌟🌟🌟 这里是输出的结果:✨✨✨



五、源码获取
关注此公众号:人生苦短我用Pythons,回复 神经网络源码获取源码,快点击我吧
🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~


最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓
相关文章:
神经网络实战--使用迁移学习完成猫狗分类
前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~ 本文目录:一、加载数据集1.调用库函数2.加载数据集3.数据集管理二、猫狗数据集介绍1.猫狗数据集介…...
Attention机制 学习笔记
学习自https://easyai.tech/ai-definition/attention/ Attention本质 Attention(注意力)机制如果浅层的理解,跟他的名字非常匹配。他的核心逻辑就是“从关注全部到关注重点”。 比如我们人在看图片时,对图片的不同地方的注意力…...
数据类型与运算符
1.字符型作用: 字符型变量用于显示单个字符语法: char cc a ;注意1: 在显示字符型变量时,用单引号将字符括起来,不要用双引号注意2: 单引号内只能有一个字符,不可以是字符串C和C中字符型变量只占用1个字节。字符型变是并不是把字符本身放到内存中存储&am…...
算法刷题-二叉树的锯齿形层序遍历、用栈实现队列 栈设计、买卖股票的最佳时机 IV
文章目录二叉树的锯齿形层序遍历(树、广度优先搜索)用栈实现队列(栈、设计)买卖股票的最佳时机 IV(数组、动态规划)二叉树的锯齿形层序遍历(树、广度优先搜索) 给定一个二叉树&…...
华为OD机试 - 最小传递延迟(Python)| 代码编写思路+核心知识点
最小传递延迟 题目 通讯网络中有 N 个网络节点 用 1 ~ N 进行标识 网络通过一个有向无环图进行表示 其中图的边的值,表示节点之间的消息传递延迟 现给定相连节点之间的延时列表 times[i]={u,v,w} 其中 u 表示源节点,v 表示目的节点,w 表示 u 和 v 之间的消息传递延时 请计…...
集中供热调度系统天然气仪表内网仪表图像识别案例
一、项目需求 出于能耗采集与冬季集中供暖工作的节能和能耗分析需要,要采集现场的6块天然气表计,并存储进入客户的mySQL数据库中,现场采集的表计不允许接线,且网络环境为内网环境,需要采集表计数据并存入数据库&#…...
笔试题-2023-复旦微-数字IC设计【纯净题目版】
回到首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 推荐内容:数字IC设计学习比较实用的资料推荐 题目背景 笔试时间:2022.07.26应聘岗位:数字前端工程师笔试时长:120min笔试平台:赛码题目类型:基础题(10道)、选做题(10道)、验证题(5道)主观评价 难…...
【Linux】冯诺依曼体系结构和操作系统概念
文章目录🎪 冯诺依曼体系结构🚀1.体系概述🚀2.CPU和内存的数据交换🚀3.体系结构中数据的流动🎪 操作系统概念理解🚀1.简述🚀2.设计目的🚀3.定位🚀4.理解🚀5.管…...
HTML5之HTML基础学习笔记
列表标签 列表的应用场景 场景:在网页中按照行展示关联性的内容,如:新闻列表、排行榜、账单等特点:按照行的方式,整齐显示内容种类:无序列表、有序列表、自定义列表 这是老师PPT上的内容, 列表…...
FreeRTOS信号量 | FreeRTOS十
目录 说明: 一、信号量 1.1、信号量简介 1.2、信号量特点 二、二值信号量 2.1、二值信号量简介 2.2、获取与释放二值信号量函数 2.3、二值信号量使用过程与相关API函数 2.4、创建二值信号量函数了解 2.5、释放二值信号量了解 2.6、获取二值信号量了解 三…...
【SpringBoot】SpringBoot常用注解
一、前言首先这里说的SpringBoot常用注解是指在我们开发项目过程中,我们经常使用的注解,包含Spring、SpringBoot、SpringCloud、SpringMVC等这些框架中的注解,而不仅仅是SpringBoot中的注解。这里只是作一个注解列举,每个注解具体…...
数据一致性
目录一、AOP 动态代理切入方法(1) Aspect Oriented Programming(2) 切入点表达式二、SpringBoot 项目扫描类(1) ResourceLoader 扫描类(2) Map 的 computeIfAbsent 方法(3) 反射几个常用 api① 创建一个测试注解② 创建测试 PO 类③ 反射 api 获取指定类的指定注解信息(4) 返回…...
Docker不做虚拟化内核,对.NET有什么影响?
引子前两天刷抖音,看见了这样一个问题。问题:容器化不做虚拟内核,会有什么弊端?Java很多方法会跟CPU的核数有关,这个时候调用系统函数,读到的是宿主机信息,而不是我们限制资源的大小。思考&…...
HTML总结
CSS代码风格 空格规范: 1. 属性值前面,冒号后面,保留一个空格; 2. 选择器(标签)和大括号中间保留空格。 基本语法概述: 1.HTML标签是由尖括号包围的关键词,如<html> 2.HTM…...
ByteHouse:基于ClickHouse的实时数仓能力升级解读
更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 ByteHouse是火山引擎上的一款云原生数据仓库,为用户带来极速分析体验,能够支撑实时数据分析和海量数据离线分析。便捷的弹性扩缩容能力&…...
[SSD固态硬盘技术 15] FTL映射表的神秘面纱
为什么需要映射表?固态硬盘的存储器件采用的是闪存[5],具有以下几个特点: (1)读写基本单位是以页(Page)为单位,擦除是以块(Block)为单位。...
浅析依赖注入框架的生命周期(以 InversifyJS 为例)
在上一篇介绍了 VSCode 的依赖注入设计,并且实现了一个简单的 IOC 框架。但是距离成为一个生产环境可用的框架还差的很远。 行业内已经有许多非常优秀的开源 IOC 框架,它们划分了更为清晰地模块来应对复杂情况下依赖注入运行的正确性。 这里我将以 Inv…...
HER2靶向药物研发进展-销售数据-上市药品前景分析
HER2长期作为肿瘤领域的热门靶点之一,其原因是它在多部位、多种形式的癌症中均有异常的表达,据研究表明HER2除了在胃癌、胆道癌、胆管癌、乳腺癌、卵巢癌、结肠癌、膀胱癌、肺癌、子宫颈癌、子宫浆液性子宫内膜癌、头颈癌、食道癌中的异常表达还存在于多…...
【第38天】不同路径数问题 | 网格 dp 入门
本文已收录于专栏🌸《Java入门一百例》🌸学习指引序、专栏前言一、网格模型二、【例题1】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、【例题2】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、推荐专栏四、课后习题序、专…...
LINUX之链接命令
链接命令学习目标能够说出软链接的创建方式能够说出硬链接的创建方式1. 链接命令的介绍链接命令是创建链接文件,链接文件分为:软链接硬链接命令说明ln -s创建软链接ln创建硬链接2. 软链接类似于Windows下的快捷方式,当一个源文件的目录层级比较深&#x…...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...
剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...
前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...
Typeerror: cannot read properties of undefined (reading ‘XXX‘)
最近需要在离线机器上运行软件,所以得把软件用docker打包起来,大部分功能都没问题,出了一个奇怪的事情。同样的代码,在本机上用vscode可以运行起来,但是打包之后在docker里出现了问题。使用的是dialog组件,…...
AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...
Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统
💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storms…...
Linux系统部署KES
1、安装准备 1.版本说明V008R006C009B0014 V008:是version产品的大版本。 R006:是release产品特性版本。 C009:是通用版 B0014:是build开发过程中的构建版本2.硬件要求 #安全版和企业版 内存:1GB 以上 硬盘…...
热烈祝贺埃文科技正式加入可信数据空间发展联盟
2025年4月29日,在福州举办的第八届数字中国建设峰会“可信数据空间分论坛”上,可信数据空间发展联盟正式宣告成立。国家数据局党组书记、局长刘烈宏出席并致辞,强调该联盟是推进全国一体化数据市场建设的关键抓手。 郑州埃文科技有限公司&am…...
WEB3全栈开发——面试专业技能点P7前端与链上集成
一、Next.js技术栈 ✅ 概念介绍 Next.js 是一个基于 React 的 服务端渲染(SSR)与静态网站生成(SSG) 框架,由 Vercel 开发。它简化了构建生产级 React 应用的过程,并内置了很多特性: ✅ 文件系…...
