当前位置: 首页 > article >正文

# 基于PyTorch的食品图像分类系统:从训练到部署全流程指南

基于PyTorch的食品图像分类系统:从训练到部署全流程指南

本文将详细介绍如何使用PyTorch框架构建一个完整的食品图像分类系统,涵盖数据预处理、模型构建、训练优化以及模型保存与加载的全过程。

1. 系统概述

本系统实现了一个基于卷积神经网络(CNN)的食品图像分类器,主要特点包括:

  • 支持20种不同食品的分类
  • 使用数据增强提高模型泛化能力
  • 实现了完整的训练-验证-测试流程
  • 提供模型保存与加载功能

2. 数据准备与预处理

2.1 数据增强策略

在这里插入图片描述

我们为训练集和验证集分别设计了不同的数据增强策略:

data_transforms = {'train':  # 训练集  也可以使用PIL库  smote 训练集transforms.Compose([  # transforms.Compose用于将多个图像预处理操作整合在一起transforms.Resize([300, 300]),  # 使图像变换大小transforms.RandomRotation(45),  # 随机旋转,-42到45度之间随机选transforms.CenterCrop(256),  # 从中心开始裁剪[256.256]transforms.RandomHorizontalFlip(p=0.5),  # 随机水平旋转,随机概率为0.5transforms.RandomVerticalFlip(p=0.5),  # 随机垂直旋转,随机概率0.5transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 随机改变图像参数,参数分别表示 亮度、对比度、饱和度、色温transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),  # 将PIL图像或NumPy ndarray转换为tensor类型,并将像素值的范围从[0, 255]缩放到[0.0, 1.0],默认把通道维度放在前面transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 给定均值和标准差对图像进行标准化,前者为均值,后者为标准差,三个值表示三通道图像]),'valid':  # 验证集transforms.Compose([  # 整合图像处理的操作transforms.Resize([256, 256]),  # 缩放图像尺寸transforms.ToTensor(),  # 转换为torch类型transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化])
}

关键点说明

  • 训练集使用了丰富的数据增强来防止过拟合
  • 验证集只进行必要的尺寸调整和归一化
  • 使用ImageNet的均值和标准差进行归一化

2.2 自定义数据集类

我们创建了food_dataset类来管理数据:

class food_dataset(Dataset):  # food_dataset是自己创建的类名称,继承Dataset类def __init__(self, file_path, transform=None):  # 类的初始化,解析数据文件txt,file_path表示文件路径,transform可选的图像转换操作self.file_path = file_path  # 将文件地址传入self空间self.imgs = []self.labels = []self.transform = transform  # 将数据增强操作传入self空间with open(self.file_path) as f:  # 打开存放图片地址及其类别的文本文件train.txt,samples = [x.strip().split(' ') for x in f.readlines()]  # 遍历文件里的每一条数据,经过处理后存入sample列表,元祖的形式存放for img_path, label in samples:  # 遍历列表中的每个元组的每个元素self.imgs.append(img_path)  # 将图像的路径存入img列表self.labels.append(label)  # 将图片类别标签存入label列表# 初始化:把图片目录加载到self.def __len__(self):  # 类实例化对象后,可以使用len函数测量对象的个数return len(self.imgs)  # 返回数据集中样本的总数def __getitem__(self, idx):  # 关键,可通过索引idx的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx])  # 使用PIL库中的用法Image打开并识别图像,还不是tensorif self.transform:  # 判断是否有图像转换操作,上述定义默认为None,有则将pil图像数据转换为tensor类型image = self.transform(image)  # 图像处理为256*256,转换为tenorlabel = self.labels[idx]  # label还不是tensorlabel = torch.from_numpy(np.array(label, dtype=np.int64))  # 首先指定标签类型为int型,然后将其转换为numpy数组类型,然后再使用torch.from_numpy转换为torch类型return image, label  # 返回处理完的图片和标签

关键方法

  • __init__: 从文本文件加载图像路径和标签
  • __len__: 返回数据集大小
  • __getitem__: 按索引返回图像和标签

3. 模型架构设计

我们构建了一个三层的CNN模型:

class CNN(nn.Module):def __init__(self):  # 翰入大小 (3,256,256)super(CNN, self).__init__()self.conv1 = nn.Sequential(  # 将多个层组合成一起。nn.Conv2d(  # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序in_channels=3,  # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=16,  # 要得到几多少个特征图,卷积核的个数.kernel_size=5,  # 卷积核大小,5*5stride=1,  # 步长padding=2,  # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那p),  # 输出的特征图为 (16,256,256)nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域),输出结果为:(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),  # 输出(32,128,128)nn.ReLU(),nn.MaxPool2d(2)  # 输出)self.conv3 = nn.Sequential(nn.Conv2d(32, 128, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(128 * 64 * 64, 20)  # 全连接def forward(self, x):  # 前向传播x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)  # 输出(64,128,64,64)x = x.view(x.size(0), -1)output = self.out(x)return output  # 返回输出结果

架构特点

  1. 使用nn.Sequential组织网络层
  2. 每层包含卷积、ReLU激活和池化
  3. 最后一层全连接输出20个类别的概率

4. 模型训练与验证

4.1 训练流程

def train(dataloader, model, loss_fn, optimizer):  # 传入参数 打包的数据,卷积模型,损失函数,优化器model.train()  # 表示模型开始训练batch_size_num = 1for x, y in dataloader:  # 遍历打包的图片及其对应的标签,其中batch为每一个数据的编号x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPUpred = model.forward(x)  # 自动初始化 W权值loss = loss_fn(pred, y)  # 传入模型训练结果的预测值和真实值,通过交叉熵损失函数计算损失值L0optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度optimizer.step()  # 根据梯度更新网络参数loss = loss.item()  # 获取损失值if batch_size_num % 100 == 0:print(f"loss: {loss:>7f}[number:{batch_size_num}]")  # 打印损失值,右对齐,长度为7batch_size_num += 1  # 右下方传入的参数,表示训练轮数

4.2 验证流程

def test(dataloader, model, loss_fn):  # 定义一个test函数,用于测试模型性能global best_acc  # 定义一个全局变量size = len(dataloader.dataset)  # 返回打包的图片总数num_batches = len(dataloader)  # 返回打包的包的个数model.eval()  # 表示模型进入测试模式test_loss, correct = 0, 0  # 初始化两个值,一个用来存放总体损失值,一个存放预测准确的个数with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()时可以减少for x, y in dataloader:  # 遍历数据加载器中测试集图片的图片及其标签x, y = x.to(device), y.to(device)  # 传入GPUpred = model.forward(x)  # 前向传播,返回预测结果test_loss += loss_fn(pred, y).item()  # 计算所有的损失值的和,item表示将tensor类型值转化为python标量correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # 判断预测的值是等于真实值,返回布尔值,将其转换为0和1,然后求和# a = (pred.argmax(1)== y)  dim=1表示每一行中的最大值对应的索引号,dim=日表示每 b=(pred.argmax(1)==y).type(torch.float)test_loss /= num_batches  # 总体损失值除以数据条数得到平均损失值correct /= size  # 求准确率print(f"Test result:in Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")  # 表示准确率机器对应的损失值# acc_s.append(correct)# loss_s.append(test_loss)### 4.3 训练配置```python
# 初始化
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 数据加载
#training_data包含了本次需要训练的全部数据集
training_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度学习\train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度学习\test.txt', transform=data_transforms['valid'])train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)# 训练循环
epochs = 150  # 设置模型训练的轮数,不停更新模型参数,找到最优值
acc_s = []  # 初始化了两个空列表,用于存储模型在每个epoch结束时的准确率和损失值
loss_s = []
for t in range(epochs):  # 遍历轮数print(f"Epoch {t + 1}\n---------------------------")  # 表示轮数展示train(train_dataloader, model, loss_fn, optimizer)  # 调用函数train传入训练集数据加载器、初始化的模型、损失函数、优化器test(test_dataloader, model, loss_fn) 

运行结果

在这里插入图片描述

5. 模型保存与加载

5.1 保存模型

我们提供了两种保存方式:

# 方法1:仅保存模型参数(推荐)
torch.save(model.state_dict(), 'best.pth')# 方法2:保存整个模型
torch.save(model, 'best.pt')

5.2 加载模型

对应两种加载方式:

# 方法1:加载参数
model = CNN().to(device)
model.load_state_dict(torch.load('best.pth'))# 方法2:加载完整模型
model = torch.load('best.pt')

6. 模型测试与结果分析

我们实现了详细的测试函数:

def test_true(dataloader, model):correct = 0  # 正确预测的数量total = 0  # 总样本数量with torch.no_grad():  # 上下文管理器,关闭梯度运算for x, y in dataloader:  # 遍历打包好的图片及其标签x, y = x.to(device), y.to(device)  # 将其传入GPUpred = model.forward(x)  # 前向传播_, predicted = torch.max(pred, 1)  # 获取预测值的类别索引total += y.size(0)  # 累加总样本数量correct += (predicted == y).sum().item()  # 累加正确预测的数量result.append(predicted.item())  # 将预测值的结果转换成Python变量然后增加到列表labels.append(y.item())  # 同时将真实值的标签转变成Python标量然后存入labels列表accuracy = correct / total  # 计算准确率print(f'准确率: {accuracy:.4f}')  # 打印准确率# 调用测试函数
test_true(test_dataloader, model)  # 导入数据和模型
print('预测值:\t', result)
print('真实值:\t', labels)

运行结果

在这里插入图片描述

7. 总结

本文详细介绍了基于PyTorch的食品图像分类系统的完整实现流程,从数据准备到模型部署。该系统具有以下优势:

  1. 高效的数据处理:完善的数据增强和加载机制
  2. 可靠的模型架构:经过优化的CNN结构
  3. 完整的训练流程:包含训练、验证和测试
  4. 灵活的部署方案:提供多种模型保存方式

相关文章:

# 基于PyTorch的食品图像分类系统:从训练到部署全流程指南

基于PyTorch的食品图像分类系统:从训练到部署全流程指南 本文将详细介绍如何使用PyTorch框架构建一个完整的食品图像分类系统,涵盖数据预处理、模型构建、训练优化以及模型保存与加载的全过程。 1. 系统概述 本系统实现了一个基于卷积神经网络(CNN)的…...

v-html 显示富文本内容

返回数据格式&#xff1a; 只有图片名称 显示不出完整路径 解决方法&#xff1a;在接收数据后手动给img格式的拼接vite.config中的服务器地址 页面&#xff1a; <el-button click"">获取信息<el-button><!-- 弹出层 --> <el-dialog v-model&…...

【数学建模】孤立森林算法:异常检测的高效利器

孤立森林算法&#xff1a;异常检测的高效利器 文章目录 孤立森林算法&#xff1a;异常检测的高效利器1 引言2 孤立森林算法原理2.1 核心思想2.2 算法流程步骤一&#xff1a;构建孤立树(iTree)步骤二&#xff1a;构建孤立森林(iForest)步骤三&#xff1a;计算异常分数 3 代码实现…...

<项目代码>YOLO小船识别<目标检测>

项目代码下载链接 YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0…...

Crawl4AI:打破数据孤岛,开启大语言模型的实时智能新时代

当大语言模型遇见数据饥渴症 在人工智能的竞技场上&#xff0c;大语言模型&#xff08;LLMs&#xff09;正以惊人的速度进化&#xff0c;但其认知能力的跃升始终面临一个根本性挑战——如何持续获取新鲜、结构化、高相关性的数据。传统数据供给方式如同输血式营养支持&#xff…...

AI 技术发展:从起源到未来的深度剖析

一、AI 的起源与早期发展​ 人工智能&#xff08;AI&#xff09;作为计算机科学的重要分支&#xff0c;其诞生可以追溯到 20 世纪中叶。1943 年&#xff0c;艾伦・图灵提出图灵机的概念&#xff0c;为计算机科学和 AI 理论奠定了基础。1950 年&#xff0c;图灵又提出著名的图灵…...

jsconfig.json文件的作用

jsconfig.json文件的作用 ​ 为什么今天会谈到这个呢&#xff1f;有这么一个场景&#xff1a;我们每次开发项目时都会给路径配置别名&#xff0c;配完别名之后可以简化我们的开发&#xff0c;但是随之而来的就有一个问题&#xff0c;一般来说&#xff0c;当我们使用相对路径时…...

nodejs的包管理工具介绍,npm的介绍和安装,npm的初始化包 ,搜索包,下载安装包

nodejs的包管理工具介绍&#xff0c;npm的介绍和安装&#xff0c;npm的初始化包 &#xff0c;搜索包&#xff0c;下载安装包 &#x1f9f0; 一、Node.js 的包管理工具有哪些&#xff1f; 工具简介是否默认特点npmNode.js 官方的包管理工具&#xff08;Node Package Manager&am…...

常见的raid有哪些,使用场景是什么?

RAID&#xff08;Redundant Array of Independent Disks&#xff0c;独立磁盘冗余阵列&#xff09;是一种将多个物理硬盘组合成一个逻辑硬盘的技术&#xff0c;目的是通过数据冗余和/或并行访问提高性能、容错能力和存储容量。不同的 RAID 级别有不同的实现方式和应用场景。以下…...

【Spring Boot】MyBatis多表查询的操作:注解和XML实现SQL语句

1.准备工作 1.1创建数据库 &#xff08;1&#xff09;创建数据库&#xff1a; CREATE DATABASE mybatis_test DEFAULT CHARACTER SET utf8mb4;&#xff08;2&#xff09;使用数据库 -- 使⽤数据数据 USE mybatis_test;1.2 创建用户表和实体类 创建用户表 -- 创建表[⽤⼾表…...

金融数据分析(Python)个人学习笔记(12):网络爬虫

一、导入模块和函数 from bs4 import BeautifulSoup from urllib.request import urlopen import re from urllib.error import HTTPError from time import timebs4&#xff1a;用于解析HTML和XML文档的Python库。 BeautifulSoup&#xff1a;方便地从网页内容中提取和处理数据…...

[Android]豆包爱学v4.5.0小学到研究生 题目Ai解析

拍照解析答案 【应用名称】豆包爱学 【应用版本】4.5.0 【软件大小】95mb 【适用平台】安卓 【应用简介】豆包爱学&#xff0c;一般又称河马爱学教育平台app,河马爱学。 关于学习&#xff0c;你可能也需要一个“豆包爱学”这样的AI伙伴&#xff0c;它将为你提供全方位的学习帮助…...

Qt开发:软件崩溃时,如何生成dump文件

文章目录 一、程序崩溃时如何自动生成 Dump 文件二、支持多线程中的异常捕获三、在 DLL 中使用 Dump 捕获四、封装成可复用类五、MiniDumpWriteDump函数详解 一、程序崩溃时如何自动生成 Dump 文件 步骤一&#xff1a;包含必要的头文件 #include <Windows.h> #include …...

普罗米修斯Prometheus监控安装(mac)

普罗米修斯是后端数据监控平台&#xff0c;通过Node_exporter/mysql_exporter等收集数据&#xff0c;Grafana将数据用图形的方式展示出来 官网各平台下载 Prometheus安装&#xff08;mac&#xff09; &#xff08;1&#xff09;通过brew安装 brew install prometheus &…...

Python SQL 工具包:SQLAlchemy介绍

SQLAlchemy 是一个功能强大且灵活的 Python SQL 工具包和对象关系映射&#xff08;ORM&#xff09;库。它被广泛用于与关系型数据库进行交互&#xff0c;提供了从低级 SQL 表达式到高级 ORM 的完整工具链。SQLAlchemy 的设计目标是让开发者能够以 Pythonic 的方式操作数据库&am…...

Shader属性讲解+Cg语言讲解

CPU调用GPU传递数据 修改Render组件的material属性 在脚本中更改游戏物体材质颜色代码示例&#xff1a; using System.Collections; using System.Collections.Generic; using UnityEngine;public class TestFixedColor : MonoBehaviour {void Start(){//创建预制体GameObjec…...

基于LightGBM-TPE算法对交通事故严重程度的分析与可视化

基于LightGBM-TPE算法对交通事故严重程度的分析与可视化 原文&#xff1a; Analysis and visualization of accidents severity based on LightGBM-TPE 1. 引言部分 文章开篇强调了道路交通事故作为意外死亡的主要原因&#xff0c;引起了多学科领域的关注。分析事故严重性特…...

什么是CRM系统,它的作用是什么?CRM全面指南

CRM&#xff08;Customer Relationship Management&#xff0c;客户关系管理&#xff09;系统是一种专门用于集中管理客户信息、优化销售流程、提升客户满意度、支持精准营销、驱动数据分析决策、加强跨部门协同、提升客户生命周期价值的业务系统工具。其中&#xff0c;优化销售…...

MySQL 启动报错:InnoDB 表空间丢失问题及解决方法

MySQL 启动报错&#xff1a;InnoDB 表空间丢失问题及解决方法 在启动 MySQL 时&#xff0c;遇到了如下错误&#xff1a; 2025-01-16T12:43:28.341240Z 0 [ERROR] InnoDB: Tablespace 5975 was not found at ./my_jspt/sw_rtu_message_202408.ibd. 2025-01-16T12:43:28.341244…...

MYSQL之库的操作

创建数据库 语法很简单, 主要是看看选项(与编码相关的): CREATE DATABASE [IF NOT EXISTS] db_name [create_specification [, create_specification] ...] create_specification: [DEFAULT] CHARACTER SET charset_name [DEFAULT] COLLATE collation_name 1. 语句中大写的是…...

笔记本电脑研发笔记:BIOS,Driver,Preloader详记

在笔记本电脑的研发过程中&#xff0c;Driver&#xff08;驱动程序&#xff09;、BIOS&#xff08;基本输入输出系统&#xff09;和 Preloader&#xff08;预加载程序&#xff09;之间存在着密切的相互关系和影响&#xff0c;具体如下&#xff1a; 相互关系 BIOS 与 Preload…...

同样的html标记,不同语言的文本,显示的字体和粗细会不一样吗

同样的 HTML 标记&#xff0c;在不同语言的文本下&#xff0c;显示出来的字体和粗细确实可能会不一样&#xff0c;原因如下&#xff1a; &#x1f30d; 不同语言默认字体不同 浏览器字体回退机制 CSS 里写的字体如果当前系统不支持&#xff0c;就会回退到下一个&#xff0c;比如…...

JavaScript 笔记 --- part 5 --- Web API (part 3)

(webAPI part3) BOM 操作 JS 执行机制 javascript 是单线程的, 也就是说, 只能同时执行一个任务。 为了解决这个问题, 利用多核 CPU 的计算能力, HTML5 提出 Web Worker API, 允许 JavaScript 脚本创建多个线程, 并将任务分配给这些线程。 于是, JS 出现了同步和异步的概念。…...

Linux 下的网络管理(附加详细实验案例)

一、简单了解 NM&#xff08;NetworkManager&#xff09; 在 Linux 中&#xff0c;NM 是 NetworkManager 的缩写。它是一个用于管理网络连接的守护进程和工具集。 在 RHEL9 上&#xff0c;使用 NM 进行网络配置&#xff0c;ifcfg &#xff08;也称为文件&#xff09;将不再…...

基于SpringBoot的疫情居家检测管理系统(源码+数据库)

514基于SpringBoot的疫情居家检测管理系统&#xff0c;系统包含三种角色&#xff1a;管理员、用户、医生&#xff0c;主要功能如下。 【用户功能】 1. 首页&#xff1a;获取系统信息。 2. 论坛&#xff1a;参与居民讨论和分享信息。 3. 公告&#xff1a;查看社区发布的各类公告…...

关于系统架构思考,如何设计实现系统的高可用?

绪论、系统高可用的必要性 系统高可用为了保持业务连续性保障&#xff0c;以及停机成本量化&#xff0c;比如在以前的双十一当天如果出现宕机&#xff0c;那将会损失多少钱&#xff1f;比如最近几年Amazon 2021年30分钟宕机损失$5.6M。当然也有成功的案例&#xff0c;比如异地…...

MATLAB 控制系统设计与仿真 - 35

MATLAB鲁棒控制器分析 所谓鲁棒性是指控制系统在一定(结构&#xff0c;大小)的参数扰动下&#xff0c;维持某些性能的特征。 根据对性能的不同定义&#xff0c;可分为稳定鲁棒性(Robust stability)和性能鲁棒性(Robust performance)。 以闭环系统的鲁棒性作为目标设计得到的…...

性能比拼: Nginx vs Caddy

本内容是对知名性能评测博主 Anton Putra Nginx vs Caddy Performance 内容的翻译与整理, 有适当删减, 相关指标和结论以原作为准 引言 在本期视频中&#xff0c;我们将对比 Nginx 和 Caddy---一个用 Go 编写的 Web 服务器和反向代理。 在第一个测试中&#xff0c;我们会使用…...

C++项目-衡码云判项目演示

衡码云判项目是什么呢&#xff1f;简单来说就是这是一个类似于牛客、力扣等在线OJ系统&#xff0c;用户在网页编写代码&#xff0c;点击提交后传递给后端云服务器&#xff0c;云服务器将用户的代码和测试用例进行合并编译&#xff0c;返回结果到网页。 项目最大的两个亮点&…...

李宏毅NLP-6-seq2seqHMM

比较seq2seq和HMM Hidden Markov Model(HMM) 隐马尔可夫模型&#xff08;HMM&#xff09;在语音识别中的应用&#xff0c;具体内容如下&#xff1a; 整体流程&#xff1a; 左侧为语音信号&#xff08;标记为 “speech”&#xff09;&#xff0c;其特征表示为 X X X。中间蓝色模…...