pytorch迁移学习训练图像分类
pytorch迁移学习训练图像分类
- 一、环境配置
- 二、迁移学习关键代码
- 三、完整代码
- 四、结果对比
代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型
一、环境配置
1,安装所需的包
pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
2,安装Pytorch
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
3,创建目录
import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')
4,下载数据集压缩包(下载之后需要解压数据集)
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip
二、迁移学习关键代码
以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:
- 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
- 选择二:微调训练所有层。
适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关
model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
- 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
三、完整代码
import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train') # 测试集路径
test_path = os.path.join(dataset_dir, 'val') # 测试集路径from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)# 定义数据加载器DataLoader
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from torchvision import models
import torch.optim as optim# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc # 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters()) # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 训练配置
model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 训练轮次 Epoch
EPOCHS = 30# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model.train()for images, labels in train_loader: # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward() # 损失函数对神经网络权重反向传播求梯度optimizer.step() # 优化更新神经网络权重# 测试集上初步测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1) # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum() # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重
四、结果对比
调用不同迁移学习得到的模型对比测试集准确率
# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')# 测试集上进行测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1) # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum() # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))
结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%

而所有权重随机的选择三测试集准确率为 43.228%

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。
相关文章:
pytorch迁移学习训练图像分类
pytorch迁移学习训练图像分类 一、环境配置二、迁移学习关键代码三、完整代码四、结果对比 代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄 讲解视频:Pytorch迁移学习训练自己的图像分类模型 一、环境配置 1,安装所需的包 pip install …...
SQL 如何提取多级分类目录
前言 POI数据处理,原始数据为csv格式,整理入库至PostGreSQL,本例使用PostGreSQL13版本。 一、POI POI(一般作为Point of Interest的缩写,也有Point of Information的说法),通常称作兴趣点&am…...
从中序遍历和后序遍历构建二叉树
题目描述 106. 从中序与后序遍历序列构造二叉树 中等 1.1K 相关企业 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入࿱…...
《计算机视觉中的多视图几何》笔记(11)
11 Computation of the Fundamental Matrix F F F 本章讲述如何用数值方法在已知若干对应点的情况下求解基本矩阵 F F F。 文章目录 11 Computation of the Fundamental Matrix F F F11.1 Basic equations11.1.1 The singularity constraint11.1.2 The minimum case – sev…...
UE5 ChaosVehicles载具研究
一、基本组成 载具Actor类名称:WheeledVehiclePawn Actor最原始的结构 官方增加了两个摇臂相机,可以像驾驶游戏那样切换多机位、旋转观察 选择骨骼网格体、动画蓝图类、开启物理模拟 二、SportsCar_Pawn 角阻尼:物体旋转的阻力。数值越大…...
数据通信——应用层(域名系统)
引言 TCP到此就告一段落,这也意味着传输层结束了,紧随其后的就是TCP/IP五层架构的应用层。操作系统、编程语言、用户的可视化界面等等都要通过应用层来体现。应用层和我们息息相关,我们使用电子设备娱乐或办公时,接触到的就是应用…...
Visual Studio 更新:远程文件管理器
Visual Studio 中的远程文件管理器可以用来访问远程机器上的文件和文件夹,通过 Visual Studio 自带的连接管理器,可以实现不离开开发环境直接访问远程系统,这确实十分方便。 自从此功能发布以来,VS 开发团队努力工作,…...
ChatGPT追祖寻宗:GPT-3技术报告要点解读
论文地址:Language Models are Few-Shot Learners 往期相关文章: ChatGPT追祖寻宗:GPT-1论文要点解读_五点钟科技的博客-CSDN博客ChatGPT追祖寻宗:GPT-2论文要点解读_五点钟科技的博客-CSDN博客 本文的标题之所以取名技术报告而不…...
java easyexcel 导出多级表头
maven <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>${easyexcel.version}</version> </dependency> 导出行的对象 import com.alibaba.excel.annotation.ExcelIgnore; import …...
rar格式转换zip格式,如何做?
平时大家压缩文件时对压缩包格式可能没有什么要求,但是,可能因为工作需要,我们要将压缩包格式进行转换,那么我们如何将rar格式转换为其他格式呢?方法如下: 工具:WinRAR 打开WinRAR,…...
Java中的构造方法
在Java中,构造方法是类的特殊方法,用于初始化对象的实例变量和执行其他必要的操作,以便使对象能够正确地工作。构造方法与类同名,没有返回类型,并且在创建对象时自动调用。 以下是构造方法的一些基本特性:…...
【Java】fastjson
Fastjson简介 Fastjson是阿里巴巴的团队开发的一款Java语言实现的JSON解析器和生成器,它具有简单易用、高性能、高可用性等优点,适用于Java开发中的数据解析和生成。Fastjson的主要特点包括: 简单易用:Fastjson提供了简单易用的…...
JMeter之脚本录制
【软件测试面试突击班】如何逼自己一周刷完软件测试八股文教程,刷完面试就稳了,你也可以当高薪软件测试工程师(自动化测试) 前言: 对于一些JMeter初学者来说,录制脚本可能是最容易掌握的技能之一。…...
计算机网络的相关知识点总结
1.谈一谈对OSI七层模型和TCP/IP四层模型的理解? 不管是OSI七层模型亦或是TCP/IP四层模型,它们的提出都有一个共同的目的:通过分层来将复杂问题细化,通过各个层级之间的相互配合来更好的解决计算机中出现的问题。 说到分层…...
WPF实现轮播图(图片、视屏)
✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…...
【Vue.js】使用Element搭建首页导航左侧菜单
目录 Mock.js 是什么 有什么好处 安装mockjs 编辑 引入mockjs mockjs使用 login-mock Bus事物总线 首页导航栏与左侧菜单搭建 结合总线完成组件通讯 Mock.js 是什么 Mock.js是一个用于生成随机数据的模拟数据生成器。它可以帮助开发人员模拟接口请求,生…...
Spring MVC常见面试题
Spring MVC简介 Spring MVC框架是以请求为驱动,围绕Servlet设计,将请求发给控制器,然后通过模型对象,分派器来展示请求结果视图。简单来说,Spring MVC整合了前端请求的处理及响应。 Servlet 是运行在 Web 服务器或应用…...
Java基础面试题精选:深入探讨哈希表、链表和接口等
目录 1.ArrayList和LinkedList有什么区别?🔒 2.ArrayList和Vector有什么区别?🔒 3.抽象类和普通类有什么区别?🔒 4.抽象类和接口有什么区别?🔒 5.HashMap和Hashtable有什么区别&…...
Spark计算框架
Spark计算框架 一、Spark概述二、Spark的安装部署(安装部署Spark的Cluster Manager-资源调度管理器的)1、Spark的安装模式1.1、Spark(单节点)本地安装1.2 Spark的Standalone部署模式的伪分布式安装1.3Spark的YARN部署模式1.4Spark…...
mybatis缓存源码分析
mybatis缓存源码分析 背景 在java程序与数据库交互的过程中永远存在着性能瓶颈,所以需要一直进行优化.而我们大部分会直接将目标放到数据库优化,其实我们应该先从宏观上去解决问题进而再去解决微观上的问题.性能瓶颈体现在什么地方呢?第一网络通信开销,网络数据传输通信.…...
模型参数、模型存储精度、参数与显存
模型参数量衡量单位 M:百万(Million) B:十亿(Billion) 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的,但是一个参数所表示多少字节不一定,需要看这个参数以什么…...
ServerTrust 并非唯一
NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)
目录 一、👋🏻前言 二、😈sinx波动的基本原理 三、😈波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、🌊波动优化…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...
群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...
基于单片机的宠物屋智能系统设计与实现(论文+源码)
本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢,连接红外测温传感器,可实时精准捕捉宠物体温变化,以便及时发现健康异常;水位检测传感器时刻监测饮用水余量,防止宠物…...
如何通过git命令查看项目连接的仓库地址?
要通过 Git 命令查看项目连接的仓库地址,您可以使用以下几种方法: 1. 查看所有远程仓库地址 使用 git remote -v 命令,它会显示项目中配置的所有远程仓库及其对应的 URL: git remote -v输出示例: origin https://…...
