基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
1. 引言
在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。

2. 环境准备
首先,确保已安装以下 Python 库:
pip install torch torchvision pandas d2l
torch:PyTorch 核心库。torchvision:提供计算机视觉相关的工具。pandas:用于处理 CSV 文件。d2l:深度学习工具库,提供辅助函数。
3. 数据准备
竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public
3.1 数据集结构
假设数据集位于 classify-leaves 目录下,包含以下文件:
classify-leaves/
├── train.csv
├── test.csv
├── images/├── image1.jpg├── image2.jpg...
train.csv:包含训练图像的路径和标签。test.csv:包含测试图像的路径。
3.2 数据加载与预处理
import os
import pandas as pd
import randomimgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i
num2name:获取所有类别标签,并按类别数量排序。name2num:将类别名称映射到数字编号。
4. 自定义数据集类
为了加载数据,我们需要定义一个自定义数据集类 Leaf_data:
from torch.utils.data import Dataset
from d2l import torch as d2lclass Leaf_data(Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)
__getitem__:根据索引返回一个样本,包括图像和标签。__len__:返回数据集的长度。
5. 模型定义与初始化
我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:
import torch
import torchvision
from torch import nndef init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
init_weight:使用 Xavier 初始化方法初始化全连接层的权重。net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。
6. 训练过程
6.1 优化器与损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
trainer:使用 Adam 优化器,对全连接层使用更高的学习率。LR_con:使用余弦退火学习率调度器。loss:使用交叉熵损失函数。
6.2 训练函数
def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean() # 计算损失# 反向传播和优化trainer.zero_grad() # 梯度清零l.backward() # 反向传播trainer.step() # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc
train:训练模型,记录损失和准确率,并在验证集上评估模型。
7. 测试与结果保存
在测试集上进行预测,并保存结果到 CSV 文件:
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
test_dataloader:加载测试数据。res:保存预测结果到 CSV 文件。
8. 总结
本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:
- 自定义数据集类的实现。
- 使用预训练模型进行迁移学习。
- 训练模型并保存最佳模型。
- 在测试集上进行预测并生成提交文件。
希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊
完整代码
import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i# GPU 检查
def try_gpu():if torch.cuda.device_count() > 0:return torch.device('cuda')return torch.device('cpu')# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')def save_model(net):torch.save(net.state_dict(), model_path)# 自定义数据集类
class Leaf_data(Data.Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean() # 计算损失# 反向传播和优化trainer.zero_grad() # 梯度清零l.backward() # 反向传播trainer.step() # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc# 模型初始化
def init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(dataset=Leaf_data(imgpath, True, augs),lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
参考链接:
- PyTorch 官方文档
- torchvision 官方文档
- d2l 深度学习工具库
相关文章:
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试 1. 引言 在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试…...
【STM32系列】利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程)
ps.源码放在最后面 设计FIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程) 设计IIR滤波器 MATLAB配置 设计步骤 首先在命令行窗口输入"filterDesigner",接着就会跳出以下界面…...
如何在本地部署deepseek?
1、打开ollama官网,点download(下载需要翻墙 https://ollama.com/ 2、双击下载好的OllamaSetup.exe,一直点下一步即可。 3、winR 输入cmd,打开命令提示符,输入ollama。有以下提示即安装完成。 4、可以根据 nvidia-…...
AJAX项目——数据管理平台
黑马程序员视频地址: 黑马程序员——数据管理平台 前言 功能: 1.登录和权限判断 2.查看文章内容列表(筛选,分页) 3.编辑文章(数据回显) 4.删除文章 5.发布文章(图片上传࿰…...
MarsCode AI插件在IntelliJ IDEA中使用
文章目录 前言一、MarsCode是什么?二、下载三、使用1、登录2、操作界面3、生成代码4、解释代码5、注释代码6、生成单测7、智能修复8、代码补全 总结 前言 随着 AI 技术浪潮席卷而来,各类 AI 工具呈爆发式涌现,深度融入我们的日常与职场&…...
如何将网站提交百度收录完整SEO教程
百度收录是中文网站获取流量的重要渠道。本文以我的网站,www.mnxz.fun(当然现在没啥流量) 为例,详细讲解从提交收录到自动化维护的全流程。 一、百度收录提交方法 1. 验证网站所有权 1、登录百度搜索资源平台 2、选择「用户中心…...
基于 SVPWM 的异步电机直接转矩控制系统的研究与仿真
标题:基于 SVPWM 的异步电机直接转矩控制系统的研究与仿真 内容:1.摘要 摘要:本文主要研究了基于 SVPWM 的异步电机直接转矩控制系统。首先,介绍了异步电机直接转矩控制的基本原理和 SVPWM 技术的特点。然后,详细阐述了系统的设计方法&#…...
C# OpenCV机器视觉:SoftNMS非极大值抑制
嘿,你知道吗?阿强最近可忙啦!他正在处理一个超级棘手的问题呢,就好像在一个混乱的战场里,到处都是乱糟糟的候选框,这些候选框就像一群调皮的小精灵,有的重叠在一起,让阿强头疼不已。…...
生信云服务器:让生物信息学分析更高效、更简单【附带西柚云优惠码】
随着生物信息学的快速发展,基因组测序、单细胞分析等复杂任务逐渐成为研究者们的日常工作。然而,个人电脑在处理这些任务时往往面临性能瓶颈,如内存不足、运算速度慢等问题,导致分析任务频繁失败或崩溃。为了解决这一难题…...
【清晰教程】通过Docker为本地DeepSeek-r1部署WebUI界面
【清晰教程】本地部署DeepSeek-r1模型-CSDN博客 目录 安装Docker 配置&检查 Open WebUI 部署Open WebUI 安装Docker 完成本地DeepSeek-r1的部署后【清晰教程】本地部署DeepSeek-r1模型-CSDN博客,通过Docker为本地DeepSeek-r1部署WebUI界面。 访问Docker官…...
AI知识库和全文检索的区别
1、AI知识库的作用 AI知识库是基于人工智能技术构建的智能系统,能够理解、推理和生成信息。它的核心作用包括: 1.1 语义理解 自然语言处理(NLP):AI知识库能够理解用户查询的语义,而不仅仅是关键词匹配。 …...
Flink-序列化
一、概述 几乎每个Flink作业都必须在其运算符之间交换数据,由于这些记录不仅可以发送到同一JVM中的另一个实例,还可以发送到单独的进程,因此需要先将记录序列化为字节。类似地,Flink的堆外状态后端基于本地嵌入式RocksDB实例&…...
快速部署 DeepSeek R1 模型
1. DeepSeek R1 模型的介绍 DeepSeek R1 模型是专为自然语言处理(NLP)和其他复杂任务设计的先进大规模深度学习模型 ,其高效的架构设计是一大亮点,能够更高效地提取特征,减少冗余计算。这意味着在处理海量数据时&…...
Java全栈项目实战:在线课程评价系统开发
一、项目概述 在线课程评价系统是一款基于Spring Boot Vue3的全栈应用,面向高校师生提供课程评价、教学反馈、数据可视化分析等功能。系统包含Web管理端和用户门户,日均承载10万课程数据,支持高并发访问和实时数据更新。 项目核心价值&…...
数据库系统概念第六版记录 四
1.sql组成 SQL 是最有影响力的商用市场化的关系查询语言。SQL 语言包括几个部分: 数据定义语言(DDL) ,它提供了定义关系模式、删除关系以及修改关系模式的命令。 数据操纵语言(DML) ,它包括查询语言,以及往数据库中插入元组、从数据库中删…...
无人机飞行试验大纲
无人机飞行试验大纲 编制日期:2025年02月11日 一、试验目的与背景 本次无人机飞行试验旨在验证无人机的飞行性能、控制系统稳定性、机体结构强度以及各项任务执行能力。随着无人机技术在各个领域的广泛应用,对其性能进行全面、系统的测试显得…...
DeepSeek在FPGA/IC开发中的创新应用与未来潜力
随着人工智能技术的飞速发展,以DeepSeek为代表的大语言模型(LLM)正在逐步渗透到传统硬件开发领域。在FPGA(现场可编程门阵列)和IC(集成电路)开发这一技术密集型行业中,DeepSeek凭借其…...
DeepSeek-V3 的核心技术创新
DeepSeek-V3 的核心技术创新 flyfish DeepSeek-V3 的核心技术创新主要体现在其架构设计和训练目标上,通过 多头潜在注意力(MLA)、DeepSeekMoE 架构、无辅助损失的负载均衡策略 和 多 Token 预测训练目标(MTP) 1. 多…...
函数指针(Function Pointer)与 typedef int (*FuncPtr)(int, int);typedef与using(更推荐)
C 函数指针(Function Pointer)详解 函数指针是指向函数的指针,它可以存储函数地址,并通过该指针调用函数。函数指针在回调函数、策略模式、动态函数调用等场景中非常有用。 1. 什么是函数指针? 函数指针是一个指向函…...
【AI时代】以聊天框的模式与本地部署DeepSeek交互 (Docker方式-Open WebUI)
一、本地部署DeepSeek 参考地址:(含资源下载) https://blog.csdn.net/Bjxhub/article/details/145536134二、安装Docker https://www.docker.com/ 三、拉取Open WebUI 镜像 docker pull ghcr.io/open-webui/open-webui:main 四、启动并验证 启动: docker run …...
【Elasticsearch】监控与管理:集群监控指标
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
鸿蒙接入支付宝SDK后模拟器无法运行,报错error: install parse native so failed.
鸿蒙项目接入支付宝后,运行提示error: install parse native so failed. 该问题可能由于设备支持的 Abi 类型与 C 工程中的不匹配导致. 官网error: install parse native so failed.错误解决办法 根据官网提示在模块build-profile.json5中添加“x86_64”依然报错 问…...
react redux用法学习
参考资料: https://www.bilibili.com/video/BV1ZB4y1Z7o8 https://cn.redux.js.org/tutorials/essentials/part-5-async-logic AI工具:deepseek,通义灵码 第一天 安装相关依赖: 使用redux的中间件: npm i react-redu…...
Maven 在 Eclipse 中的使用指南
Maven 在 Eclipse 中的使用指南 引言 Maven 是一个强大的项目管理和构建自动化工具,它可以帮助开发者更高效地管理项目依赖、构建和测试。Eclipse 作为一款流行的集成开发环境(IDE),与 Maven 的结合使用大大提高了 Java 项目的开发效率。本文将详细介绍如何在 Eclipse 中…...
【Matlab优化算法-第13期】基于多目标优化算法的水库流量调度
一、前言 水库流量优化是水资源管理中的一个重要环节,通过合理调度水库流量,可以有效平衡防洪、发电和水资源利用等多方面的需求。本文将介绍一个水库流量优化模型,包括其约束条件、目标函数以及应用场景。 二、模型概述 水库流量优化模型…...
Redis 集群(Cluster)和基础的操作 部署实操篇
三主三从 集群概念 Redis 的哨兵模式,提高了系统的可用性,但是正在用来存储数据的还是 master 和 slave 节点,所有的数据都需要存储在单个 master 和 salve 节点中。 如果数据量很大,接近超出了 master / slave 所在机器的物理内…...
[2025年最新]2024.3版本idea无法安装插件问题解决
背景 随着大模型的持续发展,特别年前年后deepseek的优异表现,编程过程中,需要解决ai来辅助编程,因此需要安装一些大模型插件 问题描述 在线安装插件的时候会遇到以下问题: 1.数据一直在加载,加载的很满 2.点…...
elasticsearch安装插件analysis-ik分词器(深度研究docker内elasticsearch安装插件的位置)
最近在学习使用elasticsearch,但是在安装插件ik的时候遇到许多问题。 所以在这里开始对elasticsearch做一个深度的研究。 首先提供如下链接: https://github.com/infinilabs/analysis-ik/releases 我们下载elasticsearch-7-17-2的Linux x86_64版本 …...
golang 开启HTTP代理认证
内部网路不能直接访问外网接口,可以通过代理发送HTTP请求。 HTTP代理服务需要进行认证。 package cmdimport ("fmt""io/ioutil""log""net/http""net/url""strings" )// 推送CBC07功能 func main() {l…...
【Unity3D】UGUI的anchoredPosition锚点坐标
本文直接以实战去理解锚点坐标,围绕着将一个UI移动到另一个UI位置的需求进行说明。 (anchoredPosition)UI锚点坐标,它是UI物体的中心点坐标,以UI物体锚点为中心的坐标系得来,UI锚点坐标受锚点(Anchors Min…...
