当前位置: 首页 > news >正文

第P2周:Pytorch实现CIFAR10彩色图片识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

  1. 实现CIFAR-10的彩色图片识别
  2. 实现比P1周更复杂一点的CNN网络

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1

(二)具体步骤
1.
import torch  
import torch.nn as nn  
import matplotlib.pyplot as plt  
import torchvision  # 第一步:设置GPU  
def USE_GPU():  if torch.cuda.is_available():  print('CUDA is available, will use GPU')  device = torch.device("cuda")  else:  print('CUDA is not available. Will use CPU')  device = torch.device("cpu")  return device  device = USE_GPU()  

输出:CUDA is available, will use GPU

  
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载  
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=torchvision.transforms.ToTensor())  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,  transform=torchvision.transforms.ToTensor())  batch_size = 32  
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)  # 取一个批次查看数据格式  
# 数据的shape为:[batch_size, channel, height, weight]  
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。  
imgs, labels = next(iter(train_dataload))  
print(imgs.shape)  # 查看一下图片  
import numpy as np  
plt.figure(figsize=(20, 5))  
for i, images in enumerate(imgs[:20]):  # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理  npimg = imgs.numpy().transpose((1, 2, 0))  # 将整个figure分成2行10列,并绘制第i+1个子图  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()  

输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
image.png

# 第三步,构建CNN网络  
import torch.nn.functional as F  num_classes = 10  # 因为CIFAR-10是10种类型  
class Model(nn.Module):  def __init__(self):  super(Model, self).__init__()  # 提取特征网络  self.conv1 = nn.Conv2d(3, 64, 3)  self.pool1 = nn.MaxPool2d(kernel_size=2)  self.conv2 = nn.Conv2d(64, 64, 3)  self.pool2 = nn.MaxPool2d(kernel_size=2)  self.conv3 = nn.Conv2d(64, 128, 3)  self.pool3 = nn.MaxPool2d(kernel_size=2)  # 分类网络  self.fc1 = nn.Linear(512, 256)  self.fc2 = nn.Linear(256, num_classes)  # 前向传播  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = self.pool3(F.relu(self.conv3(x)))  x = torch.flatten(x, 1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  from torchinfo import summary  
# 将模型转移到GPU中  
model = Model().to(device)  
summary(model)  

image.png

# 训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)    # 设置优化器  # 编写训练函数  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片  num_batches = len(dataloader)   # 批次大小,这里是1875(60000/32=1875)  train_acc, train_loss = 0, 0    # 初始化训练正确率和损失率都为0  for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值)  X, y = X.to(device), y.to(device)  # 计算预测误差  pred = model(X) # 网络输出预测值  loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距  # 反向传播  optimizer.zero_grad()   # grad属性归零  loss.backward() # 反向传播  optimizer.step()    # 第一步自动更新  # 记录正确率和损失率  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 测试函数  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片  num_batches = len(dataloader)   # 批次大小 ,这里312,即10000/32=312.5,向上取整  test_acc, test_loss = 0, 0  # 因为是测试,因此不用训练,梯度也不用计算不用更新  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  # 计算loss  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式训练  
epochs = 10  
train_acc, train_loss, test_acc, test_loss = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}'  print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  print('Done')  # 结果可视化  
# 隐藏警告  
import warnings  
warnings.filterwarnings('ignore')   # 忽略警告信息  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False  # 正常显示+/-号  
plt.rcParams['figure.dpi'] = 100    # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  plt.subplot(1, 2, 1)    # 第一张子图  
plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练和测试正确率比较')  plt.subplot(1, 2, 2)    # 第二张子图  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练和测试损失率比较')  plt.show()# 保存模型  
torch.save(model, './models/cnn-cifar10.pth')

image.png
再次设置epochs为50训练结果:
image.png
epochs增加到100,训练结果:
image.png
可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:

data_transforms= {  'train': transforms.Compose([  transforms.RandomHorizontalFlip(),  transforms.ToTensor(),  ]),  'test': transforms.Compose([  transforms.ToTensor(),  ])  
}

在dataset中:

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=data_transforms['train'])  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])

运行结果:
image.png
image.png
比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。
image.png
batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。
image.png

(三)总结
  1. epochs并不是越多越好。batch_size同样的道理
  2. 数据增强确实可以提高模型训练的准确性。

相关文章:

第P2周:Pytorch实现CIFAR10彩色图片识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 实现CIFAR-10的彩色图片识别实现比P1周更复杂一点的CNN网络 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: …...

CTFHub 命令注入-综合练习(学习记录)

综合过滤练习 命令分隔符的绕过姿势 ; %0a %0d & 那我们使用%0a试试,发现ls命令被成功执行 /?ip127.0.0.1%0als 发现一个名为flag_is_here的文件夹和index.php的文件,那么我们还是使用cd命令进入到文件夹下 http://challenge-438c1c1fb670566b.sa…...

OpenCV目标检测 级联分类器 C++实现

一.目标检测技术 目前常用实用性目标检测与跟踪的方法有以下两种: 帧差法 识别原理:基于前后两帧图像之间的差异进行对比,获取图像画面中正在运动的物体从而达到目标检测 缺点:画面中所有运动中物体都能识别 举个例子&#xf…...

QT6 Socket通讯封装(TCP/UDP)

为大家分享一下最近封装的以太网socket通讯接口 效果演示 如图,界面还没优化,后续更新 废话不多说直接上教程 添加库 如果为qmake项目中,在.pro文件添加 QT network QT core gui QT networkgreaterThan(QT_MAJOR_VERS…...

elasticsearch设置密码访问

1 用户认证介绍 默认ES是没有设置用户认证访问的,所以每次访问时,直接调相关API就能查询和写入数据。现在做一个认证,只有通过认证的用户才能访问和操作ES。 2 开启加密设置 1.生成证书文件 /usr/share/elasticsearch/bin/elasticsearch-…...

彻底理解如何优化接口性能

作为后端研发,必须要掌握怎么优化接口的性能或者说是响应时间,这样才能提高系统的系能,本文通过如下两个方面进行分析: 一.后端代码 有如下几步: 1.缓存机制 这是最场景的方式,当使用了缓存后,…...

C# 位运算

一、数据大小对应关系 说明: 将一个数据每左移一位,相当于乘以2。因此,左移8位就是乘以2的8次方,即256。 二、转换 1、 10进制转2进制字符串 #region 10进制转2进制字符串int number1 10;string binary Convert.ToString(num…...

【Flink-scala】DataStream编程模型之状态编程

DataStream编程模型之状态编程 参考: 1.【Flink-Scala】DataStream编程模型之数据源、数据转换、数据输出 2.【Flink-scala】DataStream编程模型之 窗口的划分-时间概念-窗口计算程序 3.【Flink-scala】DataStream编程模型之窗口计算-触发器-驱逐器 4.【Flink-scal…...

RabbitMQ的核心组件有哪些?

大家好,我是锋哥。今天分享关于【RabbitMQ的核心组件有哪些?】面试题。希望对大家有帮助; RabbitMQ的核心组件有哪些? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RabbitMQ是一个开源的消息代理(Messag…...

【Linux基础】基本开发工具的使用

目录 一、编译器——gcc/g的使用 gcc/g的安装 gcc的安装: g的安装: gcc/g的基本使用 gcc的使用 g的使用 动态链接与静态链接 程序的翻译过程 1. 一个C/C程序的构建过程,程序从源代码到可执行文件必须经历四个阶段 2. 理解选项的含…...

常见的数据结构和应用场景

数据结构是计算机科学中的基础概念,用于组织和存储数据,以便能够高效地访问和修改。下面是几种常见数据结构及其代表性应用场景: 1. 数组(Array) 问题解决:数组是一种线性数据结构,用于存储相…...

爬虫基础学习

爬虫概念与工作原理 爬虫是什么:爬虫(Web Scraping)是自动化地访问网站并提取数据的技术。它模拟用户浏览器的行为,通过HTTP请求访问网页,解析HTML文档并提取有用信息。 爬虫的基本工作流程: 发送HTTP请求…...

C++对象数组对象指针对象指针数组

一、对象数组 对象数组中的每一个元素都是同类的对象&#xff1b; 例1 对象数组成员的初始化 #include<iostream> using namespace std;class Student { public:Student( ){ };Student(int n,string nam,char s):num(n),name(nam),sex(s){};void display(){cout<&l…...

D96【python 接口自动化学习】- pytest进阶之fixture用法

day96 pytest的fixture详解&#xff08;三&#xff09; 学习日期&#xff1a;20241211 学习目标&#xff1a;pytest基础用法 -- pytest的fixture详解&#xff08;三&#xff09; 学习笔记&#xff1a; fixture(scop"class") (scop"class") 每一个类调…...

【算法】动态规划中01背包问题解析

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由 JohnKi 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f4e2;未来很长&#…...

选择WordPress和Shopify:搭建对谷歌SEO友好的网站

在建设网站时&#xff0c;不仅要考虑它的美观和功能性&#xff0c;还要关注它是否对谷歌SEO友好。如果你希望网站能够获得更好的搜索排名&#xff0c;WordPress和Shopify是两个值得推荐的建站平台。 WordPress作为最流行的内容管理系统&#xff0c;其强大的灵活性和丰富的插件…...

代理IP与生成式AI:携手共创未来

目录 代理IP&#xff1a;网络世界的“隐形斗篷” 1. 隐藏真实IP&#xff0c;保护隐私 2. 突破网络限制&#xff0c;访问更多资源 生成式AI&#xff1a;创意与效率的“超级大脑” 1. 提高创作效率 2. 个性化定制 代理IP与生成式AI的协同作用 1. 网络安全 2. 内容创作与…...

iOS 应用的生命周期

Managing your app’s life cycle | Apple Developer Documentation Performance and metrics | Apple Developer Documentation iOS 应用的生命周期状态是理解应用如何在不同状态下运行和管理资源的基础。在 iOS 开发中&#xff0c;应用生命周期管理的是应用从启动到终止的整…...

Elasticsearch 集群快照的定期备份设置指南

Elasticsearch 集群快照的定期备份设置指南 概述 快照&#xff1a; 在给定时刻对整个集群或者单个索引进行备份&#xff0c;以便在之后出现故障时可以基于之前备份的快照进行快速恢复。 前提条件&#xff1a; 准备一个备份存储盘&#xff0c;本指南采用的是AWS EFS文件系统做…...

Docker--Docker Image(镜像)

什么是Docker Image&#xff1f; Docker镜像&#xff08;Docker Image&#xff09;是Docker容器技术的核心组件之一&#xff0c;它包含了运行应用程序所需的所有依赖、库、代码、运行时环境以及配置文件等。 简单来说&#xff0c;Docker镜像是一个轻量级、可执行的软件包&…...

基于数字孪生的水厂可视化平台建设:架构与实践

分享大纲&#xff1a; 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年&#xff0c;数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段&#xff0c;基于数字孪生的水厂可视化平台的…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器&#xff0c;其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机&#xff08;Virtual Host&#xff09;。 1. 简介 Nginx 使用 server_name 指令来确定…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI&#xff08;https://spring.io/projects/spring-ai&#xff09;作为Spring生态中的AI集成框架&#xff0c;其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似&#xff0c;但特别为多语…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

Web后端基础(基础知识)

BS架构&#xff1a;Browser/Server&#xff0c;浏览器/服务器架构模式。客户端只需要浏览器&#xff0c;应用程序的逻辑和数据都存储在服务端。 优点&#xff1a;维护方便缺点&#xff1a;体验一般 CS架构&#xff1a;Client/Server&#xff0c;客户端/服务器架构模式。需要单独…...

沙箱虚拟化技术虚拟机容器之间的关系详解

问题 沙箱、虚拟化、容器三者分开一一介绍的话我知道他们各自都是什么东西&#xff0c;但是如果把三者放在一起&#xff0c;它们之间到底什么关系&#xff1f;又有什么联系呢&#xff1f;我不是很明白&#xff01;&#xff01;&#xff01; 就比如说&#xff1a; 沙箱&#…...

热烈祝贺埃文科技正式加入可信数据空间发展联盟

2025年4月29日&#xff0c;在福州举办的第八届数字中国建设峰会“可信数据空间分论坛”上&#xff0c;可信数据空间发展联盟正式宣告成立。国家数据局党组书记、局长刘烈宏出席并致辞&#xff0c;强调该联盟是推进全国一体化数据市场建设的关键抓手。 郑州埃文科技有限公司&am…...

门静脉高压——表现

一、门静脉高压表现 00:01 1. 门静脉构成 00:13 组成结构&#xff1a;由肠系膜上静脉和脾静脉汇合构成&#xff0c;是肝脏血液供应的主要来源。淤血后果&#xff1a;门静脉淤血会同时导致脾静脉和肠系膜上静脉淤血&#xff0c;引发后续系列症状。 2. 脾大和脾功能亢进 00:46 …...

Selenium 查找页面元素的方式

Selenium 查找页面元素的方式 Selenium 提供了多种方法来查找网页中的元素&#xff0c;以下是主要的定位方式&#xff1a; 基本定位方式 通过ID定位 driver.find_element(By.ID, "element_id")通过Name定位 driver.find_element(By.NAME, "element_name"…...

C++ Saucer 编写Windows桌面应用

文章目录 一、背景二、Saucer 简介核心特性典型应用场景 三、生成自己的项目四、以Win32项目方式构建Win32项目禁用最大化按钮 五、总结 一、背景 使用Saucer框架&#xff0c;开发Windows桌面应用&#xff0c;把一个html页面作为GUI设计放到Saucer里&#xff0c;隐藏掉运行时弹…...