当前位置: 首页 > article >正文

ResNet50实战:用Fruits-360数据集训练自己的水果分类模型(附完整代码)

ResNet50实战用Fruits-360数据集训练自己的水果分类模型附完整代码在计算机视觉领域图像分类是最基础也最实用的任务之一。无论是工业质检、医疗影像分析还是零售商品识别都需要可靠的分类模型作为支撑。而水果分类作为典型的细粒度图像识别问题对模型的特征提取能力提出了更高要求。本文将带您从零开始使用经典的ResNet50架构和Fruits-360数据集构建一个专业级的水果分类系统。1. 环境准备与数据探索1.1 开发环境配置首先需要搭建支持GPU加速的深度学习环境。推荐使用conda创建独立的Python环境conda create -n fruit_classifier python3.8 conda activate fruit_classifier pip install torch torchvision pillow pandas tqdm matplotlib对于硬件配置建议至少满足GPUNVIDIA GTX 1060 6GB或更高内存16GB以上存储空间至少20GB可用空间数据集解压后约14GB1.2 Fruits-360数据集解析Fruits-360是一个专业的水果蔬菜图像数据集包含131种不同类别的水果和蔬菜超过8.3万张高质量图像统一背景所有样本都在白色背景下拍摄多角度拍摄每个水果都有不同角度的照片数据集目录结构如下Fruits-360/ ├── Training/ │ ├── Apple Braeburn/ │ ├── Apple Crimson Snow/ │ └── ...其他类别 └── Test/ ├── Apple Braeburn/ ├── Apple Crimson Snow/ └── ...其他类别提示数据集可从Kaggle直接下载解压后确保Training和Test目录位于同一父目录下2. 数据预处理与增强2.1 自定义数据加载器使用PyTorch的Dataset类创建自定义数据加载器from torchvision import transforms from torch.utils.data import Dataset from PIL import Image import os class FruitsDataset(Dataset): def __init__(self, root_dir, transformNone, trainTrue): self.root_dir os.path.join(root_dir, Training if train else Test) self.transform transform self.classes sorted(os.listdir(self.root_dir)) self.class_to_idx {cls: i for i, cls in enumerate(self.classes)} self.images self._load_images() def _load_images(self): images [] for cls in self.classes: cls_dir os.path.join(self.root_dir, cls) for img_name in os.listdir(cls_dir): img_path os.path.join(cls_dir, img_name) images.append((img_path, self.class_to_idx[cls])) return images def __len__(self): return len(self.images) def __getitem__(self, idx): img_path, label self.images[idx] image Image.open(img_path).convert(RGB) if self.transform: image self.transform(image) return image, label2.2 数据增强策略针对水果分类任务设计以下增强策略train_transform transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])关键增强技术说明增强技术参数设置作用RandomResizedCrop224x224模拟不同拍摄距离ColorJitterbrightness0.2应对光照变化RandomRotation15度增强角度不变性3. ResNet50模型定制3.1 模型架构调整加载预训练ResNet50并修改最后一层import torch.nn as nn from torchvision import models class FruitResNet(nn.Module): def __init__(self, num_classes131): super(FruitResNet, self).__init__() self.base_model models.resnet50(pretrainedTrue) in_features self.base_model.fc.in_features self.base_model.fc nn.Linear(in_features, num_classes) def forward(self, x): return self.base_model(x)3.2 迁移学习策略采用分阶段训练方法冻结阶段只训练最后的全连接层微调阶段解冻所有层进行整体微调def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad False model FruitResNet(num_classes131) set_parameter_requires_grad(model, feature_extractingTrue) # 仅优化最后一层 optimizer torch.optim.Adam(model.fc.parameters(), lr0.001)4. 模型训练与评估4.1 训练循环实现完整的训练流程包含以下关键组件from tqdm import tqdm def train_model(model, dataloaders, criterion, optimizer, num_epochs25): best_acc 0.0 for epoch in range(num_epochs): print(fEpoch {epoch}/{num_epochs-1}) print(- * 10) for phase in [train, val]: if phase train: model.train() else: model.eval() running_loss 0.0 running_corrects 0 for inputs, labels in tqdm(dataloaders[phase]): inputs inputs.to(device) labels labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase train): outputs model(inputs) _, preds torch.max(outputs, 1) loss criterion(outputs, labels) if phase train: loss.backward() optimizer.step() running_loss loss.item() * inputs.size(0) running_corrects torch.sum(preds labels.data) epoch_loss running_loss / len(dataloaders[phase].dataset) epoch_acc running_corrects.double() / len(dataloaders[phase].dataset) print(f{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}) if phase val and epoch_acc best_acc: best_acc epoch_acc torch.save(model.state_dict(), best_model.pth) return model4.2 学习率调度策略采用余弦退火学习率调度from torch.optim.lr_scheduler import CosineAnnealingLR optimizer torch.optim.Adam(model.parameters(), lr0.001) scheduler CosineAnnealingLR(optimizer, T_max10, eta_min1e-5)典型训练过程参数配置参数值说明Batch Size32平衡内存和稳定性初始学习率0.001微调常用初始值Epochs50包含冻结和解冻阶段权重衰减1e-4防止过拟合5. 模型部署与优化5.1 模型量化与加速使用TorchScript导出优化后的模型model.eval() example_input torch.rand(1, 3, 224, 224).to(device) traced_script_module torch.jit.trace(model, example_input) traced_script_module.save(fruit_classifier.pt)5.2 构建推理API简单的Flask服务端实现from flask import Flask, request, jsonify from PIL import Image import io app Flask(__name__) model torch.jit.load(fruit_classifier.pt, map_locationcpu) model.eval() app.route(/predict, methods[POST]) def predict(): if file not in request.files: return jsonify({error: no file uploaded}), 400 file request.files[file].read() image Image.open(io.BytesIO(file)).convert(RGB) image val_transform(image).unsqueeze(0) with torch.no_grad(): output model(image) _, predicted torch.max(output, 1) return jsonify({class: classes[predicted.item()]}) if __name__ __main__: app.run(host0.0.0.0, port5000)5.3 性能优化技巧实际部署时可考虑以下优化TensorRT加速将模型转换为TensorRT引擎ONNX导出实现跨平台部署量化压缩8位整数量化减小模型体积缓存机制对高频访问类别实现结果缓存在NVIDIA T4 GPU上的性能对比优化方式推理延迟(ms)模型大小(MB)原始模型45.298.7TensorRT12.683.4INT8量化8.324.96. 常见问题与解决方案6.1 类别不平衡处理Fruits-360中各类别样本数量差异较大可采用以下策略from torch.utils.data import WeightedRandomSampler class_counts [len(os.listdir(fFruits-360/Training/{cls})) for cls in classes] class_weights 1. / torch.tensor(class_counts, dtypetorch.float) sample_weights class_weights[labels] sampler WeightedRandomSampler( weightssample_weights, num_sampleslen(sample_weights), replacementTrue )6.2 过拟合应对方案当验证集准确率停滞时可尝试增加正则化optimizer torch.optim.Adam(model.parameters(), lr0.001, weight_decay1e-4)早停机制patience 5 best_acc 0.0 epochs_no_improve 0 if val_acc best_acc: best_acc val_acc epochs_no_improve 0 else: epochs_no_improve 1 if epochs_no_improve patience: print(Early stopping!) break标签平滑criterion nn.CrossEntropyLoss(label_smoothing0.1)6.3 模型解释性分析使用Grad-CAM可视化模型关注区域from torchcam.methods import GradCAM cam_extractor GradCAM(model, base_model.layer4.2) with torch.no_grad(): out model(input_tensor) activation_map cam_extractor(out.squeeze(0).argmax().item(), out) # 叠加原始图像 result overlay_mask( to_pil_image(input_tensor.squeeze(0)), to_pil_image(activation_map[0].squeeze(0), modeF), alpha0.5 )在实际项目中我们发现模型对水果的纹理特征如苹果的条纹和形状轮廓如香蕉的弯曲度最为敏感。通过可视化分析可以验证模型是否学习了有意义的特征而非依赖背景等无关信息。

相关文章:

ResNet50实战:用Fruits-360数据集训练自己的水果分类模型(附完整代码)

ResNet50实战:用Fruits-360数据集训练自己的水果分类模型(附完整代码) 在计算机视觉领域,图像分类是最基础也最实用的任务之一。无论是工业质检、医疗影像分析还是零售商品识别,都需要可靠的分类模型作为支撑。而水果分…...

惊艳!Qwen3-4B-Instruct-2507文本生成效果实测:看看AI能写出什么

惊艳!Qwen3-4B-Instruct-2507文本生成效果实测:看看AI能写出什么 1. 开篇:认识这款强大的文本生成模型 Qwen3-4B-Instruct-2507是阿里开源的最新文本生成大模型,它在多个方面都有显著提升。简单来说,这个AI不仅能理解…...

QMCDecode:解放加密音乐的格式转换专家指南

QMCDecode:解放加密音乐的格式转换专家指南 【免费下载链接】QMCDecode QQ音乐QMC格式转换为普通格式(qmcflac转flac,qmc0,qmc3转mp3, mflac,mflac0等转flac),仅支持macOS,可自动识别到QQ音乐下载目录,默认转换结果存储…...

SecGPT-14B赋能教育行业:高校网络安全实验室AI教学平台搭建

SecGPT-14B赋能教育行业:高校网络安全实验室AI教学平台搭建 1. 引言:当网络安全教学遇上AI大模型 想象一下,在高校的网络安全实验室里,学生面对一个复杂的漏洞分析报告,不再需要花费数小时翻阅厚重的教材和零散的在线…...

PyTorch 2.8镜像实操手册:/workspace+/data+/output目录规范使用详解

PyTorch 2.8镜像实操手册:/workspace/data/output目录规范使用详解 1. 镜像环境概述 PyTorch 2.8深度学习镜像基于RTX 4090D 24GB显卡和CUDA 12.4深度优化,专为高性能计算任务设计。这个环境预装了完整的深度学习工具链,从基础框架到加速库…...

AI智能二维码工坊 vs 传统方案:OpenCV+QRCode性能对比评测

AI智能二维码工坊 vs 传统方案:OpenCVQRCode性能对比评测 二维码,这个黑白相间的小方块,早已渗透进我们生活的方方面面。从扫码支付到添加好友,从产品溯源到活动签到,它无处不在。作为开发者,我们经常需要…...

如何通过智能备份技术实现微信聊天记录的数据主权?本地化管理方案全解析

如何通过智能备份技术实现微信聊天记录的数据主权?本地化管理方案全解析 【免费下载链接】WeChatMsg 提取微信聊天记录,将其导出成HTML、Word、CSV文档永久保存,对聊天记录进行分析生成年度聊天报告 项目地址: https://gitcode.com/GitHub_…...

终极存储设备容量检测指南:如何用F3工具3分钟识别假冒U盘和SD卡

终极存储设备容量检测指南:如何用F3工具3分钟识别假冒U盘和SD卡 【免费下载链接】f3 F3 - Fight Flash Fraud 项目地址: https://gitcode.com/gh_mirrors/f3/f3 在数字存储时代,容量造假已成为困扰用户的普遍问题。F3(Fight Flash Fra…...

零成本商用开源字体解决方案:思源宋体全面应用指南

零成本商用开源字体解决方案:思源宋体全面应用指南 【免费下载链接】source-han-serif-ttf Source Han Serif TTF 项目地址: https://gitcode.com/gh_mirrors/so/source-han-serif-ttf 如何在商业项目中避免字体侵权风险?怎样才能不花一分钱获得专…...

3分钟彻底解决Windows安装错误2502/2503:AtlasOS一键修复方案揭秘 [特殊字符]

3分钟彻底解决Windows安装错误2502/2503:AtlasOS一键修复方案揭秘 🚀 【免费下载链接】Atlas 🚀 An open and lightweight modification to Windows, designed to optimize performance, privacy and security. 项目地址: https://gitcode.…...

StarVCenter单机版安装避坑指南:从BIOS设置到虚拟机创建的完整流程

StarVCenter单机版安装全流程实战:从硬件准备到虚拟机管理的深度解析 在当今企业IT基础设施快速迭代的背景下,虚拟化技术已成为资源整合与管理的核心解决方案。StarVCenter作为一款国产化虚拟化管理平台,其单机版部署方案特别适合中小型业务场…...

如何构建企业级中文大语言模型平台:3大核心策略与实战指南

如何构建企业级中文大语言模型平台:3大核心策略与实战指南 【免费下载链接】Awesome-Chinese-LLM 整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主,包括底座模型,垂直领域微调及应用,数据…...

终极指南:OpenAI Python SDK推理强度参数调优实战

终极指南:OpenAI Python SDK推理强度参数调优实战 【免费下载链接】openai-python The official Python library for the OpenAI API 项目地址: https://gitcode.com/GitHub_Trending/op/openai-python 掌握OpenAI Python SDK推理强度参数配置,让…...

AI大语言模型其实就是一个归纳与演绎的概率机器

您这句话精准地概括了当前主流人工智能(尤其是大语言模型)的核心本质。它确实是一个基于海量数据,通过统计归纳来学习模式,并通过概率演绎来生成输出的机器。 但这一定义既是其强大能力的根源,也是其根本局限的边界。我们可以从三个层面来理解: 一、这句话为什么是精准…...

次元画室赋能微信小程序:开发个人AI画室应用

次元画室赋能微信小程序:开发个人AI画室应用 你有没有过这样的经历?脑子里闪过一个绝妙的画面,可能是某个角色的形象,或是一个奇幻的场景,但苦于不会画画,只能任由灵感溜走。或者,你随手画了个…...

OpenClaw备份与迁移:GLM-4.7-Flash项目完整转移指南

OpenClaw备份与迁移:GLM-4.7-Flash项目完整转移指南 1. 为什么需要完整的迁移方案 上周我的主力开发机突然硬盘故障,导致所有数据丢失。虽然OpenClaw本身是开源工具可以重装,但那些精心调试的配置文件、自定义技能和对接好的GLM-4.7-Flash模…...

UMAP降维技术:拓扑数据分析驱动的高效可视化方案

UMAP降维技术:拓扑数据分析驱动的高效可视化方案 【免费下载链接】umap Uniform Manifold Approximation and Projection 项目地址: https://gitcode.com/gh_mirrors/um/umap 在高维数据可视化领域,研究者长期面临"鱼和熊掌不可兼得"的…...

Phi-3-Mini-128K高并发服务架构设计:负载均衡与自动扩缩容策略

Phi-3-Mini-128K高并发服务架构设计:负载均衡与自动扩缩容策略 你是不是也遇到过这种情况?自己部署的AI模型服务,平时用着挺好,一旦用户量稍微上来点,或者有人发了个长请求,服务就卡死甚至直接挂掉。然后就…...

大模型遇“知识盲区“?RAG让它秒变“开卷考试“学霸!

过去一年,在落地RAG过程中,发现一个有意思的现象:很多人把AI当成了"万能百科全书",结果一问企业内部数据就抓瞎。 你有没有遇到过这样的情况: 问ChatGPT:“我们公司去年的销售额是多少&#xff1…...

HsMod:炉石传说体验增强插件技术解析与应用指南

HsMod:炉石传说体验增强插件技术解析与应用指南 【免费下载链接】HsMod Hearthstone Modify Based on BepInEx 项目地址: https://gitcode.com/GitHub_Trending/hs/HsMod HsMod作为基于BepInEx框架开发的炉石传说插件,通过非侵入式技术手段重构游…...

有关数组的学习

数组的概念简介数组是编程中最基础也最常用的数据结构之一,理解它能帮你高效管理一组同类型的数据。1. 什么是数组?核心概念同类型:数组里的所有元素必须是相同的数据类型(如全是 int 或全是 float)。连续内存&#xf…...

Win10系统代理服务器拒绝连接?3步搞定网络恢复(附图文详解)

Win10代理服务器连接故障排查指南:从原理到实战解决方案 当Windows 10突然弹出"代理服务器拒绝连接"的错误提示时,很多用户会感到手足无措。这种情况通常发生在系统更新后、网络环境变更时,或是某些应用程序擅自修改了系统设置。本…...

Chandra AI性能调优:GPU显存优化全攻略

Chandra AI性能调优:GPU显存优化全攻略 1. 引言 跑大模型最头疼的是什么?对,就是那个让人又爱又恨的GPU显存!明明买了张不错的显卡,结果跑个模型就提示"Out of Memory",这种经历想必很多朋友都…...

解锁DeerFlow:零基础搭建智能研究环境完全指南

解锁DeerFlow:零基础搭建智能研究环境完全指南 【免费下载链接】deer-flow DeerFlow is a community-driven framework for deep research, combining language models with tools like web search, crawling, and Python execution, while contributing back to th…...

3分钟上手!FrankMocap让普通摄像头变身专业动捕设备

3分钟上手!FrankMocap让普通摄像头变身专业动捕设备 【免费下载链接】frankmocap A Strong and Easy-to-use Single View 3D HandBody Pose Estimator 项目地址: https://gitcode.com/gh_mirrors/fr/frankmocap 在数字内容创作与交互设计领域,3D动…...

如何快速上手艾尔登法环存档编辑器:新手完整指南

如何快速上手艾尔登法环存档编辑器:新手完整指南 【免费下载链接】ER-Save-Editor Elden Ring Save Editor. Compatible with PC and Playstation saves. 项目地址: https://gitcode.com/GitHub_Trending/er/ER-Save-Editor ER-Save-Editor是一款专为《艾尔登…...

电脑风扇智能控制完全指南:从噪音烦恼到散热优化

电脑风扇智能控制完全指南:从噪音烦恼到散热优化 【免费下载链接】FanControl.Releases This is the release repository for Fan Control, a highly customizable fan controlling software for Windows. 项目地址: https://gitcode.com/GitHub_Trending/fa/FanC…...

阿里云服务器上Certbot更新Let‘s Encrypt证书总超时?一个更换公网IP的实战解决记录

阿里云服务器Certbot更新Lets Encrypt证书超时问题深度解析与实战解决 最近在阿里云北京区域的服务器上更新Lets Encrypt证书时,遇到了一个看似简单却令人困扰的问题:Certbot在续签证书时频繁报错,提示acme-v02.api.letsencrypt.org连接超时。…...

硬件突破:用OpenCore Legacy Patcher实现旧Mac的焕新体验

硬件突破:用OpenCore Legacy Patcher实现旧Mac的焕新体验 【免费下载链接】OpenCore-Legacy-Patcher 体验与之前一样的macOS 项目地址: https://gitcode.com/GitHub_Trending/op/OpenCore-Legacy-Patcher OpenCore Legacy Patcher是一款强大的开源工具&#…...

C# rtwpriv Wi-Fi定频工具

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录一、使用简介,说明#前言 对于无线产品,很多需要做CE,FCC,SRRC等认证,需要测试RF,像Realtek方案的Wi-Fi用到rtwpriv工具…...