第P2周:Pytorch实现CIFAR10彩色图片识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
- 实现CIFAR-10的彩色图片识别
- 实现比P1周更复杂一点的CNN网络
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1
(二)具体步骤
1.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision # 第一步:设置GPU
def USE_GPU(): if torch.cuda.is_available(): print('CUDA is available, will use GPU') device = torch.device("cuda") else: print('CUDA is not available. Will use CPU') device = torch.device("cpu") return device device = USE_GPU()
输出:CUDA is available, will use GPU
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor()) batch_size = 32
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) # 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dataload))
print(imgs.shape) # 查看一下图片
import numpy as np
plt.figure(figsize=(20, 5))
for i, images in enumerate(imgs[:20]): # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理 npimg = imgs.numpy().transpose((1, 2, 0)) # 将整个figure分成2行10列,并绘制第i+1个子图 plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')
plt.show()
输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])

# 第三步,构建CNN网络
import torch.nn.functional as F num_classes = 10 # 因为CIFAR-10是10种类型
class Model(nn.Module): def __init__(self): super(Model, self).__init__() # 提取特征网络 self.conv1 = nn.Conv2d(3, 64, 3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, 3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, 3) self.pool3 = nn.MaxPool2d(kernel_size=2) # 分类网络 self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) # 前向传播 def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

# 训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 设置学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate) # 设置优化器 # 编写训练函数
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片 num_batches = len(dataloader) # 批次大小,这里是1875(60000/32=1875) train_acc, train_loss = 0, 0 # 初始化训练正确率和损失率都为0 for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值) X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出预测值 loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 第一步自动更新 # 记录正确率和损失率 train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss # 测试函数
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片 num_batches = len(dataloader) # 批次大小 ,这里312,即10000/32=312.5,向上取整 test_acc, test_loss = 0, 0 # 因为是测试,因此不用训练,梯度也不用计算不用更新 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss # 正式训练
epochs = 10
train_acc, train_loss, test_acc, test_loss = [], [], [], [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}' print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done') # 结果可视化
# 隐藏警告
import warnings
warnings.filterwarnings('ignore') # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 正常显示+/-号
plt.rcParams['figure.dpi'] = 100 # 分辨率 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) # 第一张子图
plt.plot(epochs_range, train_acc, label='训练正确率')
plt.plot(epochs_range, test_acc, label='测试正确率')
plt.legend(loc='lower right')
plt.title('训练和测试正确率比较') plt.subplot(1, 2, 2) # 第二张子图
plt.plot(epochs_range, train_loss, label='训练损失率')
plt.plot(epochs_range, test_loss, label='测试损失率')
plt.legend(loc='upper right')
plt.title('训练和测试损失率比较') plt.show()# 保存模型
torch.save(model, './models/cnn-cifar10.pth')

再次设置epochs为50训练结果:

epochs增加到100,训练结果:

可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:
data_transforms= { 'train': transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.ToTensor(), ])
}
在dataset中:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])
运行结果:


比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。

batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。

(三)总结
- epochs并不是越多越好。batch_size同样的道理
- 数据增强确实可以提高模型训练的准确性。
相关文章:
第P2周:Pytorch实现CIFAR10彩色图片识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 实现CIFAR-10的彩色图片识别实现比P1周更复杂一点的CNN网络 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: …...
CTFHub 命令注入-综合练习(学习记录)
综合过滤练习 命令分隔符的绕过姿势 ; %0a %0d & 那我们使用%0a试试,发现ls命令被成功执行 /?ip127.0.0.1%0als 发现一个名为flag_is_here的文件夹和index.php的文件,那么我们还是使用cd命令进入到文件夹下 http://challenge-438c1c1fb670566b.sa…...
OpenCV目标检测 级联分类器 C++实现
一.目标检测技术 目前常用实用性目标检测与跟踪的方法有以下两种: 帧差法 识别原理:基于前后两帧图像之间的差异进行对比,获取图像画面中正在运动的物体从而达到目标检测 缺点:画面中所有运动中物体都能识别 举个例子…...
QT6 Socket通讯封装(TCP/UDP)
为大家分享一下最近封装的以太网socket通讯接口 效果演示 如图,界面还没优化,后续更新 废话不多说直接上教程 添加库 如果为qmake项目中,在.pro文件添加 QT network QT core gui QT networkgreaterThan(QT_MAJOR_VERS…...
elasticsearch设置密码访问
1 用户认证介绍 默认ES是没有设置用户认证访问的,所以每次访问时,直接调相关API就能查询和写入数据。现在做一个认证,只有通过认证的用户才能访问和操作ES。 2 开启加密设置 1.生成证书文件 /usr/share/elasticsearch/bin/elasticsearch-…...
彻底理解如何优化接口性能
作为后端研发,必须要掌握怎么优化接口的性能或者说是响应时间,这样才能提高系统的系能,本文通过如下两个方面进行分析: 一.后端代码 有如下几步: 1.缓存机制 这是最场景的方式,当使用了缓存后,…...
C# 位运算
一、数据大小对应关系 说明: 将一个数据每左移一位,相当于乘以2。因此,左移8位就是乘以2的8次方,即256。 二、转换 1、 10进制转2进制字符串 #region 10进制转2进制字符串int number1 10;string binary Convert.ToString(num…...
【Flink-scala】DataStream编程模型之状态编程
DataStream编程模型之状态编程 参考: 1.【Flink-Scala】DataStream编程模型之数据源、数据转换、数据输出 2.【Flink-scala】DataStream编程模型之 窗口的划分-时间概念-窗口计算程序 3.【Flink-scala】DataStream编程模型之窗口计算-触发器-驱逐器 4.【Flink-scal…...
RabbitMQ的核心组件有哪些?
大家好,我是锋哥。今天分享关于【RabbitMQ的核心组件有哪些?】面试题。希望对大家有帮助; RabbitMQ的核心组件有哪些? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RabbitMQ是一个开源的消息代理(Messag…...
【Linux基础】基本开发工具的使用
目录 一、编译器——gcc/g的使用 gcc/g的安装 gcc的安装: g的安装: gcc/g的基本使用 gcc的使用 g的使用 动态链接与静态链接 程序的翻译过程 1. 一个C/C程序的构建过程,程序从源代码到可执行文件必须经历四个阶段 2. 理解选项的含…...
常见的数据结构和应用场景
数据结构是计算机科学中的基础概念,用于组织和存储数据,以便能够高效地访问和修改。下面是几种常见数据结构及其代表性应用场景: 1. 数组(Array) 问题解决:数组是一种线性数据结构,用于存储相…...
爬虫基础学习
爬虫概念与工作原理 爬虫是什么:爬虫(Web Scraping)是自动化地访问网站并提取数据的技术。它模拟用户浏览器的行为,通过HTTP请求访问网页,解析HTML文档并提取有用信息。 爬虫的基本工作流程: 发送HTTP请求…...
C++对象数组对象指针对象指针数组
一、对象数组 对象数组中的每一个元素都是同类的对象; 例1 对象数组成员的初始化 #include<iostream> using namespace std;class Student { public:Student( ){ };Student(int n,string nam,char s):num(n),name(nam),sex(s){};void display(){cout<&l…...
D96【python 接口自动化学习】- pytest进阶之fixture用法
day96 pytest的fixture详解(三) 学习日期:20241211 学习目标:pytest基础用法 -- pytest的fixture详解(三) 学习笔记: fixture(scop"class") (scop"class") 每一个类调…...
【算法】动态规划中01背包问题解析
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...
选择WordPress和Shopify:搭建对谷歌SEO友好的网站
在建设网站时,不仅要考虑它的美观和功能性,还要关注它是否对谷歌SEO友好。如果你希望网站能够获得更好的搜索排名,WordPress和Shopify是两个值得推荐的建站平台。 WordPress作为最流行的内容管理系统,其强大的灵活性和丰富的插件…...
代理IP与生成式AI:携手共创未来
目录 代理IP:网络世界的“隐形斗篷” 1. 隐藏真实IP,保护隐私 2. 突破网络限制,访问更多资源 生成式AI:创意与效率的“超级大脑” 1. 提高创作效率 2. 个性化定制 代理IP与生成式AI的协同作用 1. 网络安全 2. 内容创作与…...
iOS 应用的生命周期
Managing your app’s life cycle | Apple Developer Documentation Performance and metrics | Apple Developer Documentation iOS 应用的生命周期状态是理解应用如何在不同状态下运行和管理资源的基础。在 iOS 开发中,应用生命周期管理的是应用从启动到终止的整…...
Elasticsearch 集群快照的定期备份设置指南
Elasticsearch 集群快照的定期备份设置指南 概述 快照: 在给定时刻对整个集群或者单个索引进行备份,以便在之后出现故障时可以基于之前备份的快照进行快速恢复。 前提条件: 准备一个备份存储盘,本指南采用的是AWS EFS文件系统做…...
Docker--Docker Image(镜像)
什么是Docker Image? Docker镜像(Docker Image)是Docker容器技术的核心组件之一,它包含了运行应用程序所需的所有依赖、库、代码、运行时环境以及配置文件等。 简单来说,Docker镜像是一个轻量级、可执行的软件包&…...
【Linux】shell脚本忽略错误继续执行
在 shell 脚本中,可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行,可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令,并忽略错误 rm somefile…...
ESP32读取DHT11温湿度数据
芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...
2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面
代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...
HBuilderX安装(uni-app和小程序开发)
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
Swagger和OpenApi的前世今生
Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章,二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑: 🔄 一、起源与初创期:Swagger的诞生(2010-2014) 核心…...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...
Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信
文章目录 Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket(服务端和客户端都要)2. 绑定本地地址和端口&#x…...
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10pip3.10) 一:前言二:安装编译依赖二:安装Python3.10三:安装PIP3.10四:安装Paddlepaddle基础框架4.1…...
深入解析光敏传感技术:嵌入式仿真平台如何重塑电子工程教学
一、光敏传感技术的物理本质与系统级实现挑战 光敏电阻作为经典的光电传感器件,其工作原理根植于半导体材料的光电导效应。当入射光子能量超过材料带隙宽度时,价带电子受激发跃迁至导带,形成电子-空穴对,导致材料电导率显著提升。…...
代理服务器-LVS的3种模式与调度算法
作者介绍:简历上没有一个精通的运维工程师。请点击上方的蓝色《运维小路》关注我,下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 我们上一章介绍了Web服务器,其中以Nginx为主,本章我们来讲解几个代理软件:…...
