当前位置: 首页 > news >正文

卷积神经网络实现彩色图像分类 - P2

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P2周:彩色识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 硬件设备
    • 数据准备
      • 数据集下载与加载
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会

上周使用Pytorch构建卷积神经网络,实现了MNIST手写数字的识别,这周的目标是CIFAR10中复杂的彩色图像分类。


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transformsimport numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary # 方便像tensorflow一样打印模型

硬件设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

数据集下载与加载

train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor()) # 不要忘记这个transform
test_dataset = datasets.CIFAR10(root='data', train=False, download=True, transform=transforms.ToTensor())

数据集预览

image, label = train_dataset[0]
print(image.shape)
plt.figure(figsize=(20,4))
for i in range(20):image, label = train_dataset[i]plt.subplot(2, 10, i+1)plt.imshow(image.numpy().transpose(1,2,0)plt.axis('off')plt.title(label) # 加载的数据集没有对应的名称,暂时展示它们的id

数据集预览

数据集准备

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

class Model(nn.Module):def __init__(self, num_classes):super().__init__()# 3x3的卷积无padding每次宽高-2# 2x2的最大池化,每次宽高缩短为原来的一半# 32x32 -> conv1 -> 30x30 -> maxpool -> 15x15self.conv1 = nn.Conv2d(3, 64, kernel_size=3)# 15x15 -> conv2 -> 13x13 -> maxpool -> 6x6self.conv2 = nn.Conv2d(64, 64, kernel_size=3)# 6x6 -> conv3 -> 4x4 -> maxpool -> 2x2self.conv3 = nn.Conv2d(64, 128, kernel_size=3)self.maxpool = nn.MaxPool2d(2),self.flatten = nn.Flatten(),self.fc1 = nn.Linear(2*2*128, 256)self.fc2 = nn.Linear(256, num_classes)def forward(self, x):x = F.relu(self.conv1(x))x = self.maxpool(x)x = F.relu(self.conv2(x))x = self.maxpool(x)x = F.relu(self.conv3(x))x = self.maxpool(x)x = self.flatten(x)x = F.relu(self.fc1(x))x = self.fc2(x)return xmodel = Model(10).to(device)
summary(model, input_size=(1, 3, 32, 32))

模型结构图

模型训练

超参数设置

learning_rate = 1e-2
epochs = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

helper函数

def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_loss, train_accdef test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0with torch.no_grad():for x, y in test_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)test_loss += loss.item()test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_accdef fit(train_loader, test_loader, model, loss_fn, optimizer, epochs):train_loss, train_acc = [], []test_loss, test_acc = [], []for epoch in range(epochs):model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval()epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)return train_loss, train_acc, test_loss, test_acc

正式训练

train_loss, train_acc, test_loss, test_acc = fit(train_loader, test_loader, model, loss_fn, optimizer, 20)

训练结果

结果呈现

series = range(len(train_loss))
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1,2,2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

实验结果
从结果图可以发现,模型应该还没收敛,将epoch设置为30,重新跑一遍模型。
实验结果2
可以看出20个epoch后,训练集上的正确率持续增长,在验证集上的正确率几乎就不再增长了,符合过拟合的特征。需要对模型进行改进才能提升正确率了。


总结与心得体会

通过本周的学习,掌握了使用pytorch编写一个完整深度学习的过程,包括环境的配置、数据的准备、模型定义与训练、结果分析呈现等步骤,并且掌握了通过pytorch的API组建一个简单的卷积神经网络的过程。

相关文章:

卷积神经网络实现彩色图像分类 - P2

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍦 参考文章:365天深度学习训练营-第P2周:彩色识别🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子…...

【博客694】k8s kubelet 状态更新机制

k8s kubelet 状态更新机制 场景: 当 Kubernetes 中 Node 节点出现状态异常的情况下,节点上的 Pod 会被重新调度到其他节点上去,但是有的时候我们会发现节点 Down 掉以后,Pod 并不会立即触发重新调度,这实际上就是和 K…...

【博客692】grafana如何解决step动态变化时可能出现range duration小于step

grafana如何解决step动态变化时可能出现range duration小于step 1、grafana中的step和resolution grafana中的 “step” grafana本身是没有提供step参数的,因为仪表盘根据查询数据区间以及仪表盘线条宽度等,对于不同查询,相同的step并不能…...

eNSP:ibgp的破水平切割练习

实验要求&#xff1a; 拓扑展示&#xff1a; 命令操作&#xff1a; R1&#xff1a; <Huawei>sys [Huawei]sys r1 [r1]int g 0/0/1 [r1-GigabitEthernet0/0/1]ip add 12.1.1.1 24 [r1-GigabitEthernet0/0/1]int lo0 [r1-LoopBack0]ip add 1.1.1.1 24 [r1-LoopBack0]osp…...

maven是什么?安装+配置

目录 1.什么是maven&#xff1f; 1.2.maven的核心功能是什么&#xff1f; 2.Maven安装配置 2.1Maven的安装 2.2Maven环境配置 1.配置 MAVEN_HOME &#xff0c;变量值就是你的 maven 安装的路径&#xff08;bin 目录之前一级目录&#xff09; 2.将MAVEN_HOME 添加到Path系…...

基于长短期神经网络LSTM的多分类代码

目录 背影 摘要 LSTM的基本定义 LSTM实现的步骤 基于长短期神经网络LSTM的股票预测 MATALB编程实现,附有代码:基于长短期神经网络LSTM的多分类代码,基于LSTM的多分类预测-深度学习文档类资源-CSDN文库 https://download.csdn.net/download/abc991835105/88184779 效果图 结果…...

利用爬虫爬取图片并保存

1 问题 在工作中&#xff0c;有时会遇到需要相当多的图片资源&#xff0c;可是如何才能在短时间内获得大量的图片资源呢&#xff1f; 2 方法 我们知道&#xff0c;网页中每一张图片都是一个连接&#xff0c;所以我们提出利用爬虫爬取网页图片并下载保存下来。 首先通过网络搜索…...

设计模式之Bridge模式的C++实现

目录 1、Bridge模式的提出 2、Bridge模式的定义 3、Bridge模式总结 4、需求描述 5、多继承方式实现 6、使用Bridge设计模式实现 1、Bridge模式的提出 在软件功能模块设计中&#xff0c;如果类的实现功能划分不清晰&#xff0c;使得继承得到的子类往往是随着需求的变化&am…...

springboot异步任务

在Service类声明一个注解Async作为异步方法的标识 package com.qf.sping09test.service;import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service;Service public class AsyncService {//告诉spring这是一个异步的方法Asyncp…...

Flutter父宽度自适应子控件的宽度

需求&#xff1a; 控件随着金币进行自适应宽度 image.png 步骤&#xff1a; 1、Container不设置宽度&#xff0c;需要设置约束padding&#xff1b; 2、文本使用Flexible形式&#xff1b; Container(height: 24.dp,padding: EdgeInsetsDirectional.only(start: 8.dp, end: 5.d…...

什么是 API 安全?学习如何防止攻击和保护数据

随着 API 技术的普及&#xff0c;API 安全成为了一个越来越重要的问题。本文将介绍什么是 API 安全&#xff0c;以及目前 API 面临的安全问题和相应的解决方案。 什么是 API 安全 API 安全是指保护 API 免受恶意攻击和滥用的安全措施。API 安全通常包括以下几个方面&#xff1…...

简述 TCP 和 UDP 的区别以及优缺点和使用场景?

一、TCP与UDP区别总结&#xff1a; 1、TCP面向连接&#xff08;如打电话要先拨号建立连接&#xff09;;UDP是无连接的&#xff0c;即发送数据之前不需要建立连接 2、TCP提供可靠的服务。也就是说&#xff0c;通过TCP连接传送的数据&#xff0c;无差错&#xff0c;不丢失&…...

react进阶

react-virtualized的高阶组件&#xff0c;Autosize可以使屏幕适配。使用render-props模式来获取到AutoSizer组件暴露的width和height属性。JSON.parse(JSON.stringify())不适用于有undefined的数据。 深拷贝的使用&#xff0c;不能使用在有undefined的数据中。有直接过滤undefi…...

使用windows搭建WebDAV服务,并内网穿透公网访问【无公网IP】

文章目录 1. 安装IIS必要WebDav组件2. 客户端测试3. 使用cpolar内网穿透&#xff0c;将WebDav服务暴露在公网3.1 打开Web-UI管理界面3.2 创建隧道3.3 查看在线隧道列表3.4 浏览器访问测试 4. 安装Raidrive客户端4.1 连接WebDav服务器4.2 连接成功4.2 连接成功 1. Linux(centos8…...

科技感响应式管理系统后台登录页ui设计html模板

做了一个科技感的后台管理系统登录页设计&#xff0c;并且尝试用响应式布局把前端html写了出来&#xff0c;发现并没有现象中的那么容易&#xff0c;chrome等标准浏览器都显示的挺好&#xff0c;但IE11下面却出现了很多错位&#xff0c;兼容起来还是挺费劲的&#xff0c;真心不…...

Lombok的使用及注解含义

文章目录 一、简介二、如何使用2.1、在IDEA中安装Lombok插件2.2、添加maven依赖 三、常用注解3.1、Getter / Setter3.2、ToString3.3、NoArgsConstructor / AllArgsConstructor3.4、EqualsAndHashCode3.5、Data3.6、Value3.7、Accessors3.7.1、Accessors(chain true)3.7.2、Ac…...

实时通信应用的开发:Vue.js、Spring Boot 和 WebSocket 整合实践

目录 1. 什么是webSocket 2. webSocket可以用来做什么? 3. webSocket协议 4. 服务器端 5. 客户端 6. 测试通讯 1. 什么是webSocket WebSocket是一种在单个TCP连接上进行全双工通信的协议。WebSocket使得客户端和服务器之间的数据交换变得更加简单&#xff0c;允许服务…...

【C++】C++异常

文章目录 1. C语言传统处理错误的方式2. C异常的概念3. 异常的使用3.1 异常的抛出和捕获3.2 异常的重新抛出3.3 异常安全3.4 异常规范 4. C标准库的异常体系5. 自定义的异常体系6. 异常的优缺点 1. C语言传统处理错误的方式 C语言传统的错误处理机制有两个&#xff1a; 终止程…...

学生成绩管理系统V2.0

某班有最多不超过30人&#xff08;具体人数由键盘输入&#xff09;参加某门课程的考试&#xff0c;参考前面章节的“学生成绩管理系统V1.0”&#xff0c;用一维数组和函数指针作函数参数编程实现如下菜单驱动的学生成绩管理系统&#xff0c;其中每位同学的学号和成绩等数据可以…...

【C++】开源:tinyxml2解析库配置使用

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍tinyxml2解析库配置使用。 无专精则不能成&#xff0c;无涉猎则不能通。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c;…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽&#xff0c;大家好&#xff0c;我是左手python&#xff01; Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库&#xff0c;用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

【磁盘】每天掌握一个Linux命令 - iostat

目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat&#xff08;I/O Statistics&#xff09;是Linux系统下用于监视系统输入输出设备和CPU使…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

GitHub 趋势日报 (2025年06月08日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

数据库分批入库

今天在工作中&#xff0c;遇到一个问题&#xff0c;就是分批查询的时候&#xff0c;由于批次过大导致出现了一些问题&#xff0c;一下是问题描述和解决方案&#xff1a; 示例&#xff1a; // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...