当前位置: 首页 > news >正文

任何使用 Keras 进行迁移学习

在前面的文章中,我们介绍了如何使用 Keras 构建和训练全连接神经网络(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)。本文将带你深入学习如何使用 迁移学习(Transfer Learning) 来加速和提升模型性能。我们将使用 Keras 和预训练的卷积神经网络(如 VGG16)来完成一个图像分类任务。

目录

  1. 什么是迁移学习
  2. 环境准备
  3. 导入必要的库
  4. 加载和预处理数据
  5. 加载预训练模型
  6. 构建迁移学习模型
  7. 编译模型
  8. 训练模型
  9. 评估模型
  10. 保存和加载模型
  11. 总结

1. 什么是迁移学习

迁移学习 是一种机器学习技术,它利用在一个任务上训练的模型来解决另一个相关任务。通过迁移学习,我们可以:

  • 加速训练: 利用预训练模型的特征提取能力,减少训练时间。
  • 提高性能: 在数据量有限的情况下,迁移学习可以显著提高模型的泛化能力。
  • 减少数据需求: 预训练模型已经在大规模数据集上训练过,可以减少对新数据的需求。

在图像分类任务中,迁移学习通常涉及使用在 ImageNet 等大型数据集上预训练的卷积神经网络(如 VGG16、ResNet、Inception 等),并将其应用到新的图像分类任务中。

2. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

3. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • ImageDataGenerator: 用于数据增强和预处理。
  • applications: 预训练模型模块,包含 VGG16、ResNet 等。

4. 加载和预处理数据

我们将使用 猫狗数据集(Cats vs Dogs),这是一个二分类图像数据集,包含 25,000 张猫和狗的图片。我们将使用 Keras 的 ImageDataGenerator 进行数据增强和预处理。

# 数据集路径
train_dir = 'data/train'
validation_dir = 'data/validation'# 图像参数
img_height, img_width = 150, 150
batch_size = 32# 训练数据生成器(数据增强)
train_datagen = ImageDataGenerator(rescale=1./255,               # 归一化rotation_range=40,            # 随机旋转width_shift_range=0.2,        # 随机水平平移height_shift_range=0.2,       # 随机垂直平移shear_range=0.2,              # 随机剪切zoom_range=0.2,               # 随机缩放horizontal_flip=True,         # 随机水平翻转fill_mode='nearest'           # 填充方式
)# 测试数据生成器(仅归一化)
test_datagen = ImageDataGenerator(rescale=1./255)# 加载训练数据
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary'  # 二分类
)# 加载验证数据
validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary'
)

说明:

  • 使用 ImageDataGenerator 进行数据增强,可以提高模型的泛化能力。
  • flow_from_directory 方法从目录中加载数据,目录结构应为 train_dir/class1/train_dir/class2/

5. 加载预训练模型

我们将使用预训练的 VGG16 模型,并冻结其卷积基(convolutional base),只训练顶部的全连接层。

# 加载预训练的 VGG16 模型,不包括顶部的全连接层
conv_base = applications.VGG16(weights='imagenet',include_top=False,input_shape=(img_height, img_width, 3))# 冻结卷积基
conv_base.trainable = False# 查看模型结构
conv_base.summary()

说明:

  • weights='imagenet': 使用在 ImageNet 数据集上预训练的权重。
  • include_top=False: 不包括顶部的全连接层,以便我们添加自己的分类器。
  • conv_base.trainable = False: 冻结卷积基,防止其权重在训练过程中被更新。

6. 构建迁移学习模型

我们将添加自己的全连接层来进行分类。

model = models.Sequential([conv_base,  # 预训练的卷积基layers.Flatten(),  # 展平层layers.Dense(256, activation='relu'),  # 全连接层layers.Dropout(0.5),  # Dropout 层,防止过拟合layers.Dense(1, activation='sigmoid')  # 输出层,二分类
])# 查看模型结构
model.summary()

说明:

  • 添加 Flatten 层将多维输出展平。
  • 添加 Dense 层和 Dropout 层进行分类。
  • 输出层使用 sigmoid 激活函数进行二分类。

7. 编译模型

model.compile(optimizer=keras.optimizers.Adam(),loss='binary_crossentropy',metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和二元交叉熵损失函数。
  • 评估指标为准确率。

8. 训练模型

# 设置训练参数
epochs = 10# 训练模型
history = model.fit(train_generator,steps_per_epoch=train_generator.samples // batch_size,epochs=epochs,validation_data=validation_generator,validation_steps=validation_generator.samples // batch_size
)

说明:

  • steps_per_epoch: 每个 epoch 的步数,通常为训练样本数除以批量大小。
  • validation_steps: 每个 epoch 的验证步数,通常为验证样本数除以批量大小。

9. 评估模型

test_loss, test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print(f"\n测试准确率: {test_acc:.4f}")

10. 保存和加载模型

# 保存模型
model.save("cats_vs_dogs_transfer_learning.h5")# 加载模型
new_model = keras.models.load_model("cats_vs_dogs_transfer_learning.h5")

11. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')plt.show()

12. 解冻部分卷积基进行微调

为了进一步提高模型性能,可以解冻部分卷积基,进行微调。

# 解冻最后几个卷积层
conv_base.trainable = True# 查看可训练的参数
for layer in conv_base.layers:if layer.name == 'block5_conv1':breaklayer.trainable = False# 重新编译模型
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # 使用较低的学习率loss='binary_crossentropy',metrics=['accuracy'])# 继续训练模型
history_fine = model.fit(train_generator,steps_per_epoch=train_generator.samples // batch_size,epochs=5,validation_data=validation_generator,validation_steps=validation_generator.samples // batch_size
)

说明:

  • 解冻部分卷积层,并使用较低的学习率进行微调。
  • 继续训练模型以微调预训练模型的权重。

13. 课程回顾

本文其实不算什么知识点,只是利用迁移学习来加速训练的一个实际操作的例子。

作者简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师

相关文章:

任何使用 Keras 进行迁移学习

在前面的文章中,我们介绍了如何使用 Keras 构建和训练全连接神经网络(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)。本文将带你深入学习如何使用 迁移学习(Transfer Learning&#…...

Mac 使用mac 原生工具将mp4视频文件提取其中的 mp3 音频文件

简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:海轰 标签:程序猿|C++选手|学生 简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金,有幸在竞赛中拿过一些国奖、省奖…已保研 学习经验:扎实基础 + 多做笔…...

【SQL】一文速通SQL

SQL知识概念介绍 1. Relation Schema vs Relation Instance 简单而言,Relation Schema 是一个表,有变量还有数据类型 R (A1, A2, … , An) e.g. Student (sid: integer, name: string, login: string, addr: string, gender: char) Relation insta…...

【学习】【HTML】块级元素,行内元素,行内块级元素

块级元素 块级元素是 HTML 中一类重要的元素&#xff0c;它们在页面布局中占据整行空间&#xff0c;通常用于创建页面的主要结构组件。 常见的块级元素有哪些&#xff1f; <div>: 通用的容器元素&#xff0c;常用于创建布局块。<p>&#xff1a;段落元素&#xf…...

握手协议是如何在SSL VPN中发挥作用的?

SSL握手协议&#xff1a;客户端和服务器通过握手协议建立一个会话。会话包含一组参数&#xff0c;主要有会话ID、对方的证书、加密算法列表&#xff08;包括密钥交换算法、数据加密算法和MAC算法&#xff09;、压缩算法以及主密钥。SSL会话可以被多个连接共享&#xff0c;以减少…...

机器学习 - 为 Jupyter Notebook 安装新的 Kernel

https://ipython.readthedocs.io/en/latest/install/kernel_install.html 当使用jupyter-notebook --no-browser 启动一个 notebook 时&#xff0c;默认使用了该 jupyter module 所在的 Python 环境作为 kernel&#xff0c;比如 C:\devel\Python\Python311。 如果&#xff0c…...

CTF攻防世界小白刷题自学笔记13

1.fileinclude,难度&#xff1a;1,方向&#xff1a;Web 题目来源:宜兴网信办 题目描述:无 给一下题目链接&#xff1a;攻防世界Web方向新手模式第16题。 打开一看给了很多提示&#xff0c;什么language在index.php的第九行&#xff0c;flag在flag.php中&#xff0c;但事情显…...

Rust 模板匹配——根据指定图片查找处于大图中的位置(支持GPU加速)

Rust 模板匹配——根据指定图片查找处于大图中的位置(支持GPU加速) 01 前言 在手搓RPA工具的时候,总会碰到不好定位的情况,那么,就需要根据小图来找到对应屏幕上的位置(以图识图),这个需求也比较简单。想到市面上也有不少RPA工具都有这个功能,那么人家有的,俺也可以…...

JVM详解:类的加载过程

JVM中类的加载主要分为三个部分&#xff0c;分别为加载&#xff08;loading&#xff09;&#xff0c;链接&#xff08;linking&#xff09;&#xff0c;初始化&#xff08;initing&#xff09;。其中加载负责的主要是讲类文件加载到内存中变为类对象&#xff0c;不过此时只有基…...

Python →爬虫实践

爬取研究中心的书目 现在&#xff0c;想要把如下网站中的书目信息爬取出来。 案例一 耶鲁 Publications | Yale Law School 分析网页&#xff0c;如下图所示&#xff0c;需要爬取的页面&#xff0c;标签信息是“<p>”&#xff0c;所以用 itemssoup.find_all("p&…...

Visitor 访问者模式

1)意图 表示一个作用于某对象结构中的各元素的操作。它允许在不改变各元素的类的前提下定义用于这些元素的新操作。 2)结构 访问者模式的结构图如图 7-48 所示。 其中: Visitor(访问者) 为该对象结构中ConcreteElement 的每一个类声明一个 Vsit 操作。该操作的名字和特征标识…...

Mac解压包安装MongoDB8并设置launchd自启动

记录一下在mac上安装mongodb8过程&#xff0c;本机是M3芯片所以下载m芯片的安装包&#xff0c;intel芯片的类似操作。 首先下载安装程序包。 # M芯片下载地址 https://fastdl.mongodb.org/osx/mongodb-macos-arm64-8.0.3.tgz # intel芯片下载地址 https://fastdl.mongodb.org…...

Springboot采用jasypt加密配置

目录 前言 一、Jasypt简介 二、运用场景 三、整合Jasypt 2.1.环境配置 2.2.添加依赖 2.3.添加Jasypt配置 2.4.编写加/解密工具类 2.5.自定义加密属性前缀和后缀 2.6.防止密码泄露措施 2.61.自定义加密器 2.6.2通过环境变量指定加密盐值 总结 前言 在以往的多数项目中&#xff0…...

加载shellcode

​​​​​​ #include <stdio.h>#include <windows.h>DWORD GetHash(const char* fun_name){ DWORD digest 0; while (*fun_name) { digest ((digest << 25) | (digest >> 7)); //循环右移 7 位 digest *fun_name; //累加…...

K8S如何基于Istio实现全链路HTTPS

K8S如何基于Istio实现全链路HTTPS Istio 简介Istio 是什么?为什么选择 Istio?Istio 的核心概念Service Mesh(服务网格)Data Plane(数据平面)Sidecar Mode(边车模式)Ambient Mode(环境模式)Control Plane(控制平面)Istio 的架构与组件Envoy ProxyIstiod其他组件Istio 的流量管…...

React Query在现代前端开发中的应用

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 React Query在现代前端开发中的应用 React Query在现代前端开发中的应用 React Query在现代前端开发中的应用 引言 React Query …...

【HAProxy09】企业级反向代理HAProxy高级功能之压缩功能与后端服务器健康性监测

HAProxy 高级功能 介绍 HAProxy 高级配置及实用案例 压缩功能 对响应给客户端的报文进行压缩&#xff0c;以节省网络带宽&#xff0c;但是会占用部分CPU性能 建议在后端服务器开启压缩功能&#xff0c;而非在HAProxy上开启压缩 注意&#xff1a;默认Ubuntu的包安装nginx开…...

PostgreSQL中表的数据量很大且索引过大时怎么办

在PostgreSQL中&#xff0c;当表的数据量很大且索引过大时&#xff0c;可能会导致性能问题。以下是一些优化索引和表数据的方法&#xff1a; 1. 评估和删除不必要的索引 识别未使用的索引&#xff1a;使用pg_stat_user_indexes和pg_index系统视图来查找未被使用的索引&#x…...

【QML】QML多线程应用(WorkerScript)

1. 实现功能 QML项目中&#xff0c;点击一个按键后&#xff0c;运行一段比较耗时的程序&#xff0c;此时ui线程会卡住。如何避免ui线程卡住。 2. 单线程&#xff08;会卡住&#xff09; 2.1 界面 2.2 现象 点击delay btn后&#xff0c;执行耗时函数&#xff08;TestJs.func…...

认证鉴权框架SpringSecurity-1--概念和原理篇

1、基本概念 Spring Security 是一个强大且高度可定制的框架&#xff0c;用于构建安全的 Java 应用程序。它是 Spring 生态系统的一部分&#xff0c;提供了全面的安全解决方案&#xff0c;包括认证、授权、CSRF防护、会话管理等功能。 2、认证、授权和鉴权 &#xff08;1&am…...

如何突破Office功能限制?本地化激活方案全解析

如何突破Office功能限制&#xff1f;本地化激活方案全解析 【免费下载链接】ohook An universal Office "activation" hook with main focus of enabling full functionality of subscription editions 项目地址: https://gitcode.com/gh_mirrors/oh/ohook 当…...

PyTorch 2.8镜像真实效果:物理实验→电磁场/流体力学可视化视频

PyTorch 2.8镜像真实效果&#xff1a;物理实验→电磁场/流体力学可视化视频 1. 开箱即用的专业级物理模拟环境 当你第一次启动这个基于RTX 4090D优化的PyTorch 2.8镜像时&#xff0c;最直接的感受就是"专业工具就该这样"。这个镜像不是普通的深度学习环境&#xff…...

Ostrakon-VL终端效果展示:深夜食堂风格终端打印输出全过程录屏

Ostrakon-VL终端效果展示&#xff1a;深夜食堂风格终端打印输出全过程录屏 1. 像素特工终端概览 在零售与餐饮行业的数字化转型浪潮中&#xff0c;我们开发了这款基于Ostrakon-VL-8B多模态大模型的Web交互终端。与传统工业级UI不同&#xff0c;我们采用了高饱和度的像素艺术风…...

自学嵌入式第三天

选择排序 先找再换&#xff0c;每轮先找最小的再做交换每一轮的第一个数作为最小&#xff0c;从第二个数开始比较选出最小值交换到第一个位置。插入排序 一列数从第二个开始作为要插入的数a&#xff0c;逐个向前比较&#xff0c;比a大的数向后移一位&#xff0c;直到遇见比a小…...

Z-Image-Turbo-rinaiqiao-huiyewunv实战落地:高校动漫社AI辅助创作工作流搭建

Z-Image-Turbo-rinaiqiao-huiyewunv实战落地&#xff1a;高校动漫社AI辅助创作工作流搭建 1. 项目背景与核心价值 高校动漫社团经常面临创作效率低、人手不足的问题。传统手绘方式需要大量时间&#xff0c;而通用AI绘图工具又难以保持角色一致性。Z-Image Turbo (辉夜大小姐-…...

无损视频剪辑效率全攻略:5分钟掌握革新性剪辑技术

无损视频剪辑效率全攻略&#xff1a;5分钟掌握革新性剪辑技术 【免费下载链接】lossless-cut The swiss army knife of lossless video/audio editing 项目地址: https://gitcode.com/gh_mirrors/lo/lossless-cut 你是否曾因视频剪辑软件的漫长渲染过程而错失发布良机&a…...

cool-admin(midway版)前端表单验证:AsyncValidator与异步校验完整指南

cool-admin(midway版)前端表单验证&#xff1a;AsyncValidator与异步校验完整指南 【免费下载链接】cool-admin-midway &#x1f525; cool-admin(midway版)一个很酷的后台权限管理框架&#xff0c;模块化、插件化、CRUD极速开发&#xff0c;永久开源免费&#xff0c;基于midwa…...

Mermaid Live Editor:重新定义图表创作的开源利器

Mermaid Live Editor&#xff1a;重新定义图表创作的开源利器 【免费下载链接】mermaid-live-editor Edit, preview and share mermaid charts/diagrams. New implementation of the live editor. 项目地址: https://gitcode.com/GitHub_Trending/me/mermaid-live-editor …...

手把手教你配置Figma MCP:打造属于你自己的AI驱动设计组件库(以阅读题为例)

智能设计革命&#xff1a;用Figma MCP构建AI驱动的交互式学习组件库 当设计系统遇上生成式AI&#xff0c;一场关于效率与智能化的变革正在悄然发生。在Figma中构建可动态响应数据的智能组件库&#xff0c;已成为中高级UI/UX设计师突破传统设计边界的必备技能。本文将深入解析如…...

UE4实战:利用VaRest与VictoryBPLibrary实现高效本地文件读写

1. 为什么需要本地文件读写 在虚幻引擎4开发过程中&#xff0c;我们经常需要保存游戏配置、玩家进度或者关卡数据。想象一下你正在开发一个RPG游戏&#xff0c;需要记录玩家背包里的所有物品、当前任务进度和角色属性。如果每次退出游戏这些数据都消失&#xff0c;玩家肯定会抓…...