当前位置: 首页 > news >正文

加载预训练模型,模型微调,在自己的数据集上快速出效果

  • 针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune)。 PyTorch里面提供的经典的网络模型都是官方通过Imagenet的数据集与训练好的数据,如果我们的数据训练数据不够,这些数据是可以作为基础模型来使用的。(Fine tuning 模型微调)

  • Fine tuning 模型微调的好处

    • 对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠微调已经训练好的模型。

    • 可以降低训练成本:如果使用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用 CPU 都完全无压力,没有深度学习机器也可以做。

    • 前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子

  • 迁移学习初衷是节省人工标注样本的时间,让模型可以通过一个已有的标记数据的领域向未标记数据领域进行迁移从而训练出适用于该领域的模型,直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。把统一的概念抽象出来,只学习不同的内容。迁移学习按照学习方式可以分为基于样本的迁移,基于特征的迁移,基于模型的迁移,以及基于关系的迁移。

  • 微调应该是迁移学习中的一部分。微调只能说是一个trick,一种技术;迁移学习是一个更宏大的概念

  • Pytorch模型保存、加载与预训练

  • 保存和加载整个模型和参数:这种方式会保存整个模型的结构以及参数,会占用较大的磁盘空间, 通常不采用这种方式

  • torch.save(model, 'model.pkl')  #保存
    model = torch.load('model.pkl') # 加载
    
  • 保存和加载模型的参数, 优点是速度快,占用的磁盘空间少, 是最常用的模型保存方法。load_state_dict有一个strict参数,该参数默认是True, 表示预训练模型的网络结构与自定义的网络结构严格相同(包括名字和维度)。 如果自定义网络和预训练网络不严格相同时, 需要将不属于自定义网络的key去掉

  • torch.save(model.state_dict(), 'model_state_dict.pkl')
    model = model.load_state_dict(torch.load(model_state_dict.pkl))
    
  • 在实际场景中, 我们往往需要保存更多的信息,如优化器的参数, 那么可以通过字典的方式进行存储

  • # 保存
    torch.save({'epoch': epochId,'state_dict': model.state_dict,'best_acc': best_acc,'optimizer': optimizer.state_dict()}, checkpoint_path + "/m-" + timestamp + str("%.4f" % best_acc) + ".pth.tar")
    # 加载
    def load_model(model, checkpoint, optimizer):model_CKPT = torch.load(checkpoint)model.load_state_dict(model_CKPT['state_dict'])optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 加载部分预训练模型: 如果我们修改了网络, 那么就需要将这部分参数过滤掉:(值得注意的是,当两个网络的结构相同, 但是结构的命名不同时, 直接加载会报错。因此需要修改结构的key值)

  • def load_model(model, chinkpoint, optimizer):model_CKPT = torch.load(checkpoint)model_dict = model.state_dict()pretrained_dict = model_CKPT['state_dict']# 将不在model中的参数过滤掉new_dict = {k, v for k, v in pretrained_dict.items() if k in model_dict.keys()}model_dict.update(new_dict)model.load_state_dict(model_dict)# 加载优化器参数optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 冻结网络的部分参数, 训练另一部分参数(注意,必须同时在优化器中将这些参数过滤掉, 否则会报错。因为optimizer里面的参数要求required_grad为Ture)

    • 当输入给模型的数据集形式相似或者相同时,常见的是利用现有的经典模型(如Residual Network、 GoogleNet等)作为backbone来提取特征,那么这些经典模型已经训练好的模型参数可以直接拿过来使用。通常情况下, 我们希望将这些经典网络模型的参数固定下来, 不进行训练,只训练后面我们添加的和具体任务相关的网络参数。

      • 新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器

      • 新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine-tuning

      • 新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的化那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据

      • 新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错

      • 对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于(一般可设置小于10倍)初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。

  • # 以ResNet网络为例
    # 当我们加载ResNet预训练模型之后,在ResNet的基础上连接了新的网络模块, ResNet那部分网络参数先冻结不更新
    # 只更新新引入网络结构的参数
    class Net(torch.nn.Module):def __init__(self, model, pretrained):super(Net, self).__init__()self.resnet = model(pretained)for p in self.parameters():p.requires_grad = Falseself.conv1 = torch.nn.Conv2d(2048, 1024, 1)self.conv2 = torch.nn.Conv2d(1024, 1024, 1)
    
  • 参数修改: resnet网络的最后一层对应1000个类别, 如果我们自己的数据只有10个类别, 那么可以进行如下修改

  • import torch
    import torchvision.models as models
    model = models.resnet50(pretrained=True)
    fc_inDim = model.fc.in_features
    # 修改为10个类别
    model.fc = torch.nn.Linear(fc_inDim, 10)
    
  • Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

相关文章:

加载预训练模型,模型微调,在自己的数据集上快速出效果

针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune&#xff…...

VScode远程连接服务器-过程试图写入的管道不存在-could not establist connection to【已解决】

问题描述 使用服务器的过程中突然与服务器断连,报错如下:could not establist connection to [20:23:39.487] > ssh: connect to host 10.201.0.131 port 22: Connection timed out > [20:23:39.495] > 过程试图写入的管道不存在。 > [20…...

电子技术——B类输出阶

电子技术——B类输出阶 下图展示了一个B类输出阶的原理图,B类输出阶由两个互补的BJT组成,不同时导通。 原理 当输入电压 vI0v_I 0vI​0 的时候,两个晶体管都截止输出电压为零。当 vIv_IvI​ 上升至超过0.5V的时候,此时 QNQ_NQN…...

【老卫搬砖】034期:HarmonyOS 3.1 Beta 1初体验,我在本地模拟器里面刷短视频

今天啊打开这个DevEco Studio的话,已经提示有3.1Beta1版本的一个更新啊。然后看一下它的一些特性。本文也演示了如何在本地模拟器里面运行HarmonyOS版短视频。 主要特性 新特性包括: Added support for Windows 11 64-bit and macOS 13.x OSs, as well…...

Day901.内部临时表 -MySQL实战

内部临时表 Hi,我是阿昌,今天学习记录的是关于内部临时表的内容。 sort buffer、内存临时表和 join buffer。这三个数据结构都是用来存放语句执行过程中的中间数据,以辅助 SQL 语句的执行的。 其中,在排序的时候用到了 sort bu…...

jstatd的启动方式与关闭方式

启动方式与注意事项: 启动方式: 前台启动不打印日志: jstatd -J-Djava.security.policyjstatd.all.policy -J-Djava.rmi.server.hostname服务器IP 前台启动并打印日志: ./jstatd -J-Djava.security.policyjstatd.all.policy -…...

_improve-3

createElement过程 React.createElement(): 根据指定的第一个参数创建一个React元素 React.createElement(type,[props],[...children] )第一个参数是必填,传入的是似HTML标签名称,eg: ul, li第二个参数是选填,表示的是属性&#…...

C++——异常

目录 C语言传统的处理错误的方式 C异常概念 异常的使用 异常的抛出和匹配原则 在函数调用链中异常栈展开匹配原则 自定义异常体系 异常的重新抛出 ​编辑 异常安全 异常规范 C标准库的异常体系 异常的优缺点 C语言传统的处理错误的方式 传统的错误处理机制: …...

MVVM 架构进阶:MVI 架构详解

前言Android开发发展到今天已经相当成熟了,各种架构大家也都耳熟能详,如MVC,MVP,MVVM等,其中MVVM更是被官方推荐,成为Android开发中的显学。不过软件开发中没有银弹,MVVM架构也不是尽善尽美的,在使用过程中…...

有没有必要考PMP证书?

其实针对有没有必要考试吗,这个可以根本不同行业的人来决定的。 1.高等教育项目管理专业科班出身的人员。 在我国本科学历和硕士研究生学历中,项目管理也有开设。不管以后从事的工作是否为项目管理或其他管理,作为本专业的同学,…...

1 机器学习基础

1 机器学习概述 1.1 数据驱动的问题求解 大数据-Big Data 大数据的多面性 1.2 数据分析 机器学习:海量的数据,获取有用的信息 专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之…...

java基础系列(六) sleep()和wait() 区别

一.前言 关于并发编程这块, 线程的一些基础知识我们得搞明白, 本篇文章来说一下这两个方法的区别,对Android中的HandlerThread机制原理可以有更深的理解, HandlerThread源码理解,请查看笔者的这篇博客: HandlerThread源码理解_handlerthread 源码_broadview_java的博客-CSDN博…...

Urho3D序列化

从Serializable派生的类可以通过定义属性将其自动序列化为二进制或XML格式。属性存储到每个类的上下文中。场景加载/保存和网络复制都是通过从Serializable派生Node和Component类来实现的。 支持的属性类型是Variant支持的所有属性类型,不包括指针和自定义值。 属性…...

企业级信息系统开发学习1.3——利用注解配置取代Spring配置文件

文章目录一、利用注解配置类取代Spring配置文件(一)打开项目(二)创建新包(三)拷贝类与接口(四)创建注解配置类(五)创建测试类(六)运行…...

VUE DIFF算法之快速DIFF

VUE DIFF算法系列讲解 VUE 简单DIFF算法 VUE 双端DIFF算法 文章目录VUE DIFF算法系列讲解前言一、快速DIFF的代码实现二、实践练习1练习2总结前言 本节我们来写一下VUE3中新的DIFF算法-快速DIFF,顾名思义,也就是目前最快的DIFF算法(在VUE中&…...

一文掌握如何轻松稿定项目风险管理【静说】

风险管理对于每个项目经理和PMO都非常重要,如果管理不当会出现很多问题,咱们以前分享过很多风险管理的内容: 风险无处不在,一旦发生,会对一个或多个项目目标产生积极或消极影响的确定事件或条件。那么接下来介绍下五大…...

操作系统权限提升(十四)之绕过UAC提权-基于白名单AutoElevate绕过UAC提权

系列文章 操作系统权限提升(十二)之绕过UAC提权-Windows UAC概述 操作系统权限提升(十三)之绕过UAC提权-MSF和CS绕过UAC提权 注:阅读本编文章前,请先阅读系列文章,以免造成看不懂的情况!! 基于白名单AutoElevate绕过…...

ecology9-谷歌浏览器下-pdf.js在渲染时部分发票丢失文字 问题定位及解决

问题 问题描述 : 在谷歌浏览器下,pdf.js在渲染时部分发票丢失文字;360浏览器兼容模式不存在此问题 排查思路:1、对比谷歌浏览器的css样式和360浏览器兼容模式下的样式,没有发现关键差别 2、✔使用Fiddler修改网页js D…...

JavaScript Window Navigator

文章目录JavaScript Window NavigatorWindow Navigator警告!!!浏览器检测JavaScript Window Navigator window.navigator 对象包含有关访问者浏览器的信息。 Window Navigator window.navigator 对象在编写时可不使用 window 这个前缀。 实例 <div id"example"…...

Linux基础命令-du查看文件的大小

文章目录 du 命令介绍 语法格式 基本参数 参考实例 1&#xff09;以人类可读形式显示指定的文件大小 2&#xff09;显示当前目录下所有文件大小 3&#xff09;只显示目录的大小 4&#xff09;显示根下哪个目录文件最大 5&#xff09;显示所有文件的大小 6&#xff0…...

《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》

在注意力分散、内容高度同质化的时代&#xff0c;情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现&#xff0c;消费者对内容的“有感”程度&#xff0c;正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中&#xff0…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

【生成模型】视频生成论文调研

工作清单 上游应用方向&#xff1a;控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...

nnUNet V2修改网络——暴力替换网络为UNet++

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...

uniapp 实现腾讯云IM群文件上传下载功能

UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中&#xff0c;群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS&#xff0c;在uniapp中实现&#xff1a; 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...

针对药品仓库的效期管理问题,如何利用WMS系统“破局”

案例&#xff1a; 某医药分销企业&#xff0c;主要经营各类药品的批发与零售。由于药品的特殊性&#xff0c;效期管理至关重要&#xff0c;但该企业一直面临效期问题的困扰。在未使用WMS系统之前&#xff0c;其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...

Vue3 PC端 UI组件库我更推荐Naive UI

一、Vue3生态现状与UI库选择的重要性 随着Vue3的稳定发布和Composition API的广泛采用&#xff0c;前端开发者面临着UI组件库的重新选择。一个好的UI库不仅能提升开发效率&#xff0c;还能确保项目的长期可维护性。本文将对比三大主流Vue3 UI库&#xff08;Naive UI、Element …...

CSS3相关知识点

CSS3相关知识点 CSS3私有前缀私有前缀私有前缀存在的意义常见浏览器的私有前缀 CSS3基本语法CSS3 新增长度单位CSS3 新增颜色设置方式CSS3 新增选择器CSS3 新增盒模型相关属性box-sizing 怪异盒模型resize调整盒子大小box-shadow 盒子阴影opacity 不透明度 CSS3 新增背景属性ba…...