pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录
LeNet-5
LeNet-5 结构
CIFAR-10
pytorch实现
lenet模型
训练模型
1.导入数据
2.训练模型
3.测试模型
测试单张图片
代码
运行结果
LeNet-5
LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。
LeNet-5 结构
LeNet-5 的结构包括以下几个层次:
- 输入层: 32x32 的灰度图像。
- 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
- 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
- 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
- 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
- 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
- 全连接层 F6: 包含 84 个神经元。
- 输出层: 包含 10 个神经元,对应于 10 个类别。

CIFAR-10
CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。
1. 标注数据量训练集:50000张图像测试集:10000张图像
2. 标注类别数据集共有10个类别。具体分类见图1。
3. 可视化

pytorch实现
lenet模型
- 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
- 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。
在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。
import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x
训练模型
1.导入数据
导入训练数据和测试数据
def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)
2.训练模型
def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1) # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total
3.测试模型
def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total
测试单张图片
网上随便下载一个图片

然后使用图片编辑工具,把图片设置为32x32大小
![]()
通过导入模型,然后测试一下
代码
import torch
import cv2
import torch.nn.functional as F
#from model import Net ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth') # 加载模型model = model.to(device)model.eval() # 把模型转为test模式img = cv2.imread("bird1.png") # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)
运行结果
tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird
从结果看,效果还不错。记录一下
相关文章:
pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录 LeNet-5 LeNet-5 结构 CIFAR-10 pytorch实现 lenet模型 训练模型 1.导入数据 2.训练模型 3.测试模型 测试单张图片 代码 运行结果 LeNet-5 LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要…...
Word卡顿的处理方法
1. 检查和关闭后台程序 关闭不必要的后台程序,释放系统资源。使用任务管理器(Ctrl + Shift + Esc)查看占用CPU和内存较高的应用,并关闭它们。2. 更新Microsoft Office 确保你的Microsoft Office软件是最新版本。新版本通常修复了已知的性能问题。打开Word,点击文件 > 账…...
在 Linux上常见的10大压缩格式解压命令和它们对应的压缩格式
文章目录 前言一、解压 .zip 文件二、解压 .tar.gz 或 .tgz 文件三、解压 .tar 文件四、解压 .tar.bz2 文件五、解压 .tar.xz 文件六、解压 .gz 文件七、解压 .bz2 文件八、解压 .xz 文件九、解压 .7z 文件十、解压 .rar 文件总结 前言 Linux 命令可以解压不同格式的压缩文件。…...
【数据结构】三、栈和队列:6.链队列、双端队列、队列的应用(树的层次遍历、广度优先BFS、先来先服务FCFS)
文章目录 2.链队列2.1初始化(带头结点)不带头结点 2.2入队(带头结点)2.3出队(带头结点)❗2.4链队列c实例 3.双端队列考点:输出序列合法性栈双端队列 队列的应用1.树的层次遍历2.图的广度优先遍历3.操作系统…...
技术速递|使用 Native Library Interop 为 .NET MAUI 创建绑定
作者:Rachel Kang 排版:Alan Wang 在当今的应用开发领域,通过利用本机功能来扩展 .NET 应用程序的能力非常宝贵。.NET MAUI 处理程序架构使开发人员能够使用 .NET 代码直接操作本机控件,甚至允许无缝创建跨平台自定义控件。然而&a…...
Linux笔记 --- 标准IO
系统IO的最大特点一个是更具通用性,不管是普通文件、管道文件、设备节点文件、接字文件等等都可以使用,另一个是他的简约性,对文件内数据的读写在任何情况下都是带任何格式的,而且数据的读写也都没有经过任何缓冲处理,…...
洛谷:B3625 迷宫寻路
迷宫寻路 题目描述 机器猫被困在一个矩形迷宫里。 迷宫可以视为一个 n m n\times m nm 矩阵,每个位置要么是空地,要么是墙。机器猫只能从一个空地走到其上、下、左、右的空地。 机器猫初始时位于 ( 1 , 1 ) (1, 1) (1,1) 的位置,问能否…...
【C#】explicit、implicit与operator
字面解释 explicit:清楚明白的;易于理解的;(说话)清晰的,明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit:含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator:操作人员;技工;电话员;接线员;…...
Vue:Vuex-Store使用指南
一、简介 1.1Vuex 是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。Vuex 也集成到 Vue 的官方调试工具 devtools extension (opens new window)…...
对经典动态规划问题【爬台阶】的一些思考
背景 今天在做Leetcode题目时,做到了一道经典的动态规划问题:爬楼梯,题目的大致意思很简单,有个小孩正在上楼梯,楼梯有n阶台阶,小孩一次可以上1阶、2阶或3阶。实现一种方法,计算小孩有多少种上…...
开发一个能打造虚拟带货直播间的工具!
在当今数字化时代,直播带货已成为电商领域的一股强劲力量,其直观、互动性强的特点极大地提升了消费者的购物体验。 然而,随着技术的不断进步,传统直播带货模式正逐步向更加智能化、虚拟化的方向演进,本文将深入探讨如…...
汽车补光照明实验太阳光模拟器光源
汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节,它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果,以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…...
MediaPipe人体姿态、手指关键点检测
MediaPipe人体姿态、手指关键点检测 文章目录 MediaPipe人体姿态、手指关键点检测前言一、手指关键点检测二、姿态检测三、3D物体案例检测案例 前言 Mediapipe是google的一个开源项目,用于构建机器学习管道。 提供了16个预训练模型的案例:人脸检测、…...
树上dp之换根dp
基本概念: 换根dp是树上dp的一种 我们在什么时候需要用到换根dp呢? 当题目询问的属性,是需要当前结点为根时的属性,这个时候,我们就要使用换根dp 换根dp的基本思路: 假设题目询问的的属性为x 通常我们…...
2024/8/13 英语每日一段
Mackey says while Whole Foods has become more homogenized under Amazon, it did enable the store to do what it couldn’t have done independently. “People saw us as too expensive and out of touch with our customers,” he says. “The main thing Whole Foods n…...
Java多线程练习(1)
MultiProcessingExercise package MultiProcessingExercise120240813;public class MultiProcessingExercise {public static void main(String[] args) {/*需求:一共有1000张电影票,可以在两个窗口领取,假设每次领取的时间为3000毫秒,请用多线程模拟卖票过程并打印…...
AI高级肖像动画神器LivePortrait
文章目录 前言一、安装1.1 源码安装1.2 windows一键启动包 二、人像生成2.1 浏览器2.2 输入图像2.3 选择驱动视频2.4 生成2.5 结果 三、动物生成3.1 浏览器3.2 输入图片3.3 选择视频3.4 生成3.5 最终结果 四、软件获取 前言 最近,快手可灵大模型团队、中国科学技术…...
Java反射机制深度解析与实践应用
Java反射机制深度解析与实践应用 引言 Java反射是Java语言提供的一种能力,允许程序在运行时访问、检测和修改其自身的属性和行为。反射机制是Java面向对象编程的一大亮点,也是Java框架和库常用的技术之一。 反射的基本概念 反射的核心是java.lang.re…...
Oracle递归查询层级及路径
一、建表及插入数据 ocation_idlocation_nameparent_location_id1广东省NULL2广州市13深圳市14天河区25番禺区26南山区37宝安区3 建表sql: CREATE TABLE locations (location_id NUMBER PRIMARY KEY,location_name VARCHAR2(100),parent_location_id NUMBER ); I…...
leetcode300. 最长递增子序列,动态规划附状态转移方程
leetcode300. 最长递增子序列 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 [0,3,1,6,2,2…...
利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
Linux链表操作全解析
Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...
边缘计算医疗风险自查APP开发方案
核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...
理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
短视频矩阵系统文案创作功能开发实践,定制化开发
在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...
CVPR2025重磅突破:AnomalyAny框架实现单样本生成逼真异常数据,破解视觉检测瓶颈!
本文介绍了一种名为AnomalyAny的创新框架,该方法利用Stable Diffusion的强大生成能力,仅需单个正常样本和文本描述,即可生成逼真且多样化的异常样本,有效解决了视觉异常检测中异常样本稀缺的难题,为工业质检、医疗影像…...
【安全篇】金刚不坏之身:整合 Spring Security + JWT 实现无状态认证与授权
摘要 本文是《Spring Boot 实战派》系列的第四篇。我们将直面所有 Web 应用都无法回避的核心问题:安全。文章将详细阐述认证(Authentication) 与授权(Authorization的核心概念,对比传统 Session-Cookie 与现代 JWT(JS…...
