第P2周:Pytorch实现CIFAR10彩色图片识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
- 实现CIFAR-10的彩色图片识别
- 实现比P1周更复杂一点的CNN网络
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1
(二)具体步骤
1.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision # 第一步:设置GPU
def USE_GPU(): if torch.cuda.is_available(): print('CUDA is available, will use GPU') device = torch.device("cuda") else: print('CUDA is not available. Will use CPU') device = torch.device("cpu") return device device = USE_GPU()
输出:CUDA is available, will use GPU
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor()) batch_size = 32
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) # 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dataload))
print(imgs.shape) # 查看一下图片
import numpy as np
plt.figure(figsize=(20, 5))
for i, images in enumerate(imgs[:20]): # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理 npimg = imgs.numpy().transpose((1, 2, 0)) # 将整个figure分成2行10列,并绘制第i+1个子图 plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')
plt.show()
输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])

# 第三步,构建CNN网络
import torch.nn.functional as F num_classes = 10 # 因为CIFAR-10是10种类型
class Model(nn.Module): def __init__(self): super(Model, self).__init__() # 提取特征网络 self.conv1 = nn.Conv2d(3, 64, 3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, 3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, 3) self.pool3 = nn.MaxPool2d(kernel_size=2) # 分类网络 self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) # 前向传播 def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

# 训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 设置学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate) # 设置优化器 # 编写训练函数
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片 num_batches = len(dataloader) # 批次大小,这里是1875(60000/32=1875) train_acc, train_loss = 0, 0 # 初始化训练正确率和损失率都为0 for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值) X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出预测值 loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 第一步自动更新 # 记录正确率和损失率 train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss # 测试函数
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片 num_batches = len(dataloader) # 批次大小 ,这里312,即10000/32=312.5,向上取整 test_acc, test_loss = 0, 0 # 因为是测试,因此不用训练,梯度也不用计算不用更新 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss # 正式训练
epochs = 10
train_acc, train_loss, test_acc, test_loss = [], [], [], [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}' print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done') # 结果可视化
# 隐藏警告
import warnings
warnings.filterwarnings('ignore') # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 正常显示+/-号
plt.rcParams['figure.dpi'] = 100 # 分辨率 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) # 第一张子图
plt.plot(epochs_range, train_acc, label='训练正确率')
plt.plot(epochs_range, test_acc, label='测试正确率')
plt.legend(loc='lower right')
plt.title('训练和测试正确率比较') plt.subplot(1, 2, 2) # 第二张子图
plt.plot(epochs_range, train_loss, label='训练损失率')
plt.plot(epochs_range, test_loss, label='测试损失率')
plt.legend(loc='upper right')
plt.title('训练和测试损失率比较') plt.show()# 保存模型
torch.save(model, './models/cnn-cifar10.pth')

再次设置epochs为50训练结果:

epochs增加到100,训练结果:

可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:
data_transforms= { 'train': transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.ToTensor(), ])
}
在dataset中:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])
运行结果:


比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。

batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。

(三)总结
- epochs并不是越多越好。batch_size同样的道理
- 数据增强确实可以提高模型训练的准确性。
相关文章:
第P2周:Pytorch实现CIFAR10彩色图片识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 实现CIFAR-10的彩色图片识别实现比P1周更复杂一点的CNN网络 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: …...
CTFHub 命令注入-综合练习(学习记录)
综合过滤练习 命令分隔符的绕过姿势 ; %0a %0d & 那我们使用%0a试试,发现ls命令被成功执行 /?ip127.0.0.1%0als 发现一个名为flag_is_here的文件夹和index.php的文件,那么我们还是使用cd命令进入到文件夹下 http://challenge-438c1c1fb670566b.sa…...
OpenCV目标检测 级联分类器 C++实现
一.目标检测技术 目前常用实用性目标检测与跟踪的方法有以下两种: 帧差法 识别原理:基于前后两帧图像之间的差异进行对比,获取图像画面中正在运动的物体从而达到目标检测 缺点:画面中所有运动中物体都能识别 举个例子…...
QT6 Socket通讯封装(TCP/UDP)
为大家分享一下最近封装的以太网socket通讯接口 效果演示 如图,界面还没优化,后续更新 废话不多说直接上教程 添加库 如果为qmake项目中,在.pro文件添加 QT network QT core gui QT networkgreaterThan(QT_MAJOR_VERS…...
elasticsearch设置密码访问
1 用户认证介绍 默认ES是没有设置用户认证访问的,所以每次访问时,直接调相关API就能查询和写入数据。现在做一个认证,只有通过认证的用户才能访问和操作ES。 2 开启加密设置 1.生成证书文件 /usr/share/elasticsearch/bin/elasticsearch-…...
彻底理解如何优化接口性能
作为后端研发,必须要掌握怎么优化接口的性能或者说是响应时间,这样才能提高系统的系能,本文通过如下两个方面进行分析: 一.后端代码 有如下几步: 1.缓存机制 这是最场景的方式,当使用了缓存后,…...
C# 位运算
一、数据大小对应关系 说明: 将一个数据每左移一位,相当于乘以2。因此,左移8位就是乘以2的8次方,即256。 二、转换 1、 10进制转2进制字符串 #region 10进制转2进制字符串int number1 10;string binary Convert.ToString(num…...
【Flink-scala】DataStream编程模型之状态编程
DataStream编程模型之状态编程 参考: 1.【Flink-Scala】DataStream编程模型之数据源、数据转换、数据输出 2.【Flink-scala】DataStream编程模型之 窗口的划分-时间概念-窗口计算程序 3.【Flink-scala】DataStream编程模型之窗口计算-触发器-驱逐器 4.【Flink-scal…...
RabbitMQ的核心组件有哪些?
大家好,我是锋哥。今天分享关于【RabbitMQ的核心组件有哪些?】面试题。希望对大家有帮助; RabbitMQ的核心组件有哪些? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RabbitMQ是一个开源的消息代理(Messag…...
【Linux基础】基本开发工具的使用
目录 一、编译器——gcc/g的使用 gcc/g的安装 gcc的安装: g的安装: gcc/g的基本使用 gcc的使用 g的使用 动态链接与静态链接 程序的翻译过程 1. 一个C/C程序的构建过程,程序从源代码到可执行文件必须经历四个阶段 2. 理解选项的含…...
常见的数据结构和应用场景
数据结构是计算机科学中的基础概念,用于组织和存储数据,以便能够高效地访问和修改。下面是几种常见数据结构及其代表性应用场景: 1. 数组(Array) 问题解决:数组是一种线性数据结构,用于存储相…...
爬虫基础学习
爬虫概念与工作原理 爬虫是什么:爬虫(Web Scraping)是自动化地访问网站并提取数据的技术。它模拟用户浏览器的行为,通过HTTP请求访问网页,解析HTML文档并提取有用信息。 爬虫的基本工作流程: 发送HTTP请求…...
C++对象数组对象指针对象指针数组
一、对象数组 对象数组中的每一个元素都是同类的对象; 例1 对象数组成员的初始化 #include<iostream> using namespace std;class Student { public:Student( ){ };Student(int n,string nam,char s):num(n),name(nam),sex(s){};void display(){cout<&l…...
D96【python 接口自动化学习】- pytest进阶之fixture用法
day96 pytest的fixture详解(三) 学习日期:20241211 学习目标:pytest基础用法 -- pytest的fixture详解(三) 学习笔记: fixture(scop"class") (scop"class") 每一个类调…...
【算法】动态规划中01背包问题解析
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...
选择WordPress和Shopify:搭建对谷歌SEO友好的网站
在建设网站时,不仅要考虑它的美观和功能性,还要关注它是否对谷歌SEO友好。如果你希望网站能够获得更好的搜索排名,WordPress和Shopify是两个值得推荐的建站平台。 WordPress作为最流行的内容管理系统,其强大的灵活性和丰富的插件…...
代理IP与生成式AI:携手共创未来
目录 代理IP:网络世界的“隐形斗篷” 1. 隐藏真实IP,保护隐私 2. 突破网络限制,访问更多资源 生成式AI:创意与效率的“超级大脑” 1. 提高创作效率 2. 个性化定制 代理IP与生成式AI的协同作用 1. 网络安全 2. 内容创作与…...
iOS 应用的生命周期
Managing your app’s life cycle | Apple Developer Documentation Performance and metrics | Apple Developer Documentation iOS 应用的生命周期状态是理解应用如何在不同状态下运行和管理资源的基础。在 iOS 开发中,应用生命周期管理的是应用从启动到终止的整…...
Elasticsearch 集群快照的定期备份设置指南
Elasticsearch 集群快照的定期备份设置指南 概述 快照: 在给定时刻对整个集群或者单个索引进行备份,以便在之后出现故障时可以基于之前备份的快照进行快速恢复。 前提条件: 准备一个备份存储盘,本指南采用的是AWS EFS文件系统做…...
Docker--Docker Image(镜像)
什么是Docker Image? Docker镜像(Docker Image)是Docker容器技术的核心组件之一,它包含了运行应用程序所需的所有依赖、库、代码、运行时环境以及配置文件等。 简单来说,Docker镜像是一个轻量级、可执行的软件包&…...
忍者像素绘卷:天界画坊Java面试题精讲:AI绘画服务的高并发设计
忍者像素绘卷:天界画坊Java面试题精讲:AI绘画服务的高并发设计 1. 高并发AI绘画服务的挑战与价值 在数字艺术创作领域,AI绘画服务正经历爆发式增长。以"忍者像素绘卷:天界画坊"为例,这款融合传统忍者文化与…...
电商人必看!RMBG-2.0轻量抠图实战:证件照换背景+短视频素材一键生成
电商人必看!RMBG-2.0轻量抠图实战:证件照换背景短视频素材一键生成 还在为商品图片抠图发愁吗?每天处理几十张产品图,用PS一点点抠边缘,既费时间又费眼睛?或者需要给员工批量制作证件照,但换背…...
InfluxDB新手必看:从安装到基本操作的完整指南(Windows版)
InfluxDB Windows实战指南:从零搭建时序数据库系统 时序数据正成为物联网、DevOps和业务监控领域的核心资产。想象一下,您需要每秒处理数千台设备的温度读数,或者分析应用程序每分钟的性能指标——传统关系型数据库在这种高频写入场景下往往…...
ESP32组件化开发实战:从零构建高效项目结构
1. 为什么需要组件化开发? 第一次接触ESP32开发时,我习惯把所有代码都塞进main文件夹里。结果项目稍微复杂点就乱成一锅粥,每次修改都要在几十个文件里翻找,不同功能模块互相纠缠,想复用某个传感器驱动都得连带着拷贝…...
用STM32F103C8和5路红外模块,我花了一个周末做了个能自己拐弯的小车(附完整代码)
从零打造智能循迹小车:STM32F103C8与红外模块的实战指南 看着桌上散落的电子元件逐渐组合成一个能自主行动的小车,这种成就感是任何现成玩具都无法比拟的。本文将带你完整经历一次基于STM32F103C8和五路红外模块的智能小车开发过程,无需复杂算…...
Voron 2.4 3D打印机进阶调试与故障排除指南
Voron 2.4 3D打印机进阶调试与故障排除指南 【免费下载链接】Voron-2 Voron 2 CoreXY 3D Printer design 项目地址: https://gitcode.com/gh_mirrors/vo/Voron-2 机械系统精调:从结构应力到运动精度 问题导向:框架组装后出现对角线偏差超过2mm&a…...
ClickHouse数据报表实战:如何把分组后的明细‘压缩’成一行摘要(附完整SQL)
ClickHouse数据报表实战:高效聚合多行文本的工程化解决方案 在数据分析与报表生成的实际业务场景中,我们经常遇到这样的需求:需要将同一维度下的多条文本明细(如用户行为日志、错误信息、月份列表等)合并成一条简洁的摘…...
Phi-4-mini-reasoning企业级落地:金融风控规则推理引擎构建案例
Phi-4-mini-reasoning企业级落地:金融风控规则推理引擎构建案例 1. 项目背景与模型介绍 在金融风控领域,规则推理引擎是核心决策系统的重要组成部分。传统规则引擎往往面临维护成本高、灵活性差、难以应对复杂场景等问题。Phi-4-mini-reasoning作为一款…...
Omni-Vision Sanctuary在嵌入式边缘设备上的轻量化部署思考
Omni-Vision Sanctuary在嵌入式边缘设备上的轻量化部署思考 1. 嵌入式视觉的挑战与机遇 在智能摄像头、工业质检设备、无人机等嵌入式场景中,视觉模型的部署一直面临特殊挑战。传统方案要么性能不足,要么功耗过高,难以平衡实时性与能效比。…...
Lychee Rerank在遥感影像分析中的应用:多源地理数据关联
Lychee Rerank在遥感影像分析中的应用:多源地理数据关联 1. 引言 每天,卫星和无人机都在产生海量的遥感影像数据。地质勘探团队需要从数万张卫星图片中找出可能的矿藏迹象,环境监测人员要追踪森林覆盖变化,城市规划者则要分析城…...
