李沐37_微调——自学笔记
标注数据集很贵
网络架构
1.一般神经网络分为两块,一是特征抽取原始像素变成容易线性分割的特征,二是线性分类器来做分类
微调
1.原数据集不能直接使用,因为标号发生改变,通过微调可以仍然对我数据集做特征提取
2.pre-train源数据集后,迁移模型中的架构,调整权重
3.是一个目标数据集上的正常训练任务,但使用更强的正则化(更小的学习率、更少的数据迭代)
4.源数据集远复杂于目标数据,通常微调效果更好
重用分类器权重
1.源数据集可能也有目标数据中的部分标号
2.可以使用预训练好的模型分类器中对应标号对应的向量来初始化
固定一些层
1.神经网络通常学习有层次的特征表达,低层次的特征更加通用,高层次的特征则和数据集相关
2.可以固定底部一些层的参数,不参与更新,更强的正则化
总结
1.微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来提升精度
2.预训练模型质量很重要
3.微调通常速度更快、精度更高
代码实现——热狗识别
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
获取数据集
解压下载热狗数据集,一共1400张热狗“正向”图片,其他尽可能多的其他食物“负向”图片,解压下载的数据集,我们获得了两个文件夹hotdog/train和hotdog/test。 这两个文件夹都有hotdog(有热狗)和not-hotdog(无热狗)两个子文件夹, 子文件夹内都包含相应类的图像。
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...
读取训练集和测试集的图片
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
前8个正类、后8个负类,图片的大小和高宽比不同。
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

训练:从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224X224
输入图像。
测试:我们将图像的高度和宽度都缩放到256像素,然后裁剪中央224X224
区域作为输入。
此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])
初始化模型
使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True以自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。
pretrained_net = torchvision.models.resnet18(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]
fc:预训练的源模型实例包含许多特征层和一个输出层fc。 此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。
# fc
pretrained_net.fc
Linear(in_features=512, out_features=1000, bias=True)
目标模型finetune_net中成员变量features的参数被初始化为源模型相应层的模型参数。 由于模型参数是在ImageNet数据集上预训练的,并且足够好,因此通常只需要较小的学习率即可微调这些参数。
成员变量output的参数是随机初始化的,通常需要更高的学习率才能从头开始训练。 假设Trainer实例中的学习率为n,我们将成员变量output中参数的学习率设置为10n。
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);
训练模型微调
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
使用较小的学习率
train_fine_tuning(finetune_net, 5e-5)
loss 0.238, train acc 0.910, test acc 0.941
360.1 examples/sec on [device(type='cuda', index=0)]

为了比较,定义相同的模型,但需较大学习率
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
loss 0.403, train acc 0.821, test acc 0.791
366.6 examples/sec on [device(type='cuda', index=0)]

相关文章:
李沐37_微调——自学笔记
标注数据集很贵 网络架构 1.一般神经网络分为两块,一是特征抽取原始像素变成容易线性分割的特征,二是线性分类器来做分类 微调 1.原数据集不能直接使用,因为标号发生改变,通过微调可以仍然对我数据集做特征提取 2.pre-train源…...
【小程序】生成短信中可点击的链接
文章目录 前言一、如何生成链接二、仔细拜读小程序开发文档文档说明1文档说明2 总结 前言 由于线上运营需求,需要给用户发送炮轰短信,用户通过短信点击链接直接跳转进入小程序 一、如何生成链接 先是找了一些三方的,生成的倒是快速…...
欧拉函数(模板题)
给定 n 个正整数 ai,请你求出每个数的欧拉函数。 欧拉函数的定义 输入格式 第一行包含整数 n。 接下来 n 行,每行包含一个正整数 ai。 输出格式 输出共 n 行,每行输出一个正整数 ai 的欧拉函数。 数据范围 1≤n≤100, 1≤ai≤2109 输…...
Thingsboard PE 白标的使用
只有专业版支持白标功能。 使用 ThingsBoard Cloud 或安装您自己的平台实例。 一、介绍 ThingsBoard Web 界面提供了简便的操作,让您能够轻松配置您的公司或产品标识和配色方案,无需进行编码工作或重新启动服务。 系统管理员、租户和客户管理员可以根据需要自定义配色方案、…...
智能物联网远传冷水表管理系统
智能物联网远传冷水表管理系统是一种基于物联网技术的先进系统,旨在实现对冷水表的远程监测、数据传输和智能化管理。本文将从系统特点、构成以及带来的效益三个方面展开介绍。 系统特点 1.远程监测:系统可以实现对冷水表数据的远程监测,无…...
Qt教程3-Ubuntu(x86_64)上配置arm64(aarch64)交叉编译环境及QT编译arm64架构工程
汇创慧玩 写在前面1. 查看系统架构相关指令2. ARM64交叉编译器环境搭建3. Qt编译arm64环境搭建4. 配置 Qt的本地aarch64交叉编译器5. 工程建立及编译验证 写在前面 苦辣酸甜时光八载,春夏秋冬志此一生 Qt简介: Qt(官方发音 [kju:t]ÿ…...
2024年华为OD机试真题-最长子字符串的长度(二)-Python-OD统一考试(C卷)
题目描述: 给你一个字符串 s,字符串s首尾相连成一个环形 ,请你在环中找出l、o、x 字符都恰好出现了偶数次最长子字符串的长度。 输入描述: 输入是一串小写的字母组成的字符串。 输出描述: 输出是一个整数 补充说明: 1 <= s.length <= 5 x 10^5 s 只包含小写英文字母…...
【24届数字IC秋招总结】正式批面试经验汇总5——蔚来、tp-link
文章目录 一、蔚来-数字芯片验证工程师1.1 一面面试问题1.2 二面面试问题二、tp-link-数字IC验证工程师2.1 面试问题一、蔚来-数字芯片验证工程师 面试时间:9.6 10.6 1.1 一面面试问题 1、 讲下项目结构 2、 scoreboard如何进行数据对比的 3、 golden 数据怎么产生的 4、 在…...
【JAVA基础篇教学】第八篇:Java中List详解说明
博主打算从0-1讲解下java基础教学,今天教学第八篇:Java中List详解说明。 在 Java 编程中,List 接口是一个非常常用的集合接口,它代表了一个有序的集合,可以包含重复的元素。List 接口提供了一系列操作方法,…...
RN向上向下滑动组件封装(带有渐变色)
这段组件代码逻辑是出事有一个View和下面的块,下面的块也就是红色区域可以按住向上向下滑动,当滑动到屏幕最上面则停止滑动,再向上滑动的过程中,上方的View的背景色也会有个渐变效果,大概逻辑就是这样 代码如下 import React, {useEffect, useRef, useState} from react; impo…...
27、Lua 学习笔记之五(Lua中的数学库)
Lua中的数学库 Lua5.1中数学库的所有函数如下表: math.pi 为圆周率常量 3.14159265358979323846 数学库说明例子方法abs取绝对值math.abs(-15)15acos反余弦函数math.acos(0.5)1.04719755asin反正弦函数math.asin(0.5)0.52359877atan2x / y的反正切值math.atan2(9…...
【C++成长记】C++入门 | 类和对象(中) |拷贝构造函数、赋值运算符重载、const成员函数、 取地址及const取地址操作符重载
🐌博主主页:🐌倔强的大蜗牛🐌 📚专栏分类:C❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、拷贝构造函数 1、概念 2、特征 二、赋值运算符重载 1、运算符重载 2、赋值运算符重载 3、前置…...
OpenHarmony实战开发-页面深色模式适配。
介绍 本示例介绍在开发应用以适应深色模式时,对于深色和浅色模式的适配方案,采取了多种策略如下: 1. 固定属性适配:对于部分组件的颜色属性,如背景色或字体颜色,若保持不变,可直接设定固定色值…...
域名解析出现错误,该如何解决?
域名作为网络地址,是我们访问网站的必经之路,域名解析就是把你的域名解析成一个ip地址,在使用的过程中遇到域名解析文件异常也是常有的事。如果域名解析出现错误,该怎么解决呢? 一、打开网页时,显示域名解析…...
从iPhone恢复已删除照片的最佳软件
本文分享了从iPhone恢复已删除照片的最佳软件。如果您正在寻找如何从iPhone恢复已删除的照片,请查看这篇文章。 为什么您需要软件从iPhone恢复已删除的照片? 没有什么比丢失iPhone上的重要数据更痛苦的了,尤其是一些具有珍贵回忆的照片。有时…...
MySQL模糊查询
一、MySQL通配符模糊查询(%,_) 1.1.通配符的分类 1.“%”百分号通配符:表示任何字符出现任意次数(可以是0次) 2.“_”下划线通配符:表示只能匹配单个字符,不能多也不能少,就是一个字符。当然…...
QEMU_v8搭建OP-TEE运行环境
文章目录 一、依赖下载二、设置网络三、安装下载四、运行OP-TEE 一、依赖下载 更新依赖包,下载一系列依赖。比如Python需要Python3.x版本,需要配置git的用户名和邮箱等。这里不详细展开了,很多博客都有涉及到。 二、设置网络 这一点非常重…...
C++11 设计模式0. 设计模式的基本概念,设计模式的准则,如何学习设计模式,24种设计模式的分为3大类
一 设计模式的基本概念: 模式:指事物的标准样式 或者 理解成 针对特定问题的可重用解决方案。 设计模式,是在特定问题发生时的可重用解决方案。 设计模式一般用于大型项目中。 大型项目中,设计模式保证所设计的模块之间代码的灵…...
(十)C++自制植物大战僵尸游戏设置功能实现
植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/m0EtD 游戏设置 游戏设置功能是一个允许玩家根据个人喜好和设备性能来调整游戏各项参数的重要工具。游戏设置功能是为了让玩家能够根据自己的需求和设备性能来调整游戏,以获得最佳的游戏体验。不同的游戏和平…...
数据结构——通讯录(顺序表的实战项目)
(—).通讯录的功能 大家应该都十分了解通讯录的功能吧,无非就是对联系人的增添删除,还有信息的修改,并且联系人信息要包含名字,电话,性别,地址等。我把通讯录的功能总结如下&#x…...
IDEA运行Tomcat出现乱码问题解决汇总
最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...
从零实现富文本编辑器#5-编辑器选区模型的状态结构表达
先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...
用机器学习破解新能源领域的“弃风”难题
音乐发烧友深有体会,玩音乐的本质就是玩电网。火电声音偏暖,水电偏冷,风电偏空旷。至于太阳能发的电,则略显朦胧和单薄。 不知你是否有感觉,近两年家里的音响声音越来越冷,听起来越来越单薄? —…...
【JVM面试篇】高频八股汇总——类加载和类加载器
目录 1. 讲一下类加载过程? 2. Java创建对象的过程? 3. 对象的生命周期? 4. 类加载器有哪些? 5. 双亲委派模型的作用(好处)? 6. 讲一下类的加载和双亲委派原则? 7. 双亲委派模…...
