pytorch迁移学习训练图像分类
pytorch迁移学习训练图像分类
- 一、环境配置
- 二、迁移学习关键代码
- 三、完整代码
- 四、结果对比
代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型
一、环境配置
1,安装所需的包
pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
2,安装Pytorch
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
3,创建目录
import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')
4,下载数据集压缩包(下载之后需要解压数据集)
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip
二、迁移学习关键代码
以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:
- 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
- 选择二:微调训练所有层。
适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关
model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
- 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
三、完整代码
import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train') # 测试集路径
test_path = os.path.join(dataset_dir, 'val') # 测试集路径from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)# 定义数据加载器DataLoader
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from torchvision import models
import torch.optim as optim# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc # 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters()) # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 训练配置
model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 训练轮次 Epoch
EPOCHS = 30# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model.train()for images, labels in train_loader: # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward() # 损失函数对神经网络权重反向传播求梯度optimizer.step() # 优化更新神经网络权重# 测试集上初步测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1) # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum() # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重
四、结果对比
调用不同迁移学习得到的模型对比测试集准确率
# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')# 测试集上进行测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1) # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum() # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))
结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%

而所有权重随机的选择三测试集准确率为 43.228%

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。
相关文章:
pytorch迁移学习训练图像分类
pytorch迁移学习训练图像分类 一、环境配置二、迁移学习关键代码三、完整代码四、结果对比 代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄 讲解视频:Pytorch迁移学习训练自己的图像分类模型 一、环境配置 1,安装所需的包 pip install …...
SQL 如何提取多级分类目录
前言 POI数据处理,原始数据为csv格式,整理入库至PostGreSQL,本例使用PostGreSQL13版本。 一、POI POI(一般作为Point of Interest的缩写,也有Point of Information的说法),通常称作兴趣点&am…...
从中序遍历和后序遍历构建二叉树
题目描述 106. 从中序与后序遍历序列构造二叉树 中等 1.1K 相关企业 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入࿱…...
《计算机视觉中的多视图几何》笔记(11)
11 Computation of the Fundamental Matrix F F F 本章讲述如何用数值方法在已知若干对应点的情况下求解基本矩阵 F F F。 文章目录 11 Computation of the Fundamental Matrix F F F11.1 Basic equations11.1.1 The singularity constraint11.1.2 The minimum case – sev…...
UE5 ChaosVehicles载具研究
一、基本组成 载具Actor类名称:WheeledVehiclePawn Actor最原始的结构 官方增加了两个摇臂相机,可以像驾驶游戏那样切换多机位、旋转观察 选择骨骼网格体、动画蓝图类、开启物理模拟 二、SportsCar_Pawn 角阻尼:物体旋转的阻力。数值越大…...
数据通信——应用层(域名系统)
引言 TCP到此就告一段落,这也意味着传输层结束了,紧随其后的就是TCP/IP五层架构的应用层。操作系统、编程语言、用户的可视化界面等等都要通过应用层来体现。应用层和我们息息相关,我们使用电子设备娱乐或办公时,接触到的就是应用…...
Visual Studio 更新:远程文件管理器
Visual Studio 中的远程文件管理器可以用来访问远程机器上的文件和文件夹,通过 Visual Studio 自带的连接管理器,可以实现不离开开发环境直接访问远程系统,这确实十分方便。 自从此功能发布以来,VS 开发团队努力工作,…...
ChatGPT追祖寻宗:GPT-3技术报告要点解读
论文地址:Language Models are Few-Shot Learners 往期相关文章: ChatGPT追祖寻宗:GPT-1论文要点解读_五点钟科技的博客-CSDN博客ChatGPT追祖寻宗:GPT-2论文要点解读_五点钟科技的博客-CSDN博客 本文的标题之所以取名技术报告而不…...
java easyexcel 导出多级表头
maven <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>${easyexcel.version}</version> </dependency> 导出行的对象 import com.alibaba.excel.annotation.ExcelIgnore; import …...
rar格式转换zip格式,如何做?
平时大家压缩文件时对压缩包格式可能没有什么要求,但是,可能因为工作需要,我们要将压缩包格式进行转换,那么我们如何将rar格式转换为其他格式呢?方法如下: 工具:WinRAR 打开WinRAR,…...
Java中的构造方法
在Java中,构造方法是类的特殊方法,用于初始化对象的实例变量和执行其他必要的操作,以便使对象能够正确地工作。构造方法与类同名,没有返回类型,并且在创建对象时自动调用。 以下是构造方法的一些基本特性:…...
【Java】fastjson
Fastjson简介 Fastjson是阿里巴巴的团队开发的一款Java语言实现的JSON解析器和生成器,它具有简单易用、高性能、高可用性等优点,适用于Java开发中的数据解析和生成。Fastjson的主要特点包括: 简单易用:Fastjson提供了简单易用的…...
JMeter之脚本录制
【软件测试面试突击班】如何逼自己一周刷完软件测试八股文教程,刷完面试就稳了,你也可以当高薪软件测试工程师(自动化测试) 前言: 对于一些JMeter初学者来说,录制脚本可能是最容易掌握的技能之一。…...
计算机网络的相关知识点总结
1.谈一谈对OSI七层模型和TCP/IP四层模型的理解? 不管是OSI七层模型亦或是TCP/IP四层模型,它们的提出都有一个共同的目的:通过分层来将复杂问题细化,通过各个层级之间的相互配合来更好的解决计算机中出现的问题。 说到分层…...
WPF实现轮播图(图片、视屏)
✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…...
【Vue.js】使用Element搭建首页导航左侧菜单
目录 Mock.js 是什么 有什么好处 安装mockjs 编辑 引入mockjs mockjs使用 login-mock Bus事物总线 首页导航栏与左侧菜单搭建 结合总线完成组件通讯 Mock.js 是什么 Mock.js是一个用于生成随机数据的模拟数据生成器。它可以帮助开发人员模拟接口请求,生…...
Spring MVC常见面试题
Spring MVC简介 Spring MVC框架是以请求为驱动,围绕Servlet设计,将请求发给控制器,然后通过模型对象,分派器来展示请求结果视图。简单来说,Spring MVC整合了前端请求的处理及响应。 Servlet 是运行在 Web 服务器或应用…...
Java基础面试题精选:深入探讨哈希表、链表和接口等
目录 1.ArrayList和LinkedList有什么区别?🔒 2.ArrayList和Vector有什么区别?🔒 3.抽象类和普通类有什么区别?🔒 4.抽象类和接口有什么区别?🔒 5.HashMap和Hashtable有什么区别&…...
Spark计算框架
Spark计算框架 一、Spark概述二、Spark的安装部署(安装部署Spark的Cluster Manager-资源调度管理器的)1、Spark的安装模式1.1、Spark(单节点)本地安装1.2 Spark的Standalone部署模式的伪分布式安装1.3Spark的YARN部署模式1.4Spark…...
mybatis缓存源码分析
mybatis缓存源码分析 背景 在java程序与数据库交互的过程中永远存在着性能瓶颈,所以需要一直进行优化.而我们大部分会直接将目标放到数据库优化,其实我们应该先从宏观上去解决问题进而再去解决微观上的问题.性能瓶颈体现在什么地方呢?第一网络通信开销,网络数据传输通信.…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...
JVM虚拟机:内存结构、垃圾回收、性能优化
1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...
人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...
前端工具库lodash与lodash-es区别详解
lodash 和 lodash-es 是同一工具库的两个不同版本,核心功能完全一致,主要区别在于模块化格式和优化方式,适合不同的开发环境。以下是详细对比: 1. 模块化格式 lodash 使用 CommonJS 模块格式(require/module.exports&a…...
react更新页面数据,操作页面,双向数据绑定
// 路由不是组件的直接跳转use client,useEffect,useRouter,需3个结合, use client表示客户端 use client; import { Button,Card, Space,Tag,Table,message,Input } from antd; import { useEffect,useState } from react; impor…...
Redis专题-实战篇一-基于Session和Redis实现登录业务
GitHub项目地址:https://github.com/whltaoin/redisLearningProject_hm-dianping 基于Session实现登录业务功能提交版本码:e34399f 基于Redis实现登录业务提交版本码:60bf740 一、导入黑马点评后端项目 项目架构图 1. 前期阶段2. 后续阶段导…...
Java高级 |【实验八】springboot 使用Websocket
隶属文章:Java高级 | (二十二)Java常用类库-CSDN博客 系列文章:Java高级 | 【实验一】Springboot安装及测试 |最新-CSDN博客 Java高级 | 【实验二】Springboot 控制器类相关注解知识-CSDN博客 Java高级 | 【实验三】Springboot 静…...
