当前位置: 首页 > news >正文

卷积神经网络实现运动鞋识别 - P5

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P5周:运动鞋识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 训练设备
    • 数据准备
      • 图像解压后的路径
      • 打印图像的参数
      • 展示图像
      • 图像的预处理
      • 创建数据集
      • 获取数据集的分类
      • 打乱数据的顺序,生成批次
    • 模型设计
    • 模型训练
      • 训练函数
      • 评估函数
      • 循环迭代部分
    • 模型效果展示
      • 训练过程图表展示
      • 载入最佳模式,随机选择图像进行预测
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn 
import torch.optim as optim # 优化器
import torch.nn.functional as F # 可以静态调用的方法from torchvision import datasets, transforms # 数据集创建、数据预处理方法
from torch.utils.data import DataLoader # DataLoader可以将数据集封装成批次数据import matplotlib.pyplot as plt
import numpy as np
from PIL import Image # 加载图片预览使用的库
from torchinfo import summary # 可以打印模型实际运行时的图
import copy # 深拷贝使用的库
import pathlib, random # 文件夹遍历和随机数

训练设备

# 声明一个全局设备对象,方便后面将数据和模型拷贝到设备中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

图像解压后的路径

train_path = 'train'
test_path = 'test'

打印图像的参数

train_pathlib = pathlib.Path(train_path)
train_image_list = list(train_pathlib.glob('*/*'))
for _ in range(5):print(np.array(Image.open(str(random.choice(train_image_list)))).shape)

图片的参数
重复执行了多次,返回结果都是(240, 240, 3),可以确定图像的大小统一为240,240,在数据加载的过程中可以不对图像做缩放处理。

展示图像

plt.figure(figsize=(20, 4))
for i in range(20):image = random.choice(train_image_list)plt.subplot(2, 10, i+1)plt.axis('off')plt.imshow(Image.open(str(image)))plt.title(image.parts[-2])

数据集预览
至此我们对数据集中的图像有了一个初步的了解。接下来就是准备训练数据。

图像的预处理

定义一些图像的预处理方法,例如将图像读取并转为pytorch的tensor对象,然后对图像的数值做归一化处理

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

创建数据集

train_dataset = datasets.ImageFolder(train_path, transform=transform)
test_dataset = datasets.ImageFolder(test_path, transform=transform)

获取数据集的分类

class_names = [key for key in train_dataset.class_to_idx]
print(class_names)

数据分类

打乱数据的顺序,生成批次

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

使用3x3的卷积核,最大的通道数到256,每次卷积操作后,就紧跟一个池化层,一共使用了4个卷积层和4个池化层。最后使用了三层全连接网络来做分类器。
模型结构图

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 64, 3)self.bn1  = nn.BatchNorm2d(64)self.conv2 = nn.Conv2d(64, 128, 3)self.bn2 = nn.BatchNorm2d(128)self.conv3 = nn.Conv2d(128, 256, 3)self.bn3 = nn.BatchNorm2d(256)self.conv4 = nn.Conv2d(256, 256, 3)self.bn4 = nn.BatchNorm2d(256)self.maxpool = nn.MaxPool2d(2)self.fc1 = nn.Linear(13*13*256, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, num_classes)self.dropout = nn.Dropout(0.5)def forward(self, x):# 240 -> 238x = F.relu(self.bn1(self.conv1(x)))# 238 -> 119x = self.maxpool(x)# 119 -> 117x = F.relu(self.bn2(self.conv2(x)))# 117 -> 58x = self.maxpool(x)# 58 -> 56x = F.relu(self.bn3(self.conv3(x)))# 56 -> 28x = self.maxpool(x)# 28 -> 26x = F.relu(self.bn4(self.conv4(x)))# 26 -> 13x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = F.relu(self.dropout(self.fc1(x)))x = F.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x
model = Network(len(class_names)).to(device)summary(model, input_size=(32, 3, 240, 240))

模型结构图

模型训练

模型训练过程中,每个epoch都会对全部的训练集进行一次完整的遍历,所以可以封装一些训练和评估方法,将业务逻辑和循环分开

训练函数

def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_loss, train_acc

评估函数

def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc

循环迭代部分

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.92**(epoch //2)) 
# 创建学习率的衰减
epochs = 50train_loss, train_acc = [], []
test_loss, test_acc = [], []
best_acc = 0
for epoch in range(epochs):model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval()with torch.no_grad():epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)scheduler.step() # 每次迭代调用一次,自动做学习率衰减# 如果当前评估的学习率更好,就保存当前模型if best_acc < epoch_test_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)# 记录历史记录train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 打印每个迭代的数据print(f"Epoch:{epoch+1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}")# 打印本次训练的最佳正确率
print(f'training finished, best_acc is {best_acc*100:.1f}')# 将最佳模型保存到文件中
torch.save(model.state_dict(), 'best_model.pth')

模型训练过程

模型效果展示

训练过程图表展示

画一个拆线图,观察训练过程中损失函数和正确率的变化趋势

plt.figure(figsize=(20,5))epoch_range = range(epochs)plt.subplot(1,2, 1)
plt.plot(epoch_range, train_loss, label='train loss')
plt.plot(epoch_range, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.subplot(1,2,2)
plt.plot(epoch_range, train_acc, label='train accuracy')
plt.plot(epoch_range, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练过程图示
可以看出模型在最后基本已经收敛,最佳准确率是88.2%,满足了挑战任务。

载入最佳模式,随机选择图像进行预测

model.load_state_dict(torch.load('best_model.pth'))
model = model.to(device)test_pathlib = pathlib.Path(test_path)image_list = list(test_pathlib.glob('*/*'))image_path = random.choice(image_list)
image = transform(Image.open(str(image_path)))
image = image.unsqueeze(0)
image = image.to(device)pred = model(image)plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(Image.open(str(image_path)))
plt.title(f'real: {image_path.parts[-2]}, predict: {class_names[pred.argmax(1).item()]}')

预测结果
上次运行上面的预测任务,发现正确率还不错。

总结与心得体会

  1. 整个模型设计的思路其实是模仿了vgg16模型,在卷积层的数量和通道上做了简化。轻量级的任务可以首先试着减少池化层间的卷积次数,减少模型中最大的特征图的通道数
  2. 对图像的归一化操作很重要。在没有归一化前,模型的最佳正确率只能达到80%,推测可能是因为未做归一化的图像值域范围太大,不方便收敛,归一化后,原始图像中的输入特征值范围变成0~1,模型的权重变化更易作用到特征上。

相关文章:

卷积神经网络实现运动鞋识别 - P5

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f366; 参考文章&#xff1a;Pytorch实战 | 第P5周&#xff1a;运动鞋识别&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制&#x1f680; 文章来源&#xff1a;K同学的学习圈子 目录…...

C#安装“Windows 窗体应用(.NET Framework)”

目录 背景: 第一步: 第二步: 第三步&#xff1a; 总结: 背景: 如下图所示:在Visual Studio Installer创建新项目的时候&#xff0c;想要添加windows窗体应用程序&#xff0c;发现里面并没有找到Windows窗体应用(.NET Framework)模板&#xff0c;快捷搜索也没有发现&#…...

SQL高阶语句

目录 1、概念 1.1、概述 1.2、常见的MySQL高阶语句的概念&#xff1a; 1.3、 SQL高阶语句的作用 2、常用查询 2.1、按关键字排序 2.1.1、概述和作用 2.1.2、 &#xff08;1&#xff09;语法 2.1.3、模板表&#xff1a;ky30 ​编辑2.1.4、分数按降序排列 2.1.5、ORDER…...

【交换机】如何通过Web方式登陆交换机

一、华为交换机web登陆配置 Web网管是一种对交换机的管理方式&#xff0c;它利用交换机内置的Web服务器&#xff0c;为用户提供图形化的操作界面。用户可以从终端通过HTTPS登录到Web网管&#xff0c;对交换机进行管理和维护&#xff0c;同时也非常方便。 一、配置思路&#xff…...

Flink CDC学习笔记

第一章 CDC简介 1.1 什么是CDC ​ CDC (Change Data Capture 变更数据获取&#xff09;的简称。核心思想就是&#xff0c;检测并获取数据库的变动&#xff08;增删查改&#xff09;&#xff0c;将这些变更按发生的顺序记录下来&#xff0c;写入到消息中间件以供其它服务进行订…...

NEOVIM学习笔记

GitHub - blogercn/nvim-config: A pretty epic NeoVim setup 一直使用vim&#xff0c;每次到了新公司都要配置半天&#xff0c;而且常常配置失败&#xff0c;很多插件过期不好用。偶然看到别人的NEO VIM&#xff0c;就试着用了一下&#xff0c;感觉还不错。 用来开发和阅读C代…...

Docker三剑客之docker-compose

docker-compose 是 Docker 生态系统中的一个重要成员&#xff0c;它允许开发人员使用一个简单的配置文件来定义和运行多个 Docker 容器。通过 docker-compose&#xff0c;你可以定义应用程序的各个组件、容器之间的依赖关系以及网络配置&#xff0c;从而实现在一个命令中启动、…...

单调队列

目录 一&#xff0c;单调队列 二&#xff0c;模板实现 三&#xff0c;OJ实战 剑指 Offer 59 - I. 滑动窗口的最大值 一&#xff0c;单调队列 单调队列是双端队列的拓展&#xff0c;支持尾部插入&#xff0c;双端删除&#xff0c;其中的数据始终维持单调性&#xff0c;从而…...

effective c++ 笔记

TODO&#xff1a;还没看太懂的篇章 item25 item35 模板相关内容 文章目录 基础视C为一个语言联邦以const, enum, inline替换#define尽可能使用constconst成员函数 确定对象使用前已被初始化 构造、析构和赋值内含引用或常量成员的类的赋值操作需要自己重写不想使用自动生成的函…...

【送书活动】深入浅出SSD:固态存储核心技术、原理与实战

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 「推荐专栏」&#xff1a; ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄&#xff0c;vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…...

GaussDB数据库SQL系列-行列转换

一、前言 二、简述 1、行转列概念 2、列转行概念 三、GaussDB数据库的行列转行实验示例 1、行转列示例 1&#xff09;创建实验表&#xff08;行存表&#xff09; 2&#xff09;静态行转列 3&#xff09;行转列&#xff08;结果值&#xff1a;拼接式&#xff09; 4&…...

美国陆军网络司令部利用人工智能增强网络攻防和作战决策能力

源自&#xff1a; 奇安网情局 声明:公众号转载的文章及图片出于非商业性的教育和科研目的供大家参考和探讨&#xff0c;并不意味着支持其观点或证实其内容的真实性。版权归原作者所有&#xff0c;如转载稿涉及版权等问题&#xff0c;请立即联系我们删除。 “人工智能技术与咨询…...

Notion团队协作魔法:如何玩转数字工作空间?

Notion简介 Notion已经成为现代团队协作的首选工具之一。它不仅仅是一个笔记应用&#xff0c;更是一个强大的团队协作平台&#xff0c;能够满足多种工作场景的需求。 Notion的核心功能 Notion提供了丰富的功能&#xff0c;如文档、数据库、看板、日历等&#xff0c;满足团队的…...

视频云存储/安防监控/AI视频智能分析平台新功能:人员倒地检测详解

人工智能技术已经越来越多地融入到视频监控领域中&#xff0c;近期我们也发布了基于AI智能视频云存储/安防监控视频智能分析平台的众多新功能&#xff0c;该平台内置多种AI算法&#xff0c;可对实时视频中的人脸、人体、物体等进行检测、跟踪与抓拍&#xff0c;支持口罩佩戴检测…...

解决RabbitMQ报错Stats in management UI are disabled on this node

文章目录 问题描述&#xff1a;解决步骤&#xff1a;进入容器后&#xff0c;cd到以下路径修改 management_agent.disable_metrics_collector false退出容器重启rabbitmq容器 问题描述&#xff1a; linux 部署 rabbitmq后&#xff0c;打开rabbitmq管理界面。点击channels&#…...

【重点】【NAND】聊聊固态硬盘SSD的寿命及其影响因素

固态硬盘是由主控芯片、存储颗粒芯片组成的闪存设备&#xff0c;固体硬盘的英文简称是SSD&#xff0c;如果是移动用的固态硬盘&#xff0c;则其英文简称为PSSD。 固态硬盘SSD分工业级和消费级等&#xff0c;目前&#xff0c;工业级固态硬盘SSD通常采用MLC闪存&#xff0c;而消…...

数据库约束

文章目录 1. 简介2. 代码演示3. 外键约束4. 外键删除和更新行为 1. 简介 概念&#xff1a;约束时作用于表中子段上的规则&#xff0c;用于限制存储在表中的shuju目的&#xff1a;保证数据库中数据的正确、有效性和完整性分类&#xff1a; 约束描述关键字非空约束限制该字段不…...

Unity实现MQTT服务器

首先下载MqttNet&#xff1a;MqttNet下载地址 解压好后使用vs打开&#xff0c;并生成.dll文件&#xff08;我这里下载的是4.1.2.350版本&#xff09; 然后再/Source/MQTTnet/bin/Debug/net452 文件夹中找到生成的文件 新建unity工程&#xff0c;创建Plugins文件夹&#xff0…...

Linux(centos) 下 Mysql 环境安装

linux 下进行环境安装相对比较简单&#xff0c;可还是会遇到各种奇奇怪怪的问题&#xff0c;我们来梳理一波 安装 mysql 我们会用到下地址&#xff1a; Mysql 官方文档的地址&#xff0c;可以参考&#xff0c;不要全部使用 https://dev.mysql.com/doc/refman/8.0/en/linux-i…...

决策树(Decision Tree)

决策树的定义: 分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点&#xff08;node&#xff09;和有向边&#xff08;directed edge&#xff09;组成。结点有两种类型: 内部结点&#xff08;internal node&#xff09;和叶结点&#xff08;leaf node&#xff0…...

树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频

使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展&#xff0c;光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域&#xff0c;IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选&#xff0c;但在长期运行中&#xff0c;例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

华硕a豆14 Air香氛版,美学与科技的馨香融合

在快节奏的现代生活中&#xff0c;我们渴望一个能激发创想、愉悦感官的工作与生活伙伴&#xff0c;它不仅是冰冷的科技工具&#xff0c;更能触动我们内心深处的细腻情感。正是在这样的期许下&#xff0c;华硕a豆14 Air香氛版翩然而至&#xff0c;它以一种前所未有的方式&#x…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...

MFC 抛体运动模拟:常见问题解决与界面美化

在 MFC 中开发抛体运动模拟程序时,我们常遇到 轨迹残留、无效刷新、视觉单调、物理逻辑瑕疵 等问题。本文将针对这些痛点,详细解析原因并提供解决方案,同时兼顾界面美化,让模拟效果更专业、更高效。 问题一:历史轨迹与小球残影残留 现象 小球运动后,历史位置的 “残影”…...

MySQL 8.0 事务全面讲解

以下是一个结合两次回答的 MySQL 8.0 事务全面讲解&#xff0c;涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容&#xff0c;并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念&#xff08;ACID&#xff09; 事务是…...