第P2周:Pytorch实现CIFAR10彩色图片识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
- 实现CIFAR-10的彩色图片识别
- 实现比P1周更复杂一点的CNN网络
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1
(二)具体步骤
1.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision # 第一步:设置GPU
def USE_GPU(): if torch.cuda.is_available(): print('CUDA is available, will use GPU') device = torch.device("cuda") else: print('CUDA is not available. Will use CPU') device = torch.device("cpu") return device device = USE_GPU()
输出:CUDA is available, will use GPU
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor()) batch_size = 32
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) # 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dataload))
print(imgs.shape) # 查看一下图片
import numpy as np
plt.figure(figsize=(20, 5))
for i, images in enumerate(imgs[:20]): # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理 npimg = imgs.numpy().transpose((1, 2, 0)) # 将整个figure分成2行10列,并绘制第i+1个子图 plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')
plt.show()
输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
# 第三步,构建CNN网络
import torch.nn.functional as F num_classes = 10 # 因为CIFAR-10是10种类型
class Model(nn.Module): def __init__(self): super(Model, self).__init__() # 提取特征网络 self.conv1 = nn.Conv2d(3, 64, 3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, 3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, 3) self.pool3 = nn.MaxPool2d(kernel_size=2) # 分类网络 self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) # 前向传播 def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)
# 训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 设置学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate) # 设置优化器 # 编写训练函数
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片 num_batches = len(dataloader) # 批次大小,这里是1875(60000/32=1875) train_acc, train_loss = 0, 0 # 初始化训练正确率和损失率都为0 for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值) X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出预测值 loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 第一步自动更新 # 记录正确率和损失率 train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss # 测试函数
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片 num_batches = len(dataloader) # 批次大小 ,这里312,即10000/32=312.5,向上取整 test_acc, test_loss = 0, 0 # 因为是测试,因此不用训练,梯度也不用计算不用更新 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss # 正式训练
epochs = 10
train_acc, train_loss, test_acc, test_loss = [], [], [], [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}' print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done') # 结果可视化
# 隐藏警告
import warnings
warnings.filterwarnings('ignore') # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 正常显示+/-号
plt.rcParams['figure.dpi'] = 100 # 分辨率 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) # 第一张子图
plt.plot(epochs_range, train_acc, label='训练正确率')
plt.plot(epochs_range, test_acc, label='测试正确率')
plt.legend(loc='lower right')
plt.title('训练和测试正确率比较') plt.subplot(1, 2, 2) # 第二张子图
plt.plot(epochs_range, train_loss, label='训练损失率')
plt.plot(epochs_range, test_loss, label='测试损失率')
plt.legend(loc='upper right')
plt.title('训练和测试损失率比较') plt.show()# 保存模型
torch.save(model, './models/cnn-cifar10.pth')
再次设置epochs为50训练结果:
epochs增加到100,训练结果:
可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:
data_transforms= { 'train': transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.ToTensor(), ])
}
在dataset中:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])
运行结果:
比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。
batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。
(三)总结
- epochs并不是越多越好。batch_size同样的道理
- 数据增强确实可以提高模型训练的准确性。
相关文章:

第P2周:Pytorch实现CIFAR10彩色图片识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 实现CIFAR-10的彩色图片识别实现比P1周更复杂一点的CNN网络 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: …...

CTFHub 命令注入-综合练习(学习记录)
综合过滤练习 命令分隔符的绕过姿势 ; %0a %0d & 那我们使用%0a试试,发现ls命令被成功执行 /?ip127.0.0.1%0als 发现一个名为flag_is_here的文件夹和index.php的文件,那么我们还是使用cd命令进入到文件夹下 http://challenge-438c1c1fb670566b.sa…...

OpenCV目标检测 级联分类器 C++实现
一.目标检测技术 目前常用实用性目标检测与跟踪的方法有以下两种: 帧差法 识别原理:基于前后两帧图像之间的差异进行对比,获取图像画面中正在运动的物体从而达到目标检测 缺点:画面中所有运动中物体都能识别 举个例子…...

QT6 Socket通讯封装(TCP/UDP)
为大家分享一下最近封装的以太网socket通讯接口 效果演示 如图,界面还没优化,后续更新 废话不多说直接上教程 添加库 如果为qmake项目中,在.pro文件添加 QT network QT core gui QT networkgreaterThan(QT_MAJOR_VERS…...

elasticsearch设置密码访问
1 用户认证介绍 默认ES是没有设置用户认证访问的,所以每次访问时,直接调相关API就能查询和写入数据。现在做一个认证,只有通过认证的用户才能访问和操作ES。 2 开启加密设置 1.生成证书文件 /usr/share/elasticsearch/bin/elasticsearch-…...

彻底理解如何优化接口性能
作为后端研发,必须要掌握怎么优化接口的性能或者说是响应时间,这样才能提高系统的系能,本文通过如下两个方面进行分析: 一.后端代码 有如下几步: 1.缓存机制 这是最场景的方式,当使用了缓存后,…...

C# 位运算
一、数据大小对应关系 说明: 将一个数据每左移一位,相当于乘以2。因此,左移8位就是乘以2的8次方,即256。 二、转换 1、 10进制转2进制字符串 #region 10进制转2进制字符串int number1 10;string binary Convert.ToString(num…...

【Flink-scala】DataStream编程模型之状态编程
DataStream编程模型之状态编程 参考: 1.【Flink-Scala】DataStream编程模型之数据源、数据转换、数据输出 2.【Flink-scala】DataStream编程模型之 窗口的划分-时间概念-窗口计算程序 3.【Flink-scala】DataStream编程模型之窗口计算-触发器-驱逐器 4.【Flink-scal…...

RabbitMQ的核心组件有哪些?
大家好,我是锋哥。今天分享关于【RabbitMQ的核心组件有哪些?】面试题。希望对大家有帮助; RabbitMQ的核心组件有哪些? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RabbitMQ是一个开源的消息代理(Messag…...

【Linux基础】基本开发工具的使用
目录 一、编译器——gcc/g的使用 gcc/g的安装 gcc的安装: g的安装: gcc/g的基本使用 gcc的使用 g的使用 动态链接与静态链接 程序的翻译过程 1. 一个C/C程序的构建过程,程序从源代码到可执行文件必须经历四个阶段 2. 理解选项的含…...

常见的数据结构和应用场景
数据结构是计算机科学中的基础概念,用于组织和存储数据,以便能够高效地访问和修改。下面是几种常见数据结构及其代表性应用场景: 1. 数组(Array) 问题解决:数组是一种线性数据结构,用于存储相…...

爬虫基础学习
爬虫概念与工作原理 爬虫是什么:爬虫(Web Scraping)是自动化地访问网站并提取数据的技术。它模拟用户浏览器的行为,通过HTTP请求访问网页,解析HTML文档并提取有用信息。 爬虫的基本工作流程: 发送HTTP请求…...

C++对象数组对象指针对象指针数组
一、对象数组 对象数组中的每一个元素都是同类的对象; 例1 对象数组成员的初始化 #include<iostream> using namespace std;class Student { public:Student( ){ };Student(int n,string nam,char s):num(n),name(nam),sex(s){};void display(){cout<&l…...

D96【python 接口自动化学习】- pytest进阶之fixture用法
day96 pytest的fixture详解(三) 学习日期:20241211 学习目标:pytest基础用法 -- pytest的fixture详解(三) 学习笔记: fixture(scop"class") (scop"class") 每一个类调…...

【算法】动态规划中01背包问题解析
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文由 JohnKi 原创,首发于 CSDN🙉 📢未来很长&#…...

选择WordPress和Shopify:搭建对谷歌SEO友好的网站
在建设网站时,不仅要考虑它的美观和功能性,还要关注它是否对谷歌SEO友好。如果你希望网站能够获得更好的搜索排名,WordPress和Shopify是两个值得推荐的建站平台。 WordPress作为最流行的内容管理系统,其强大的灵活性和丰富的插件…...

代理IP与生成式AI:携手共创未来
目录 代理IP:网络世界的“隐形斗篷” 1. 隐藏真实IP,保护隐私 2. 突破网络限制,访问更多资源 生成式AI:创意与效率的“超级大脑” 1. 提高创作效率 2. 个性化定制 代理IP与生成式AI的协同作用 1. 网络安全 2. 内容创作与…...

iOS 应用的生命周期
Managing your app’s life cycle | Apple Developer Documentation Performance and metrics | Apple Developer Documentation iOS 应用的生命周期状态是理解应用如何在不同状态下运行和管理资源的基础。在 iOS 开发中,应用生命周期管理的是应用从启动到终止的整…...

Elasticsearch 集群快照的定期备份设置指南
Elasticsearch 集群快照的定期备份设置指南 概述 快照: 在给定时刻对整个集群或者单个索引进行备份,以便在之后出现故障时可以基于之前备份的快照进行快速恢复。 前提条件: 准备一个备份存储盘,本指南采用的是AWS EFS文件系统做…...

Docker--Docker Image(镜像)
什么是Docker Image? Docker镜像(Docker Image)是Docker容器技术的核心组件之一,它包含了运行应用程序所需的所有依赖、库、代码、运行时环境以及配置文件等。 简单来说,Docker镜像是一个轻量级、可执行的软件包&…...

C++ 中的序列化和反序列化
一、C 中的序列化和反序列化 (一)基本概念 在 C 中,序列化是将对象转换为字节流的过程,反序列化则是从字节流重新构建对象的过程。这对于存储对象状态到文件、网络传输等场景非常有用。 (二)简单的序列化…...

我的Github学生认证申请过程
先说结论:很简单。 学生认证链接:GitHub Education GitHub 1. 首先你得绑定edu邮箱。这个应该没什么问题,Github也会提示。 2. 我是在学校里面、使用流量而非WiFi申请的,听说地理位置很重要,该给的权限(…...

信奥题解:勾股数计算中的浮点数精度问题
来源:GESP C++ 二级模拟题 本文给出官方参考答案的详细解析,包括每一部分的功能和关键点,以及与浮点数精度相关的问题的分析。 题目描述 勾股数是很有趣的数学概念。如果三个正整数a 、b 、c ,满足 a 2 + b 2 = c 2 a^2 + b^2 = c^2 a2+b2=c2 ,而且1 ≤ a ≤ b ≤ c ,…...

重生之我在学Vue--第2天 Vue 3 Composition API 与响应式系统
重生之我在学Vue–第2天 Vue 3 Composition API 与响应式系统 文章目录 重生之我在学Vue--第2天 Vue 3 Composition API 与响应式系统前言一、Composition API 核心概念1.1 什么是 Composition API?1.2 Composition API 的核心工具1.3 基础用法示例 二、响应式系统2…...

【AI知识】逻辑回归介绍+ 做二分类任务的实例(代码可视化)
1. 分类的基本概念 在机器学习的有监督学习中,分类一种常见任务,它的目标是将输入数据分类到预定的类别中。具体来说: 分类任务的常见应用: 垃圾邮件分类:判断一封电子邮件是否是垃圾邮件 。 医学诊断:…...

Mysql 笔记2 emp dept HRs
-- 注意事项 -- 1.给数据库和表起名字时尽量选择全小写 -- 2.作为筛选条件的字符串是否区分大小写看设置的校对规则utf8_bin 区分 drop database if exists hrs; create database hrs default charset utf8 collate utf8_general_ci;use hrs; drop table if exists tb_emp; dro…...

MySQL和Oracle的区别
MySQL和Oracle的区别 MySQL是轻量型数据库,并且免费,没有服务恢复数据。 Oracle是重量型数据库,收费,Oracle公司对Oracle数据库有任何服务。 1.对事务的提交 MySQL默认是自动提交,而Oracle默认不自动提交࿰…...

实验12 C语言连接和操作MySQL数据库
一、安装MySQL 1、使用包管理器安装MySQL sudo apt update sudo apt install mysql-server2、启动MySQL服务: sudo systemctl start mysql3、检查MySQL服务状态: sudo systemctl status mysql二、安装MySQL开发库 sudo apt-get install libmysqlcli…...

09篇--图片的水印添加(掩膜的运用)
如何添加水印? 添加水印其实可以理解为将一张图片中的某个物体或者图案提取出来,然后叠加到另一张图片上。具体的操作思想是通过将原始图片转换成灰度图,并进行二值化处理,去除背景部分,得到一个类似掩膜的图像。然后…...

sql-labs(21-25)
第21关 第一步 可以发现cookie是经过64位加密的 我们试试在这里注入 选择给他编码 发现可以成功注入 爆出表名 爆出字段 爆出数据 第22关 跟二十一关一模一样 闭合换成" 第 23 关 第二十三关重新回到get请求,会发现输入单引号报错,但是注释符…...