当前位置: 首页 > news >正文

CNN-day5-经典神经网络LeNets5

经典神经网络-LeNets5

1998年Yann LeCun等提出的第一个用于手写数字识别问题并产生实际商业(邮政行业)价值的卷积神经网络

参考:论文笔记:Gradient-Based Learning Applied to Document Recognition-CSDN博客

1 网络模型结构

整体结构解读:

输入图像:32×32×1

三个卷积层:

C1:输入图片32×32,6个5×5卷积核 ,输出特征图大小28×28(32-5+1)=28,一个bias参数;

可训练参数一共有:(5×5+1)×6=156

C3 :输入图片14×14,16个5×5卷积核,有6×3+6×4+3×4+1×6=60个通道,输出特征图大小10×10((14-5)/1+1),一个bias参数;

可训练参数一共有:6(3×5×5+1)+6×(4×5×5+1)+3×(4×5×5+1)+1×(6×5×5+1)=1516

C3的非密集的特征图连接:

C3的前6个特征图与S2层相连的3个特征图相连接,后面6个特征图与S2层相连的4个特征图相连 接,后面3个特征图与S2层部分不相连的4个特征图相连接,最后一个与S2层的所有特征图相连。 采用非密集连接的方式,打破对称性,同时减少计算量,共60组卷积核。主要是为了节省算力。

C5:输入图片5×5,16个5×5卷积核,包括120×16个5×5卷积核 ,输出特征图大小1×1(5-5+1),一个bias参数;

可训练参数一共有:120×(16×5×5+1)=48120

两个池化层S2和S4:

都是2×2的平均池化,并添加了非线性映射

S2(下采样层):输入28×28,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:14×14(28/2)

S4(下采样层):输入10×10,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:5×5(10/2)

两个全连接层:

第一个全连接层:输入120维向量,输出84个神经元,计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过sigmoid函数输出。84的原因是:字符编码是ASCII编码,用7×12大小的位图表示,-1白色1黑色,84可以用于对每一个像素点的值进行估计。

第二个全连接层(Output层-输出层):输出 10个神经元 ,共有10个节点,代表数字0-9。

所有激活函数采用Sigmoid

2 网络模型实现

2.1模型定义

import torch
import torch.nn as nn
​
​
class LeNet5s(nn.Module):def __init__(self):super(LeNet5s, self).__init__()  # 继承父类# 第一个卷积层self.C1 = nn.Sequential(nn.Conv2d(in_channels=1,  # 输入通道out_channels=6,  # 输出通道kernel_size=5,  # 卷积核大小),nn.ReLU(),)# 池化:平均池化self.S2 = nn.AvgPool2d(kernel_size=2)
​# C3:3通道特征融合单元self.C3_unit_6x3 = nn.Conv2d(in_channels=3,out_channels=1,kernel_size=5,)# C3:4通道特征融合单元self.C3_unit_6x4 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:4通道特征融合单元,剔除中间的1通道self.C3_unit_3x4_pop1 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:6通道特征融合单元self.C3_unit_1x6 = nn.Conv2d(in_channels=6,out_channels=1,kernel_size=5,)
​# S4:池化self.S4 = nn.AvgPool2d(kernel_size=2)# 全连接层self.fc1 = nn.Sequential(nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=120, out_features=84), nn.ReLU())self.fc3 = nn.Linear(in_features=84, out_features=10)
​def forward(self, x):# 训练数据批次大小batch_sizenum = x.shape[0]
​x = self.C1(x)x = self.S2(x)# 生成一个empty张量outchannel = torch.empty((num, 0, 10, 10))# 6个3通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 3)])x1 = self.C3_unit_6x3(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 6个4通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 4)])x1 = self.C3_unit_6x4(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 3个4通道的单元,先拿五个,干掉中那一个for i in range(3):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 5)])# 删除第三个元素channel_idx = channel_idx[:2] + channel_idx[3:]print(channel_idx)x1 = self.C3_unit_3x4_pop1(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​x1 = self.C3_unit_1x6(x)# 平均池化outchannel = torch.cat([outchannel, x1], dim=1)outchannel = nn.ReLU()(outchannel)
​x = self.S4(outchannel)# 对数据进行变形x = x.view(x.size(0), -1)# 全连接层x = self.fc1(x)x = self.fc2(x)# TODO:SOFTMAXoutput = self.fc3(x)
​return output
​
​
def test001():net = LeNet5s()# 随机一个测试数据input = torch.randn(128, 1, 32, 32)output = net(input)print(output.shape)pass
​
​
if __name__ == "__main__":test001()

2.2全局变量

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
​
dir = os.path.dirname(__file__)
modelpath = os.path.join(dir, "weight/model.pth")
datapath = os.path.join(dir, "data")
​
# 数据预处理和加载
transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整输入图像大小为32x32transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),]
)
​

2.3模型训练

def train():
​trainset = torchvision.datasets.MNIST(root=datapath, train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
​# 实例化模型net = LeNet5()
​# 使用MSELoss作为损失函数criterion = nn.MSELoss()
​# 使用SGD优化器optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
​# 训练模型num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data
​# 将labels转换为one-hot编码labels_one_hot = torch.zeros(labels.size(0), 10).scatter_(1, labels.view(-1, 1), 1.0)labels_one_hot = labels_one_hot.to(torch.float32)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels_one_hot)loss.backward()optimizer.step()
​running_loss += loss.item()if i % 100 == 99:print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")running_loss = 0.0# 保存模型参数torch.save(net.state_dict(), modelpath)print("Finished Training")

2.4验证

def vaild():
​testset = torchvision.datasets.MNIST(root=datapath, train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)# 实例化模型net = LeNet5()net.load_state_dict(torch.load(modelpath))# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
​print(f"验证集: {100 * correct / total:.2f}%")

相关文章:

CNN-day5-经典神经网络LeNets5

经典神经网络-LeNets5 1998年Yann LeCun等提出的第一个用于手写数字识别问题并产生实际商业(邮政行业)价值的卷积神经网络 参考:论文笔记:Gradient-Based Learning Applied to Document Recognition-CSDN博客 1 网络模型结构 …...

登录到docker里

在Docker中登录到容器通常有两种情况: 登录到正在运行的容器内部:如果你想要进入到正在运行的容器内部,可以使用docker exec命令。 登录到容器中并启动一个shell:如果你想要启动一个容器,并在其中启动一个shell&…...

利用PHP爬虫开发获取淘宝分类详情:解锁电商数据新视角

在电商领域,淘宝作为中国最大的电商平台之一,其分类详情数据对于市场分析、竞争策略制定以及电商运营优化具有极高的价值。通过PHP爬虫技术,我们可以高效地获取这些数据,为电商从业者提供强大的数据支持。本文将详细介绍如何使用P…...

LeetCode 142题解|环形链表II的快慢指针法(含数学证明)

题目如下: 解题过程如下: 思路:快慢指针在环里一定会相遇,相遇结点到入环起始结点的距离 链表头结点到入环起始结点的距离(距离看从左往右的方向,也就是单链表的方向),从链表头结点…...

[图文]课程讲解片段-Fowler分析模式的剖析和实现01

​ 解说: GJJ-004-1,分析模式高阶Fowler分析模式的剖析和实现,这个课是针对Martin Fowler的《分析模式》那本书里面的模式来讲解,对里面的模式来剖析,然后用代码来实现。 做到这一步的,我们这个是世界上独…...

Dify使用

1. 概述 官网:Dify.AI 生成式 AI 应用创新引擎 文档:欢迎使用 Dify | Dify GITHUB:langgenius/dify: Dify is an open-source LLM app development platform. Difys intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, ob…...

解锁 DeepSeek 模型高效部署密码:蓝耘平台全解析

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…...

7.PPT:“中国梦”学习实践活动【20】

目录 NO1234​ NO5678​ NO9\10\11 NO1234 考生文件夹下创建一个名为“PPT.pptx”的新演示文稿Word素材文档的文字:复制/挪动→“PPT.pptx”的新演示文稿(蓝色、黑色、红色) 视图→幻灯片母版→重命名:“中国梦母版1”→背景样…...

Linux系统-centos防火墙firewalld详解

Linux系统-centos7.6 防火墙firewalld详解 1 firewalld了解 CentOS 7.6默认的防火墙管理工具是firewalld,它取代了之前的iptables防火墙。firewalld属于典型的包过滤防火墙或称之为网络层防火墙,与iptables一样,都是用来管理防火墙的工具&a…...

零基础都可以本地部署Deepseek R1

文章目录 一、硬件配置需求二、详细部署步骤1. 安装 Ollama 工具2. 部署 DeepSeek-R1 模型3. API使用4. 配置图形化交互界面(可选)5. 使用与注意事项 一、硬件配置需求 不同版本的 DeepSeek-R1 模型参数量不同,对硬件资源的要求也不尽相同。…...

通过Ollama本地部署DeepSeek R1以及简单使用的教程(超详细)

本文介绍了在Windows环境下,通过Ollama来本地部署DeepSeek R1。该问包含了Ollama的下载、安装、安装目录迁移、大模型存储位置修改、下载DeepSeek以及通过Web UI来对话等相关内容。 1、🥇下载Ollama 首先我们到Ollama官网去下载安装包,此处我…...

css实现长尾箭头(夹角小于45度的)

1. 长尾夹角小于45度的箭头 代码 //h5<div class"singleArrow"></div>//css .singleArrow {width: 150px;height: 1px;position: relative;background-color: #15ff00;/* transform: rotate(-40deg); */ /* 旋转角度 */}.singleArrow::after{ // 成品-有…...

封装descriptions组件,描述,灵活

效果 1、组件1&#xff0c;dade-descriptions.vue <template><table><tbody><slot></slot></tbody> </table> </template><script> </script><style scoped>table {width: 100%;border-collapse: coll…...

OC-Block

关于OC中的block作为属性时&#xff0c;为什么要要用copy修饰 property (nonatomic, copy) void (^completionBlock)(void);很多文章包括AI都会给出类似结论 Block 默认分配在栈上&#xff0c;如果没有 copy&#xff0c;当方法退出后&#xff0c;Block 会被销毁。使用 copy 修…...

关于知识蒸馏的概念原理以及常见方法

1. 概念与原理 知识蒸馏的基本定义 知识蒸馏(Knowledge Distillation) 是一种将模型压缩与迁移学习结合的技术:它利用预先训练好的大模型(通常参数量大、精度高、计算开销大)指导一个更轻量(参数量小、推理速度快)的学生模型进行训练,从而在保持模型精度的同时显著减少…...

C++轻量级桌面GUI库FLTK

C轻量级桌面GUI库FLTK Screenshots - Fast Light Toolkit (FLTK) 这里写个备忘录,可以参考一下....

C++20导出模块及使用

1.模块声明 .ixx文件为导入模块文件 math_operations.ixx export module math_operations;//模块导出 //导出命名空间 export namespace math_ {//导出命名空间中函数int add(int a, int b);int sub(int a, int b);int mul(int a, int b);int div(int a, int b); } .cppm文件…...

PID 算法简介(C语言)

一、简介: PID是比例、积分、微分三个环节的组合,用来进行反馈控制。每个部分都有对应的系数,也就是Kp、Ki、Kd。PID 算法实现这三个部分的计算,然后综合起来得到控制输出。 二、PID控制器结构体: PID控制器结构体:包含PID参数(Kp, Ki, Kd);存储积分项和上一次误差;…...

Java中的继承及相关概念

在 Java 中&#xff0c;继承是一种允许一个类继承另一个类的特性。通过继承&#xff0c;子类可以获取父类的属性和方法&#xff0c;这有助于减少代码冗余并提高代码的可维护性。以下是关于文件内容的相关分析和知识点总结&#xff1a; 一、继承的核心概念 1.继承的语法 Java …...

语言月赛 202308【小粉兔做麻辣兔头】题解(AC)

》》》点我查看「视频」详解》》》 [语言月赛 202308] 小粉兔做麻辣兔头 题目描述 粉兔喜欢吃麻辣兔头&#xff0c;麻辣兔头的辣度分为若干级&#xff0c;用数字表示&#xff0c;数字越大&#xff0c;兔头越辣。为了庆祝粉兔专题赛 #1 的顺利举行&#xff0c;粉兔要做一些麻…...

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表&#xff1f;1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代&#xff0c;智能代理&#xff08;agents&#xff09;不再是孤立的个体&#xff0c;而是能够像一个数字团队一样协作。然而&#xff0c;当前 AI 生态系统的碎片化阻碍了这一愿景的实现&#xff0c;导致了“AI 巴别塔问题”——不同代理之间…...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN&#xff0c;根据VPN原理&#xff0c;打通两个内网必然需要借助一个公共中继节点&#xff0c;ktconnect工具巧妙的利用k8s原生的portforward能力&#xff0c;简化了建立连接的过程&#xff0c;apiserver间接起到了中继节…...