当前位置: 首页 > news >正文

【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习

一、说明

        在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么它被称为迁移学习;我们将现有模型的学习转移到新的数据集中。

教程概述:

  1. 介绍
  2. 使用内置的 TensorFlow 模型进行迁移学习
  3. 使用 TensorFlow Hub 进行迁移学习

二、为什么迁移学习简介

之        我们已经探讨了如何使用数据增强来提高模型性能。现在的问题是,“如果我们没有足够的数据来从头开始训练我们的网络怎么办?

对        此的解决方案是使用迁移学习方法。一篇更具理论意义的帖子已经发布在我们的博客上。如果需要,请查看它以刷新一些想法。我们可以使用迁移学习将知识从一些预先训练好的开源网络转移到我们自己的简历问题中。

计        算机视觉研究社区在互联网上发布了许多数据集,如Imagenet或MS Coco或Pascal数据集。许多计算机视觉研究人员已经在这些数据集上训练了他们的算法。有时,此培训需要数周时间,并且可能需要许多 GPU。事实上,其他人已经完成了这项任务并经历了痛苦的高性能研究过程,这意味着我们经常可以下载开源权重。

有        很多网络已经过训练。例如,Imagenet 数据集,它由 1000 个不同的类和超过 14 万张图像组成。因此,网络可能有一个 softmax 单元,它输出一千个可能的类之一。我们可以做的是摆脱softmax层并创建我们自己的输出单元来表示例如猫或

        由于我们使用下载的权重,我们将只训练与我们的输出层关联的参数,在我们的例子中,这将是一个 sigmoid 输出层。

三、使用内置的 TensorFlow 模型进行迁移学习

        让我们首先为训练准备数据集。我们将使用 wget.download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile
wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        现在,让我们导入所有必需的库并构建模型。我们将使用一个名为MobileNetV2的预训练网络,该网络在ImageNet数据集上进行训练。在这里,我们希望使用除顶部分类图层之外的所有图层,因此我们不会将它们包含在我们的网络中。

        obileNetV2 体系结构概述:https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html

这        个模型和其他预训练模型已经在TensorFlow中可用。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGeneratorbase_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.summary()


输出:
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]  

        果最后一层输出不同数量的类,那么我们需要有自己的输出单元来输出以下类:。有几种方法可以做到这一点:

  • 取最后几层的权重,将它们用作初始化并进行梯度下降。这样,我们将重新训练网络的一部分。
  • 删除最后几层的权重,使用我们自己的新隐藏单元和我们自己的最终 sigmoid(或 softmax)输出。通过这种方式,我们可以更改输出的数量。

        因此,这两种方法中的任何一种都值得尝试。

        现在让我们冻结预训练层并添加一个新层,称为 GlobalAveragePooling2D,之后是具有 sigmoid 激活函数的 Dense 层。

现        

base_model.trainable = Falsemodel = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.summary()

    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

          现在是训练步骤的时候了。我们将使用图像数据生成器来迭代图像。现在没有必要为大量时期训练网络,因为我们已经预先训练了一部分网络。

model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)

让        我们看看结果。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy, label="Training")
plt.plot(epochs, val_accuracy, label="Validation")
plt.legend()
plt.title('Training and validation accuracy')
plt.figure()plt.plot(epochs, loss, label="Training")
plt.plot(epochs, val_loss, label="Validation")
plt.legend()
plt.title('Training and validation loss')

3.Text(0.5, 1.0, 'Training and validation loss')

四、 使用 TensorFlow Hub 进行迁移学习

        访问预训练模型的另一种方法是TensorFlow Hub。TensorFlow Hub是一个库,用于发布,发现和使用机器学习模型的可重用部分。您可以在此处找到更多预训练模型。

        我们将冻结图层并添加一个用于分类的新图层。全连接网络的输入称为瓶颈要素。它们表示网络中最后一个卷积层的激活图。

我们为给定数量的 epoch 训练模型。

        在这里,我们也可以使用 TensorBoard,但这次让我们保持简单。

        最后,让我们可视化一些预测。在此数据集中,猫标记为 0,狗标记为 1。我们将用蓝色显示正确的预测,用红色显示假。

预测

五、总结

        正如我们在上面看到的,使用迁移学习可以帮助我们在短时间内取得非常好的结果。使用数据增强,可以进一步增强结果。

         在下一篇文章中,我们将展示如何创建网络并将其转换以在移动设备上使用。

相关文章:

【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习 一、说明 在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么…...

mysql面试题46:MySQL中datetime和timestamp的区别

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:MySQL中DATETIME和TIMESTAMP的区别 在MySQL中,DATETIME和TIMESTAMP是两种用于存储日期和时间的数据类型。虽然它们都可以用于存储日期和时间信息…...

【Spring Boot】RabbitMQ消息队列 — RabbitMQ入门

💠一名热衷于分享知识的程序员 💠乐于在CSDN上与广大开发者交流学习。 💠希望通过每一次学习,让更多读者了解我 💠也希望能结识更多志同道合的朋友。 💠将继续努力,不断提升自己的专业技能,创造更多价值。🌿欢迎来到@"衍生星球"的CSDN博文🌿 🍁本…...

Navicat定时任务

Navicat定时任务 1、启动Navicat for MySQL工具,连接数据库。 2、查询定时任务选项是否开启 查询命令:SHOW VARIABLES LIKE ‘%event_scheduler%’; ON表示打开,OFF表示关闭。 打开定时任务命令 SET GLOBAL event_scheduler 0; 或者 SET G…...

小白必备:简单几步, 使用Cpolar+Emlog在Ubuntu上搭建个人博客网站

文章目录 前言1. 网站搭建1.1 Emolog网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2.Cpolar稳定隧道(云端设置)2.3.Cpolar稳定隧道(本地设置) 3. 公网访问测试总结 前言 博客作为使…...

封装 Token

什么是token? 作为计算机术语,是“令牌”的意思 。 Token 是服务端生成的一串字符串,以作客户端进行请求的一个令牌,当第一次登录后,服务器生成一个Token便将此Token返回给客户端,以后客户端只需带上这个Token前来请…...

CloudCompare 二次开发(17)——点云添加均匀分布的随机噪声

目录 一、概述二、代码集成三、结果展示一、概述 不依赖任何第三方点云相关库,使用CloudCompare编程实现点云添加随机噪声。添加随机噪声的算法原理见:PCL 点云添加均匀分布的随机噪声。 二、代码集成 1、mainwindow.h文件public中添加: void doActionAddRandomNoise(); …...

研发必会-异步编程利器之CompletableFuture(含源码 中)

近期热推文章: 1、springBoot对接kafka,批量、并发、异步获取消息,并动态、批量插入库表; 2、SpringBoot用线程池ThreadPoolTaskExecutor异步处理百万级数据; 3、基于Redis的Geo实现附近商铺搜索(含源码) 4、基于Redis实现关注、取关、共同关注及消息推送(含源码) 5…...

上海亚商投顾:沪指高开高走 锂电等新能源赛道大反攻

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日高开后强势震荡,创业板指盘中一度翻绿,随后探底回升再度走高。碳酸锂期货合约…...

力扣第235题 二又搜索树的最近公共祖先 c++

题目 235. 二叉搜索树的最近公共祖先 中等 (简单) 相关标签 树 深度优先搜索 二叉搜索树 二叉树 给定一个二叉搜索树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个结点 p、q&…...

时代风口中的Web3.0基建平台,重新定义Web3.0!

近年来,Web3.0概念的广泛兴起,给加密行业带来了崭新的叙事方式,同时也为加密行业提供了更加具有想象力的应用场景与商业空间,并让越来越多的行业从业者们意识到只有更大众化的市场共性需求才能推动加密市场的持续繁荣。当前围绕这…...

React学习笔记 001

什么是React 1.发送请求获取数据 处理数据(过滤、整理格式等) 3.操作DOM呈现页面 react 主要是负责第三部 操作dom 处理页面 数据渲染为HTML视图的开源js库。 好处 避免dom繁琐 组件化 提升复用率 特点 声明式编程: 简单 组件化编程…...

2023 | github无法访问或速度慢的问题解决方案

github无法访问或速度慢的问题解决方案 前言: 最近经常遇到github无法访问, 或者访问特别慢的问题, 在搜索了一圈解决方案后, 有些不再有效了, 但是其中有几个还特别好用, 总结一下. 首选方案 直接在github.com的域名上加一个fast > githubfast.com, 访问的是与github完全相…...

unity各种插件集合(自用)

2D Animation——2D序列帧/骨骼动画相关 2D PSD Importer——psb骨骼动画(unity官方建议使用psb而非psd) (Advanced —show preview package 勾选)出现 2D IK——反向动力学IK Universal RP——升级项目到Urp(通用渲…...

内网收集哈希传递

1.内网收集的前提 获得一个主机权限 补丁提权 可以使用 systeminfo 然后使用python脚本找到缺少的补丁 下载下来 让后使用exp提权 收集信息 路由信息 补丁接口 dns域看一看是不是域控 扫描别的端口 看看有没有内在的web网站 哈希传递 哈希是啥 哈希…...

前端目录笔记

HTML HTML 笔记:初识 HTML(HTML文本标签、文本列表、嵌入图片、背景色、网页链接)-CSDN博客html 笔记:CSS_UQI-LIUWJ的博客-CSDN博客HTML 笔记 表格_UQI-LIUWJ的博客-CSDN博客 javascript JavaScript 笔记 初识JavaScript&…...

Sui主网升级至V1.11.2版本

Sui主网现已升级至V1.11.2版本,同时Sui协议升级至27版本。其他升级要点如下: 对于一些更高级别的交易,更改了一些gas费设置,使其gas费消耗的更快。这些更改不影响以前在网络上运行的任何交易,只是为了确保在开始大量使…...

Mysql-数据库和数据表的基本操作

Mysql数据库和数据表的基本操作 一.数据库 1.创建数据库 创建数据库就是在数据库系统中划分一块空间存储数据 (1)语法 create database 数据库名称;(2)查看数据库 show create database 数据库名;(3)…...

拓扑排序求最长路

P1807 最长路 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 题目要求我们求出第1号到第n号节点之间最长的距离。 我们想到使用拓扑排序来求最长路。 正常来讲,我们应该把1号节点入队列,再出队列,把一号节点能到达的所有的点的入度减一&a…...

sqli-lab靶场通关

文章目录 less-1less-2less-3less-4less-5less-6less-7less-8less-9less-10 less-1 1、提示输入参数id,且值为数字; 2、判断是否存在注入点 id1报错,说明存在 SQL注入漏洞。 3、判断字符型还是数字型 id1 and 11 --id1 and 12 --id1&quo…...

2025年能源电力系统与流体力学国际会议 (EPSFD 2025)

2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...

day52 ResNet18 CBAM

在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...

NFT模式:数字资产确权与链游经济系统构建

NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...

【Oracle】分区表

个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...