神经网络实战--使用迁移学习完成猫狗分类
前言:
Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~
本文目录:
- 一、加载数据集
- 1.调用库函数
- 2.加载数据集
- 3.数据集管理
- 二、猫狗数据集介绍
- 1.猫狗数据集介绍:
- 2.图片展示
- 三、MobileNetV2网络介绍
- 1.加载tensorflow提供的预训练模型
- 2.轻量级网络——MobileNetV2
- 3.MobileNetV2的网络模块
- 四、搭建迁移学习
- 1.训练
- 2.训练结果可视化
- 3.输出训练的准确率
- 4.用cnn工具可视化一批数据的预测结果
- 5.数据输出
- 6.用cnn工具可视化一个数据样本的各层输出
- 7.输出结果图像
- 五、源码获取
说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:
-
实现基于tensorflow和keras的迁移学习
-
加载tensorflow提供的数据集(不得使用cifar10)
-
需要使用markdown单元格对数据集进行说明
-
加载tensorflow提供的预训练模型(不得使用vgg16)
-
需要使用markdown单元格对原始模型进行说明
-
网络末端连接任意结构的输出端网络
-
用图表显示准确率和损失函数
-
用cnn工具可视化一批数据的预测结果
-
用cnn工具可视化一个数据样本的各层输出
一、加载数据集
1.调用库函数
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout
2.加载数据集
数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。
path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)
3.数据集管理
使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。
train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)
二、猫狗数据集介绍
1.猫狗数据集介绍:
猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
2.图片展示
下面是数据集中的图片展示:
class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
🌟🌟🌟 这里是输出的结果:✨✨✨
三、MobileNetV2网络介绍
1.加载tensorflow提供的预训练模型
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
2.轻量级网络——MobileNetV2
使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩。
3.MobileNetV2的网络模块
MobileNetV2的网络模块样子是这样的:
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride:
其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
四、搭建迁移学习
1.训练
inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
base_model.trainable = False
base_model.summary()
🌟🌟🌟 这里是输出的结果:✨✨✨
2.训练结果可视化
用图表显示准确率和损失函数
# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()
🌟🌟🌟 这里是输出的结果:✨✨✨
3.输出训练的准确率
# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))
🌟🌟🌟 这里是输出的结果:✨✨✨
4.用cnn工具可视化一批数据的预测结果
label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)
🌟🌟🌟 这里是输出的结果:✨✨✨
5.数据输出
# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])
🌟🌟🌟 这里是输出的结果:✨✨✨
6.用cnn工具可视化一个数据样本的各层输出
cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))
🌟🌟🌟 这里是输出的结果:✨✨✨
7.输出结果图像
🌟🌟🌟 这里是输出的结果:✨✨✨
五、源码获取
关注此公众号:人生苦短我用Pythons
,回复 神经网络源码
获取源码,快点击我吧
🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~
最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓
相关文章:

神经网络实战--使用迁移学习完成猫狗分类
前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~ 本文目录:一、加载数据集1.调用库函数2.加载数据集3.数据集管理二、猫狗数据集介绍1.猫狗数据集介…...

Attention机制 学习笔记
学习自https://easyai.tech/ai-definition/attention/ Attention本质 Attention(注意力)机制如果浅层的理解,跟他的名字非常匹配。他的核心逻辑就是“从关注全部到关注重点”。 比如我们人在看图片时,对图片的不同地方的注意力…...

数据类型与运算符
1.字符型作用: 字符型变量用于显示单个字符语法: char cc a ;注意1: 在显示字符型变量时,用单引号将字符括起来,不要用双引号注意2: 单引号内只能有一个字符,不可以是字符串C和C中字符型变量只占用1个字节。字符型变是并不是把字符本身放到内存中存储&am…...
算法刷题-二叉树的锯齿形层序遍历、用栈实现队列 栈设计、买卖股票的最佳时机 IV
文章目录二叉树的锯齿形层序遍历(树、广度优先搜索)用栈实现队列(栈、设计)买卖股票的最佳时机 IV(数组、动态规划)二叉树的锯齿形层序遍历(树、广度优先搜索) 给定一个二叉树&…...
华为OD机试 - 最小传递延迟(Python)| 代码编写思路+核心知识点
最小传递延迟 题目 通讯网络中有 N 个网络节点 用 1 ~ N 进行标识 网络通过一个有向无环图进行表示 其中图的边的值,表示节点之间的消息传递延迟 现给定相连节点之间的延时列表 times[i]={u,v,w} 其中 u 表示源节点,v 表示目的节点,w 表示 u 和 v 之间的消息传递延时 请计…...

集中供热调度系统天然气仪表内网仪表图像识别案例
一、项目需求 出于能耗采集与冬季集中供暖工作的节能和能耗分析需要,要采集现场的6块天然气表计,并存储进入客户的mySQL数据库中,现场采集的表计不允许接线,且网络环境为内网环境,需要采集表计数据并存入数据库&#…...
笔试题-2023-复旦微-数字IC设计【纯净题目版】
回到首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 推荐内容:数字IC设计学习比较实用的资料推荐 题目背景 笔试时间:2022.07.26应聘岗位:数字前端工程师笔试时长:120min笔试平台:赛码题目类型:基础题(10道)、选做题(10道)、验证题(5道)主观评价 难…...

【Linux】冯诺依曼体系结构和操作系统概念
文章目录🎪 冯诺依曼体系结构🚀1.体系概述🚀2.CPU和内存的数据交换🚀3.体系结构中数据的流动🎪 操作系统概念理解🚀1.简述🚀2.设计目的🚀3.定位🚀4.理解🚀5.管…...

HTML5之HTML基础学习笔记
列表标签 列表的应用场景 场景:在网页中按照行展示关联性的内容,如:新闻列表、排行榜、账单等特点:按照行的方式,整齐显示内容种类:无序列表、有序列表、自定义列表 这是老师PPT上的内容, 列表…...

FreeRTOS信号量 | FreeRTOS十
目录 说明: 一、信号量 1.1、信号量简介 1.2、信号量特点 二、二值信号量 2.1、二值信号量简介 2.2、获取与释放二值信号量函数 2.3、二值信号量使用过程与相关API函数 2.4、创建二值信号量函数了解 2.5、释放二值信号量了解 2.6、获取二值信号量了解 三…...
【SpringBoot】SpringBoot常用注解
一、前言首先这里说的SpringBoot常用注解是指在我们开发项目过程中,我们经常使用的注解,包含Spring、SpringBoot、SpringCloud、SpringMVC等这些框架中的注解,而不仅仅是SpringBoot中的注解。这里只是作一个注解列举,每个注解具体…...

数据一致性
目录一、AOP 动态代理切入方法(1) Aspect Oriented Programming(2) 切入点表达式二、SpringBoot 项目扫描类(1) ResourceLoader 扫描类(2) Map 的 computeIfAbsent 方法(3) 反射几个常用 api① 创建一个测试注解② 创建测试 PO 类③ 反射 api 获取指定类的指定注解信息(4) 返回…...

Docker不做虚拟化内核,对.NET有什么影响?
引子前两天刷抖音,看见了这样一个问题。问题:容器化不做虚拟内核,会有什么弊端?Java很多方法会跟CPU的核数有关,这个时候调用系统函数,读到的是宿主机信息,而不是我们限制资源的大小。思考&…...
HTML总结
CSS代码风格 空格规范: 1. 属性值前面,冒号后面,保留一个空格; 2. 选择器(标签)和大括号中间保留空格。 基本语法概述: 1.HTML标签是由尖括号包围的关键词,如<html> 2.HTM…...

ByteHouse:基于ClickHouse的实时数仓能力升级解读
更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 ByteHouse是火山引擎上的一款云原生数据仓库,为用户带来极速分析体验,能够支撑实时数据分析和海量数据离线分析。便捷的弹性扩缩容能力&…...

[SSD固态硬盘技术 15] FTL映射表的神秘面纱
为什么需要映射表?固态硬盘的存储器件采用的是闪存[5],具有以下几个特点: (1)读写基本单位是以页(Page)为单位,擦除是以块(Block)为单位。...

浅析依赖注入框架的生命周期(以 InversifyJS 为例)
在上一篇介绍了 VSCode 的依赖注入设计,并且实现了一个简单的 IOC 框架。但是距离成为一个生产环境可用的框架还差的很远。 行业内已经有许多非常优秀的开源 IOC 框架,它们划分了更为清晰地模块来应对复杂情况下依赖注入运行的正确性。 这里我将以 Inv…...

HER2靶向药物研发进展-销售数据-上市药品前景分析
HER2长期作为肿瘤领域的热门靶点之一,其原因是它在多部位、多种形式的癌症中均有异常的表达,据研究表明HER2除了在胃癌、胆道癌、胆管癌、乳腺癌、卵巢癌、结肠癌、膀胱癌、肺癌、子宫颈癌、子宫浆液性子宫内膜癌、头颈癌、食道癌中的异常表达还存在于多…...

【第38天】不同路径数问题 | 网格 dp 入门
本文已收录于专栏🌸《Java入门一百例》🌸学习指引序、专栏前言一、网格模型二、【例题1】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、【例题2】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、推荐专栏四、课后习题序、专…...

LINUX之链接命令
链接命令学习目标能够说出软链接的创建方式能够说出硬链接的创建方式1. 链接命令的介绍链接命令是创建链接文件,链接文件分为:软链接硬链接命令说明ln -s创建软链接ln创建硬链接2. 软链接类似于Windows下的快捷方式,当一个源文件的目录层级比较深&#x…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...
Matlab | matlab常用命令总结
常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
MySQL用户和授权
开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务: test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台
🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

USB Over IP专用硬件的5个特点
USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中,从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备(如专用硬件设备),从而消除了直接物理连接的需要。USB over IP的…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...

【p2p、分布式,区块链笔记 MESH】Bluetooth蓝牙通信 BLE Mesh协议的拓扑结构 定向转发机制
目录 节点的功能承载层(GATT/Adv)局限性: 拓扑关系定向转发机制定向转发意义 CG 节点的功能 节点的功能由节点支持的特性和功能决定。所有节点都能够发送和接收网格消息。节点还可以选择支持一个或多个附加功能,如 Configuration …...
Docker拉取MySQL后数据库连接失败的解决方案
在使用Docker部署MySQL时,拉取并启动容器后,有时可能会遇到数据库连接失败的问题。这种问题可能由多种原因导致,包括配置错误、网络设置问题、权限问题等。本文将分析可能的原因,并提供解决方案。 一、确认MySQL容器的运行状态 …...