pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录
LeNet-5
LeNet-5 结构
CIFAR-10
pytorch实现
lenet模型
训练模型
1.导入数据
2.训练模型
3.测试模型
测试单张图片
代码
运行结果
LeNet-5
LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。
LeNet-5 结构
LeNet-5 的结构包括以下几个层次:
- 输入层: 32x32 的灰度图像。
- 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
- 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
- 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
- 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
- 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
- 全连接层 F6: 包含 84 个神经元。
- 输出层: 包含 10 个神经元,对应于 10 个类别。
CIFAR-10
CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。
1. 标注数据量训练集:50000张图像测试集:10000张图像
2. 标注类别数据集共有10个类别。具体分类见图1。
3. 可视化
pytorch实现
lenet模型
- 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
- 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。
在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。
import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x
训练模型
1.导入数据
导入训练数据和测试数据
def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)
2.训练模型
def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1) # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total
3.测试模型
def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total
测试单张图片
网上随便下载一个图片
然后使用图片编辑工具,把图片设置为32x32大小
通过导入模型,然后测试一下
代码
import torch
import cv2
import torch.nn.functional as F
#from model import Net ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth') # 加载模型model = model.to(device)model.eval() # 把模型转为test模式img = cv2.imread("bird1.png") # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)
运行结果
tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird
从结果看,效果还不错。记录一下
相关文章:

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录 LeNet-5 LeNet-5 结构 CIFAR-10 pytorch实现 lenet模型 训练模型 1.导入数据 2.训练模型 3.测试模型 测试单张图片 代码 运行结果 LeNet-5 LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要…...
Word卡顿的处理方法
1. 检查和关闭后台程序 关闭不必要的后台程序,释放系统资源。使用任务管理器(Ctrl + Shift + Esc)查看占用CPU和内存较高的应用,并关闭它们。2. 更新Microsoft Office 确保你的Microsoft Office软件是最新版本。新版本通常修复了已知的性能问题。打开Word,点击文件 > 账…...
在 Linux上常见的10大压缩格式解压命令和它们对应的压缩格式
文章目录 前言一、解压 .zip 文件二、解压 .tar.gz 或 .tgz 文件三、解压 .tar 文件四、解压 .tar.bz2 文件五、解压 .tar.xz 文件六、解压 .gz 文件七、解压 .bz2 文件八、解压 .xz 文件九、解压 .7z 文件十、解压 .rar 文件总结 前言 Linux 命令可以解压不同格式的压缩文件。…...

【数据结构】三、栈和队列:6.链队列、双端队列、队列的应用(树的层次遍历、广度优先BFS、先来先服务FCFS)
文章目录 2.链队列2.1初始化(带头结点)不带头结点 2.2入队(带头结点)2.3出队(带头结点)❗2.4链队列c实例 3.双端队列考点:输出序列合法性栈双端队列 队列的应用1.树的层次遍历2.图的广度优先遍历3.操作系统…...

技术速递|使用 Native Library Interop 为 .NET MAUI 创建绑定
作者:Rachel Kang 排版:Alan Wang 在当今的应用开发领域,通过利用本机功能来扩展 .NET 应用程序的能力非常宝贵。.NET MAUI 处理程序架构使开发人员能够使用 .NET 代码直接操作本机控件,甚至允许无缝创建跨平台自定义控件。然而&a…...

Linux笔记 --- 标准IO
系统IO的最大特点一个是更具通用性,不管是普通文件、管道文件、设备节点文件、接字文件等等都可以使用,另一个是他的简约性,对文件内数据的读写在任何情况下都是带任何格式的,而且数据的读写也都没有经过任何缓冲处理,…...
洛谷:B3625 迷宫寻路
迷宫寻路 题目描述 机器猫被困在一个矩形迷宫里。 迷宫可以视为一个 n m n\times m nm 矩阵,每个位置要么是空地,要么是墙。机器猫只能从一个空地走到其上、下、左、右的空地。 机器猫初始时位于 ( 1 , 1 ) (1, 1) (1,1) 的位置,问能否…...

【C#】explicit、implicit与operator
字面解释 explicit:清楚明白的;易于理解的;(说话)清晰的,明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit:含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator:操作人员;技工;电话员;接线员;…...

Vue:Vuex-Store使用指南
一、简介 1.1Vuex 是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。Vuex 也集成到 Vue 的官方调试工具 devtools extension (opens new window)…...
对经典动态规划问题【爬台阶】的一些思考
背景 今天在做Leetcode题目时,做到了一道经典的动态规划问题:爬楼梯,题目的大致意思很简单,有个小孩正在上楼梯,楼梯有n阶台阶,小孩一次可以上1阶、2阶或3阶。实现一种方法,计算小孩有多少种上…...

开发一个能打造虚拟带货直播间的工具!
在当今数字化时代,直播带货已成为电商领域的一股强劲力量,其直观、互动性强的特点极大地提升了消费者的购物体验。 然而,随着技术的不断进步,传统直播带货模式正逐步向更加智能化、虚拟化的方向演进,本文将深入探讨如…...

汽车补光照明实验太阳光模拟器光源
汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节,它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果,以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…...

MediaPipe人体姿态、手指关键点检测
MediaPipe人体姿态、手指关键点检测 文章目录 MediaPipe人体姿态、手指关键点检测前言一、手指关键点检测二、姿态检测三、3D物体案例检测案例 前言 Mediapipe是google的一个开源项目,用于构建机器学习管道。 提供了16个预训练模型的案例:人脸检测、…...
树上dp之换根dp
基本概念: 换根dp是树上dp的一种 我们在什么时候需要用到换根dp呢? 当题目询问的属性,是需要当前结点为根时的属性,这个时候,我们就要使用换根dp 换根dp的基本思路: 假设题目询问的的属性为x 通常我们…...

2024/8/13 英语每日一段
Mackey says while Whole Foods has become more homogenized under Amazon, it did enable the store to do what it couldn’t have done independently. “People saw us as too expensive and out of touch with our customers,” he says. “The main thing Whole Foods n…...
Java多线程练习(1)
MultiProcessingExercise package MultiProcessingExercise120240813;public class MultiProcessingExercise {public static void main(String[] args) {/*需求:一共有1000张电影票,可以在两个窗口领取,假设每次领取的时间为3000毫秒,请用多线程模拟卖票过程并打印…...

AI高级肖像动画神器LivePortrait
文章目录 前言一、安装1.1 源码安装1.2 windows一键启动包 二、人像生成2.1 浏览器2.2 输入图像2.3 选择驱动视频2.4 生成2.5 结果 三、动物生成3.1 浏览器3.2 输入图片3.3 选择视频3.4 生成3.5 最终结果 四、软件获取 前言 最近,快手可灵大模型团队、中国科学技术…...
Java反射机制深度解析与实践应用
Java反射机制深度解析与实践应用 引言 Java反射是Java语言提供的一种能力,允许程序在运行时访问、检测和修改其自身的属性和行为。反射机制是Java面向对象编程的一大亮点,也是Java框架和库常用的技术之一。 反射的基本概念 反射的核心是java.lang.re…...
Oracle递归查询层级及路径
一、建表及插入数据 ocation_idlocation_nameparent_location_id1广东省NULL2广州市13深圳市14天河区25番禺区26南山区37宝安区3 建表sql: CREATE TABLE locations (location_id NUMBER PRIMARY KEY,location_name VARCHAR2(100),parent_location_id NUMBER ); I…...

leetcode300. 最长递增子序列,动态规划附状态转移方程
leetcode300. 最长递增子序列 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 [0,3,1,6,2,2…...
Vue记事本应用实现教程
文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

RocketMQ延迟消息机制
两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数,对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后…...

大话软工笔记—需求分析概述
需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql
智慧工地管理云平台系统,智慧工地全套源码,java版智慧工地源码,支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求,提供“平台网络终端”的整体解决方案,提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的一体化测试平台,覆盖应用全生命周期测试需求,主要提供五大核心能力: 测试类型检测目标关键指标功能体验基…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器
第一章 引言:语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域,文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量,支撑着搜索引擎、推荐系统、…...

如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...