TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
1、迁移学习
- 用已经训练好模型的权重参数当做自己任务的模型权重初始化
- 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层
一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练
2、猫狗识别
import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样
3、加载预训练模型
from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3
从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。
pre_trained_model = ResNet101(input_shape = (75, 75, 3), include_top = False, weights = 'imagenet')
- 加载导入的模型
- input_shape 为输入大小
- include_top为False就是表示不要最后的全连接层
- 这段代码执行后,会自动进行下载
downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step
for layer in pre_trained_model.layers:layer.trainable = False
选择要进行重新训练的层
4、callback模块
在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:
4.1 callback示例
callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]
上面是一个模板,继续我们的猫狗识别的迁移学习项目:
4.2 定义callback
class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
- 定义一个类,继承Callback
- 定义一个函数,传入epoch值和日志
- 从当前epoch的日志中取出准确率,如果准确率大于95%
- 打印信息
- 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1, activation='sigmoid')(x)
model = Model(pre_trained_model.input, x)
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
- 导入优化器
- 将预训练模型的输出展平为一维
- 定义一个1024的全连接层
- 在这层加入dropout
- 输出全连接层
- 构建模型
- 指定优化器、损失函数、验证方法等配置训练器
5、模型训练
定义需要重新训练的层
train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75)) validation_generator = test_datagen.flow_from_directory( validation_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据
callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])
指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志
打印结果:
Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910
6、预测效果展示
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()


相关文章:
TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、迁移学习 用已经训练好模型的权重参数当做自己任务的模型权重初始化一般全连接层需…...
怎样开发adobe indesign插件,具体流程?
文章目录 第一.流程步骤第二.如何调试indesign插件第三.相关资源第四.总结 第一.流程步骤 开发Adobe InDesign插件通常涉及以下步骤: 获取SDK和工具: 从Adobe官方网站下载最新的Adobe InDesign SDK(Software Development Kit)&am…...
Docker 安装与基本操作
目录 一、Docker 概述 1、Docker 简述 2、Docker 的优势 3、Docker与虚拟机的区别 4、Docker 的核心概念 1)镜像 2)容器 3)仓库 二、Docker 安装 1、命令: 2、实操: 三、Docker 镜像操作 1、命令࿱…...
译文带你理解Python的dataclass装饰器
dataclass 是 Python dataclasses 模块中的一个 decorator。当使用 dataclass 装饰器时,它会自动生成一些特殊方法,包括: _ _ init _ _:用于初始化字段的构造函数_ _ repr _ _:对象的字符串表示_ _ eq _ _:…...
【C语言】实现程序的暂停
编写程序时,有时候需要让程序在某些地方暂停执行,等待用户输入或者观察程序执行结果。在 C 语言中,有多种方法可以实现程序的暂停,包括 system("pause")、getchar() 和 while ((c getchar()) ! \n && c ! EOF)…...
Hana SQL+正则表达式
目录 一、Pre 前言 二、知识点拆解 1)case when…then…else 2)json_value 函数 拓展资料 3)CAST 函数 拓展资料 4) ROUND 函数 5)occurences_regexpr 函数 拓展资料 6)正则表达式 拓展资料 三、整合分析…...
【笔记】顺利通过EMC试验(16-41)-视频笔记
目录 视频链接 P1:电子设备中有哪些主要骚扰源 P2:怎样减小DC模块的骚扰 P3:PCB上的辐射源究竟在哪里 P4:怎样控制PCB板的电磁辐射 P5:多层线路板是解决电磁兼容问题的简单方法 P6:怎样处理地线上的裂缝 P7:怎样降低时钟信号的辐射 P8:为什么IO接口的处理特别重要 P9…...
Qlik Sense 调用NPrinting生成On-Demand报表
安装 Qlik Sense On-Demand 报表控件 On-Demand 报表控件添加按钮,该按钮按需生成 Qlik NPrinting 报表。它包括在 Dashboard bundle 中。 当您希望用户能够使用应用程序中的选择作为过滤器在 Qlik Sense 中打印预定义 Qlik NPrinting 报表时,On-Deman…...
ElasticSearch重建/创建/删除索引操作 - 第501篇
历史文章(文章累计500) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 E…...
数据写入HBase(scala)
package sourceimport org.apache.hadoop.hbase.{HBaseConfiguration, TableName} import org.apache.hadoop.hbase.client.{ConnectionFactory, Put} import org.apache.hadoop.hbase.util.Bytesobject ffff {def main(args: Array[String]): Unit {//hbase连接配置val conf …...
Codeforces Round 799 (Div. 4)
目录 A. Marathon B. All Distinct C. Where’s the Bishop? D. The Clock E. Binary Deque F. 3SUM G. 2^Sort H. Gambling A. Marathon 直接模拟 void solve() {int ans0;for(int i1;i<4;i) {cin>>a[i];if(i>1&&a[i]>a[1]) ans;}cout<&l…...
为什么要用云手机养tiktok账号
在拓展海外电商市场的过程中,许多用户选择采用tiktok短视频平台引流的策略,以提升在电商平台上的流量,吸引更多消费者。而要进行tiktok引流,养号是必不可少的一个环节。tiktok云手机成为实现国内跨境养号的一种有效方式࿰…...
vue pc端网页实现自适应
一、基本原理 pc端做自适应可以用rem来实现,啥是rem,自己百度 二、新建rem.ts文件 // rem等比适配配置文件 // 基准大小 const baseSize 14 // 设置 rem 函数 function setRem () {// 当前页面宽度相对于 1920宽的缩放比例,可根据自己需要…...
Android 13以上版本读写SD卡权限适配
如题,最近工作上处理的问题,把解决方案简单逻列出来,供有需要的朋友参考之 解决方案: 1、配置权限 <uses-permission android:name"android.permission.READ_MEDIA_IMAGES" /><uses-permission android:name&q…...
并查集模板:食物链详解
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 50010;static int n,m; //n个动物,m局判断static int[] p new int[N]; //p[i]是i的根节点static int[] d new int[N]; //d[i]表示i到…...
使用WAF防御网络上的隐蔽威胁之反序列化攻击
什么是反序列化 反序列化是将数据结构或对象状态从某种格式转换回对象的过程。这种格式通常是二进制流或者字符串(如JSON、XML),它是对象序列化(即对象转换为可存储或可传输格式)的逆过程。 反序列化的安全风险 反…...
05. 交换机的基本配置
文章目录 一. 初识交换机1.1. 交换机的概述1.2. Ethernet_ll格式1.3. MAC分类1.4. 冲突域1.5. 广播域1.6. 交换机的原理1.7. 交换机的3种转发行为 二. 初识ARP2.1. ARP概述2.2. ARP报文格式2.3. ARP的分类2.4. 免费ARP的作用 三. 实验专题3.1. 实验1:交换机的基本原…...
yolo将标签数据打到原图上形成目标框
第一章 目标:为了查看自己在标注标签时是否准确,写了这段代码来将标注的框打到原图上 第二章 步骤:进行反归一化得到坐标画出矩形框 第二行是目标图片对应的txt,第三行是目标图片 第三章 全部代码如下: import cv2 import …...
002-00-02【大红ai源码】dolphinscheduler3.2.0 源码环境搭建------by孤山村头王大爷家女儿大红
【ai阅读源码-dolphinscheduler】 DolphinScheduler 开发手册1、软件要求2、克隆代码库3、编译打包4、代码风格5、新建数据库,导入元数据。6, 启动后端6.1 启动api-server 6.2 启动master-server6.3 启动worker-server 7 启动前端 DolphinScheduler 开发…...
python-自动化篇-运维-监控-如何使⽤Python处理和解析⽇志⽂件?-实操记录
文章目录 1. 选择日志文件格式: 确定要处理的日志文件的格式。不同的日志文件可能具有不同的格式,如文本日志、CSV、JSON、XML等。了解日志文件的格式对解析⾮常重要。2. 打开日志文件: 使⽤Python的文件操作功能打开日志文件,以便…...
【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型
摘要 拍照搜题系统采用“三层管道(多模态 OCR → 语义检索 → 答案渲染)、两级检索(倒排 BM25 向量 HNSW)并以大语言模型兜底”的整体框架: 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后,分别用…...
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...
基于Java+MySQL实现(GUI)客户管理系统
客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...
GruntJS-前端自动化任务运行器从入门到实战
Grunt 完全指南:从入门到实战 一、Grunt 是什么? Grunt是一个基于 Node.js 的前端自动化任务运行器,主要用于自动化执行项目开发中重复性高的任务,例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...
LOOI机器人的技术实现解析:从手势识别到边缘检测
LOOI机器人作为一款创新的AI硬件产品,通过将智能手机转变为具有情感交互能力的桌面机器人,展示了前沿AI技术与传统硬件设计的完美结合。作为AI与玩具领域的专家,我将全面解析LOOI的技术实现架构,特别是其手势识别、物体识别和环境…...
AD学习(3)
1 PCB封装元素组成及简单的PCB封装创建 封装的组成部分: (1)PCB焊盘:表层的铜 ,top层的铜 (2)管脚序号:用来关联原理图中的管脚的序号,原理图的序号需要和PCB封装一一…...
Java并发编程实战 Day 11:并发设计模式
【Java并发编程实战 Day 11】并发设计模式 开篇 这是"Java并发编程实战"系列的第11天,今天我们聚焦于并发设计模式。并发设计模式是解决多线程环境下常见问题的经典解决方案,它们不仅提供了优雅的设计思路,还能显著提升系统的性能…...
【技巧】dify前端源代码修改第一弹-增加tab页
回到目录 【技巧】dify前端源代码修改第一弹-增加tab页 尝试修改dify的前端源代码,在知识库增加一个tab页"HELLO WORLD",完成后的效果如下 [gif01] 1. 前端代码进入调试模式 参考 【部署】win10的wsl环境下启动dify的web前端服务 启动调试…...
MySQL基本操作(续)
第3章:MySQL基本操作(续) 3.3 表操作 表是关系型数据库中存储数据的基本结构,由行和列组成。在MySQL中,表操作包括创建表、查看表结构、修改表和删除表等。本节将详细介绍这些操作。 3.3.1 创建表 在MySQL中&#…...
