当前位置: 首页 > news >正文

神经网络实战--使用迁移学习完成猫狗分类

在这里插入图片描述

前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

本文目录:

  • 一、加载数据集
    • 1.调用库函数
    • 2.加载数据集
    • 3.数据集管理
  • 二、猫狗数据集介绍
    • 1.猫狗数据集介绍:
    • 2.图片展示
  • 三、MobileNetV2网络介绍
    • 1.加载tensorflow提供的预训练模型
    • 2.轻量级网络——MobileNetV2
    • 3.MobileNetV2的网络模块
  • 四、搭建迁移学习
    • 1.训练
    • 2.训练结果可视化
    • 3.输出训练的准确率
    • 4.用cnn工具可视化一批数据的预测结果
    • 5.数据输出
    • 6.用cnn工具可视化一个数据样本的各层输出
    • 7.输出结果图像
  • 五、源码获取

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在这里插入图片描述
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
在这里插入图片描述

2.图片展示

下面是数据集中的图片展示:

class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络——MobileNetV2

使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩
在这里插入图片描述

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:
在这里插入图片描述
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride
在这里插入图片描述

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
在这里插入图片描述

四、搭建迁移学习

1.训练

inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

2.训练结果可视化

用图表显示准确率和损失函数

# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

3.输出训练的准确率

# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

4.用cnn工具可视化一批数据的预测结果

label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

5.数据输出

# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

6.用cnn工具可视化一个数据样本的各层输出

cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

五、源码获取

关注此公众号:人生苦短我用Pythons,回复 神经网络源码获取源码,快点击我吧

🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~
在这里插入图片描述
在这里插入图片描述

最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓

相关文章:

神经网络实战--使用迁移学习完成猫狗分类

前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~ 本文目录:一、加载数据集1.调用库函数2.加载数据集3.数据集管理二、猫狗数据集介绍1.猫狗数据集介…...

Attention机制 学习笔记

学习自https://easyai.tech/ai-definition/attention/ Attention本质 Attention(注意力)机制如果浅层的理解,跟他的名字非常匹配。他的核心逻辑就是“从关注全部到关注重点”。 比如我们人在看图片时,对图片的不同地方的注意力…...

数据类型与运算符

1.字符型作用: 字符型变量用于显示单个字符语法: char cc a ;注意1: 在显示字符型变量时,用单引号将字符括起来,不要用双引号注意2: 单引号内只能有一个字符,不可以是字符串C和C中字符型变量只占用1个字节。字符型变是并不是把字符本身放到内存中存储&am…...

算法刷题-二叉树的锯齿形层序遍历、用栈实现队列 栈设计、买卖股票的最佳时机 IV

文章目录二叉树的锯齿形层序遍历(树、广度优先搜索)用栈实现队列(栈、设计)买卖股票的最佳时机 IV(数组、动态规划)二叉树的锯齿形层序遍历(树、广度优先搜索) 给定一个二叉树&…...

华为OD机试 - 最小传递延迟(Python)| 代码编写思路+核心知识点

最小传递延迟 题目 通讯网络中有 N 个网络节点 用 1 ~ N 进行标识 网络通过一个有向无环图进行表示 其中图的边的值,表示节点之间的消息传递延迟 现给定相连节点之间的延时列表 times[i]={u,v,w} 其中 u 表示源节点,v 表示目的节点,w 表示 u 和 v 之间的消息传递延时 请计…...

集中供热调度系统天然气仪表内网仪表图像识别案例

一、项目需求 出于能耗采集与冬季集中供暖工作的节能和能耗分析需要,要采集现场的6块天然气表计,并存储进入客户的mySQL数据库中,现场采集的表计不允许接线,且网络环境为内网环境,需要采集表计数据并存入数据库&#…...

笔试题-2023-复旦微-数字IC设计【纯净题目版】

回到首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 推荐内容:数字IC设计学习比较实用的资料推荐 题目背景 笔试时间:2022.07.26应聘岗位:数字前端工程师笔试时长:120min笔试平台:赛码题目类型:基础题(10道)、选做题(10道)、验证题(5道)主观评价 难…...

【Linux】冯诺依曼体系结构和操作系统概念

文章目录🎪 冯诺依曼体系结构🚀1.体系概述🚀2.CPU和内存的数据交换🚀3.体系结构中数据的流动🎪 操作系统概念理解🚀1.简述🚀2.设计目的🚀3.定位🚀4.理解🚀5.管…...

HTML5之HTML基础学习笔记

列表标签 列表的应用场景 场景:在网页中按照行展示关联性的内容,如:新闻列表、排行榜、账单等特点:按照行的方式,整齐显示内容种类:无序列表、有序列表、自定义列表 这是老师PPT上的内容, 列表…...

FreeRTOS信号量 | FreeRTOS十

目录 说明: 一、信号量 1.1、信号量简介 1.2、信号量特点 二、二值信号量 2.1、二值信号量简介 2.2、获取与释放二值信号量函数 2.3、二值信号量使用过程与相关API函数 2.4、创建二值信号量函数了解 2.5、释放二值信号量了解 2.6、获取二值信号量了解 三…...

【SpringBoot】SpringBoot常用注解

一、前言首先这里说的SpringBoot常用注解是指在我们开发项目过程中,我们经常使用的注解,包含Spring、SpringBoot、SpringCloud、SpringMVC等这些框架中的注解,而不仅仅是SpringBoot中的注解。这里只是作一个注解列举,每个注解具体…...

数据一致性

目录一、AOP 动态代理切入方法(1) Aspect Oriented Programming(2) 切入点表达式二、SpringBoot 项目扫描类(1) ResourceLoader 扫描类(2) Map 的 computeIfAbsent 方法(3) 反射几个常用 api① 创建一个测试注解② 创建测试 PO 类③ 反射 api 获取指定类的指定注解信息(4) 返回…...

Docker不做虚拟化内核,对.NET有什么影响?

引子前两天刷抖音,看见了这样一个问题。问题:容器化不做虚拟内核,会有什么弊端?Java很多方法会跟CPU的核数有关,这个时候调用系统函数,读到的是宿主机信息,而不是我们限制资源的大小。思考&…...

HTML总结

CSS代码风格 空格规范&#xff1a; 1. 属性值前面&#xff0c;冒号后面&#xff0c;保留一个空格&#xff1b; 2. 选择器&#xff08;标签&#xff09;和大括号中间保留空格。 基本语法概述&#xff1a; 1.HTML标签是由尖括号包围的关键词&#xff0c;如<html> 2.HTM…...

ByteHouse:基于ClickHouse的实时数仓能力升级解读

更多技术交流、求职机会&#xff0c;欢迎关注字节跳动数据平台微信公众号&#xff0c;回复【1】进入官方交流群 ByteHouse是火山引擎上的一款云原生数据仓库&#xff0c;为用户带来极速分析体验&#xff0c;能够支撑实时数据分析和海量数据离线分析。便捷的弹性扩缩容能力&…...

[SSD固态硬盘技术 15] FTL映射表的神秘面纱

为什么需要映射表?固态硬盘的存储器件采用的是闪存[5],具有以下几个特点: (1)读写基本单位是以页(Page)为单位,擦除是以块(Block)为单位。...

浅析依赖注入框架的生命周期(以 InversifyJS 为例)

在上一篇介绍了 VSCode 的依赖注入设计&#xff0c;并且实现了一个简单的 IOC 框架。但是距离成为一个生产环境可用的框架还差的很远。 行业内已经有许多非常优秀的开源 IOC 框架&#xff0c;它们划分了更为清晰地模块来应对复杂情况下依赖注入运行的正确性。 这里我将以 Inv…...

HER2靶向药物研发进展-销售数据-上市药品前景分析

HER2长期作为肿瘤领域的热门靶点之一&#xff0c;其原因是它在多部位、多种形式的癌症中均有异常的表达&#xff0c;据研究表明HER2除了在胃癌、胆道癌、胆管癌、乳腺癌、卵巢癌、结肠癌、膀胱癌、肺癌、子宫颈癌、子宫浆液性子宫内膜癌、头颈癌、食道癌中的异常表达还存在于多…...

【第38天】不同路径数问题 | 网格 dp 入门

本文已收录于专栏&#x1f338;《Java入门一百例》&#x1f338;学习指引序、专栏前言一、网格模型二、【例题1】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、【例题2】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、推荐专栏四、课后习题序、专…...

LINUX之链接命令

链接命令学习目标能够说出软链接的创建方式能够说出硬链接的创建方式1. 链接命令的介绍链接命令是创建链接文件&#xff0c;链接文件分为:软链接硬链接命令说明ln -s创建软链接ln创建硬链接2. 软链接类似于Windows下的快捷方式&#xff0c;当一个源文件的目录层级比较深&#x…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

工程地质软件市场:发展现状、趋势与策略建议

一、引言 在工程建设领域&#xff0c;准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具&#xff0c;正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...

【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】

1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件&#xff08;System Property Definition File&#xff09;&#xff0c;用于声明和管理 Bluetooth 模块相…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心&#xff0c;JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例&#xff1a; 1. onclick - 点击事件 当元素被单击时触发&#xff08;左键点击&#xff09; button.onclick function() {alert("按钮被点击了&#xff01;&…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

服务器--宝塔命令

一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行&#xff01; sudo su - 1. CentOS 系统&#xff1a; yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

代码随想录刷题day30

1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...