当前位置: 首页 > news >正文

[PyTorch][chapter 52][迁移学习]

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023@author: chengxf2
"""import torch
from torch import optim,nnclass Flatten(nn.Module):def __init__(self):super(Flatten,self).__init__()def forward(self, x):a = torch.tensor(x.shape[1:])#dim 中 input 张量的每一行的乘积。shape = torch.prod(a).item()#print("\n ---new shape--- ",shape)return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器分类类型
已有分类器[猫,狗,鸡,鸭】
新分类器[猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023@author: chengxf2
"""import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18from util import FlattenbatchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)root ='pokemon'
resize =224csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()def evalute(model, loader):total =len(loader.dataset)correct =0for x,y in loader:x = x.to(device)y = y.to(device)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)correct += torch.eq(pred, y).sum().float().item()acc = correct/totalreturn acc   def main():trained_model = resnet152(pretrained=True)model = nn.Sequential(*list(trained_model.children())[:-1],Flatten(),nn.Linear(in_features=2048, out_features=5))optimizer = optim.Adam(model.parameters(),lr =lr) criteon = nn.CrossEntropyLoss()best_epoch=0,best_acc=0viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))global_step =0for epoch in range(epochs):print("\n --main---: ",epoch)for step, (x,y) in enumerate(train_loader):#x:[b,3,224,224] y:[b]x = x.to(device)y = y.to(device)#print("\n --x---: ",x.shape)logits =model(x)loss = criteon(logits, y)#print("\n --loss---: ",loss.shape)optimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1if epoch %2 ==0:val_acc = evalute(model, val_loader)if val_acc>best_acc:best_acc = val_accbest_epoch =epochtorch.save(model.state_dict(),'best.mdl')print("\n val_acc ",val_acc)viz.line([val_acc],[global_step],win='val_loss',update='append')print('\n best acc',best_acc, "best_epoch: ",best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt')test_acc = evalute(model, test_loader)print('\n test acc',test_acc)if __name__ == "__main__":main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

相关文章:

[PyTorch][chapter 52][迁移学习]

前言: 迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。 这里面主要结合前面ResNet18 例子,详细讲解一…...

Ceph如何操作底层对象数据

1.基本原理介绍 1.1 ceph中的对象(object) 在Ceph存储中,一切数据最终都会以对象(Object)的形式存储在硬盘(OSD)上,每个的Object默认大小为4M。 通过rados命令,可以查看一个存储池中的所有object信息,例如…...

sklearn机器学习库(二)sklearn中的随机森林

sklearn机器学习库(二)sklearn中的随机森林 集成算法会考虑多个评估器的建模结果,汇总之后得到一个综合的结果,以此来获取比单个模型更好的回归或分类表现。 多个模型集成成为的模型叫做集成评估器(ensemble estimator)&#xf…...

FlutterBoost 实现Flutter页面内嵌iOS view

在使用Flutter混合开发中会遇到一些原生比Flutter优秀的控件,不想使用Flutter的控件,想在Flutter中使用原生控件。这时就会用到 Flutter页面中内嵌 原生view,这里简单介绍一个 内嵌 iOS 的view。 注:这里使用了 FlutterBoost。网…...

走嵌入式还是纯软件?学长告诉你怎么选

最近有不少理工科的本科生问我,未来是走嵌入式还是纯软件好,究竟什么样的同学适合学习嵌入式呢?在这里我整合一下给他们的回答,根据自己的经验提供一些建议。 嵌入式领域也可以分为单片机方向、Linux方向和安卓方向。如果你的专业…...

【云计算原理及实战】初识云计算

该学习笔记取自《云计算原理及实战》一书,关于具体描述可以查阅原本书籍。 云计算被视为“革命性的计算模型”,因为它通过互联网自由流通使超级计算能力成为可能。 2006年8月,在圣何塞举办的SES(捜索引擎战略)大会上&a…...

Open3D (C++) 基于拟合高差的点云地面点提取

目录 一、算法原理1、原理概述2、参考文献二、代码实现三、结果展示1、原始点云2、提取结果四、相关链接系列文章(连载中。。。): Open3D (C++) 基于高程的点云地面点提取Open3D (C++) 基于拟合平面的点云地面点提取Open3D (C++) 基于拟合高差的点云地面点提取</...

认识Transformer:入门知识

视频链接&#xff1a; https://www.youtube.com/watch?vugWDIIOHtPA&listPLJV_el3uVTsOK_ZK5L0Iv_EQoL1JefRL4&index60 文章目录 Self-Attention layerMulti-head self-attentionPositional encodingSeq2Seq with AttentionTransformerUniversal Transformer Seq2Seq …...

《TCP IP网络编程》第二十四章

第 24 章 制作 HTTP 服务器端 24.1 HTTP 概要 本章将编写 HTTP&#xff08;HyperText Transfer Protocol&#xff0c;超文本传输协议&#xff09;服务器端&#xff0c;即 Web 服务器端。 理解 Web 服务器端&#xff1a; web服务器端就是要基于 HTTP 协议&#xff0c;将网页对…...

【AI】文心一言的使用

一、获得内测资格&#xff1a; 1、点击网页链接申请&#xff1a;https://yiyan.baidu.com/ 2、点击加入体验&#xff0c;等待通过 二、获得AI伙伴内测名额 1、收到短信通知&#xff0c;点击链接 网页Link&#xff1a;https://chat.baidu.com/page/launch.html?fa&sourc…...

CSAPP Lab2:Bomb Lab

说明 6关卡&#xff0c;每个关卡需要输入相应的内容&#xff0c;通过逆向工程来获取对应关卡的通过条件 准备工作 环境 需要用到gdb调试器 apt-get install gdb系统: Ubuntu 22.04 本实验会用到的gdb调试器的指令如下 r或者 run或者run filename 运行程序,run filename就…...

Java中使用流将两个集合根据某个字段进行过滤去重?

Java中使用流将两个集合根据某个字段进行过滤去重? 在Java中&#xff0c;您可以使用流(Stream)来过滤和去重两个集合。下面是一个示例代码&#xff0c;展示如何根据对象的某个字段进行过滤和去重操作&#xff1a; import java.util.ArrayList; import java.util.List; impor…...

自动驾驶HMI产品技术方案

版本变更 序号 日期 变更内容 编制人 审核人 文档版本 1 2 1....

Git判断本地是否最新

场景需求 需要判断是否有新内容更新,确定有更新之后执行pull操作&#xff0c;然后pull成功之后再将新内容进行复制到其他地方 pgit log -1 --prettyformat:"%H" HEAD -- . "origin/HEAD" rgit rev-parse origin/HEAD if [[ $p $r ]];thenecho "Is La…...

Spring 整合RabbitMQ,笔记整理

1.创建生产者工程 spring-rabbitmq-producer 2.pom.xml添加依赖 <dependencies><dependency><groupId>org.springframework</groupId><artifactId>spring-context</artifactId><version>5.1.7.RELEASE</version></dep…...

Lua 语言笔记(一)

1. 变量命名规范 弱类型语言(动态类型语言)&#xff0c;定义变量的时候&#xff0c;不需要类型修饰 而且&#xff0c;变量类型可以随时改变每行代码结束的时候&#xff0c;要不要分号都可以变量名 由数字&#xff0c;字母下划线组成&#xff0c;不能以数字开头&#xff0c;也不…...

【Redis】什么是缓存穿透,如何预防缓存穿透?

【Redis】什么是缓存穿透&#xff0c;如何预防缓存穿透&#xff1f; 缓存穿透是指查询一个一定不存在的数据&#xff0c;由于缓存中不存在&#xff0c;这时会去数据库查询查不到数据则不写入缓存&#xff0c;这将导致这个不存在的数据每次请求都要到数据库去查询&#xff0c;这…...

LeetCode128.最长连续序列

我这个方法有点投机取巧了&#xff0c;题目说时间复杂度最多O(n),而我调用了Arrays.sort(&#xff09;方法&#xff0c;他的时间复杂度是n*log(n)&#xff0c;但是AC了&#xff0c;这样的话这道题还是非常简单的&#xff0c;创建一个Hashmap&#xff0c;以nums数组的元素作为ke…...

Datawhale Django入门组队学习Task02

Task02 首先启动虚拟环境&#xff08;复习一下之前的&#xff09; 先退出conda的&#xff0c; conda deactivate然后cd到我的venv下面 &#xff0c;然后cd 到 scripts&#xff0c;再 activate &#xff08;powershell里面&#xff09; 创建admin管理员 首先cd到项目路径下&a…...

PCTA 认证考试高分通过经验分享

作者&#xff1a; msx-yzu 原文来源&#xff1a; https://tidb.net/blog/0b343c9f 序言 我在2023年8月10日&#xff0c;参加了 PingCAP 认证 TiDB 数据库专员 V6 考试 &#xff0c;并以 90分 的成绩通过考试。 考试总分是100分&#xff0c;超过60分就算通过考试。试卷…...

基于RIME-CNN-LSSVM回归模型的优化与预测应用——以MATLAB环境为例

RIME-CNN-LSSVM回归 基于霜冰优化算法优化卷积神经网络(CNN)结合最小二乘向量机(LSSVM)的数据回归预测(可以更换为分类/单、多变量时序预测/回归&#xff0c;前私我)&#xff0c;Matlab代码&#xff0c;可直接运行&#xff0c;适合小白新手 程序已经调试好&#xff0c;无需更改…...

3D Face HRN快速上手:无需代码,Gradio界面三步完成人脸重建

3D Face HRN快速上手&#xff1a;无需代码&#xff0c;Gradio界面三步完成人脸重建 1. 从一张照片到3D人脸&#xff0c;只需三步点击 你是否曾想过&#xff0c;将一张普通的自拍照或证件照&#xff0c;瞬间转化为一张可用于3D建模、游戏角色或虚拟形象的“皮肤地图”&#xf…...

HY-Motion 1.0应用案例:为AR试衣间生成‘转身→抬手→比划’交互动作流

HY-Motion 1.0应用案例&#xff1a;为AR试衣间生成转身→抬手→比划交互动作流 1. 项目背景与需求 AR试衣间正在改变传统购物体验&#xff0c;但如何让虚拟服装在用户身上自然流动&#xff0c;一直是个技术难题。传统方案要么动作生硬不连贯&#xff0c;要么需要复杂的动作捕…...

快速集成A2A Agent

面我们提到可以将MCP服务也封装为一个Tool&#xff08;AIFunction&#xff09;让Agent调用&#xff0c;这里A2A Agent也是一样的道理。 这样做的好处是&#xff1a;让MAF中的Agent像调用本地函数一样调用远程A2A Agent 或 MCP Server。 下面的代码展示了在MAF中将A2A Card转换…...

OpenClaw+Qwen3-32B自动化办公:会议纪要生成与飞书同步实战

OpenClawQwen3-32B自动化办公&#xff1a;会议纪要生成与飞书同步实战 1. 为什么需要自动化会议纪要 每次开完会最痛苦的事情是什么&#xff1f;对我来说就是整理会议纪要。作为技术负责人&#xff0c;每周要参加5-6个不同主题的会议&#xff0c;会后需要花大量时间回听录音、…...

SPI通信协议与菊花链模式应用解析

四线SPI通信协议与菊花链模式应用详解1. SPI接口基础1.1 四线SPI接口定义串行外设接口(SPI)是微控制器与外围IC之间最广泛使用的通信接口之一&#xff0c;具有同步、全双工、主从式架构特点。标准四线SPI接口包含以下信号线&#xff1a;SCLK(Serial Clock)&#xff1a;时钟信号…...

保姆级教程:在Windows上用PyTorch 2.0复现PointNet(含数据集下载与常见坑点修复)

Windows平台PyTorch 2.0实战&#xff1a;从零构建PointNet点云处理模型全指南 当3D点云处理遇上深度学习&#xff0c;PointNet无疑是这个领域的里程碑式架构。不同于传统CNN处理规则网格数据的方式&#xff0c;PointNet开创性地直接处理无序点云数据&#xff0c;在分类和分割任…...

计算机毕设 java 基于 Javaweb 的家教管理系统 智能家教匹配管理系统 家教服务综合平台

计算机毕设 java 基于 Javaweb 的家教管理系统 f7xm39&#xff08;配套有源码 程序 mysql 数据库 论文&#xff09;本套源码可以先看具体功能演示视频领取&#xff0c;文末有联 xi 可分享随着家庭教育需求的不断增长&#xff0c;家教市场规模持续扩大&#xff0c;但传统家教模式…...

单机游戏多人化:Nucleus Co-Op的技术突破与实践指南

单机游戏多人化&#xff1a;Nucleus Co-Op的技术突破与实践指南 【免费下载链接】nucleuscoop Starts multiple instances of a game for split-screen multiplayer gaming! 项目地址: https://gitcode.com/gh_mirrors/nu/nucleuscoop 你是否曾梦想在同一台电脑上与朋友…...

基于SPI硬件外设的NeoPixel高精度驱动方案

1. 项目概述neopixels_spi是一个专为 ARM Cortex-M 平台设计的轻量级、高可靠性 NeoPixel&#xff08;WS2812B 类&#xff09;驱动库&#xff0c;其核心创新在于完全摒弃传统 GPIO 模拟时序方案&#xff0c;转而采用硬件 SPI 外设配合 DMA 和精确时序控制机制实现单线协议物理层…...