第P2周:Pytorch实现CIFAR10彩色图片识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
- 实现CIFAR-10的彩色图片识别
- 实现比P1周更复杂一点的CNN网络
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1
(二)具体步骤
1.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision # 第一步:设置GPU
def USE_GPU(): if torch.cuda.is_available(): print('CUDA is available, will use GPU') device = torch.device("cuda") else: print('CUDA is not available. Will use CPU') device = torch.device("cpu") return device device = USE_GPU()
输出:CUDA is available, will use GPU
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor()) batch_size = 32
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) # 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dataload))
print(imgs.shape) # 查看一下图片
import numpy as np
plt.figure(figsize=(20, 5))
for i, images in enumerate(imgs[:20]): # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理 npimg = imgs.numpy().transpose((1, 2, 0)) # 将整个figure分成2行10列,并绘制第i+1个子图 plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')
plt.show()
输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])

# 第三步,构建CNN网络
import torch.nn.functional as F num_classes = 10 # 因为CIFAR-10是10种类型
class Model(nn.Module): def __init__(self): super(Model, self).__init__() # 提取特征网络 self.conv1 = nn.Conv2d(3, 64, 3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, 3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, 3) self.pool3 = nn.MaxPool2d(kernel_size=2) # 分类网络 self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) # 前向传播 def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

# 训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 设置学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate) # 设置优化器 # 编写训练函数
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片 num_batches = len(dataloader) # 批次大小,这里是1875(60000/32=1875) train_acc, train_loss = 0, 0 # 初始化训练正确率和损失率都为0 for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值) X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出预测值 loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 第一步自动更新 # 记录正确率和损失率 train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss # 测试函数
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片 num_batches = len(dataloader) # 批次大小 ,这里312,即10000/32=312.5,向上取整 test_acc, test_loss = 0, 0 # 因为是测试,因此不用训练,梯度也不用计算不用更新 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss # 正式训练
epochs = 10
train_acc, train_loss, test_acc, test_loss = [], [], [], [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}' print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done') # 结果可视化
# 隐藏警告
import warnings
warnings.filterwarnings('ignore') # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 正常显示+/-号
plt.rcParams['figure.dpi'] = 100 # 分辨率 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) # 第一张子图
plt.plot(epochs_range, train_acc, label='训练正确率')
plt.plot(epochs_range, test_acc, label='测试正确率')
plt.legend(loc='lower right')
plt.title('训练和测试正确率比较') plt.subplot(1, 2, 2) # 第二张子图
plt.plot(epochs_range, train_loss, label='训练损失率')
plt.plot(epochs_range, test_loss, label='测试损失率')
plt.legend(loc='upper right')
plt.title('训练和测试损失率比较') plt.show()# 保存模型
torch.save(model, './models/cnn-cifar10.pth')

再次设置epochs为50训练结果:

epochs增加到100,训练结果:

可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:
data_transforms= { 'train': transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.ToTensor(), ])
}
在dataset中:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])
运行结果:


比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。

batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。

(三)总结
- epochs并不是越多越好。batch_size同样的道理
- 数据增强确实可以提高模型训练的准确性。
相关文章:
第P2周:Pytorch实现CIFAR10彩色图片识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 实现CIFAR-10的彩色图片识别实现比P1周更复杂一点的CNN网络 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: …...
CTFHub 命令注入-综合练习(学习记录)
综合过滤练习 命令分隔符的绕过姿势 ; %0a %0d & 那我们使用%0a试试,发现ls命令被成功执行 /?ip127.0.0.1%0als 发现一个名为flag_is_here的文件夹和index.php的文件,那么我们还是使用cd命令进入到文件夹下 http://challenge-438c1c1fb670566b.sa…...
OpenCV目标检测 级联分类器 C++实现
一.目标检测技术 目前常用实用性目标检测与跟踪的方法有以下两种: 帧差法 识别原理:基于前后两帧图像之间的差异进行对比,获取图像画面中正在运动的物体从而达到目标检测 缺点:画面中所有运动中物体都能识别 举个例子…...
QT6 Socket通讯封装(TCP/UDP)
为大家分享一下最近封装的以太网socket通讯接口 效果演示 如图,界面还没优化,后续更新 废话不多说直接上教程 添加库 如果为qmake项目中,在.pro文件添加 QT network QT core gui QT networkgreaterThan(QT_MAJOR_VERS…...
elasticsearch设置密码访问
1 用户认证介绍 默认ES是没有设置用户认证访问的,所以每次访问时,直接调相关API就能查询和写入数据。现在做一个认证,只有通过认证的用户才能访问和操作ES。 2 开启加密设置 1.生成证书文件 /usr/share/elasticsearch/bin/elasticsearch-…...
彻底理解如何优化接口性能
作为后端研发,必须要掌握怎么优化接口的性能或者说是响应时间,这样才能提高系统的系能,本文通过如下两个方面进行分析: 一.后端代码 有如下几步: 1.缓存机制 这是最场景的方式,当使用了缓存后,…...
C# 位运算
一、数据大小对应关系 说明: 将一个数据每左移一位,相当于乘以2。因此,左移8位就是乘以2的8次方,即256。 二、转换 1、 10进制转2进制字符串 #region 10进制转2进制字符串int number1 10;string binary Convert.ToString(num…...
【Flink-scala】DataStream编程模型之状态编程
DataStream编程模型之状态编程 参考: 1.【Flink-Scala】DataStream编程模型之数据源、数据转换、数据输出 2.【Flink-scala】DataStream编程模型之 窗口的划分-时间概念-窗口计算程序 3.【Flink-scala】DataStream编程模型之窗口计算-触发器-驱逐器 4.【Flink-scal…...
RabbitMQ的核心组件有哪些?
大家好,我是锋哥。今天分享关于【RabbitMQ的核心组件有哪些?】面试题。希望对大家有帮助; RabbitMQ的核心组件有哪些? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RabbitMQ是一个开源的消息代理(Messag…...
【Linux基础】基本开发工具的使用
目录 一、编译器——gcc/g的使用 gcc/g的安装 gcc的安装: g的安装: gcc/g的基本使用 gcc的使用 g的使用 动态链接与静态链接 程序的翻译过程 1. 一个C/C程序的构建过程,程序从源代码到可执行文件必须经历四个阶段 2. 理解选项的含…...
常见的数据结构和应用场景
数据结构是计算机科学中的基础概念,用于组织和存储数据,以便能够高效地访问和修改。下面是几种常见数据结构及其代表性应用场景: 1. 数组(Array) 问题解决:数组是一种线性数据结构,用于存储相…...
爬虫基础学习
爬虫概念与工作原理 爬虫是什么:爬虫(Web Scraping)是自动化地访问网站并提取数据的技术。它模拟用户浏览器的行为,通过HTTP请求访问网页,解析HTML文档并提取有用信息。 爬虫的基本工作流程: 发送HTTP请求…...
C++对象数组对象指针对象指针数组
一、对象数组 对象数组中的每一个元素都是同类的对象; 例1 对象数组成员的初始化 #include<iostream> using namespace std;class Student { public:Student( ){ };Student(int n,string nam,char s):num(n),name(nam),sex(s){};void display(){cout<&l…...
D96【python 接口自动化学习】- pytest进阶之fixture用法
day96 pytest的fixture详解(三) 学习日期:20241211 学习目标:pytest基础用法 -- pytest的fixture详解(三) 学习笔记: fixture(scop"class") (scop"class") 每一个类调…...
【算法】动态规划中01背包问题解析
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...
选择WordPress和Shopify:搭建对谷歌SEO友好的网站
在建设网站时,不仅要考虑它的美观和功能性,还要关注它是否对谷歌SEO友好。如果你希望网站能够获得更好的搜索排名,WordPress和Shopify是两个值得推荐的建站平台。 WordPress作为最流行的内容管理系统,其强大的灵活性和丰富的插件…...
代理IP与生成式AI:携手共创未来
目录 代理IP:网络世界的“隐形斗篷” 1. 隐藏真实IP,保护隐私 2. 突破网络限制,访问更多资源 生成式AI:创意与效率的“超级大脑” 1. 提高创作效率 2. 个性化定制 代理IP与生成式AI的协同作用 1. 网络安全 2. 内容创作与…...
iOS 应用的生命周期
Managing your app’s life cycle | Apple Developer Documentation Performance and metrics | Apple Developer Documentation iOS 应用的生命周期状态是理解应用如何在不同状态下运行和管理资源的基础。在 iOS 开发中,应用生命周期管理的是应用从启动到终止的整…...
Elasticsearch 集群快照的定期备份设置指南
Elasticsearch 集群快照的定期备份设置指南 概述 快照: 在给定时刻对整个集群或者单个索引进行备份,以便在之后出现故障时可以基于之前备份的快照进行快速恢复。 前提条件: 准备一个备份存储盘,本指南采用的是AWS EFS文件系统做…...
Docker--Docker Image(镜像)
什么是Docker Image? Docker镜像(Docker Image)是Docker容器技术的核心组件之一,它包含了运行应用程序所需的所有依赖、库、代码、运行时环境以及配置文件等。 简单来说,Docker镜像是一个轻量级、可执行的软件包&…...
<6>-MySQL表的增删查改
目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表…...
YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
LeetCode - 394. 字符串解码
题目 394. 字符串解码 - 力扣(LeetCode) 思路 使用两个栈:一个存储重复次数,一个存储字符串 遍历输入字符串: 数字处理:遇到数字时,累积计算重复次数左括号处理:保存当前状态&a…...
【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
网络编程(UDP编程)
思维导图 UDP基础编程(单播) 1.流程图 服务器:短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...
AI书签管理工具开发全记录(十九):嵌入资源处理
1.前言 📝 在上一篇文章中,我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源,方便后续将资源打包到一个可执行文件中。 2.embed介绍 🎯 Go 1.16 引入了革命性的 embed 包,彻底改变了静态资源管理的…...
ubuntu22.04有线网络无法连接,图标也没了
今天突然无法有线网络无法连接任何设备,并且图标都没了 错误案例 往上一顿搜索,试了很多博客都不行,比如 Ubuntu22.04右上角网络图标消失 最后解决的办法 下载网卡驱动,重新安装 操作步骤 查看自己网卡的型号 lspci | gre…...
云原生周刊:k0s 成为 CNCF 沙箱项目
开源项目推荐 HAMi HAMi(原名 k8s‑vGPU‑scheduler)是一款 CNCF Sandbox 级别的开源 K8s 中间件,通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度,为容器提供统一接口,实现细粒度资源配额…...
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10pip3.10) 一:前言二:安装编译依赖二:安装Python3.10三:安装PIP3.10四:安装Paddlepaddle基础框架4.1…...
