当前位置: 首页 > news >正文

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

目录

LeNet-5

LeNet-5 结构

CIFAR-10

pytorch实现

lenet模型

训练模型

1.导入数据

2.训练模型

3.测试模型

测试单张图片

代码

运行结果


LeNet-5

LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。

LeNet-5 结构

LeNet-5 的结构包括以下几个层次:

  1. 输入层: 32x32 的灰度图像。
  2. 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
  3. 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
  4. 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
  5. 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
  6. 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
  7. 全连接层 F6: 包含 84 个神经元。
  8. 输出层: 包含 10 个神经元,对应于 10 个类别。

CIFAR-10

CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。

1. 标注数据量训练集:50000张图像测试集:10000张图像

2. 标注类别数据集共有10个类别。具体分类见图1。

3. 可视化

pytorch实现

lenet模型

  • 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
  • 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。

在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。

import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x

训练模型

1.导入数据

导入训练数据和测试数据

    def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

2.训练模型

    def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1)  # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total

3.测试模型

    def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total

测试单张图片

网上随便下载一个图片

然后使用图片编辑工具,把图片设置为32x32大小

通过导入模型,然后测试一下

代码

import torch
import cv2
import torch.nn.functional as F
#from model import Net  ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth')  # 加载模型model = model.to(device)model.eval()  # 把模型转为test模式img = cv2.imread("bird1.png")  # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)

运行结果

tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird

从结果看,效果还不错。记录一下

相关文章:

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

目录 LeNet-5 LeNet-5 结构 CIFAR-10 pytorch实现 lenet模型 训练模型 1.导入数据 2.训练模型 3.测试模型 测试单张图片 代码 运行结果 LeNet-5 LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络&#xff08;CNN&#xff09;模型&#xff0c;主要…...

Word卡顿的处理方法

1. 检查和关闭后台程序 关闭不必要的后台程序,释放系统资源。使用任务管理器(Ctrl + Shift + Esc)查看占用CPU和内存较高的应用,并关闭它们。2. 更新Microsoft Office 确保你的Microsoft Office软件是最新版本。新版本通常修复了已知的性能问题。打开Word,点击文件 > 账…...

在 Linux上常见的10大压缩格式解压命令和它们对应的压缩格式

文章目录 前言一、解压 .zip 文件二、解压 .tar.gz 或 .tgz 文件三、解压 .tar 文件四、解压 .tar.bz2 文件五、解压 .tar.xz 文件六、解压 .gz 文件七、解压 .bz2 文件八、解压 .xz 文件九、解压 .7z 文件十、解压 .rar 文件总结 前言 Linux 命令可以解压不同格式的压缩文件。…...

【数据结构】三、栈和队列:6.链队列、双端队列、队列的应用(树的层次遍历、广度优先BFS、先来先服务FCFS)

文章目录 2.链队列2.1初始化&#xff08;带头结点&#xff09;不带头结点 2.2入队&#xff08;带头结点&#xff09;2.3出队&#xff08;带头结点&#xff09;❗2.4链队列c实例 3.双端队列考点:输出序列合法性栈双端队列 队列的应用1.树的层次遍历2.图的广度优先遍历3.操作系统…...

技术速递|使用 Native Library Interop 为 .NET MAUI 创建绑定

作者&#xff1a;Rachel Kang 排版&#xff1a;Alan Wang 在当今的应用开发领域&#xff0c;通过利用本机功能来扩展 .NET 应用程序的能力非常宝贵。.NET MAUI 处理程序架构使开发人员能够使用 .NET 代码直接操作本机控件&#xff0c;甚至允许无缝创建跨平台自定义控件。然而&a…...

Linux笔记 --- 标准IO

系统IO的最大特点一个是更具通用性&#xff0c;不管是普通文件、管道文件、设备节点文件、接字文件等等都可以使用&#xff0c;另一个是他的简约性&#xff0c;对文件内数据的读写在任何情况下都是带任何格式的&#xff0c;而且数据的读写也都没有经过任何缓冲处理&#xff0c;…...

洛谷:B3625 迷宫寻路

迷宫寻路 题目描述 机器猫被困在一个矩形迷宫里。 迷宫可以视为一个 n m n\times m nm 矩阵&#xff0c;每个位置要么是空地&#xff0c;要么是墙。机器猫只能从一个空地走到其上、下、左、右的空地。 机器猫初始时位于 ( 1 , 1 ) (1, 1) (1,1) 的位置&#xff0c;问能否…...

【C#】explicit、implicit与operator

字面解释 explicit&#xff1a;清楚明白的;易于理解的;(说话)清晰的&#xff0c;明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit&#xff1a;含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator&#xff1a;操作人员;技工;电话员;接线员;…...

Vue:Vuex-Store使用指南

一、简介 1.1Vuex 是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化。Vuex 也集成到 Vue 的官方调试工具 devtools extension (opens new window)&#xf…...

对经典动态规划问题【爬台阶】的一些思考

背景 今天在做Leetcode题目时&#xff0c;做到了一道经典的动态规划问题&#xff1a;爬楼梯&#xff0c;题目的大致意思很简单&#xff0c;有个小孩正在上楼梯&#xff0c;楼梯有n阶台阶&#xff0c;小孩一次可以上1阶、2阶或3阶。实现一种方法&#xff0c;计算小孩有多少种上…...

开发一个能打造虚拟带货直播间的工具!

在当今数字化时代&#xff0c;直播带货已成为电商领域的一股强劲力量&#xff0c;其直观、互动性强的特点极大地提升了消费者的购物体验。 然而&#xff0c;随着技术的不断进步&#xff0c;传统直播带货模式正逐步向更加智能化、虚拟化的方向演进&#xff0c;本文将深入探讨如…...

汽车补光照明实验太阳光模拟器光源

汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节&#xff0c;它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果&#xff0c;以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…...

MediaPipe人体姿态、手指关键点检测

MediaPipe人体姿态、手指关键点检测 文章目录 MediaPipe人体姿态、手指关键点检测前言一、手指关键点检测二、姿态检测三、3D物体案例检测案例 前言 Mediapipe是google的一个开源项目&#xff0c;用于构建机器学习管道。   提供了16个预训练模型的案例&#xff1a;人脸检测、…...

树上dp之换根dp

基本概念&#xff1a; 换根dp是树上dp的一种 我们在什么时候需要用到换根dp呢&#xff1f; 当题目询问的属性&#xff0c;是需要当前结点为根时的属性&#xff0c;这个时候&#xff0c;我们就要使用换根dp 换根dp的基本思路&#xff1a; 假设题目询问的的属性为x 通常我们…...

2024/8/13 英语每日一段

Mackey says while Whole Foods has become more homogenized under Amazon, it did enable the store to do what it couldn’t have done independently. “People saw us as too expensive and out of touch with our customers,” he says. “The main thing Whole Foods n…...

Java多线程练习(1)

MultiProcessingExercise package MultiProcessingExercise120240813;public class MultiProcessingExercise {public static void main(String[] args) {/*需求&#xff1a;一共有1000张电影票,可以在两个窗口领取,假设每次领取的时间为3000毫秒,请用多线程模拟卖票过程并打印…...

AI高级肖像动画神器LivePortrait

文章目录 前言一、安装1.1 源码安装1.2 windows一键启动包 二、人像生成2.1 浏览器2.2 输入图像2.3 选择驱动视频2.4 生成2.5 结果 三、动物生成3.1 浏览器3.2 输入图片3.3 选择视频3.4 生成3.5 最终结果 四、软件获取 前言 最近&#xff0c;快手可灵大模型团队、中国科学技术…...

Java反射机制深度解析与实践应用

Java反射机制深度解析与实践应用 引言 Java反射是Java语言提供的一种能力&#xff0c;允许程序在运行时访问、检测和修改其自身的属性和行为。反射机制是Java面向对象编程的一大亮点&#xff0c;也是Java框架和库常用的技术之一。 反射的基本概念 反射的核心是java.lang.re…...

Oracle递归查询层级及路径

一、建表及插入数据 ocation_idlocation_nameparent_location_id1广东省NULL2广州市13深圳市14天河区25番禺区26南山区37宝安区3 建表sql&#xff1a; CREATE TABLE locations (location_id NUMBER PRIMARY KEY,location_name VARCHAR2(100),parent_location_id NUMBER ); I…...

leetcode300. 最长递增子序列,动态规划附状态转移方程

leetcode300. 最长递增子序列 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1,6,2,2…...

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…...

LeetCode - 394. 字符串解码

题目 394. 字符串解码 - 力扣&#xff08;LeetCode&#xff09; 思路 使用两个栈&#xff1a;一个存储重复次数&#xff0c;一个存储字符串 遍历输入字符串&#xff1a; 数字处理&#xff1a;遇到数字时&#xff0c;累积计算重复次数左括号处理&#xff1a;保存当前状态&a…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN&#xff0c;根据VPN原理&#xff0c;打通两个内网必然需要借助一个公共中继节点&#xff0c;ktconnect工具巧妙的利用k8s原生的portforward能力&#xff0c;简化了建立连接的过程&#xff0c;apiserver间接起到了中继节…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...

如何更改默认 Crontab 编辑器 ?

在 Linux 领域中&#xff0c;crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用&#xff0c;用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益&#xff0c;允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...

【网络安全】开源系统getshell漏洞挖掘

审计过程&#xff1a; 在入口文件admin/index.php中&#xff1a; 用户可以通过m,c,a等参数控制加载的文件和方法&#xff0c;在app/system/entrance.php中存在重点代码&#xff1a; 当M_TYPE system并且M_MODULE include时&#xff0c;会设置常量PATH_OWN_FILE为PATH_APP.M_T…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...