TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
1、迁移学习
- 用已经训练好模型的权重参数当做自己任务的模型权重初始化
- 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层
一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练
2、猫狗识别
import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样
3、加载预训练模型
from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3
从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。
pre_trained_model = ResNet101(input_shape = (75, 75, 3), include_top = False, weights = 'imagenet')
- 加载导入的模型
- input_shape 为输入大小
- include_top为False就是表示不要最后的全连接层
- 这段代码执行后,会自动进行下载
downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step
for layer in pre_trained_model.layers:layer.trainable = False
选择要进行重新训练的层
4、callback模块
在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:
4.1 callback示例
callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]
上面是一个模板,继续我们的猫狗识别的迁移学习项目:
4.2 定义callback
class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
- 定义一个类,继承Callback
- 定义一个函数,传入epoch值和日志
- 从当前epoch的日志中取出准确率,如果准确率大于95%
- 打印信息
- 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1, activation='sigmoid')(x)
model = Model(pre_trained_model.input, x)
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
- 导入优化器
- 将预训练模型的输出展平为一维
- 定义一个1024的全连接层
- 在这层加入dropout
- 输出全连接层
- 构建模型
- 指定优化器、损失函数、验证方法等配置训练器
5、模型训练
定义需要重新训练的层
train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75)) validation_generator = test_datagen.flow_from_directory( validation_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据
callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])
指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志
打印结果:
Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910
6、预测效果展示
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()


相关文章:
TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、迁移学习 用已经训练好模型的权重参数当做自己任务的模型权重初始化一般全连接层需…...
怎样开发adobe indesign插件,具体流程?
文章目录 第一.流程步骤第二.如何调试indesign插件第三.相关资源第四.总结 第一.流程步骤 开发Adobe InDesign插件通常涉及以下步骤: 获取SDK和工具: 从Adobe官方网站下载最新的Adobe InDesign SDK(Software Development Kit)&am…...
Docker 安装与基本操作
目录 一、Docker 概述 1、Docker 简述 2、Docker 的优势 3、Docker与虚拟机的区别 4、Docker 的核心概念 1)镜像 2)容器 3)仓库 二、Docker 安装 1、命令: 2、实操: 三、Docker 镜像操作 1、命令࿱…...
译文带你理解Python的dataclass装饰器
dataclass 是 Python dataclasses 模块中的一个 decorator。当使用 dataclass 装饰器时,它会自动生成一些特殊方法,包括: _ _ init _ _:用于初始化字段的构造函数_ _ repr _ _:对象的字符串表示_ _ eq _ _:…...
【C语言】实现程序的暂停
编写程序时,有时候需要让程序在某些地方暂停执行,等待用户输入或者观察程序执行结果。在 C 语言中,有多种方法可以实现程序的暂停,包括 system("pause")、getchar() 和 while ((c getchar()) ! \n && c ! EOF)…...
Hana SQL+正则表达式
目录 一、Pre 前言 二、知识点拆解 1)case when…then…else 2)json_value 函数 拓展资料 3)CAST 函数 拓展资料 4) ROUND 函数 5)occurences_regexpr 函数 拓展资料 6)正则表达式 拓展资料 三、整合分析…...
【笔记】顺利通过EMC试验(16-41)-视频笔记
目录 视频链接 P1:电子设备中有哪些主要骚扰源 P2:怎样减小DC模块的骚扰 P3:PCB上的辐射源究竟在哪里 P4:怎样控制PCB板的电磁辐射 P5:多层线路板是解决电磁兼容问题的简单方法 P6:怎样处理地线上的裂缝 P7:怎样降低时钟信号的辐射 P8:为什么IO接口的处理特别重要 P9…...
Qlik Sense 调用NPrinting生成On-Demand报表
安装 Qlik Sense On-Demand 报表控件 On-Demand 报表控件添加按钮,该按钮按需生成 Qlik NPrinting 报表。它包括在 Dashboard bundle 中。 当您希望用户能够使用应用程序中的选择作为过滤器在 Qlik Sense 中打印预定义 Qlik NPrinting 报表时,On-Deman…...
ElasticSearch重建/创建/删除索引操作 - 第501篇
历史文章(文章累计500) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 E…...
数据写入HBase(scala)
package sourceimport org.apache.hadoop.hbase.{HBaseConfiguration, TableName} import org.apache.hadoop.hbase.client.{ConnectionFactory, Put} import org.apache.hadoop.hbase.util.Bytesobject ffff {def main(args: Array[String]): Unit {//hbase连接配置val conf …...
Codeforces Round 799 (Div. 4)
目录 A. Marathon B. All Distinct C. Where’s the Bishop? D. The Clock E. Binary Deque F. 3SUM G. 2^Sort H. Gambling A. Marathon 直接模拟 void solve() {int ans0;for(int i1;i<4;i) {cin>>a[i];if(i>1&&a[i]>a[1]) ans;}cout<&l…...
为什么要用云手机养tiktok账号
在拓展海外电商市场的过程中,许多用户选择采用tiktok短视频平台引流的策略,以提升在电商平台上的流量,吸引更多消费者。而要进行tiktok引流,养号是必不可少的一个环节。tiktok云手机成为实现国内跨境养号的一种有效方式࿰…...
vue pc端网页实现自适应
一、基本原理 pc端做自适应可以用rem来实现,啥是rem,自己百度 二、新建rem.ts文件 // rem等比适配配置文件 // 基准大小 const baseSize 14 // 设置 rem 函数 function setRem () {// 当前页面宽度相对于 1920宽的缩放比例,可根据自己需要…...
Android 13以上版本读写SD卡权限适配
如题,最近工作上处理的问题,把解决方案简单逻列出来,供有需要的朋友参考之 解决方案: 1、配置权限 <uses-permission android:name"android.permission.READ_MEDIA_IMAGES" /><uses-permission android:name&q…...
并查集模板:食物链详解
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 50010;static int n,m; //n个动物,m局判断static int[] p new int[N]; //p[i]是i的根节点static int[] d new int[N]; //d[i]表示i到…...
使用WAF防御网络上的隐蔽威胁之反序列化攻击
什么是反序列化 反序列化是将数据结构或对象状态从某种格式转换回对象的过程。这种格式通常是二进制流或者字符串(如JSON、XML),它是对象序列化(即对象转换为可存储或可传输格式)的逆过程。 反序列化的安全风险 反…...
05. 交换机的基本配置
文章目录 一. 初识交换机1.1. 交换机的概述1.2. Ethernet_ll格式1.3. MAC分类1.4. 冲突域1.5. 广播域1.6. 交换机的原理1.7. 交换机的3种转发行为 二. 初识ARP2.1. ARP概述2.2. ARP报文格式2.3. ARP的分类2.4. 免费ARP的作用 三. 实验专题3.1. 实验1:交换机的基本原…...
yolo将标签数据打到原图上形成目标框
第一章 目标:为了查看自己在标注标签时是否准确,写了这段代码来将标注的框打到原图上 第二章 步骤:进行反归一化得到坐标画出矩形框 第二行是目标图片对应的txt,第三行是目标图片 第三章 全部代码如下: import cv2 import …...
002-00-02【大红ai源码】dolphinscheduler3.2.0 源码环境搭建------by孤山村头王大爷家女儿大红
【ai阅读源码-dolphinscheduler】 DolphinScheduler 开发手册1、软件要求2、克隆代码库3、编译打包4、代码风格5、新建数据库,导入元数据。6, 启动后端6.1 启动api-server 6.2 启动master-server6.3 启动worker-server 7 启动前端 DolphinScheduler 开发…...
python-自动化篇-运维-监控-如何使⽤Python处理和解析⽇志⽂件?-实操记录
文章目录 1. 选择日志文件格式: 确定要处理的日志文件的格式。不同的日志文件可能具有不同的格式,如文本日志、CSV、JSON、XML等。了解日志文件的格式对解析⾮常重要。2. 打开日志文件: 使⽤Python的文件操作功能打开日志文件,以便…...
掌控散热:OmenSuperHub开源风扇控制与性能优化工具深度解析
掌控散热:OmenSuperHub开源风扇控制与性能优化工具深度解析 【免费下载链接】OmenSuperHub 项目地址: https://gitcode.com/gh_mirrors/om/OmenSuperHub OmenSuperHub是一款专为惠普暗影精灵系列游戏本打造的开源控制软件,提供完全离线的硬件监控…...
awk实战:从基础语法到高效文本处理技巧
1. 为什么你应该掌握awk文本处理 第一次接触awk是在处理服务器日志的时候,当时我需要从几GB的访问日志中统计每个IP的出现次数。同事随手写了个awk命令,一行代码就解决了让我头疼半天的问题。从那时起,我就把这个"文本处理瑞士军刀&quo…...
从数据故事到视觉叙事:用Matplotlib定制专属渐变色,让你的图表会‘说话’
从数据故事到视觉叙事:用Matplotlib定制专属渐变色,让你的图表会‘说话’ 在数据爆炸的时代,图表早已不再是简单的数字呈现工具。当一位市场分析师需要向董事会展示季度业绩趋势,当一位科研人员需要向同行解释复杂的气候变化模式…...
各工厂产能负荷不透明?SAP 集团生产模块实现服装多工厂协同生产
在服装企业规模化扩张过程中,多工厂布局成为提升产能、覆盖市场的重要选择,但 “各工厂产能负荷不透明” 却成为制约协同效率的关键瓶颈。很多服装集团面临这样的困境:总部不清楚 A 工厂的高端定制生产线是否饱和,B 工厂的批量生产…...
如何快速搭建Windows syslog服务器:开源日志监控终极指南
如何快速搭建Windows syslog服务器:开源日志监控终极指南 【免费下载链接】visualsyslog Syslog Server for Windows with a graphical user interface 项目地址: https://gitcode.com/gh_mirrors/vi/visualsyslog 在Windows环境下高效监控Unix/Linux系统和网…...
07_gstack并行开发:Git Worktrees与Conductor多会话管理
07_gstack并行开发:Git Worktrees与Conductor多会话管理关键字:gstack、Git Worktrees、Conductor、并行开发、多会话管理、Claude Code、并行sprint、Garry Tan、AI并行工作流“One sprint, one person, one feature — that takes about 30 minutes wi…...
uni-app实战:驰腾打印机蓝牙对接与二维码打印全解析
1. 为什么选择uni-app对接驰腾打印机? 在移动开发领域,跨平台解决方案越来越受到开发者青睐。uni-app作为一款基于Vue.js的跨平台框架,可以一次开发同时发布到iOS、Android以及各种小程序平台。这种特性使得它成为对接硬件设备的理想选择&am…...
Qwen3-ASR-0.6B与Anaconda环境配置:一站式语音识别开发平台
Qwen3-ASR-0.6B与Anaconda环境配置:一站式语音识别开发平台 1. 引言 语音识别技术正在改变我们与设备交互的方式,从智能助手到实时字幕,从会议记录到语音搜索,这项技术已经深入到我们生活的方方面面。今天我要跟大家分享的是如何…...
用ESP32-S3给OV2640摄像头上‘网课’:手把手实现低延迟MJPEG监控系统
基于ESP32-S3与OV2640构建低延迟MJPEG监控系统的工程实践 在物联网和边缘计算领域,实时视频监控系统的需求日益增长。本文将深入探讨如何利用ESP32-S3微控制器和OV2640摄像头模组构建一个完整的低延迟MJPEG监控系统,从硬件连接到软件优化,全…...
OptiScaler终极指南:一键解锁三大显卡厂商的免费超采样神器
OptiScaler终极指南:一键解锁三大显卡厂商的免费超采样神器 【免费下载链接】OptiScaler DLSS replacement for AMD/Intel/Nvidia cards with multiple upscalers (XeSS/FSR2/DLSS) 项目地址: https://gitcode.com/GitHub_Trending/op/OptiScaler 还在为游戏…...
