LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记

手写分析LeNet网三卷积运算和两池化加两全连接层计算分析

修正上图中H,W的计算公式为下面格式

基于paddle飞桨框架构建测试代码
#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
详细源代码如下所示:
# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F# 定义 LeNet 网络结构
#==============================================================================
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet, self).__init__()# 创建卷积和池化层# 创建第1个卷积层self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 尺寸的逻辑:池化层未改变通道数;当前通道数为6# 创建第2个卷积层self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 创建第3个卷积层self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
#==========================================================================================
# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
print(x.shape)
x = paddle.to_tensor(x)
print(x.shape)
for item in model.sublayers():# item是LeNet类中的一个子层# 查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)
#
'''
#显示子图层列表model.sublayers()
[Conv2D(1, 6, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(6, 16, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(16, 120, kernel_size=[4, 4], data_format=NCHW), Linear(in_features=120, out_features=64, dtype=float32), Linear(in_features=64, out_features=10, dtype=float32)
]
''' # -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST# 定义训练过程
def train(model, opt, train_loader, valid_loader):# 开启0号GPU训练use_gpu = Truepaddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')print('start training ... ')model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist_LeNet.pdparams')# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
相关文章:
LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记 手写分析LeNet网三卷积运算和两池化加两全连接层计算分析 修正上图中H,W的计算公式为下面格式 基于paddle飞桨框架构建测试代码 #输出结果为: #[validation] accuracy/loss: 0.9530/0.1516 #这里准确率为95.3% #通过运行结果可以看出&am…...
使用XMLHttpRequest实现文件异步下载
1、问题描述 我想通过异步的方式实现下载文化,请求为post请求。一开始我打算用ajax。 $.ajax({type:post,contentType:application/json,url:http://xxx/downloadExcel,data:{data:JSON.stringify(<%oJsonResponse.JSONoutput()%>)},}).success(function(dat…...
Lombok 的安装与使用
文章目录 一、什么是 Lombok1.1 Lombok 的概念1.2 为什么使用 Lombok1.3 Lombok 的相关注解 二、Lombok 的安装2.1 引入依赖2.2 安装插件 三、Lombok 的使用案例四、Lombok 的原理 一、什么是 Lombok 1.1 Lombok 的概念 Lombok(“Project Lombok”)是一…...
springBean生命周期解析
本文基于Spring5.3.7 参考: kykangyuky Spring中bean的生命周期 阿斌Java之路 SpringBean的生命周期, 杨开振 JavaEE互联网轻量级框架整合开发 黑马程序员 JavaEE企业级应用开发教程 马士兵 Spring源码讲解 一. SpringBean生命周期流程图 二. 示例代码 …...
人工智能轨道交通行业周刊-第54期(2023.7.31-8.6)
本期关键词:BIM智能运维、铁水联运、编组站美容、鸿蒙4.0、LK-99完全悬浮 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨道世界…...
Docker Compose 使用方法
目录 前言 安装 Docker Compose Ubuntu 安装与更新 Red Hat 安装与更新 验证是否安装 Docker Compose 创建 docker-compose.yml 文件 创建一个MySQL 与 tomcat 示例 使用Docker Compose启动服务 前言 Docker Compose 是一个工具,旨在帮助定义和 共享多容器…...
HTML 初
前言 HTML的基本骨架 HTML基本骨架是构建网页的最基本的结果。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">…...
IPv6地址分类,EUI-64转换规则
1、可聚合的单全球单播地址Global Unique Address: Aggregate global unicast address,前3位是001,即2000::/3,目前IANA已经将一部分可聚合全球单播进行了专门使用,如:2001::/16用于IPV6互联网,…...
Nginx安装部署
什么是Nginx? Nginx(发音同engine x)是一款由俄罗斯程序员Igor Sysoev所开发轻量级的网页服务器、反向代 理服务器以及电子邮件(IMAP/POP3)代理服务器。 Nginx 因具有高并发(特別是静态资源)、 占用系统资…...
物联网|按键实验---学习I/O的输入及中断的编程|读取I/O的输入信号|中断的编程方法|轮询实现按键捕获实验-学习笔记(13)
文章目录 实验目的了解擒键的工作原理及电原理图 STM32F407中如何读取I/O的输入信号STM32F407对中断的编程方法通过轮询实现按键捕获实验如何利用已有内工程创建新工程通过轮询实现按键捕获代码实现及分析1 代码的流程分析2 代码的实现 Tips:下载错误的解决 实验目的 了解擒键…...
Hadoop-HDFS的Namenode及Datanode(参考Hadoop官网)
HDFS有什么特点,被设计做什么 Hadoop分布式文件系统(HDFS)被设计成适合运行在通用硬件(commodity hardware)上的分布式文件系统。有一下几个特点: HDFS是一个高度容错性的系统,具有高容错、高可靠性、高扩展性的特点,适合部…...
C:通过alarm发送信号
可以通过alarm定时发送SIGALRM信号: #include <unistd.h> unsigned int alarm(unsigned int seconds); alarm()函数用来在seconds秒之后安排发送一个SIGALRM信号,如果seconds为0,将取消所有已设置的闹钟请求。alarm()函数的返回值是以前…...
如何将 dubbo filter 拦截器原理运用到日志拦截器中?
业务背景 我们希望可以在使用日志拦截器时,定义属于自己的拦截器方法。 实现的方式有很多种,我们分别来看一下。 拓展阅读 java 注解结合 spring aop 实现自动输出日志 java 注解结合 spring aop 实现日志traceId唯一标识 java 注解结合 spring ao…...
【java】【maven】【基础】MAVEN安装配置介绍
目录 1 下载 2 安装-windows为例 3 配置环境变量 3.1 JAVA_HOME 3.2 MAVEN_HOME 3.3 PATH 3.4 验证 4 MAVEN基础概念 4.1 仓库概念 4.2 坐标概念 4.2.1 打开网址 4.2.2 输入搜索内容junit 4.2.3 找到对应API名称点击 4.2.4 点击对应版本 4.2.5 复制MAVEN坐标 4.3 配置…...
【C语言进阶】指针的高级应用(下)
文章目录 一、指针数组与数组指针1.1 指针数组与数组指针的表达式 二、函数指针2.1 函数指针的书写方式 三、二重指针与一重指针3.1 二重指针的本质3.2 二重指针的用法3.3 二重指针与数组指针 总结 一、指针数组与数组指针 (1)指针数组的实质是一个数组,这个数组中存…...
【uniapp APP隐藏顶部的电量,无线,时间状态栏与导航栏】
uniapp APP隐藏顶部的电量,无线,时间状态栏 如下代码配置(在一个页面设置这个段代码,所有页面都会消失) onShow() {// #ifdef APP-PLUS// 隐藏顶部电池,时间等信息 plus.navigator.setFullscreen(true);//隐藏虚拟按…...
微信小程序前后页面传值
微信小程序前后页面传值 从前一个页面跳转到下一个页面,如何传递参数?从后一个页面返回前一个页面,如何回调参数? 向后传值 从前一个页面跳转到下一个页面并传值。 前页面:在跳转链接中添加参数并传递 wx.navigat…...
没有jodatime,rust里怎么比较两个日期(时间)的大小呢?
关注我,学习Rust不迷路!! 在 Rust 中,比较两个日期的大小有多种方法。以下是列举的四种常见方法: 1. 使用 PartialOrd trait: use chrono::prelude::*;fn main() {let date1 NaiveDate::from_ymd(2022,…...
【雕爷学编程】Arduino动手做(186)---WeMos ESP32开发板18
37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…...
C语言假期作业 DAY 14
一、选择题 1、有以下函数,该函数的功能是( ) int fun(char *s) {char *t s;while(*t);return(t-s); } A: 比较两个字符的大小 B: 计算s所指字符串占用内存字节的个数 C: 计算s所指字符串的长度 D: 将s所指字符串复制到字符串t中 答案解析 …...
项目介绍 MATLAB实现基于概率路图法(PRM)进行无人机三维路径规划的详细项目实例(含模型描述及部分示例代码) 专栏近期有大量优惠 还请多多点一下关注 加油 谢谢 你的鼓励是我前行的动力 谢谢支持
MATLAB实现基于概率路图法(PRM)进行无人机三维路径规划的详细项目实例 更多详细内容可直接联系博主本人 或者访问对应标题的完整博客或者文档下载页面(含完整的程序,GUI设计和代码详解) 随着无人机技术的快速发展&…...
G-Helper完整指南:三步掌握华硕笔记本性能优化神器
G-Helper完整指南:三步掌握华硕笔记本性能优化神器 【免费下载链接】g-helper Lightweight, open-source control tool for ASUS laptops and ROG Ally. Manage performance modes, fans, GPU, battery, and RGB lighting across Zephyrus, Flow, TUF, Strix, Scar,…...
效率倍增:借助快马ai智能生成与管理系统化java面试题库
作为一名经常需要准备Java面试的开发者,我深刻体会到传统刷题方式的低效——手动收集题目、整理答案、标注重点不仅耗时,还容易遗漏关键知识点。最近尝试用InsCode(快马)平台的AI功能搭建了一个智能题库工具,效率提升超乎想象。以下是具体实现…...
Libsvm 编译mex不同平台兼容性问题 Application not supported on glnxa64 due to platform dependencies. Intended pl
matlab线上算法执行报错:Application not supported on glnxa64 due to platform dependencies. Intended platforms include: win64 排查后发现是使用了libsvm-3.3, 而libsvm编译的时候是基于win64编译的导致出现此bug.(因为libsvm的开源代码不是matlab࿰…...
告别while循环轮询!用STM32 HAL库定时器中断实现按键扫描(附状态机源码)
STM32高效按键处理实战:定时器中断与状态机的完美结合 在嵌入式开发中,按键处理看似简单却暗藏玄机。传统while循环轮询方式不仅占用CPU资源,还容易导致代码结构混乱。本文将带你用STM32 HAL库的定时器中断和状态机,实现一套高效、…...
[火]图像数据增强 支持目标检测数据集图像增强 标注框信息同步增强 支持以下图像增强方式HSV-Hue 增强HSV-Saturation 增强 HSV-Value 增强图像旋转 (+/
[火]图像数据增强 支持目标检测数据集图像增强 标注框信息同步增强 支持以下图像增强方式 HSV-Hue 增强 HSV-Saturation 增强 HSV-Value 增强 图像旋转 (/- degrees) 图像平移 (/- 分数) 图像缩放 (/- 增益) 图像错切 (/- 分数) 图像透视 (/- 分数), 范围:0-0.00…...
CN3881-规格书 如韵电子 10A 降压型同步单节锂电池充电管理集成电路
概述: CN3881 是一款可使用太阳能供电的 PWM 降压模式单节锂电池充电管理集成电路,可独立对单 节锂电池充电进行管理,具有封装外形小,外围元器件少和使用简单等优点。 CN3881 采用涓流,恒流和恒压充电模式,非常适合单节…...
利用快马平台自动化生成contextmenumanager提升前端开发效率
最近在开发一个后台管理系统时,遇到了一个很常见的需求:需要为表格、图表等元素添加右键菜单功能。这种需求看似简单,但实际开发中却要花费不少时间在重复的配置工作上。经过一番摸索,我发现利用InsCode(快马)平台可以大幅提升这类…...
华为OD机考双机位C卷 - 数字游戏 (Java)
# 数字游戏 2026华为OD机试双机位C卷 - 华为OD上机考试双机位C卷 华为OD机试双机位C卷真题目录(Java)点击查看: 【全网首发】2026华为OD机位C卷 机考真题题库含考点说明以及在线OJ(Java题解) 题目描述 小明玩一个游戏。 系统发1+n张牌,每张牌上有一个整数。 第一张给…...
AI写专著超实用攻略:精选工具推荐,提升写作效率与质量
第一次尝试写学术专著的挑战与AI写作工具介绍 对于第一次尝试写学术专著的研究者来说,写作的过程就像是一场充满挑战的冒险之旅,伴随着许多不确定的困难。在选题方面常常陷入困扰,难以在“具有价值”和“可行性”之间找到合适的平衡。有时选…...
