LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记

手写分析LeNet网三卷积运算和两池化加两全连接层计算分析

修正上图中H,W的计算公式为下面格式

基于paddle飞桨框架构建测试代码
#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
详细源代码如下所示:
# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F# 定义 LeNet 网络结构
#==============================================================================
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet, self).__init__()# 创建卷积和池化层# 创建第1个卷积层self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 尺寸的逻辑:池化层未改变通道数;当前通道数为6# 创建第2个卷积层self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 创建第3个卷积层self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
#==========================================================================================
# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
print(x.shape)
x = paddle.to_tensor(x)
print(x.shape)
for item in model.sublayers():# item是LeNet类中的一个子层# 查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)
#
'''
#显示子图层列表model.sublayers()
[Conv2D(1, 6, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(6, 16, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(16, 120, kernel_size=[4, 4], data_format=NCHW), Linear(in_features=120, out_features=64, dtype=float32), Linear(in_features=64, out_features=10, dtype=float32)
]
''' # -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST# 定义训练过程
def train(model, opt, train_loader, valid_loader):# 开启0号GPU训练use_gpu = Truepaddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')print('start training ... ')model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist_LeNet.pdparams')# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。
相关文章:
LeNet卷积神经网络-笔记
LeNet卷积神经网络-笔记 手写分析LeNet网三卷积运算和两池化加两全连接层计算分析 修正上图中H,W的计算公式为下面格式 基于paddle飞桨框架构建测试代码 #输出结果为: #[validation] accuracy/loss: 0.9530/0.1516 #这里准确率为95.3% #通过运行结果可以看出&am…...
使用XMLHttpRequest实现文件异步下载
1、问题描述 我想通过异步的方式实现下载文化,请求为post请求。一开始我打算用ajax。 $.ajax({type:post,contentType:application/json,url:http://xxx/downloadExcel,data:{data:JSON.stringify(<%oJsonResponse.JSONoutput()%>)},}).success(function(dat…...
Lombok 的安装与使用
文章目录 一、什么是 Lombok1.1 Lombok 的概念1.2 为什么使用 Lombok1.3 Lombok 的相关注解 二、Lombok 的安装2.1 引入依赖2.2 安装插件 三、Lombok 的使用案例四、Lombok 的原理 一、什么是 Lombok 1.1 Lombok 的概念 Lombok(“Project Lombok”)是一…...
springBean生命周期解析
本文基于Spring5.3.7 参考: kykangyuky Spring中bean的生命周期 阿斌Java之路 SpringBean的生命周期, 杨开振 JavaEE互联网轻量级框架整合开发 黑马程序员 JavaEE企业级应用开发教程 马士兵 Spring源码讲解 一. SpringBean生命周期流程图 二. 示例代码 …...
人工智能轨道交通行业周刊-第54期(2023.7.31-8.6)
本期关键词:BIM智能运维、铁水联运、编组站美容、鸿蒙4.0、LK-99完全悬浮 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨道世界…...
Docker Compose 使用方法
目录 前言 安装 Docker Compose Ubuntu 安装与更新 Red Hat 安装与更新 验证是否安装 Docker Compose 创建 docker-compose.yml 文件 创建一个MySQL 与 tomcat 示例 使用Docker Compose启动服务 前言 Docker Compose 是一个工具,旨在帮助定义和 共享多容器…...
HTML 初
前言 HTML的基本骨架 HTML基本骨架是构建网页的最基本的结果。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">…...
IPv6地址分类,EUI-64转换规则
1、可聚合的单全球单播地址Global Unique Address: Aggregate global unicast address,前3位是001,即2000::/3,目前IANA已经将一部分可聚合全球单播进行了专门使用,如:2001::/16用于IPV6互联网,…...
Nginx安装部署
什么是Nginx? Nginx(发音同engine x)是一款由俄罗斯程序员Igor Sysoev所开发轻量级的网页服务器、反向代 理服务器以及电子邮件(IMAP/POP3)代理服务器。 Nginx 因具有高并发(特別是静态资源)、 占用系统资…...
物联网|按键实验---学习I/O的输入及中断的编程|读取I/O的输入信号|中断的编程方法|轮询实现按键捕获实验-学习笔记(13)
文章目录 实验目的了解擒键的工作原理及电原理图 STM32F407中如何读取I/O的输入信号STM32F407对中断的编程方法通过轮询实现按键捕获实验如何利用已有内工程创建新工程通过轮询实现按键捕获代码实现及分析1 代码的流程分析2 代码的实现 Tips:下载错误的解决 实验目的 了解擒键…...
Hadoop-HDFS的Namenode及Datanode(参考Hadoop官网)
HDFS有什么特点,被设计做什么 Hadoop分布式文件系统(HDFS)被设计成适合运行在通用硬件(commodity hardware)上的分布式文件系统。有一下几个特点: HDFS是一个高度容错性的系统,具有高容错、高可靠性、高扩展性的特点,适合部…...
C:通过alarm发送信号
可以通过alarm定时发送SIGALRM信号: #include <unistd.h> unsigned int alarm(unsigned int seconds); alarm()函数用来在seconds秒之后安排发送一个SIGALRM信号,如果seconds为0,将取消所有已设置的闹钟请求。alarm()函数的返回值是以前…...
如何将 dubbo filter 拦截器原理运用到日志拦截器中?
业务背景 我们希望可以在使用日志拦截器时,定义属于自己的拦截器方法。 实现的方式有很多种,我们分别来看一下。 拓展阅读 java 注解结合 spring aop 实现自动输出日志 java 注解结合 spring aop 实现日志traceId唯一标识 java 注解结合 spring ao…...
【java】【maven】【基础】MAVEN安装配置介绍
目录 1 下载 2 安装-windows为例 3 配置环境变量 3.1 JAVA_HOME 3.2 MAVEN_HOME 3.3 PATH 3.4 验证 4 MAVEN基础概念 4.1 仓库概念 4.2 坐标概念 4.2.1 打开网址 4.2.2 输入搜索内容junit 4.2.3 找到对应API名称点击 4.2.4 点击对应版本 4.2.5 复制MAVEN坐标 4.3 配置…...
【C语言进阶】指针的高级应用(下)
文章目录 一、指针数组与数组指针1.1 指针数组与数组指针的表达式 二、函数指针2.1 函数指针的书写方式 三、二重指针与一重指针3.1 二重指针的本质3.2 二重指针的用法3.3 二重指针与数组指针 总结 一、指针数组与数组指针 (1)指针数组的实质是一个数组,这个数组中存…...
【uniapp APP隐藏顶部的电量,无线,时间状态栏与导航栏】
uniapp APP隐藏顶部的电量,无线,时间状态栏 如下代码配置(在一个页面设置这个段代码,所有页面都会消失) onShow() {// #ifdef APP-PLUS// 隐藏顶部电池,时间等信息 plus.navigator.setFullscreen(true);//隐藏虚拟按…...
微信小程序前后页面传值
微信小程序前后页面传值 从前一个页面跳转到下一个页面,如何传递参数?从后一个页面返回前一个页面,如何回调参数? 向后传值 从前一个页面跳转到下一个页面并传值。 前页面:在跳转链接中添加参数并传递 wx.navigat…...
没有jodatime,rust里怎么比较两个日期(时间)的大小呢?
关注我,学习Rust不迷路!! 在 Rust 中,比较两个日期的大小有多种方法。以下是列举的四种常见方法: 1. 使用 PartialOrd trait: use chrono::prelude::*;fn main() {let date1 NaiveDate::from_ymd(2022,…...
【雕爷学编程】Arduino动手做(186)---WeMos ESP32开发板18
37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…...
C语言假期作业 DAY 14
一、选择题 1、有以下函数,该函数的功能是( ) int fun(char *s) {char *t s;while(*t);return(t-s); } A: 比较两个字符的大小 B: 计算s所指字符串占用内存字节的个数 C: 计算s所指字符串的长度 D: 将s所指字符串复制到字符串t中 答案解析 …...
java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别
UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
GC1808高性能24位立体声音频ADC芯片解析
1. 芯片概述 GC1808是一款24位立体声音频模数转换器(ADC),支持8kHz~96kHz采样率,集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器,适用于高保真音频采集场景。 2. 核心特性 高精度:24位分辨率,…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...
《Docker》架构
文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器,docker,镜像,k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...
运行vue项目报错 errors and 0 warnings potentially fixable with the `--fix` option.
报错 找到package.json文件 找到这个修改成 "lint": "eslint --fix --ext .js,.vue src" 为elsint有配置结尾换行符,最后运行:npm run lint --fix...
