基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
1. 引言
在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。

2. 环境准备
首先,确保已安装以下 Python 库:
pip install torch torchvision pandas d2l
torch:PyTorch 核心库。torchvision:提供计算机视觉相关的工具。pandas:用于处理 CSV 文件。d2l:深度学习工具库,提供辅助函数。
3. 数据准备
竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public
3.1 数据集结构
假设数据集位于 classify-leaves 目录下,包含以下文件:
classify-leaves/
├── train.csv
├── test.csv
├── images/├── image1.jpg├── image2.jpg...
train.csv:包含训练图像的路径和标签。test.csv:包含测试图像的路径。
3.2 数据加载与预处理
import os
import pandas as pd
import randomimgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i
num2name:获取所有类别标签,并按类别数量排序。name2num:将类别名称映射到数字编号。
4. 自定义数据集类
为了加载数据,我们需要定义一个自定义数据集类 Leaf_data:
from torch.utils.data import Dataset
from d2l import torch as d2lclass Leaf_data(Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)
__getitem__:根据索引返回一个样本,包括图像和标签。__len__:返回数据集的长度。
5. 模型定义与初始化
我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:
import torch
import torchvision
from torch import nndef init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
init_weight:使用 Xavier 初始化方法初始化全连接层的权重。net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。
6. 训练过程
6.1 优化器与损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
trainer:使用 Adam 优化器,对全连接层使用更高的学习率。LR_con:使用余弦退火学习率调度器。loss:使用交叉熵损失函数。
6.2 训练函数
def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean() # 计算损失# 反向传播和优化trainer.zero_grad() # 梯度清零l.backward() # 反向传播trainer.step() # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc
train:训练模型,记录损失和准确率,并在验证集上评估模型。
7. 测试与结果保存
在测试集上进行预测,并保存结果到 CSV 文件:
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
test_dataloader:加载测试数据。res:保存预测结果到 CSV 文件。
8. 总结
本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:
- 自定义数据集类的实现。
- 使用预训练模型进行迁移学习。
- 训练模型并保存最佳模型。
- 在测试集上进行预测并生成提交文件。
希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊
完整代码
import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i# GPU 检查
def try_gpu():if torch.cuda.device_count() > 0:return torch.device('cuda')return torch.device('cpu')# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')def save_model(net):torch.save(net.state_dict(), model_path)# 自定义数据集类
class Leaf_data(Data.Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean() # 计算损失# 反向传播和优化trainer.zero_grad() # 梯度清零l.backward() # 反向传播trainer.step() # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc# 模型初始化
def init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(dataset=Leaf_data(imgpath, True, augs),lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
参考链接:
- PyTorch 官方文档
- torchvision 官方文档
- d2l 深度学习工具库
相关文章:
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试 1. 引言 在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试…...
算法之 数论
文章目录 质数判断质数3115.质数的最大距离 质数筛选204.计数质数2761.和等于目标值的质数对 2521.数组乘积中的不同质因数数目 质数 质数的定义:除了本身和1,不能被其他小于它的数整除,最小的质数是 2 求解质数的几种方法 法1,根…...
Java 大视界 -- 人工智能驱动下 Java 大数据的技术革新与应用突破(83)
💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…...
【04】RUST特性
文章目录 隐藏shadowing所有权ownership堆区&栈区所有权规则变量&数据Copy Trait与Drop TraitCopy TraitDrop Trait移动克隆函数参数与返回值的所有权参数引用可变引用悬垂引用slice生命周期隐藏shadowing 有点像同名覆盖 let mut guess = String::new();let guess: u3…...
PlantUml常用语法
PlantUml常用语法,将从类图、流程图和序列图这三种最常用的图表类型开始。 类图 基础语法 在 PlantUML 中创建类图时,你可以定义类(Class)、接口(Interface)以及它们之间的关系,如继承&#…...
保存字典类型的文件用什么格式比较好
保存 Python 字典类型的数据时,有几个常见的格式可以选择,这些格式都具有良好的可读性和提取内容的便利性。以下是几种推荐的格式: JSON 格式: 优点:JSON 格式非常适合存储和传输结构化数据,具有良好的跨平…...
开源模型应用落地-Qwen1.5-MoE-A2.7B-Chat与vllm实现推理加速的正确姿势(一)
一、前言 在人工智能技术蓬勃发展的当下,大语言模型的性能与应用不断突破边界,为我们带来前所未有的体验。Qwen1.5-MoE-A2.7B-Chat 作为一款备受瞩目的大语言模型,以其独特的架构和强大的能力,在自然语言处理领域崭露头角。而 vllm 作为高效的推理库,为模型的部署与推理提…...
一竞技瓦拉几亚S4预选:YB 2-0击败GG
在2月11号进行的PGL瓦拉几亚S4西欧区预选赛上,留在欧洲训练的YB战队以2-0击败GG战队晋级下一轮。双方对阵第二局:对线期YB就打出了优势,中期依靠卡尔带队进攻不断扩大经济优势,最终轻松碾压拿下比赛胜利,以下是对决战报。 YB战队在天辉。阵容是潮汐、卡尔、沙王、隐刺、发条。G…...
deepseek+kimi一键生成PPT
1、deepseek生成大纲内容 访问deepseek官方网站:https://www.deepseek.com/ 将你想要编写的PPT内容输入到对话框,点击【蓝色】发送按钮,让deepseek生成内容大纲,并以markdown形式输出。 等待deepseek生成内容完毕后,…...
mybatis 是否支持延迟加载?延迟加载的原理是什么?
1. MyBatis 是否支持延迟加载? 是的,MyBatis 支持延迟加载。延迟加载的主要功能是推迟数据加载的时机,直到真正需要时再去加载。这种方式能提高性能,尤其是在处理关系型数据时,可以避免不必要的数据库查询。 具体来说…...
【Android开发】安卓手机APP拍照并使用机器学习进行OCR文字识别
前言:点击手机APP上的拍照后,调取手机设备相机拍照并获取图片显示到手机APP页面,进行提取照片内的文字,并将识别结果显示在界面上,在离线模式下也可用。文末工程链接下载 演示视频: 目录 1.新建java项目 2.添加依赖 3. MainActivity.java文件 4.activity_main.xml 文…...
力扣 15.三数之和
题目: 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k,同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意:答案中不可以包含重复的…...
机器学习:二分类和多分类
1. 二分类(Binary Classification) 定义 二分类是指将输入样本分成两个互斥的类别。例如: 邮件 spam 或不是 spam。病人是有病或健康。物品是正品或假货。实现方法 二分类任务可以通过多种算法实现,包括: 逻辑回归(Logistic Regression):通过sigmoid函数将输出值映射…...
安科瑞光伏发电防逆流解决方案——守护电网安全,提升能源效率
安科瑞 华楠 18706163979 在当今大力发展清洁能源的时代背景下,光伏发电作为一种可持续的能源解决方案, 正得到越来越广泛的应用。然而,光伏发电过程中出现的逆流问题,给电网的安全稳定 运行带来了诸多挑战。若不能有效解决&…...
ml5.js框架实现AI图片识别
ml5.js ml5.js 提供了简单的接口来加载和使用机器学习模型,如图像分类、文本生成、姿态估计等,不需要深入理解底层的数学原理或复杂的编程技巧 ml5.js 构建在 TensorFlow.js 之上,提供了一系列预训练模型和简易的 API 接口 图片识别 先进行一…...
HDFS应用-后端存储cephfs-文件存储和对象存储数据双向迁移
DistCp(分布式拷贝)是用于大规模集群内部和集群之间拷贝的工具。 它使用Map/Reduce实现文件分发,错误处理和恢复,以及报告生成。 它把文件和目录的列表作为map任务的输入,每个任务会完成源列表中部分文件的拷贝 配置/…...
关于atomic 是否是线程安全的问题
在 Objective - C 里,atomic 特性并不能保证对象是完全线程安全的,下面从其基本原理、部分线程安全场景以及局限性来详细说明: 先看一个例子 #import <Foundation/Foundation.h>interface MyClass : NSObject property (atomic, assi…...
在实体机和wsl2中安装docker、使用GPU
正常使用docker和gpu,直接命令行安装dcoker和,nvidia-container-toolkit。区别在于,后者在于安装驱动已经cuda加速时存在系统上的差异。 1、安装gpu驱动 在实体机中,安装cuda加速包,我们直接安装 driver 和 cuda 即可…...
HTTP3.0:QUIC协议详解
文章目录 HTTP3.0:QUIC协议详解QUIC是什么QUIC为什么这么快**连接建立快:一见钟情型协议****拥抱UDP:轻装上阵****多路复用:一条路走到黑****更智能的丢包处理****内置加密****网络切换无压力****拥塞控制更智能** QUIC的应用场景QUIC未来会取…...
【EXCEL】【VBA】处理GI Log获得Surf格式的CONTOUR DATA
【EXCEL】【VBA】处理GI Log获得Surf格式的CONTOUR DATA data source1: BH coordination tabledata source2:BH layer tableprocess 1:Collect BH List To Layer Tableprocess 2:match Reduced Level from "Layer"+"BH"data source1: BH coordination…...
wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...
Java 语言特性(面试系列2)
一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
逻辑回归:给不确定性划界的分类大师
想象你是一名医生。面对患者的检查报告(肿瘤大小、血液指标),你需要做出一个**决定性判断**:恶性还是良性?这种“非黑即白”的抉择,正是**逻辑回归(Logistic Regression)** 的战场&a…...
在鸿蒙HarmonyOS 5中实现抖音风格的点赞功能
下面我将详细介绍如何使用HarmonyOS SDK在HarmonyOS 5中实现类似抖音的点赞功能,包括动画效果、数据同步和交互优化。 1. 基础点赞功能实现 1.1 创建数据模型 // VideoModel.ets export class VideoModel {id: string "";title: string ""…...
Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...
Java入门学习详细版(一)
大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
均衡后的SNRSINR
本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt 根发送天线, n r n_r nr 根接收天线的 MIMO 系…...
