当前位置: 首页 > news >正文

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

目录

LeNet-5

LeNet-5 结构

CIFAR-10

pytorch实现

lenet模型

训练模型

1.导入数据

2.训练模型

3.测试模型

测试单张图片

代码

运行结果


LeNet-5

LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。

LeNet-5 结构

LeNet-5 的结构包括以下几个层次:

  1. 输入层: 32x32 的灰度图像。
  2. 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
  3. 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
  4. 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
  5. 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
  6. 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
  7. 全连接层 F6: 包含 84 个神经元。
  8. 输出层: 包含 10 个神经元,对应于 10 个类别。

CIFAR-10

CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。

1. 标注数据量训练集:50000张图像测试集:10000张图像

2. 标注类别数据集共有10个类别。具体分类见图1。

3. 可视化

pytorch实现

lenet模型

  • 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
  • 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。

在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。

import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x

训练模型

1.导入数据

导入训练数据和测试数据

    def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

2.训练模型

    def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1)  # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total

3.测试模型

    def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total

测试单张图片

网上随便下载一个图片

然后使用图片编辑工具,把图片设置为32x32大小

通过导入模型,然后测试一下

代码

import torch
import cv2
import torch.nn.functional as F
#from model import Net  ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth')  # 加载模型model = model.to(device)model.eval()  # 把模型转为test模式img = cv2.imread("bird1.png")  # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)

运行结果

tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird

从结果看,效果还不错。记录一下

相关文章:

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

目录 LeNet-5 LeNet-5 结构 CIFAR-10 pytorch实现 lenet模型 训练模型 1.导入数据 2.训练模型 3.测试模型 测试单张图片 代码 运行结果 LeNet-5 LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络&#xff08;CNN&#xff09;模型&#xff0c;主要…...

Word卡顿的处理方法

1. 检查和关闭后台程序 关闭不必要的后台程序,释放系统资源。使用任务管理器(Ctrl + Shift + Esc)查看占用CPU和内存较高的应用,并关闭它们。2. 更新Microsoft Office 确保你的Microsoft Office软件是最新版本。新版本通常修复了已知的性能问题。打开Word,点击文件 > 账…...

在 Linux上常见的10大压缩格式解压命令和它们对应的压缩格式

文章目录 前言一、解压 .zip 文件二、解压 .tar.gz 或 .tgz 文件三、解压 .tar 文件四、解压 .tar.bz2 文件五、解压 .tar.xz 文件六、解压 .gz 文件七、解压 .bz2 文件八、解压 .xz 文件九、解压 .7z 文件十、解压 .rar 文件总结 前言 Linux 命令可以解压不同格式的压缩文件。…...

【数据结构】三、栈和队列:6.链队列、双端队列、队列的应用(树的层次遍历、广度优先BFS、先来先服务FCFS)

文章目录 2.链队列2.1初始化&#xff08;带头结点&#xff09;不带头结点 2.2入队&#xff08;带头结点&#xff09;2.3出队&#xff08;带头结点&#xff09;❗2.4链队列c实例 3.双端队列考点:输出序列合法性栈双端队列 队列的应用1.树的层次遍历2.图的广度优先遍历3.操作系统…...

技术速递|使用 Native Library Interop 为 .NET MAUI 创建绑定

作者&#xff1a;Rachel Kang 排版&#xff1a;Alan Wang 在当今的应用开发领域&#xff0c;通过利用本机功能来扩展 .NET 应用程序的能力非常宝贵。.NET MAUI 处理程序架构使开发人员能够使用 .NET 代码直接操作本机控件&#xff0c;甚至允许无缝创建跨平台自定义控件。然而&a…...

Linux笔记 --- 标准IO

系统IO的最大特点一个是更具通用性&#xff0c;不管是普通文件、管道文件、设备节点文件、接字文件等等都可以使用&#xff0c;另一个是他的简约性&#xff0c;对文件内数据的读写在任何情况下都是带任何格式的&#xff0c;而且数据的读写也都没有经过任何缓冲处理&#xff0c;…...

洛谷:B3625 迷宫寻路

迷宫寻路 题目描述 机器猫被困在一个矩形迷宫里。 迷宫可以视为一个 n m n\times m nm 矩阵&#xff0c;每个位置要么是空地&#xff0c;要么是墙。机器猫只能从一个空地走到其上、下、左、右的空地。 机器猫初始时位于 ( 1 , 1 ) (1, 1) (1,1) 的位置&#xff0c;问能否…...

【C#】explicit、implicit与operator

字面解释 explicit&#xff1a;清楚明白的;易于理解的;(说话)清晰的&#xff0c;明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit&#xff1a;含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator&#xff1a;操作人员;技工;电话员;接线员;…...

Vue:Vuex-Store使用指南

一、简介 1.1Vuex 是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化。Vuex 也集成到 Vue 的官方调试工具 devtools extension (opens new window)&#xf…...

对经典动态规划问题【爬台阶】的一些思考

背景 今天在做Leetcode题目时&#xff0c;做到了一道经典的动态规划问题&#xff1a;爬楼梯&#xff0c;题目的大致意思很简单&#xff0c;有个小孩正在上楼梯&#xff0c;楼梯有n阶台阶&#xff0c;小孩一次可以上1阶、2阶或3阶。实现一种方法&#xff0c;计算小孩有多少种上…...

开发一个能打造虚拟带货直播间的工具!

在当今数字化时代&#xff0c;直播带货已成为电商领域的一股强劲力量&#xff0c;其直观、互动性强的特点极大地提升了消费者的购物体验。 然而&#xff0c;随着技术的不断进步&#xff0c;传统直播带货模式正逐步向更加智能化、虚拟化的方向演进&#xff0c;本文将深入探讨如…...

汽车补光照明实验太阳光模拟器光源

汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节&#xff0c;它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果&#xff0c;以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…...

MediaPipe人体姿态、手指关键点检测

MediaPipe人体姿态、手指关键点检测 文章目录 MediaPipe人体姿态、手指关键点检测前言一、手指关键点检测二、姿态检测三、3D物体案例检测案例 前言 Mediapipe是google的一个开源项目&#xff0c;用于构建机器学习管道。   提供了16个预训练模型的案例&#xff1a;人脸检测、…...

树上dp之换根dp

基本概念&#xff1a; 换根dp是树上dp的一种 我们在什么时候需要用到换根dp呢&#xff1f; 当题目询问的属性&#xff0c;是需要当前结点为根时的属性&#xff0c;这个时候&#xff0c;我们就要使用换根dp 换根dp的基本思路&#xff1a; 假设题目询问的的属性为x 通常我们…...

2024/8/13 英语每日一段

Mackey says while Whole Foods has become more homogenized under Amazon, it did enable the store to do what it couldn’t have done independently. “People saw us as too expensive and out of touch with our customers,” he says. “The main thing Whole Foods n…...

Java多线程练习(1)

MultiProcessingExercise package MultiProcessingExercise120240813;public class MultiProcessingExercise {public static void main(String[] args) {/*需求&#xff1a;一共有1000张电影票,可以在两个窗口领取,假设每次领取的时间为3000毫秒,请用多线程模拟卖票过程并打印…...

AI高级肖像动画神器LivePortrait

文章目录 前言一、安装1.1 源码安装1.2 windows一键启动包 二、人像生成2.1 浏览器2.2 输入图像2.3 选择驱动视频2.4 生成2.5 结果 三、动物生成3.1 浏览器3.2 输入图片3.3 选择视频3.4 生成3.5 最终结果 四、软件获取 前言 最近&#xff0c;快手可灵大模型团队、中国科学技术…...

Java反射机制深度解析与实践应用

Java反射机制深度解析与实践应用 引言 Java反射是Java语言提供的一种能力&#xff0c;允许程序在运行时访问、检测和修改其自身的属性和行为。反射机制是Java面向对象编程的一大亮点&#xff0c;也是Java框架和库常用的技术之一。 反射的基本概念 反射的核心是java.lang.re…...

Oracle递归查询层级及路径

一、建表及插入数据 ocation_idlocation_nameparent_location_id1广东省NULL2广州市13深圳市14天河区25番禺区26南山区37宝安区3 建表sql&#xff1a; CREATE TABLE locations (location_id NUMBER PRIMARY KEY,location_name VARCHAR2(100),parent_location_id NUMBER ); I…...

leetcode300. 最长递增子序列,动态规划附状态转移方程

leetcode300. 最长递增子序列 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1,6,2,2…...

在rocky linux 9.5上在线安装 docker

前面是指南&#xff0c;后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

HashMap中的put方法执行流程(流程图)

1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中&#xff0c;其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下&#xff1a; 初始判断与哈希计算&#xff1a; 首先&#xff0c;putVal 方法会检查当前的 table&#xff08;也就…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言&#xff1a; 最近在做行为检测相关的模型&#xff0c;用的是时空图卷积网络&#xff08;STGCN&#xff09;&#xff0c;但原有kinetic-400数据集数据质量较低&#xff0c;需要进行细粒度的标注&#xff0c;同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

代码随想录刷题day30

1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

三分算法与DeepSeek辅助证明是单峰函数

前置 单峰函数有唯一的最大值&#xff0c;最大值左侧的数值严格单调递增&#xff0c;最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值&#xff0c;最小值左侧的数值严格单调递减&#xff0c;最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…...

深入理解Optional:处理空指针异常

1. 使用Optional处理可能为空的集合 在Java开发中&#xff0c;集合判空是一个常见但容易出错的场景。传统方式虽然可行&#xff0c;但存在一些潜在问题&#xff1a; // 传统判空方式 if (!CollectionUtils.isEmpty(userInfoList)) {for (UserInfo userInfo : userInfoList) {…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...