LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记

手写分析LeNet网三卷积运算和两池化加两全连接层计算分析

修正上图中H,W的计算公式为下面格式

基于paddle飞桨框架构建测试代码
#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
详细源代码如下所示:
# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F# 定义 LeNet 网络结构
#==============================================================================
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet, self).__init__()# 创建卷积和池化层# 创建第1个卷积层self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 尺寸的逻辑:池化层未改变通道数;当前通道数为6# 创建第2个卷积层self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 创建第3个卷积层self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
#==========================================================================================
# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
print(x.shape)
x = paddle.to_tensor(x)
print(x.shape)
for item in model.sublayers():# item是LeNet类中的一个子层# 查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)
#
'''
#显示子图层列表model.sublayers()
[Conv2D(1, 6, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(6, 16, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(16, 120, kernel_size=[4, 4], data_format=NCHW), Linear(in_features=120, out_features=64, dtype=float32), Linear(in_features=64, out_features=10, dtype=float32)
]
''' # -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST# 定义训练过程
def train(model, opt, train_loader, valid_loader):# 开启0号GPU训练use_gpu = Truepaddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')print('start training ... ')model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist_LeNet.pdparams')# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
相关文章:
LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记 手写分析LeNet网三卷积运算和两池化加两全连接层计算分析 修正上图中H,W的计算公式为下面格式 基于paddle飞桨框架构建测试代码 #输出结果为: #[validation] accuracy/loss: 0.9530/0.1516 #这里准确率为95.3% #通过运行结果可以看出&am…...
使用XMLHttpRequest实现文件异步下载
1、问题描述 我想通过异步的方式实现下载文化,请求为post请求。一开始我打算用ajax。 $.ajax({type:post,contentType:application/json,url:http://xxx/downloadExcel,data:{data:JSON.stringify(<%oJsonResponse.JSONoutput()%>)},}).success(function(dat…...
Lombok 的安装与使用
文章目录 一、什么是 Lombok1.1 Lombok 的概念1.2 为什么使用 Lombok1.3 Lombok 的相关注解 二、Lombok 的安装2.1 引入依赖2.2 安装插件 三、Lombok 的使用案例四、Lombok 的原理 一、什么是 Lombok 1.1 Lombok 的概念 Lombok(“Project Lombok”)是一…...
springBean生命周期解析
本文基于Spring5.3.7 参考: kykangyuky Spring中bean的生命周期 阿斌Java之路 SpringBean的生命周期, 杨开振 JavaEE互联网轻量级框架整合开发 黑马程序员 JavaEE企业级应用开发教程 马士兵 Spring源码讲解 一. SpringBean生命周期流程图 二. 示例代码 …...
人工智能轨道交通行业周刊-第54期(2023.7.31-8.6)
本期关键词:BIM智能运维、铁水联运、编组站美容、鸿蒙4.0、LK-99完全悬浮 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨道世界…...
Docker Compose 使用方法
目录 前言 安装 Docker Compose Ubuntu 安装与更新 Red Hat 安装与更新 验证是否安装 Docker Compose 创建 docker-compose.yml 文件 创建一个MySQL 与 tomcat 示例 使用Docker Compose启动服务 前言 Docker Compose 是一个工具,旨在帮助定义和 共享多容器…...
HTML 初
前言 HTML的基本骨架 HTML基本骨架是构建网页的最基本的结果。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">…...
IPv6地址分类,EUI-64转换规则
1、可聚合的单全球单播地址Global Unique Address: Aggregate global unicast address,前3位是001,即2000::/3,目前IANA已经将一部分可聚合全球单播进行了专门使用,如:2001::/16用于IPV6互联网,…...
Nginx安装部署
什么是Nginx? Nginx(发音同engine x)是一款由俄罗斯程序员Igor Sysoev所开发轻量级的网页服务器、反向代 理服务器以及电子邮件(IMAP/POP3)代理服务器。 Nginx 因具有高并发(特別是静态资源)、 占用系统资…...
物联网|按键实验---学习I/O的输入及中断的编程|读取I/O的输入信号|中断的编程方法|轮询实现按键捕获实验-学习笔记(13)
文章目录 实验目的了解擒键的工作原理及电原理图 STM32F407中如何读取I/O的输入信号STM32F407对中断的编程方法通过轮询实现按键捕获实验如何利用已有内工程创建新工程通过轮询实现按键捕获代码实现及分析1 代码的流程分析2 代码的实现 Tips:下载错误的解决 实验目的 了解擒键…...
Hadoop-HDFS的Namenode及Datanode(参考Hadoop官网)
HDFS有什么特点,被设计做什么 Hadoop分布式文件系统(HDFS)被设计成适合运行在通用硬件(commodity hardware)上的分布式文件系统。有一下几个特点: HDFS是一个高度容错性的系统,具有高容错、高可靠性、高扩展性的特点,适合部…...
C:通过alarm发送信号
可以通过alarm定时发送SIGALRM信号: #include <unistd.h> unsigned int alarm(unsigned int seconds); alarm()函数用来在seconds秒之后安排发送一个SIGALRM信号,如果seconds为0,将取消所有已设置的闹钟请求。alarm()函数的返回值是以前…...
如何将 dubbo filter 拦截器原理运用到日志拦截器中?
业务背景 我们希望可以在使用日志拦截器时,定义属于自己的拦截器方法。 实现的方式有很多种,我们分别来看一下。 拓展阅读 java 注解结合 spring aop 实现自动输出日志 java 注解结合 spring aop 实现日志traceId唯一标识 java 注解结合 spring ao…...
【java】【maven】【基础】MAVEN安装配置介绍
目录 1 下载 2 安装-windows为例 3 配置环境变量 3.1 JAVA_HOME 3.2 MAVEN_HOME 3.3 PATH 3.4 验证 4 MAVEN基础概念 4.1 仓库概念 4.2 坐标概念 4.2.1 打开网址 4.2.2 输入搜索内容junit 4.2.3 找到对应API名称点击 4.2.4 点击对应版本 4.2.5 复制MAVEN坐标 4.3 配置…...
【C语言进阶】指针的高级应用(下)
文章目录 一、指针数组与数组指针1.1 指针数组与数组指针的表达式 二、函数指针2.1 函数指针的书写方式 三、二重指针与一重指针3.1 二重指针的本质3.2 二重指针的用法3.3 二重指针与数组指针 总结 一、指针数组与数组指针 (1)指针数组的实质是一个数组,这个数组中存…...
【uniapp APP隐藏顶部的电量,无线,时间状态栏与导航栏】
uniapp APP隐藏顶部的电量,无线,时间状态栏 如下代码配置(在一个页面设置这个段代码,所有页面都会消失) onShow() {// #ifdef APP-PLUS// 隐藏顶部电池,时间等信息 plus.navigator.setFullscreen(true);//隐藏虚拟按…...
微信小程序前后页面传值
微信小程序前后页面传值 从前一个页面跳转到下一个页面,如何传递参数?从后一个页面返回前一个页面,如何回调参数? 向后传值 从前一个页面跳转到下一个页面并传值。 前页面:在跳转链接中添加参数并传递 wx.navigat…...
没有jodatime,rust里怎么比较两个日期(时间)的大小呢?
关注我,学习Rust不迷路!! 在 Rust 中,比较两个日期的大小有多种方法。以下是列举的四种常见方法: 1. 使用 PartialOrd trait: use chrono::prelude::*;fn main() {let date1 NaiveDate::from_ymd(2022,…...
【雕爷学编程】Arduino动手做(186)---WeMos ESP32开发板18
37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…...
C语言假期作业 DAY 14
一、选择题 1、有以下函数,该函数的功能是( ) int fun(char *s) {char *t s;while(*t);return(t-s); } A: 比较两个字符的大小 B: 计算s所指字符串占用内存字节的个数 C: 计算s所指字符串的长度 D: 将s所指字符串复制到字符串t中 答案解析 …...
Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...
STM32+rt-thread判断是否联网
一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...
相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...
《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...
Device Mapper 机制
Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
SQL慢可能是触发了ring buffer
简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...
快刀集(1): 一刀斩断视频片头广告
一刀流:用一个简单脚本,秒杀视频片头广告,还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农,平时写代码之余看看电影、补补片,是再正常不过的事。 电影嘛,要沉浸,…...
elementUI点击浏览table所选行数据查看文档
项目场景: table按照要求特定的数据变成按钮可以点击 解决方案: <el-table-columnprop"mlname"label"名称"align"center"width"180"><template slot-scope"scope"><el-buttonv-if&qu…...
