深度学习中的迁移学习:预训练模型微调与实践
深度学习中的迁移学习:预训练模型微调与实践
目录
- 💡 迁移学习的核心概念
- 🧠 预训练模型的使用:ResNet与VGG的微调
- 🏥 迁移学习在医学图像分析中的应用
- 🔄 实践中的迁移学习微调过程
1. 💡 迁移学习的核心概念
迁移学习(Transfer Learning)在深度学习领域中发挥着至关重要的作用。其核心思想是:在大型数据集上训练好的模型可以被“迁移”到新的任务中,从而避免从零开始训练。深度神经网络的训练通常需要大量的数据和时间,通过利用已经在大规模数据集(如ImageNet)上训练过的模型,迁移学习能够极大地缩短训练时间,并显著提高性能。
迁移学习的关键点:
- 预训练模型:通过在通用数据集上训练模型(如ResNet、VGG等),这些模型学到了基础的特征表示,如边缘、形状和纹理。迁移学习的核心在于将这些基础特征应用到新的领域任务中。
- 微调(Fine-tuning):通过对预训练模型进行部分或全部参数的微调,模型可以适应新任务中的特定数据。微调的程度取决于新任务的相似性和目标。
- 冻结与解冻层:迁移学习过程中,通常会冻结模型的部分层,以保留通用的特征提取能力,针对新任务只对高层进行微调。
通过迁移学习,即使在拥有较少数据的情况下,也能获得优异的模型性能。接下来的部分将详细介绍如何使用经典的预训练模型,如ResNet和VGG,进行微调和迁移学习的实现。
2. 🧠 预训练模型的使用:ResNet与VGG的微调
深度学习中的经典模型如ResNet与VGG,常被用作迁移学习的预训练模型。它们在ImageNet等大规模数据集上预训练,并能够捕获图像中的通用特征。
ResNet与VGG的区别:
- ResNet(Residual Networks):ResNet通过引入残差块,解决了深度神经网络中的梯度消失问题。这使得ResNet可以训练非常深的网络(如ResNet50、ResNet101),同时保持较高的性能。
- VGG:VGG网络的特点在于其非常规则的卷积层堆叠结构,尽管深度较浅,但它能通过更宽的卷积核捕捉丰富的图像特征。
示例代码:微调ResNet进行图像分类
以下代码展示了如何使用预训练的ResNet模型并进行迁移学习,以适应新的图像分类任务。
# 引入必要的库
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder# 加载预训练的ResNet模型
resnet = models.resnet50(pretrained=True)# 冻结所有层的参数,以便只微调最后的全连接层
for param in resnet.parameters():param.requires_grad = False# 修改ResNet的最后一层,以适应新任务的分类数目
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 2) # 假设目标任务是二分类# 定义数据增强和预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据集
train_dataset = ImageFolder(root='path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=0.001)# 模型训练
for epoch in range(10): # 假设训练10个周期resnet.train()running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = resnet(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
代码解析:
- 模型加载与微调:代码中使用了
torchvision.models中的resnet50预训练模型,并冻结了所有卷积层,只微调最后的全连接层以适应新任务(如二分类)。 - 数据增强与预处理:通过
transforms.Compose进行图像预处理,包括图像的缩放、裁剪和归一化。 - 训练过程:通过微调最后的全连接层,模型能够快速适应新任务。
微调深度学习模型的关键在于,冻结模型的大部分层次,并根据任务的需求重新训练部分层。通过这种方式,可以在有限数据的情况下,获得良好的性能表现。
3. 🏥 迁移学习在医学图像分析中的应用
迁移学习在医学图像分析等领域中的应用尤为广泛,特别是在这种特定领域中,通常面临数据稀缺的问题。由于医学图像数据的获取和标注成本高昂,直接从头训练深度学习模型往往不可行。因此,利用预训练模型进行迁移学习成为一种行之有效的解决方案。
医学图像分析中的挑战:
- 数据稀缺:标注的医学图像数据通常较少,这使得从零开始训练模型变得困难。
- 高精度要求:医学图像分析任务通常需要非常高的精度,因为其结果会直接影响临床诊断。
- 特征差异:尽管预训练模型在自然图像上表现优异,但医学图像的特征通常与自然图像有显著区别,因此需要对模型进行专门的微调。
通过迁移学习,医学图像分析可以借助在ImageNet等大数据集上预训练的模型提取基础特征,然后通过微调,模型可以有效学习到医学图像中特定的病变或异常区域。
示例代码:应用ResNet进行医学图像分析
# 引入必要的库
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms, datasets# 加载预训练的ResNet50模型
resnet = models.resnet50(pretrained=True)# 冻结所有层的参数
for param in resnet.parameters():param.requires_grad = False# 修改最后一层以适应医学图像分析的分类
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 5) # 假设任务为五分类# 定义数据增强和预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载医学图像数据集
train_dataset = datasets.ImageFolder(root='path_to_medical_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=0.0001)# 模型训练过程
for epoch in range(20): # 假设训练20个周期resnet.train()running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = resnet(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
代码解析:
- 医学数据微调:利用预训练的ResNet模型,只微调最后的分类层,使其能够适应五分类任务,适用于医学图像分析中的不同疾病分类任务。
- 医学图像预处理:通过数据增强,如缩放、裁剪等操作,增强模型的泛化能力。
迁移学习在医学图像分析中的应用能够有效降低数据需求,同时提高模型的准确性和可靠性。
4. 🔄 实践中的迁移学习微调过程
在实际操作中,迁移学习的微调过程需要根据任务的复杂度和数据集的大小进行调整。具体微调的策略包括:
- 冻结大部分层:对于简单任务,只
需微调网络的高层特征表示层,而保留低层特征不变。
2. 解冻更多层:对于复杂任务,可能需要解冻更多层次,以学习更多领域特定的特征。
3. 调整学习率:微调时,通常使用较小的学习率,以避免破坏预训练模型中学到的有用特征。
以下是微调不同层的实践过程:
# 解冻部分层,允许更多层进行训练
for name, param in resnet.named_parameters():if "layer4" in name: # 假设只解冻ResNet的最后一层param.requires_grad = Trueelse:param.requires_grad = False# 调整学习率以适应微调
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, resnet.parameters()), lr=0.00001)# 继续进行模型的训练与微调
拓展部分:使用迁移学习进行图像分割任务
迁移学习不仅可以应用于分类任务,还可以应用于图像分割等更复杂的任务。通过调整预训练模型的结构,可以实现图像中的目标检测或分割。
相关文章:
深度学习中的迁移学习:预训练模型微调与实践
深度学习中的迁移学习:预训练模型微调与实践 目录 💡 迁移学习的核心概念🧠 预训练模型的使用:ResNet与VGG的微调🏥 迁移学习在医学图像分析中的应用🔄 实践中的迁移学习微调过程 1. 💡 迁移学…...
原生input实现时间选择器用法
2024.10.08今天我学习了如何用原生的input,实现时间选择器用法,效果如下: 代码如下: <div><input id"yf_start" type"text"> </div><script>$(#yf_start).datepicker({language: zh…...
对象的概念
对象是编程中一个重要的概念,尤其在面向对象编程(OOP)中更为核心。简单来说,对象是一种数据结构,它可以存储相关的数据和功能。以下是关于对象的详细描述: 1. 对象的定义 对象是属性(数据&…...
ARIMA|基于自回归差分移动平均模型时间序列预测
目录 一、基本内容介绍: 二、实际运行效果: 三、原理介绍: 四、完整程序下载: 一、基本内容介绍: 本代码基于Matlab平台,通过ARIMA模型对时间序列数据进行预测。程序以通过调试,解压后打开…...
sqli-labs靶场第三关less-3
sqli-labs靶场第三关less-3 1、确定注入点 http://192.168.128.3/sq/Less-3/?id1 http://192.168.128.3/sq/Less-3/?id2 有不同回显,判断可能存在注入, 2、判断注入类型 输入 http://192.168.128.3/sq/Less-3/?id1 and 11 http://192.168.128.3/sq/L…...
泡沫背后:人工智能的虚幻与现实
人工智能的盛世与泡沫 现今,人工智能热潮席卷科技行业,投资者、创业者和用户都被其光环吸引。然而,深入探讨这种现象,人工智能的泡沫正在形成,乃至具备崩溃的潜质。我们看到的,无非是一场由资本推动的狂欢…...
旅游管理智能化:SpringBoot框架的应用
第一章 绪论 1.1 研究现状 时代的发展,我们迎来了数字化信息时代,它正在渐渐的改变着人们的工作、学习以及娱乐方式。计算机网络,Internet扮演着越来越重要的角色,人们已经离不开网络了,大量的图片、文字、视频冲击着我…...
基于方块编码的图像压缩matlab仿真,带GUI界面
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1 编码单元的表示 4.2编码单元的编码 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 下图是随着方块大小的变化,图像的压缩率以及对应的图像质量指标PSN…...
不同jdk版本间的替换
假设安装了 JDK 21 后,发现电脑有兼容性问题或其他原因需要切换回 JDK 8,替换过程很简单。你只需卸载 JDK 21 或者让系统使用 JDK 8。以下是详细步骤: 1. 卸载 JDK 21 https://www.oracle.com/java/technologies/downloads/#java21 如果你想…...
408算法题leetcode--第28天
84. 柱状图中最大的矩形 题目地址:84. 柱状图中最大的矩形 - 力扣(LeetCode) 题解思路:暴力:每一列记为矩形的高,找左边和右边比他小的位置,得到以该列为高对应的宽;这样最大的矩形…...
【无人机设计与控制】无人机三维路径规划,对比蚁群算法,ACO_Astar_RRT算法
摘要 本文探讨了三种不同的无人机三维路径规划算法,即蚁群算法(ACO)、A算法(Astar)以及快速随机树算法(RRT)。通过仿真实验对比了各算法在不同环境下的性能,包括路径长度、计算效率…...
毕设 大数据电影数据分析与可视化系统(源码+论文)
文章目录 0 前言1 项目运行效果2 设计概要3 最后 0 前言 🔥这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要求,这两年不断有学弟学妹告诉学长自己做的项目系统达不到老师…...
10月7日刷题记录
C C...
苍穹外卖学习笔记(十五)
文章目录 一. 缓存菜品缓存菜品DishController.java清除缓存数据 缓存套餐Spring Cachemaven坐标常用注解 入门案例springcachedemo.sqlpom.xmlapplication.ymlCacheDemoApplication.javaWebMvcConfiguration.javaUserController.javaUser.javaUserMapper.java 套餐管理SkyAppl…...
知识图谱入门——5:Neo4j Desktop安装和使用手册(小白向:Cypher 查询语言:逐步教程!Neo4j 优缺点分析)
Neo4j简介 Neo4j 是一个基于图结构的 NoSQL 数据库,专门用于存储、查询和管理图形数据。它的核心思想是使用节点、关系和属性来描述数据。图数据库非常适合那些需要处理复杂关系的数据集,如社交网络、推荐系统、知识图谱等领域。 与传统的关系型数据库…...
35个数据分析模型
这些数据分析模型覆盖了战略规划、市场营销、运营管理、用户行为、财务分析等多个方面,是企业和组织在进行决策分析时常用的工具。分享给大家,如果想要PDF下载: https://edu.cda.cn/group/4/thread/178782 1、SWOT模型 SWOT模型是一种战略分…...
Java | Leetcode Java题解之第457题环形数组是否存在循环
题目: 题解: class Solution {public boolean circularArrayLoop(int[] nums) {int n nums.length;for (int i 0; i < n; i) {if (nums[i] 0) {continue;}int slow i, fast next(nums, i);// 判断非零且方向相同while (nums[slow] * nums[fast]…...
date:10.4(Content:Mr.Peng)( C language practice)
void reverse(char* p, int len) {char* left p;char* right p len - 2;while (left < right){char* temp left;*left *right;//当*left*right后,*temp已经被改为f了*right *temp;//你再*temp赋值给*right时,已经没用了left;right--;}}int main…...
【K8S系列】Kubernetes 集群中的网络常见面试题
在 Kubernetes 面试中,网络是一个重要的主题。理解 Kubernetes 网络模型、服务发现、网络策略等概念对候选人来说至关重要。以下是一些常见的 Kubernetes 网络面试题及其答案,帮助你准备面试。 1. Kubernetes 的网络模型是什么样的? 问题&am…...
Android 无Bug版 多语言设计方案!
出海业务为什么要做多语言? 1.市场扩大与本地化需求: 通过支持多种语言,出海项目可以触及更广泛的国际用户群体,进而扩大其市场份额。 本地化是吸引国际用户的重要策略之一,而语言本地化是其中的核心。使用用户的母语…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
【Java_EE】Spring MVC
目录 Spring Web MVC 编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 编辑参数重命名 RequestParam 编辑编辑传递集合 RequestParam 传递JSON数据 编辑RequestBody …...
鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南
1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发,使用DevEco Studio作为开发工具,采用Java语言实现,包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...
rnn判断string中第一次出现a的下标
# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...
TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?
在工业自动化持续演进的今天,通信网络的角色正变得愈发关键。 2025年6月6日,为期三天的华南国际工业博览会在深圳国际会展中心(宝安)圆满落幕。作为国内工业通信领域的技术型企业,光路科技(Fiberroad&…...
Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术解析
Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术解析 一、第一轮基础概念问题 1. Spring框架的核心容器是什么?它的作用是什么? Spring框架的核心容器是IoC(控制反转)容器。它的主要作用是管理对…...
第八部分:阶段项目 6:构建 React 前端应用
现在,是时候将你学到的 React 基础知识付诸实践,构建一个简单的前端应用来模拟与后端 API 的交互了。在这个阶段,你可以先使用模拟数据,或者如果你的后端 API(阶段项目 5)已经搭建好,可以直接连…...
