ResNet18果蔬图像识别分类
关于深度实战社区
我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万+粉丝,拥有2篇国家级人工智能发明专利。
社区特色:深度实战算法创新
获取全部完整项目数据集、代码、视频教程,请进入官网:zzgcz.com。竞赛/论文/毕设项目辅导答疑,v:zzgcz_com
1. 项目简介
本项目的目标是开发一个基于ResNet18深度学习模型的果蔬图像分类系统。随着现代农业与人工智能的结合,智能果蔬分类技术在供应链、生产和销售管理中扮演了越来越重要的角色。本项目的背景源于提升果蔬分类效率的需求,通过使用计算机视觉技术自动识别和分类不同种类的果蔬。项目使用了经典的卷积神经网络ResNet18,它在图像识别领域表现出色,尤其适合处理果蔬这种复杂且多样化的视觉数据。ResNet18凭借其深度残差结构,能够在保留模型性能的前提下有效减少梯度消失问题,使其在实际应用中稳定高效。通过训练大量果蔬图像数据,模型可以准确区分不同类别,从而实现智能化的自动分类,提升效率并减少人工误差。本项目的应用场景广泛,包括农业自动化、智能超市货架、果蔬质量检测等领域。
2.技术创新点摘要
数据处理的精细化调整:在数据集的处理方面,项目通过自定义数据预处理脚本(如split_dataset.py
和statistic_mean_std.py
),进一步优化了图像的输入。在statistic_mean_std.py
中,项目统计了训练集图像的每个通道的均值和标准差,用于后续数据归一化操作,这种归一化能显著提高模型的收敛速度和预测精度。这种针对特定领域图像(果蔬图像)的数据标准化处理,为模型提供了更具鲁棒性的输入数据。
自定义学习率调度和LARS优化器:项目在训练策略上使用了自定义学习率衰减策略和LARS优化器(在lars.py
和lr_sched.py
中实现)。LARS优化器(Layer-wise Adaptive Rate Scaling)针对大批量训练进行了优化,特别适用于高维度数据和大规模训练任务。结合学习率衰减策略,可以在训练过程中动态调整学习率,有效避免模型陷入局部最优解并加速收敛。这种策略不仅提高了模型的训练效率,还能进一步提升模型的泛化性能。
位置嵌入技术的引入:在pos_embed.py
文件中,项目引入了二维正弦-余弦位置嵌入技术(sine-cosine position embedding),这种技术常用于Transformer模型,但在此被应用于卷积神经网络中。这一创新点可能是为了增加模型对图像空间位置信息的敏感性,尤其是在处理具有一定几何形状和空间结构的果蔬图像时,能有效提升模型的感知能力。
数据增强和裁剪技术:crop.py
文件中实现了对输入图像的多种裁剪操作,保证了模型在处理不同尺寸、比例的图像时,仍能保持高精度的分类性能。这种多样化的数据增强方式,能够增加数据的多样性,提升模型的鲁棒性。
3. 数据集与预处理
本项目使用的果蔬图像数据集来源于公开的农业领域图像数据集,包含了多种不同种类的果蔬图像。该数据集的特点是图像种类丰富,覆盖了常见的蔬菜和水果类别,图像质量较高且具备良好的多样性,包括不同光照条件、角度和背景的变化。这样的数据集不仅能够训练出准确的分类模型,还可以通过增强模型的泛化能力,使其在处理未见过的果蔬图像时依然保持良好的表现。
在数据预处理过程中,项目首先对图像进行了归一化操作,利用statistic_mean_std.py
脚本计算了数据集中所有图像的每个通道的均值和标准差。通过对图像进行归一化处理,将像素值调整到相同的尺度范围内(通常是[0, 1]或[-1, 1]),从而提高模型的训练效率和收敛速度。
此外,项目引入了数据增强技术,以增加模型的鲁棒性。具体操作包括图像的随机裁剪、旋转、缩放以及色彩调整等,这些操作能够有效地增加训练数据的多样性,防止模型过拟合。在crop.py
中,图像被调整到统一的尺寸,确保输入网络的图像具有一致的维度。此外,利用随机裁剪技术,生成不同大小和比例的图像,从而增加模型在处理不同视角和尺度图像时的适应能力。
特征工程方面,项目主要依靠深度学习模型自动提取特征,并没有进行传统的手动特征提取。然而,通过自定义的归一化和数据增强步骤,确保了输入模型的数据质量,提升了模型的学习效率和泛化能力。
4. 模型架构
模型结构的逻辑: 本项目使用了经典的深度残差网络ResNet18作为基础架构,适用于果蔬图像分类任务。ResNet18由多个残差模块组成,这些模块允许信息通过跳跃连接(skip connections)在网络中传播,从而避免了深层网络中常见的梯度消失问题。其核心结构包括:
输入层:处理输入图像(通常是RGB图像,尺寸为224x224x3),将其传递到卷积层。
卷积层1:第一层是7x7的卷积核,步长为2,输出一个经过空间降采样的特征图。数学公式如下:
Z ( 1 ) = W ( 1 ) ∗ X + b ( 1 ) Z^{(1)} = W^{(1)} * X + b^{(1)} Z(1)=W(1)∗X+b(1)
其中,X是输入图像,W(1)是卷积核,∗表示卷积运算。
最大池化层:紧接着卷积层的是3x3的最大池化层,进一步减少图像尺寸并保留显著特征。
残差模块:ResNet18由4组残差块组成,每个块包含两个卷积层和一条跳跃连接。跳跃连接的引入使得输出为:
Z ( l + 2 ) = Z ( l ) + f ( W ( l + 1 ) ∗ Z ( l ) + b ( l + 1 ) ) Z^{(l+2)} = Z^{(l)} + f(W^{(l+1)} * Z^{(l)} + b^{(l+1)}) Z(l+2)=Z(l)+f(W(l+1)∗Z(l)+b(l+1))
这里,fff是激活函数(ReLU),而Z(l+2)是通过跳跃连接后的输出。这种结构允许网络层数加深的同时保持信息流动。
全连接层:经过所有卷积和池化操作后,特征图被展平,传递到全连接层,进行分类。假设输入有n个类别,输出为n维的向量,表示每个类别的预测概率。
模型的整体训练流程:
- 训练数据加载:通过
DataLoader
加载经过数据增强处理的果蔬图像,并将其传递给模型进行训练。
- 损失函数:使用交叉熵损失函数(CrossEntropyLoss): L = − ∑ i = 1 n y i log ( p i ) L = -\sum_{i=1}^{n} y_i \log(p_i) L=−i=1∑nyilog(pi)其中,yi是真实标签,pi是模型预测的概率。
- 优化器:项目使用了AdamW优化器,结合自定义的学习率衰减策略进行梯度更新。
- 训练循环:每个epoch内,模型通过前向传播计算输出,使用损失函数计算误差,反向传播更新权重。
- 评估指标:训练结束后,通过准确率(top-1和top-5 accuracy)来评估模型在验证集上的表现。准确率的计算公式为: Accuracy = 正确预测的样本数 总样本数 \text{Accuracy} = \frac{\text{正确预测的样本数}}{\text{总样本数}} Accuracy=总样本数正确预测的样本数
5. 核心代码详细讲解
1. 数据增强与裁剪 (crop.py)
该代码实现了随机尺寸裁剪功能,确保输入图像的多样性,从而提高模型的鲁棒性。
class RandomResizedCrop(transforms.RandomResizedCrop):"""RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.""" @staticmethoddef get_params(img, scale, ratio):width, height = F._get_image_size(img)area = height * width
- 解释:此函数定义了随机裁剪图像的参数,
scale
确定图像缩放的范围,ratio
确定宽高比。图像大小和面积先通过此函数计算。
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
- 解释:此部分代码根据输入图像的面积和随机缩放比例,计算出目标裁剪区域和宽高比。
torch.empty(1).uniform_()
用于生成随机数以确定新的宽高比。
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
- 解释:这里计算了裁剪区域的宽度和高度,确保符合目标裁剪比例。
2. 统计数据均值和标准差 (statistic_mean_std.py)
此脚本用于计算训练集中所有图像的均值和标准差,用于后续数据归一化操作。
train_files = glob.glob(os.path.join('train', '*', '*.jpg'))
print(f'total {len(train_files)} files for training')
- 解释:首先,该代码使用
glob
函数查找训练集中所有图像文件,并打印出文件总数。
result = []
for file in train_files:img = Image.open(file).convert('RGB')img = np.array(img).astype(np.uint8)img = img / 255.result.append(img)
- 解释:这段代码加载每个图像,将其转换为RGB格式并归一化到0-1范围,然后将归一化后的图像数据存储在
result
列表中。
mean = np.mean(result, axis=(0, 1, 2))
std = np.std(result, axis=(0, 1, 2))
print(mean)
print(std)
- 解释:最后,计算训练集中图像的每个通道的均值和标准差,用于后续归一化处理。
3. 自定义学习率调度 (lr_sched.py)
此代码实现了自定义学习率衰减策略,使用余弦退火和预热技术。
def adjust_learning_rate(optimizer, epoch, args):"""Decay the learning rate with half-cycle cosine after warmup"""if epoch < args.warmup_epochs:lr = args.lr * epoch / args.warmup_epochselse:lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
- 解释:这里采用了一种带有预热期的学习率衰减策略。在预热阶段,学习率线性上升,之后通过余弦函数在剩余训练阶段逐渐衰减。
for param_group in optimizer.param_groups:if "lr_scale" in param_group:param_group["lr"] = lr * param_group["lr_scale"]else:param_group["lr"] = lr
- 解释:该部分代码为优化器的每个参数组更新学习率,确保不同参数组可以应用不同的学习率缩放比例。
4. LARS优化器 (lars.py)
这是一个自定义实现的LARS优化器,适用于大规模训练任务。
class LARS(torch.optim.Optimizer):def init(self, params, lr=1e-3, momentum=0.9, weight_decay=0, dampening=0, nesterov=False):defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov)super(LARS, self).__init__(params, defaults)
- 解释:LARS优化器初始化时接受学习率、动量和权重衰减等参数,并使用这些默认设置初始化优化器。
for p in group['params']:if p.grad is None:continuedp = p.grad.add(p, alpha=group['weight_decay'])dp = dp.mul(1.0 / torch.norm(p))param_state = self.state[p]if 'mu' not in param_state:param_state['mu'] = torch.zeros_like(p)mu = param_state['mu']mu.mul_(group['momentum']).add_(dp)p.add_(mu, alpha=-group['lr'])
- 解释:这里实现了LARS的核心部分:首先计算权重的梯度更新,然后根据权重的范数对梯度进行缩放,最后利用动量更新模型参数。
6. 模型优缺点评价
优点:
- ResNet18架构的强大性能:ResNet18通过引入残差模块,有效解决了深层网络中梯度消失的问题,使得该模型在处理大规模图像数据时能够实现高效的训练和准确的分类。对于果蔬图像分类任务,这种深度残差网络能够捕捉到图像中的细节特征,提升模型的识别效果。
- 数据增强与归一化处理:项目中采用了随机裁剪、图像缩放等数据增强方法,有效增加了训练数据的多样性,防止过拟合。归一化处理(均值和标准差的计算与应用)确保了输入数据的尺度统一,进一步加速了模型收敛。
- 自定义优化器与学习率调度:项目使用了LARS优化器,这对于大规模训练任务非常有利。此外,自定义的余弦退火学习率调度与预热策略,帮助模型在不同训练阶段动态调整学习率,提升训练效率并防止陷入局部最优。
缺点:
- 模型复杂度较高:虽然ResNet18在分类任务上表现优秀,但其复杂的残差结构增加了模型的计算成本,对硬件资源要求较高,可能不适用于计算资源有限的场景。
- 依赖大规模数据:模型的性能在大规模训练数据上表现良好,但如果训练数据有限,模型可能无法充分学习,导致表现不佳。
- 缺乏实时性:虽然模型可以处理果蔬图像分类任务,但并未针对实时性进行优化,可能在一些实时应用场景下存在延迟。
改进方向:
- 模型结构优化:可以尝试更轻量级的网络,如MobileNet或EfficientNet,以减少计算开销并保持较高的准确率,特别是针对移动设备或资源有限的场景。
- 超参数调整:进一步优化学习率、权重衰减、批大小等超参数,有助于提升训练速度和模型表现。
- 数据增强扩展:可以引入更多的数据增强方法,如对比度调整、颜色抖动等,以进一步提升模型的鲁棒性和泛化能力。
↓↓↓更多热门推荐:
LSTM预测未来30天销售额
基于小波变换与稀疏表示优化的RIE数据深度学习预测模型
全部项目数据集、代码、教程进入官网zzgcz.com
相关文章:

ResNet18果蔬图像识别分类
关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝,拥有2篇国家级人工智能发明专利。 社区特色…...
深度强化学习中收敛图的横坐标是steps还是episode?
在深度强化学习(Deep Reinforcement Learning, DRL)的收敛图中,横坐标选择 steps 或者 episodes 主要取决于算法的设计和实验的需求,两者的差异和使用场景如下: Steps(步数): 定义&a…...

一个真实可用的登录界面!
需要工具: MySQL数据库、vscode上的php插件PHP Server等 项目结构: login | --backend | --database.sql |--login.php |--welcome.php |--index.html |--script.js |--style.css 项目开展 index.html: 首先需要一个静态网页&#x…...
Vue中watch监听属性的一些应用总结
【1】vue2中watch的应用 ① 简单监视 在 Vue 2 中,如果你不需要深度监视,即只需监听顶层属性的变化,可以使用简写形式来定义 watch。这种方式更加简洁,适用于大多数基本场景。 示例代码 假设你有一个 Vue 组件,其中…...

MongoDB-aggregate流式计算:带条件的关联查询使用案例分析
在数据库的查询中,是一定会遇到表关联查询的。当两张大表关联时,时常会遇到性能和资源问题。这篇文章就是用一个例子来分享MongoDB带条件的关联查询发挥的作用。 假设工作环境中有两张MongoDB集合:SC_DATA(学生基本信息集合&…...

Redis数据库与GO(一):安装,string,hash
安装包地址:https://github.com/tporadowski/redis/releases 建议下载zip版本,解压即可使用。解压后,依次打开目录下的redis-server.exe和redis-cli.exe,redis-cli.exe用于输入指令。 一、基本结构 如图,redis对外有个…...
expressjs,实现上传图片,返回图片链接
在 Express.js 中实现图片上传并返回图片链接,你通常需要使用一个中间件来处理文件上传,比如 multer。multer 是一个 node.js 的中间件,用于处理 multipart/form-data 类型的表单数据,主要用于上传文件。 以下是一个简单的示例&a…...

爬虫——XPath基本用法
第一章XML 一、xml简介 1.什么是XML? 1,XML指可扩展标记语言 2,XML是一种标记语言,类似于HTML 3,XML的设计宗旨是传输数据,而非显示数据 4,XML标签需要我们自己自定义 5,XML被…...
常见排序算法汇总
排序算法汇总 这篇文章说明下排序算法,直接开始。 1.冒泡排序 最简单直观的排序算法了,新手入门的第一个排序算法,也非常直观,最大的数字像泡泡一样一个个的“冒”到数组的最后面。 算法思想:反复遍历要排序的序列…...

Golang | Leetcode Golang题解之第459题重复的子字符串
题目: 题解: func repeatedSubstringPattern(s string) bool {return kmp(s s, s) }func kmp(query, pattern string) bool {n, m : len(query), len(pattern)fail : make([]int, m)for i : 0; i < m; i {fail[i] -1}for i : 1; i < m; i {j : …...
0.计网和操作系统
0.计网和操作系统 熟悉计算机网络和操作系统知识,包括 TCP/IP、UDP、HTTP、DNS 协议等。 常见的页面置换算法: 先进先出(FIFO)算法:将最早进入内存的页面替换出去。最近最少使用(LRU)算法&am…...

探索Prompt Engineering:开启大型语言模型潜力的钥匙
前言 什么是Prompt?Prompt Engineering? Prompt可以理解为向语言模型提出的问题或者指令,它是激发模型产生特定类型响应的“触发器”。 Prompt Engineering,即提示工程,是近年来随着大型语言模型(LLM,Larg…...
滚雪球学Oracle[3.3讲]:数据定义语言(DDL)
全文目录: 前言一、约束的高级使用1.1 主键(Primary Key)案例演示:定义主键 1.2 唯一性约束(Unique)案例演示:定义唯一性约束 1.3 外键(Foreign Key)案例演示:…...

ssrf学习(ctfhub靶场)
ssrf练习 目录 ssrf类型 漏洞形成原理(来自网络) 靶场题目 第一题(url探测网站下文件) 第二关(使用伪协议) 关于http和file协议的理解 file协议 http协议 第三关(端口扫描)…...
ElasticSearch之网络配置
对官方文档Networking的阅读笔记。 ES集群中的节点,支持处理两类通信平面 集群内节点之间的通信,官方文档称之为transport layer。集群外的通信,处理客户端下发的请求,比如数据的CRUD,检索等,官方文档称之…...
【C语言进阶】系统测试与调试
1. 引言 在开始本教程的深度学习之前,我们需要了解整个教程的目标及其结构,以及为何进阶学习是提升C语言技能的关键。 目标和结构: 教程目标:本教程旨在通过系统化的学习,从单元测试、系统集成测试到调试技巧…...
多个单链表的合成
建立两个非递减有序单链表,然后合并成一个非递增有序的单链表。 注意:建立非递减有序的单链表,需要采用创建单链表的算法 输入格式: 1 9 5 7 3 0 2 8 4 6 0 输出格式: 9 8 7 6 5 4 3 2 1 输入样例: 在这里给出一组输入。例如…...

『建议收藏』ChatGPT Canvas功能进阶使用指南!
大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,专注于分享AI全维度知识,包括但不限于AI科普,AI工…...

Ollama 运行视觉语言模型LLaVA
Ollama的LLaVA(大型语言和视觉助手)模型集已更新至 1.6 版,支持: 更高的图像分辨率:支持高达 4 倍的像素,使模型能够掌握更多细节。改进的文本识别和推理能力:在附加文档、图表和图表数据集上进…...

gdb 调试 linux 应用程序的技巧介绍
使用 gdb 来调试 Linux 应用程序时,可以显著提高开发和调试的效率。gdb(GNU 调试器)是一款功能强大的调试工具,适用于调试各类 C、C 程序。它允许我们在运行程序时检查其状态,设置断点,跟踪变量值的变化&am…...

基于当前项目通过npm包形式暴露公共组件
1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹,并新增内容 3.创建package文件夹...

《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
Xen Server服务器释放磁盘空间
disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...

Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

pgsql:还原数据库后出现重复序列导致“more than one owned sequence found“报错问题的解决
问题: pgsql数据库通过备份数据库文件进行还原时,如果表中有自增序列,还原后可能会出现重复的序列,此时若向表中插入新行时会出现“more than one owned sequence found”的报错提示。 点击菜单“其它”-》“序列”,…...

WinUI3开发_使用mica效果
简介 Mica(云母)是Windows10/11上的一种现代化效果,是Windows10/11上所使用的Fluent Design(设计语言)里的一个效果,Windows10/11上所使用的Fluent Design皆旨在于打造一个人类、通用和真正感觉与 Windows 一样的设计。 WinUI3就是Windows10/11上的一个…...

开源 vGPU 方案:HAMi,实现细粒度 GPU 切分
本文主要分享一个开源的 GPU 虚拟化方案:HAMi,包括如何安装、配置以及使用。 相比于上一篇分享的 TimeSlicing 方案,HAMi 除了 GPU 共享之外还可以实现 GPU core、memory 得限制,保证共享同一 GPU 的各个 Pod 都能拿到足够的资源。…...
OCC笔记:TDF_Label中有多个相同类型属性
注:OCCT版本:7.9.1 TDF_Label中有多个相同类型的属性的方案 OCAF imposes the restriction that only one attribute type may be allocated to one label. It is necessary to take into account the design of the application data tree. For exampl…...