当前位置: 首页 > news >正文

LeNet-5(论文复现)

LeNet-5(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • LeNet-5(论文复现)
        • 概述
        • LeNet-5网络架构介绍
        • 训练过程
        • 测试过程
        • 使用方式
        • 说明

概述

LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征。
出自论文《Gradient-Based Learning Applied to Document Recognition》

LeNet-5网络架构介绍

在这里插入图片描述

  • 输入层

    输入32×32通道数为1的图片

  • C1层(卷积层)

    使用6个5×5大小的卷积核,padding=0,stride=1,得到6个28×28大小的特征图

    激活函数: ReLU

    **可训练参数:**6×(5×5+1)=1566×(5×5+1)=156

  • S2层(池化层)

    最大池化,池化窗大小2×2,stride=2

    **可训练参数:**6×(1+1)=126×(1+1)=12,其中第一个 1 为池化对应的 2*2 感受野中最大的那个数的权重 w,第二个 1 为偏置 b。

  • C3层(卷积层)

    使用16个5×5大小的卷积核,padding=0,stride=1,得到16个10×10大小的特征图

    激活函数: ReLu

    **可训练参数:**6×(5×5×3+1)+6×(5×5×4+1)+3×(5×5×4+1)+1×(5×5×6+1)=15166×(5×5×3+1)+6×(5×5×4+1)+3×(5×5×4+1)+1×(5×5×6+1)=1516

    16 个卷积核并不是都与 S2 的 6 个通道层进行卷积操作,如下图所示,C3 的前六个特征图(0,1,2,3,4,5)由 S2 的相邻三个特征图作为输入,对应的卷积核尺寸为:5x5x3;接下来的 6 个特征图(6,7,8,9,10,11)由 S2 的相邻四个特征图作为输入对应的卷积核尺寸为:5x5x4;接下来的 3 个特征图(12,13,14)号特征图由 S2 间断的四个特征图作为输入对应的卷积核尺寸为:5x5x4;最后的 15 号特征图由 S2 全部(6 个)特征图作为输入,对应的卷积核尺寸为:5x5x6

  • S4层(池化层)

    最大池化,池化窗大小2×2,stride=2

    **可训练参数:**16×(1+1)=3216×(1+1)=32

  • C5层(卷积层/全连接层)

    由于该层卷积核的大小与输入图像相同,故也可认为是全连接层。

    C5 层是卷积层,使用 120 个 5×5x16 大小的卷积核,padding=0,stride=1进行卷积,得到 120 个 1×1 大小的特征图:5-5+1=1。即相当于 120 个神经元的全连接层。

    值得注意的是,与C3层不同,这里120个卷积核都与S4的16个通道层进行卷积操作。

    激活函数: ReLU

    **可训练参数:**120×(5×5×16+1)=48120120×(5×5×16+1)=48120

  • F6层(全连接层)

    F6 是全连接层,共有 84 个神经元,与 C5 层进行全连接,即每个神经元都与 C5 层的 120 个特征图相连。计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过 sigmoid 函数输出。

    **可训练参数:**84×(120+1)84×(120+1)

  • OUTPUT层(全连接层)

    最后的 Output 层也是全连接层,是 Gaussian Connections,采用了 RBF 函数(即径向欧式距离函数),计算输入向量和参数向量之间的欧式距离(目前已经被Softmax 取代)。

    **可训练参数:**84×1084×10

使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器

MNIST是一个非常有名的手写体数字识别数据集,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证;测试样本:共10000个。MNIST数据集每张图片是单通道的,大小为28x28

在这里插入图片描述

下载并加载数据,并对数据进行预处理

# 下载MNIST数据集train_set = datasets.MNIST(root = "./data", train = True, download = True, transform = pipline_train)test_set = datasets.MNIST(root = "./data", train = False, download = True, transform = pipline_test)# 加载数据集train_data = torch.utils.data.DataLoader(train_set, batch_size = opt.batch_size, shuffle = True)test_data = torch.utils.data.DataLoader(test_set, batch_size = opt.batch_size, shuffle = False)train_data_size = len(train_data)test_data_size = len(test_data)print("训练数据集长度:{}\n测试数据集长度:{}".format(train_data_size, test_data_size))

若本地无MNIST数据集,会在当前目录下新建一个data文件夹存放数据

在这里插入图片描述

MNIST数据集中的图片数据以ubyte格式存储,ubyte是一种无符号字节类型,取值范围在0~255之间。MNIST数据集的图像数据文件为"train-images-idx3-ubyte.gz"和"t10k-images-idx3-ubyte.gz",其中前者存储了训练数据,后者存储了测试数据。

由于 MNIST 数据集图片尺寸是 28x28 单通道的,而 LeNet-5 网络输入 Input 图片尺寸是 32x32,使用 transforms.Resize 将输入图片尺寸调整为 32x32

pipline_train = transforms.Compose([# 随机旋转图片transforms.RandomHorizontalFlip(),# 将图片尺寸resize到32x32transforms.Resize((32, 32)),# 将图片转化为Tensor格式transforms.ToTensor(),# 正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.1307,), (0.3081,))
])
pipline_test = transforms.Compose([# 将图片尺寸resize到32x32transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])

搭建 LeNet-5 神经网络结构,并定义前向传播的过程

# 搭建 LeNet-5 神经网络结构,并定义前向传播的过程
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)  # 输入通道的数量;输出通道的数量(也就是卷积核的数量);卷积核的大小self.relu = nn.ReLU()self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.maxpool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool1(x)x = self.conv2(x)x = self.relu(x)x = self.maxpool2(x)x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)output = F.log_softmax(x, dim = 1)return output
训练过程
def train_runner(model, device, trainloader, optimizer):# 训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct = 0.0# enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data# 把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)# 保存训练结果outputs = model(inputs)# 计算损失和# 多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)# 初始化梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()# 获取最大概率的预测结果# dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim = 1)total += labels.size(0)correct += (predict == labels).sum().item()if i % 1000 == 0:# loss.item()表示当前loss的数值print("Train Loss: {:.4f}, Accuracy: {:.2f}%".format(loss.item(), 100 * (correct / total)))return loss.item(), 100 * (correct / total)
测试过程
def val_runner(model, device, testloader):# 模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值# 因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为Falsemodel.eval()# 统计模型正确率, 设置初始值correct = 0.0test_loss = 0.0total = 0best_acc = 0.0# torch.no_grad将不会计算梯度, 也不会进行反向传播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label).item()predict = output.argmax(dim = 1)# 计算正确数量total += label.size(0)correct += (predict == label).sum().item()# 计算损失值val_acc = correct / totalprint("Test loss: {:.4f}, Accuracy: {:.2f}%".format(test_loss / total, 100 * val_acc))# 保存最佳模型if val_acc > best_acc:best_acc = val_acctorch.save(model, './model-mnist_best.pth')  # 保存模型return test_loss / total, 100 * val_acc
使用方式

可直接在IDLE中运行代码,其中train.py文件用于训练网络,model.py文件用于定义网络,test.py文件用来对训练完的模型做一个测试推理。
也可直接调用命令行实现,如

python train.py --epochs 100 --lr 0.001 --batch_size 64

若不指定相关参数,train.py默认为训练10轮,学习率0.001,batch_size为64

说明

本项目的文件夹架构如下

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

代码中还使用了tensorboard可视化工具,以下是tensorboard可视化结果

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

最终在测试样本上,average_loss降到了 0.00129,accuracy 达到了 97.28%。可以说 LeNet-5 的效果非常好!

使用test.py进行测试推理时,由于MNIST数据集中的图片数据以ubyte格式存储,需要转成图片的格式,具体转换脚本参照mnist2jpg.py

# 获取图像数据和标签img, label = mnist_train[i]# 转换图像数据为numpy数组img_np = np.squeeze(img.numpy())# 展示图像plt.imshow(img_np, cmap = 'gray')plt.axis('off')  # 关闭坐标轴显示plt.savefig('{}/mnist_image_{}.jpg'.format(save_dir, label), bbox_inches = 'tight', pad_inches = 0)plt.close()

测试图片

在这里插入图片描述

文章代码资源点击附件获取

相关文章:

LeNet-5(论文复现)

LeNet-5(论文复现) 本文所涉及所有资源均在传知代码平台可获取 文章目录 LeNet-5(论文复现)概述LeNet-5网络架构介绍训练过程测试过程使用方式说明 概述 LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeN…...

基于SpringBoot+Vue+Uniapp汽车保养系统小程序的设计与实现

详细视频演示 请联系我获取更详细的演示视频 项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念,提供了一套默认的配置,让开发者可以更专注于业务逻辑而…...

【问题实战】Jmeter中jtl格式转换图片后如何分开展示各个性能指标?

【问题实战】Jmeter中jtl格式转换图片后如何分开展示各个性能指标? 遇到的问题解决方法查看修改效果 遇到的问题 JMeter测试计划中只设置了一个性能监控器jpgc - PerfMon Metrics Collector;在这个监控器中设置几个性能监控指标,比如CPU、Di…...

解决 MySQL 连接数过多导致的 SQLNonTransientConnectionException 问题

这里写目录标题 解决 MySQL 连接数过多导致的 SQLNonTransientConnectionException 问题1. 概述2. 问题描述异常日志的关键部分: 3. 原因分析3.1. MySQL 连接数配置3.2. 连接池配置问题3.3. 代码中未正确关闭连接3.4. 高并发导致连接需求激增 4. 解决方案4.1. 增加 …...

猫头虎分享:什么是 ChatGPT 4o Canvas?

猫头虎是谁? 大家好,我是 猫头虎,猫头虎技术团队创始人,也被大家称为猫哥。我目前是COC北京城市开发者社区主理人、COC西安城市开发者社区主理人,以及云原生开发者社区主理人,在多个技术领域如云原生、前端…...

qiankun 主项目和子项目都是 vue2,部署在同一台服务器上,nginx 配置

1、主项目配置 1.1 micro.vue 组件 <template><div id"container-sub-app"></div> </template><script> import { loadMicroApp } from qiankun; import actions from /utils/actions.js;export default {name: microApp,mixins: [ac…...

深入浅出MongoDB(七)

深入浅出MongoDB&#xff08;七&#xff09; 文章目录 深入浅出MongoDB&#xff08;七&#xff09;查询优化创建索引以支持读取操作查询选择性覆盖查询 分析性能使用数据库分析器评估对数据库的操作使用db.currentOp()评估mongod操作使用explain评估查询性能 优化查询性能创建索…...

【华为】配置NAT访问互联网

1.AR1&#xff1a; int g0/0/0 ip ad 64.1.1.2 255.255.255.0 int g0/0/1 ip ad 110.242.68.1 255.255.255.02.AR2: (1)配置端口ip: int g0/0/1 ip ad 10.3.1.2 255.255.255.0 int g0/0/0 ip ad 64.1.1.1 255.255.255.0(2)配置默认路由&#xff1a; ip route-static 0.0.0.0 0.…...

Spring Boot项目使用多线程执行定时任务

我在一个Spring Boot项目中&#xff0c;采用定时器执行一些操作&#xff0c;比如10秒就发送一次数据。这些操作有2个&#xff0c;如下所示。我就想&#xff0c;虽然这两个操作各自指定了时间频率&#xff0c;但如果其中一个操作非常耗时&#xff0c;会不会影响其他操作呢&#…...

【安装JDK和Android SDK】

安装JDK和Android SDK 1 前言2 下载2.1 下载途径2.2 JDK下载和安装2.2.1 下载2.2.2 安装并配置环境变量2.2.3 验证 2.3 SDK下载和安装2.3.1 下载2.3.2 安装2.3.3 环境变量配置2.3.4 验证 1 前言 在软件开发中&#xff0c;Android应用开发通常使用Android Studio&#xff0c;但…...

汇总10个AI免费一键生成PPT的网站

一、前言 PPT幻灯片是现代办公和学习中的重要组成部分。它在工作、研究或培训中扮演着重要角色&#xff0c;并能够让观众更好地理解信息。随着当今人工智能技术的快速发展&#xff0c;现在有很多免费的AI PPT生成器可供选择&#xff0c;帮助用户更加便捷地制作出高效且具有较强…...

超材料光子晶体和禁带分析实例_CST电磁仿真教程

光子晶体是由周期性排列的不同折射率的介质制造的光学结构&#xff0c;可被视为广义超材料metamaterial的一种。本期我们演示设计一个基于光频能带(PBG,photonics band gap) 的二维光子晶体波导&#xff0c;能带分析方法也可适用于微波波段&#xff08;EBG,electromagetic band…...

关于OceanBase数据库的poc测试连接经验(by liuhui)

poc客户给了OceanBase数据库实例如下 ob实例&#xff1a; ip:1xx.xx.xx 端口&#xff1a;2883 实例名&#xff1a;obm_xczjj_1_poc#cs_pool_1 用户名&#xff1a;root 密码&#xff1a;xxxxxx 问题出现&#xff1a;根据客户提供的OceanBase数据库配置报错。配置如下 查询数据…...

Docker部署如何修改本地mysql,redis连接信息

要修改数据库 MySQL 和缓存 Redis 的地址为 ruoyi-mysql 和 ruoyi-redis&#xff0c;通常需要在 Spring Boot 项目的配置文件中进行相应的修改。 ### 修改 MySQL 数据库地址为 ruoyi-mysql 1. **在 Spring Boot 项目中找到 application.properties 或 application.yml 文件**…...

PHP中的ReflectionClass常见用法

ReflectionClass是 PHP 中的一个类&#xff0c;它提供了有关类的信息的反射。 使用ReflectionClass可以在运行时获取关于类的各种信息&#xff0c;例如类的名称、方法、属性、注释等。 以下是一些常见的用法&#xff1a; 获取类的名称&#xff1a; $reflection new Reflec…...

processing像素画教程

前提&#xff1a;各位已经安装了processing 第一步&#xff1a;创建一个简单的网格 我们首先创建一个网格来定义我们作品的像素画布。网格将帮助您在适当的位置绘制每个像素。 int gridSize 20; // 每个像素的大小 int cols, rows; void setup() {size(400, 400); // 设置画…...

【秋招笔试】10.13字节跳动(已改编)秋招-三语言题解

🍭 大家好这里是 春秋招笔试突围,一起备战大厂笔试 💻 ACM金牌团队🏅️ | 多次AK大厂笔试 | 大厂实习经历 ✨ 本系列打算持续跟新 春秋招笔试题 👏 感谢大家的订阅➕ 和 喜欢💗 和 手里的小花花🌸 ✨ 笔试合集传送们 -> 🧷春秋招笔试合集 本次的三题全部上线…...

牛客网上最全的Java八股文整理,涵盖Java全栈技术点

Java 面试 “金九银十”这个字眼对于程序员应该是再熟悉不过的了&#xff0c;每年的金九银十都会有很多程序员找工作、跳槽等一系列的安排。说实话&#xff0c;面试中 7 分靠能力&#xff0c;3 分靠技能&#xff1b;在刚开始的时候介绍项目都是技能中的重中之重&#xff0c;它…...

Skyeye 云智能制造 v3.14.9 发布,ERP 商城 + AI

Skyeye 云智能制造&#xff0c;采用 Springboot winUI 的低代码平台、移动端采用 UNI-APP。包含 30 多个应用模块、50 多种电子流程&#xff0c;CRM、PM、ERP、MES、ADM、EHR、笔记、知识库、项目、门店、商城、财务、多班次考勤、薪资、招聘、云售后、论坛、公告、问卷、报表…...

Element-快速入门

什么是 Element 在现代前端开发中&#xff0c;组件化的思想日益盛行&#xff0c;Element组件库作为一款流行的UI组件库&#xff0c;特别适用于基于Vue.js的项目&#xff0c;它为开发者提供了丰富的组件和良好的开发体验。 想要使用Element的组件库&#xff0c;我们需要完成下面…...

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

面试高频问题

文章目录 &#x1f680; 消息队列核心技术揭秘&#xff1a;从入门到秒杀面试官1️⃣ Kafka为何能"吞云吐雾"&#xff1f;性能背后的秘密1.1 顺序写入与零拷贝&#xff1a;性能的双引擎1.2 分区并行&#xff1a;数据的"八车道高速公路"1.3 页缓存与批量处理…...

动态规划-1035.不相交的线-力扣(LeetCode)

一、题目解析 光看题目要求和例图&#xff0c;感觉这题好麻烦&#xff0c;直线不能相交啊&#xff0c;每个数字只属于一条连线啊等等&#xff0c;但我们结合题目所给的信息和例图的内容&#xff0c;这不就是最长公共子序列吗&#xff1f;&#xff0c;我们把最长公共子序列连线起…...

LUA+Reids实现库存秒杀预扣减 记录流水 以及自己的思考

目录 lua脚本 记录流水 记录流水的作用 流水什么时候删除 我们在做库存扣减的时候&#xff0c;显示基于Lua脚本和Redis实现的预扣减 这样可以在秒杀扣减的时候保证操作的原子性和高效性 lua脚本 // ... 已有代码 ...Overridepublic InventoryResponse decrease(Inventor…...

FTXUI::Dom 模块

DOM 模块定义了分层的 FTXUI::Element 树&#xff0c;可用于构建复杂的终端界面&#xff0c;支持响应终端尺寸变化。 namespace ftxui {...// 定义文档 定义布局盒子 Element document vbox({// 设置文本 设置加粗 设置文本颜色text("The window") | bold | color(…...