当前位置: 首页 > news >正文

李沐37_微调——自学笔记

标注数据集很贵

网络架构

1.一般神经网络分为两块,一是特征抽取原始像素变成容易线性分割的特征,二是线性分类器来做分类

微调

1.原数据集不能直接使用,因为标号发生改变,通过微调可以仍然对我数据集做特征提取

2.pre-train源数据集后,迁移模型中的架构,调整权重

3.是一个目标数据集上的正常训练任务,但使用更强的正则化(更小的学习率、更少的数据迭代)

4.源数据集远复杂于目标数据,通常微调效果更好

重用分类器权重

1.源数据集可能也有目标数据中的部分标号

2.可以使用预训练好的模型分类器中对应标号对应的向量来初始化

固定一些层

1.神经网络通常学习有层次的特征表达,低层次的特征更加通用,高层次的特征则和数据集相关

2.可以固定底部一些层的参数,不参与更新,更强的正则化

总结

1.微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来提升精度

2.预训练模型质量很重要

3.微调通常速度更快、精度更高

代码实现——热狗识别

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

获取数据集

解压下载热狗数据集,一共1400张热狗“正向”图片,其他尽可能多的其他食物“负向”图片,解压下载的数据集,我们获得了两个文件夹hotdog/train和hotdog/test。 这两个文件夹都有hotdog(有热狗)和not-hotdog(无热狗)两个子文件夹, 子文件夹内都包含相应类的图像。

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

读取训练集和测试集的图片

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

前8个正类、后8个负类,图片的大小和高宽比不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在这里插入图片描述

训练:从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224X224
输入图像。
测试:我们将图像的高度和宽度都缩放到256像素,然后裁剪中央224X224
区域作为输入。

此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

初始化模型

使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True以自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。

pretrained_net = torchvision.models.resnet18(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]

fc:预训练的源模型实例包含许多特征层和一个输出层fc。 此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。

# fc
pretrained_net.fc
Linear(in_features=512, out_features=1000, bias=True)

目标模型finetune_net中成员变量features的参数被初始化为源模型相应层的模型参数。 由于模型参数是在ImageNet数据集上预训练的,并且足够好,因此通常只需要较小的学习率即可微调这些参数。

成员变量output的参数是随机初始化的,通常需要更高的学习率才能从头开始训练。 假设Trainer实例中的学习率为n,我们将成员变量output中参数的学习率设置为10n。

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

训练模型微调

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

使用较小的学习率

train_fine_tuning(finetune_net, 5e-5)
loss 0.238, train acc 0.910, test acc 0.941
360.1 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

为了比较,定义相同的模型,但需较大学习率

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
loss 0.403, train acc 0.821, test acc 0.791
366.6 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

相关文章:

李沐37_微调——自学笔记

标注数据集很贵 网络架构 1.一般神经网络分为两块&#xff0c;一是特征抽取原始像素变成容易线性分割的特征&#xff0c;二是线性分类器来做分类 微调 1.原数据集不能直接使用&#xff0c;因为标号发生改变&#xff0c;通过微调可以仍然对我数据集做特征提取 2.pre-train源…...

【小程序】生成短信中可点击的链接

文章目录 前言一、如何生成链接二、仔细拜读小程序开发文档文档说明1文档说明2 总结 前言 由于线上运营需求&#xff0c;需要给用户发送炮轰短信&#xff0c;用户通过短信点击链接直接跳转进入小程序 一、如何生成链接 先是找了一些三方的&#xff0c;生成的倒是快速&#xf…...

欧拉函数(模板题)

给定 n 个正整数 ai&#xff0c;请你求出每个数的欧拉函数。 欧拉函数的定义 输入格式 第一行包含整数 n。 接下来 n 行&#xff0c;每行包含一个正整数 ai。 输出格式 输出共 n 行&#xff0c;每行输出一个正整数 ai 的欧拉函数。 数据范围 1≤n≤100, 1≤ai≤2109 输…...

Thingsboard PE 白标的使用

只有专业版支持白标功能。 使用 ThingsBoard Cloud 或安装您自己的平台实例。 一、介绍 ThingsBoard Web 界面提供了简便的操作,让您能够轻松配置您的公司或产品标识和配色方案,无需进行编码工作或重新启动服务。 系统管理员、租户和客户管理员可以根据需要自定义配色方案、…...

智能物联网远传冷水表管理系统

智能物联网远传冷水表管理系统是一种基于物联网技术的先进系统&#xff0c;旨在实现对冷水表的远程监测、数据传输和智能化管理。本文将从系统特点、构成以及带来的效益三个方面展开介绍。 系统特点 1.远程监测&#xff1a;系统可以实现对冷水表数据的远程监测&#xff0c;无…...

Qt教程3-Ubuntu(x86_64)上配置arm64(aarch64)交叉编译环境及QT编译arm64架构工程

汇创慧玩 写在前面1. 查看系统架构相关指令2. ARM64交叉编译器环境搭建3. Qt编译arm64环境搭建4. 配置 Qt的本地aarch64交叉编译器5. 工程建立及编译验证 写在前面 苦辣酸甜时光八载&#xff0c;春夏秋冬志此一生 Qt简介&#xff1a; Qt&#xff08;官方发音 [kju:t]&#xff…...

2024年华为OD机试真题-最长子字符串的长度(二)-Python-OD统一考试(C卷)

题目描述: 给你一个字符串 s,字符串s首尾相连成一个环形 ,请你在环中找出l、o、x 字符都恰好出现了偶数次最长子字符串的长度。 输入描述: 输入是一串小写的字母组成的字符串。 输出描述: 输出是一个整数 补充说明: 1 <= s.length <= 5 x 10^5 s 只包含小写英文字母…...

【24届数字IC秋招总结】正式批面试经验汇总5——蔚来、tp-link

文章目录 一、蔚来-数字芯片验证工程师1.1 一面面试问题1.2 二面面试问题二、tp-link-数字IC验证工程师2.1 面试问题一、蔚来-数字芯片验证工程师 面试时间:9.6 10.6 1.1 一面面试问题 1、 讲下项目结构 2、 scoreboard如何进行数据对比的 3、 golden 数据怎么产生的 4、 在…...

【JAVA基础篇教学】第八篇:Java中List详解说明

博主打算从0-1讲解下java基础教学&#xff0c;今天教学第八篇&#xff1a;Java中List详解说明。 在 Java 编程中&#xff0c;List 接口是一个非常常用的集合接口&#xff0c;它代表了一个有序的集合&#xff0c;可以包含重复的元素。List 接口提供了一系列操作方法&#xff0c;…...

RN向上向下滑动组件封装(带有渐变色)

这段组件代码逻辑是出事有一个View和下面的块,下面的块也就是红色区域可以按住向上向下滑动,当滑动到屏幕最上面则停止滑动,再向上滑动的过程中,上方的View的背景色也会有个渐变效果,大概逻辑就是这样 代码如下 import React, {useEffect, useRef, useState} from react; impo…...

27、Lua 学习笔记之五(Lua中的数学库)

Lua中的数学库 Lua5.1中数学库的所有函数如下表&#xff1a; math.pi 为圆周率常量 3.14159265358979323846 数学库说明例子方法abs取绝对值math.abs(-15)15acos反余弦函数math.acos(0.5)1.04719755asin反正弦函数math.asin(0.5)0.52359877atan2x / y的反正切值math.atan2(9…...

【C++成长记】C++入门 | 类和对象(中) |拷贝构造函数、赋值运算符重载、const成员函数、 取地址及const取地址操作符重载

&#x1f40c;博主主页&#xff1a;&#x1f40c;​倔强的大蜗牛&#x1f40c;​ &#x1f4da;专栏分类&#xff1a;C❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、拷贝构造函数 1、概念 2、特征 二、赋值运算符重载 1、运算符重载 2、赋值运算符重载 3、前置…...

OpenHarmony实战开发-页面深色模式适配。

介绍 本示例介绍在开发应用以适应深色模式时&#xff0c;对于深色和浅色模式的适配方案&#xff0c;采取了多种策略如下&#xff1a; 1. 固定属性适配&#xff1a;对于部分组件的颜色属性&#xff0c;如背景色或字体颜色&#xff0c;若保持不变&#xff0c;可直接设定固定色值…...

域名解析出现错误,该如何解决?

域名作为网络地址&#xff0c;是我们访问网站的必经之路&#xff0c;域名解析就是把你的域名解析成一个ip地址&#xff0c;在使用的过程中遇到域名解析文件异常也是常有的事。如果域名解析出现错误&#xff0c;该怎么解决呢&#xff1f; 一、打开网页时&#xff0c;显示域名解析…...

从iPhone恢复已删除照片的最佳软件

本文分享了从iPhone恢复已删除照片的最佳软件。如果您正在寻找如何从iPhone恢复已删除的照片&#xff0c;请查看这篇文章。 为什么您需要软件从iPhone恢复已删除的照片&#xff1f; 没有什么比丢失iPhone上的重要数据更痛苦的了&#xff0c;尤其是一些具有珍贵回忆的照片。有时…...

MySQL模糊查询

一、MySQL通配符模糊查询(%&#xff0c;_) 1.1.通配符的分类 1.“%”百分号通配符&#xff1a;表示任何字符出现任意次数&#xff08;可以是0次&#xff09; 2.“_”下划线通配符&#xff1a;表示只能匹配单个字符&#xff0c;不能多也不能少&#xff0c;就是一个字符。当然…...

QEMU_v8搭建OP-TEE运行环境

文章目录 一、依赖下载二、设置网络三、安装下载四、运行OP-TEE 一、依赖下载 更新依赖包&#xff0c;下载一系列依赖。比如Python需要Python3.x版本&#xff0c;需要配置git的用户名和邮箱等。这里不详细展开了&#xff0c;很多博客都有涉及到。 二、设置网络 这一点非常重…...

C++11 设计模式0. 设计模式的基本概念,设计模式的准则,如何学习设计模式,24种设计模式的分为3大类

一 设计模式的基本概念&#xff1a; 模式&#xff1a;指事物的标准样式 或者 理解成 针对特定问题的可重用解决方案。 设计模式&#xff0c;是在特定问题发生时的可重用解决方案。 设计模式一般用于大型项目中。 大型项目中&#xff0c;设计模式保证所设计的模块之间代码的灵…...

(十)C++自制植物大战僵尸游戏设置功能实现

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/m0EtD 游戏设置 游戏设置功能是一个允许玩家根据个人喜好和设备性能来调整游戏各项参数的重要工具。游戏设置功能是为了让玩家能够根据自己的需求和设备性能来调整游戏&#xff0c;以获得最佳的游戏体验。不同的游戏和平…...

数据结构——通讯录(顺序表的实战项目)

&#xff08;—&#xff09;.通讯录的功能 大家应该都十分了解通讯录的功能吧&#xff0c;无非就是对联系人的增添删除&#xff0c;还有信息的修改&#xff0c;并且联系人信息要包含名字&#xff0c;电话&#xff0c;性别&#xff0c;地址等。我把通讯录的功能总结如下&#x…...

数据库-Redis(14)

目录 66.Redis为什么主从全量复制使用RDB而不是使用AOF? 67.Redis为什么还有无磁盘复制模式? 68.Redis为什么还会有从库的从库设计?...

Thinkphp5.0命令行创建验证器validate类

前言 最近接手了个用FastAdmin&#xff08;基于tp5&#xff09;写的项目&#xff0c;发现命令行只提供生成controller和model的命令&#xff0c;没有提供make:validate命令&#xff0c;而5.1及以上版本是有的&#xff0c;对于使用tp5.0框架或者基于tp5.0的第三框架&#xff08…...

人民网至顶科技:《开启智能新时代:2024中国AI大模型产业发展报告发布》

​3月26日&#xff0c;人民网财经研究院与至顶科技联合发布《开启智能新时代&#xff1a;2024年中国AI大模型产业发展报告》。该报告针对AI大模型产业发展背景、产业发展现状、典型案例、挑战及未来趋势等方面进行了系统全面的梳理&#xff0c;为政府部门、行业从业者以及社会公…...

AI大模型探索之路-应用篇13:企业AI大模型选型指南

目录 前言 一、概述 二、有哪些主流模型&#xff1f; 三、模型参数怎么选&#xff1f; 四、参数有什么作用&#xff1f; 五、CPU和GPU怎么选&#xff1f; 六、GPU和显卡有什么关系&#xff1f; 七、GPU主流厂商有哪些&#xff1f; 1、NVIDIA芯片怎么选&#xff1f; 2、…...

【安全】查杀linux上c3pool挖矿病毒xmrig

挖矿平台&#xff1a;猫池 病毒来源安装脚本 cat /root/c3pool/config.jsoncrontab -r cd /root/c3poolcurl -s -L http://download.c3pool.org/xmrig_setup/raw/master/setup_c3pool_miner.sh | LC_ALLen_US.UTF-8 bash -s 44SLpuV4U7gB6RNZMCweHxWug7b1YUir4jLr3RBaVX33Qxj…...

车载测试:UDS之BootLoader刷写

BootLoader刷写 本文章是花费3小时结合多个项目实践总结和整体出来的&#xff0c;欢迎大家交流&#xff01; BootLoader刷写章节 ①&#xff1a;预编程步骤流程流程图 1.1 概述 1.2 流程步骤描述 1&#xff09;整车ECU进入扩展会话 2&#xff09;刷…...

OpenHarmony实战开发-MpChart图表实现案例。

介绍 MpChart是一个包含各种类型图表的图表库&#xff0c;主要用于业务数据汇总&#xff0c;例如销售数据走势图&#xff0c;股价走势图等场景中使用&#xff0c;方便开发者快速实现图表UI。本示例主要介绍如何使用三方库MpChart实现柱状图UI效果。如堆叠数据类型显示&#xf…...

brpc: bthread使用

使用bthread并发编程 #include <gflags/gflags.h> #include <butil/logging.h> #include <bthread/bthread.h>static void* func(void* args) {std::string* num static_cast<std::string*>(args);for(int i 0; i < 5; i) {LOG(INFO) << *…...

H.265视频直播点播录像EasyPlayer.js流媒体播放器用户常见问题及解答

EasyPlayer属于一款高效、精炼、稳定且免费的流媒体播放器&#xff0c;可支持多种流媒体协议播放&#xff0c;无须安装任何插件&#xff0c;起播快、延迟低、兼容性强&#xff0c;使用非常便捷。 今天我们来汇总下用户常见的几个问题及解答。 1、EasyPlayer.js播放多路H.265视…...

蓝桥杯杂题选做

海盗分金币 题目链接&#xff1a;1.海盗分金币 - 蓝桥云课 (lanqiao.cn) 题解&#xff1a;海盗分金币-Cheery的代码 - 蓝桥云课 (lanqiao.cn) 思路&#xff1a;倒着想就行。 等腰三角形 题目链接&#xff1a;1.等腰三角形 - 蓝桥云课 (lanqiao.cn) 题解&#xff1a;等腰三…...