当前位置: 首页 > news >正文

pytorch迁移学习训练图像分类

pytorch迁移学习训练图像分类

  • 一、环境配置
  • 二、迁移学习关键代码
  • 三、完整代码
  • 四、结果对比

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型

一、环境配置

1,安装所需的包

pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装Pytorch

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,创建目录

import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')

4,下载数据集压缩包(下载之后需要解压数据集)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip

二、迁移学习关键代码

以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:

  • 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
  • 选择二:微调训练所有层。

适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关

model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
  • 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())

三、完整代码

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train')	# 测试集路径
test_path = os.path.join(dataset_dir, 'val')	# 测试集路径from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)# 定义数据加载器DataLoader
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from torchvision import models
import torch.optim as optim# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc	# 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())    # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 训练配置
model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 训练轮次 Epoch
EPOCHS = 30# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model.train()for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)           # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward()                   # 损失函数对神经网络权重反向传播求梯度optimizer.step()                  # 优化更新神经网络权重# 测试集上初步测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重

四、结果对比

调用不同迁移学习得到的模型对比测试集准确率

# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')# 测试集上进行测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%
在这里插入图片描述
而所有权重随机的选择三测试集准确率为 43.228%
43.228

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。

相关文章:

pytorch迁移学习训练图像分类

pytorch迁移学习训练图像分类 一、环境配置二、迁移学习关键代码三、完整代码四、结果对比 代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄 讲解视频:Pytorch迁移学习训练自己的图像分类模型 一、环境配置 1,安装所需的包 pip install …...

SQL 如何提取多级分类目录

前言 POI数据处理,原始数据为csv格式,整理入库至PostGreSQL,本例使用PostGreSQL13版本。 一、POI POI(一般作为Point of Interest的缩写,也有Point of Information的说法),通常称作兴趣点&am…...

从中序遍历和后序遍历构建二叉树

题目描述 106. 从中序与后序遍历序列构造二叉树 中等 1.1K 相关企业 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入&#xff1…...

《计算机视觉中的多视图几何》笔记(11)

11 Computation of the Fundamental Matrix F F F 本章讲述如何用数值方法在已知若干对应点的情况下求解基本矩阵 F F F。 文章目录 11 Computation of the Fundamental Matrix F F F11.1 Basic equations11.1.1 The singularity constraint11.1.2 The minimum case – sev…...

UE5 ChaosVehicles载具研究

一、基本组成 载具Actor类名称:WheeledVehiclePawn Actor最原始的结构 官方增加了两个摇臂相机,可以像驾驶游戏那样切换多机位、旋转观察 选择骨骼网格体、动画蓝图类、开启物理模拟 二、SportsCar_Pawn 角阻尼:物体旋转的阻力。数值越大…...

数据通信——应用层(域名系统)

引言 TCP到此就告一段落,这也意味着传输层结束了,紧随其后的就是TCP/IP五层架构的应用层。操作系统、编程语言、用户的可视化界面等等都要通过应用层来体现。应用层和我们息息相关,我们使用电子设备娱乐或办公时,接触到的就是应用…...

Visual Studio 更新:远程文件管理器

Visual Studio 中的远程文件管理器可以用来访问远程机器上的文件和文件夹,通过 Visual Studio 自带的连接管理器,可以实现不离开开发环境直接访问远程系统,这确实十分方便。 自从此功能发布以来,VS 开发团队努力工作,…...

ChatGPT追祖寻宗:GPT-3技术报告要点解读

论文地址:Language Models are Few-Shot Learners 往期相关文章: ChatGPT追祖寻宗:GPT-1论文要点解读_五点钟科技的博客-CSDN博客ChatGPT追祖寻宗:GPT-2论文要点解读_五点钟科技的博客-CSDN博客 本文的标题之所以取名技术报告而不…...

java easyexcel 导出多级表头

maven <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>${easyexcel.version}</version> </dependency> 导出行的对象 import com.alibaba.excel.annotation.ExcelIgnore; import …...

rar格式转换zip格式,如何做?

平时大家压缩文件时对压缩包格式可能没有什么要求&#xff0c;但是&#xff0c;可能因为工作需要&#xff0c;我们要将压缩包格式进行转换&#xff0c;那么我们如何将rar格式转换为其他格式呢&#xff1f;方法如下&#xff1a; 工具&#xff1a;WinRAR 打开WinRAR&#xff0c…...

Java中的构造方法

在Java中&#xff0c;构造方法是类的特殊方法&#xff0c;用于初始化对象的实例变量和执行其他必要的操作&#xff0c;以便使对象能够正确地工作。构造方法与类同名&#xff0c;没有返回类型&#xff0c;并且在创建对象时自动调用。 以下是构造方法的一些基本特性&#xff1a;…...

【Java】fastjson

Fastjson简介 Fastjson是阿里巴巴的团队开发的一款Java语言实现的JSON解析器和生成器&#xff0c;它具有简单易用、高性能、高可用性等优点&#xff0c;适用于Java开发中的数据解析和生成。Fastjson的主要特点包括&#xff1a; 简单易用&#xff1a;Fastjson提供了简单易用的…...

JMeter之脚本录制

【软件测试面试突击班】如何逼自己一周刷完软件测试八股文教程&#xff0c;刷完面试就稳了&#xff0c;你也可以当高薪软件测试工程师&#xff08;自动化测试&#xff09; 前言&#xff1a; 对于一些JMeter初学者来说&#xff0c;录制脚本可能是最容易掌握的技能之一。…...

计算机网络的相关知识点总结

1.谈一谈对OSI七层模型和TCP/IP四层模型的理解&#xff1f; 不管是OSI七层模型亦或是TCP/IP四层模型&#xff0c;它们的提出都有一个共同的目的&#xff1a;通过分层来将复杂问题细化&#xff0c;通过各个层级之间的相互配合来更好的解决计算机中出现的问题。 说到分层&#xf…...

WPF实现轮播图(图片、视屏)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

【Vue.js】使用Element搭建首页导航左侧菜单

目录 Mock.js 是什么 有什么好处 安装mockjs ​编辑 引入mockjs mockjs使用 login-mock Bus事物总线 首页导航栏与左侧菜单搭建 结合总线完成组件通讯 Mock.js 是什么 Mock.js是一个用于生成随机数据的模拟数据生成器。它可以帮助开发人员模拟接口请求&#xff0c;生…...

Spring MVC常见面试题

Spring MVC简介 Spring MVC框架是以请求为驱动&#xff0c;围绕Servlet设计&#xff0c;将请求发给控制器&#xff0c;然后通过模型对象&#xff0c;分派器来展示请求结果视图。简单来说&#xff0c;Spring MVC整合了前端请求的处理及响应。 Servlet 是运行在 Web 服务器或应用…...

Java基础面试题精选:深入探讨哈希表、链表和接口等

目录 1.ArrayList和LinkedList有什么区别&#xff1f;&#x1f512; 2.ArrayList和Vector有什么区别&#xff1f;&#x1f512; 3.抽象类和普通类有什么区别&#xff1f;&#x1f512; 4.抽象类和接口有什么区别&#xff1f;&#x1f512; 5.HashMap和Hashtable有什么区别&…...

Spark计算框架

Spark计算框架 一、Spark概述二、Spark的安装部署&#xff08;安装部署Spark的Cluster Manager-资源调度管理器的&#xff09;1、Spark的安装模式1.1、Spark&#xff08;单节点&#xff09;本地安装1.2 Spark的Standalone部署模式的伪分布式安装1.3Spark的YARN部署模式1.4Spark…...

mybatis缓存源码分析

mybatis缓存源码分析 背景 ​ 在java程序与数据库交互的过程中永远存在着性能瓶颈,所以需要一直进行优化.而我们大部分会直接将目标放到数据库优化,其实我们应该先从宏观上去解决问题进而再去解决微观上的问题.性能瓶颈体现在什么地方呢?第一网络通信开销,网络数据传输通信.…...

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…...

装饰模式(Decorator Pattern)重构java邮件发奖系统实战

前言 现在我们有个如下的需求&#xff0c;设计一个邮件发奖的小系统&#xff0c; 需求 1.数据验证 → 2. 敏感信息加密 → 3. 日志记录 → 4. 实际发送邮件 装饰器模式&#xff08;Decorator Pattern&#xff09;允许向一个现有的对象添加新的功能&#xff0c;同时又不改变其…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

AspectJ 在 Android 中的完整使用指南

一、环境配置&#xff08;Gradle 7.0 适配&#xff09; 1. 项目级 build.gradle // 注意&#xff1a;沪江插件已停更&#xff0c;推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码&#xff0c;因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存&#xff0c;无论是github还是gittee&#xff0c;都是一种基于git去保存代码的形式&#xff0c;这样保存代码…...

Unity UGUI Button事件流程

场景结构 测试代码 public class TestBtn : MonoBehaviour {void Start(){var btn GetComponent<Button>();btn.onClick.AddListener(OnClick);}private void OnClick(){Debug.Log("666");}}当添加事件时 // 实例化一个ButtonClickedEvent的事件 [Formerl…...

保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!

目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...