当前位置: 首页 > news >正文

动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l# 定义模型net
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

相关文章:

动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet LeNet由两个部分组成: 卷积编码器:由两个卷积层组成;全连接层密集块:由三个全连接层组成。 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;每个卷积层使用55卷积核和一个sigmoid激…...

腾讯面经总结

最近在准备面试,看了很多大厂的面经,抽空将腾讯面试的题目整理了一下,希望对大家有所帮助~ 一面 1、mysql索引结构? 2、redis持久化策略? 3、zookeeper节点类型说一下; 4、zookeeper选举机制&#xff…...

matlab机器人工具箱基础使用

资料:https://blog.csdn.net/huangjunsheng123/article/details/110630665 用vscode直接看工具箱api代码比较方便,代码说明很多 一、模型设置 1、基础效果 %采用机器人工具箱进行正逆运动学验证 a[0,-0.3,-0.3,0,0,0];%DH参数 d[0.05,0,0,0.06,0.05,…...

利用WonderLeak进行内存泄露检测【一】

1、下载地址: WonderLeak - Visual Studio Marketplace https://www.relyze.com/ 2、WonderLeak支持vs2017 2019扩展,或者单独启动 3、https://www.relyze.com/docs/wonderleak/help/w/overview/msvc_extension1.png 4、对于二进制程序来说支持以下…...

二刷LeetCode--155. 最小栈(C++版本),思维题

思路:本题需要使用两个栈,一个就是正常栈,执行出入操作,另一个栈只负责将对应的最小值进行保存即可.每次入栈的时候,最小值栈的栈顶也需要入栈元素,不过这个元素是最小值,那么就需要进行比较,因此在getmin()的时候只需要将最小值栈的栈顶元素弹出即可.初始化的时候只需要将最小…...

进程的状态与转换

进程在其生命周期内,由于系统中各进程之间的相互制约及系统的运行环境的变化,使得进程的状态也在不断地发生变化。通常进程有以下5种状态,前三种是基础讷航的基本状态 1)运行态。进程正在处理机上运行。在单处理机机中&#xff0…...

用MariaDB创建数据库,SQL练习,MarialDB安装和使用

前言:MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可 MariaDB的目的是完全兼容MySQL,包括API和命令行,使之能轻松成为MySQL的代替品。在存储引擎方面,使用XtraDB来代替MySQ…...

【Docker】 使用Docker-Compose 搭建基于 WordPress 的博客网站

引 本文将使用流行的博客搭建工具 WordPress 搭建一个私人博客站点。部署过程中使用到了 Docker 、MySQL 。站点搭建完成后经行了发布文章的体验。 WordPress WordPress 是一个广泛使用的开源内容管理系统(CMS),用于构建和管理网站、博客和…...

Hlang社区-前端社区宣传首页实现

文章目录 前言页面结构固定钉头部轮播JS特效完整代码总结前言 这里的话,博主其实也是今年参与考研的大军之一,所以的话,是抽空去完成这个项目的,当然这个项目的肯定是可以在较短的时间内完成的。 那么废话不多说,昨天也是干到1点多,把这个首页写出来了。先看看看效果吧:…...

【LeetCode-Medium】833. 字符串中的查找与替换

题目链接 833. 字符串中的查找与替换 标签 字符串 步骤 Step1. 初始化 ans[]&#xff1a; for (int i 0; i < s.length(); i) { // 初始化ansans[i] s[i]; }Step2. 根据 index, source, target 查找&#xff1b;如果找到&#xff0c;那么将 ans[i] 更改为 target&am…...

数据结构中公式前中后缀表达式-二叉树应用

目录 数据结构中公式前中后缀表达式-二叉树应用 数据结构中公式前中后缀表达式-二叉树应用 什么是前缀表达式、中缀表达式、后缀表达式 前缀表达式、中缀表达式、后缀表达式&#xff0c;是通过树来存储和计算表达式的三种不同方式 以如下公式为例 通过树来存储该公式&#x…...

Visual Studio 2022连接远程系统进行C/C++开发

Visual Studio被称为是宇宙最强IDE&#xff0c;以前开发Linux C/C服务器程序&#xff0c;基本上都是在Windows上使用VS编写跨平台的C/C代码&#xff0c;然后先在VS中编译、链接、调试&#xff0c;然后在Linux下编译、链接&#xff0c;再针对Linux下的特定代码进行调试。后面Vis…...

TiDB数据库从入门到精通系列之二:TiDB数据库的简介

TiDB数据库从入门到精通系列之二&#xff1a;TiDB数据库的简介 一、TiDB数据库的简介二、五大核心特性三、四大核心应用场景四、TiDB数据库与MySQL数据库的兼容性 一、TiDB数据库的简介 TiDB是开源分布式关系型数据库&#xff0c;是一款同时支持在线事务处理与在线分析处理 (H…...

opencv视频截取每一帧并保存为图片python代码CV2实现练习

当涉及到视频处理时&#xff0c;Python中的OpenCV库提供了强大的功能&#xff0c;可以方便地从视频中截取每一帧并将其保存为图片。这是一个很有趣的练习&#xff0c;可以让你更深入地了解图像处理和多媒体操作。 使用OpenCV库&#xff0c;你可以轻松地读取视频文件&#xff0…...

虹科方案 | 汽车总线协议转换解决方案(二)

上期说到&#xff0c;虹科的PCAN-LIN网关在CAN、LIN总线转换方面有显著的作用&#xff0c;尤其是为BMS电池通信的测试提供了优秀的解决方案。假如您感兴趣&#xff0c;可以点击文末相关链接进行回顾&#xff01; 而今天&#xff0c;虹科将继续给大家带来Router系列在各个领域的…...

[Android] 通过JNI 让 JAVA 调用 android native 接口

前言&#xff1a; JNI (java native interface) 是一个库&#xff0c;可以让 java 代码和其他语言互动&#xff0c;比如 java 通过 JNI 调用融合了 jni库的 c/c 代码&#xff0c;注意&#xff0c;这里要求 c/c代码中必须通过链接 jni 库并按照 JNI 规范定义一套可供 JAVA 调用…...

MySQL高可用MHA

目录 前言 一、概述 二、配置免密、组从复制 三、MHA配置 四、测试 总结 前言 MySQL高可用管理工具&#xff08;MHA&#xff0c;Master High Availability&#xff09;是一个用于自动管理MySQL主从复制的工具&#xff0c;它可以提供高可用性和自动故障转移。MHA由原版的MHA工具…...

DoIP学习笔记系列:(五)“安全认证”的.dll从何而来?

文章目录 1. “安全认证”的.dll从何而来?1.1 .dll文件base1.2 增加客户需求算法传送门 DoIP学习笔记系列:导航篇 1. “安全认证”的.dll从何而来? 无论是用CANoe还是VFlash,亦或是编辑cdd文件,都需要加载一个与$27服务相关的.dll(Windows的动态库文件),这个文件是从哪…...

205、仿真-51单片机直流数字电流表多档位切换Proteus仿真设计(程序+Proteus仿真+原理图+流程图+元器件清单+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括&#xff1a; 方案选择 单片机的选择 方案一&#xff1a;STM32系列单片机控制&#xff0c;该型号单片机为LQFP44封装&#xff0c;内部资源…...

服务器如何防止cc攻击

对于搭载网站运行的服务器来说&#xff0c;cc攻击应该并不陌生&#xff0c;特别是cc攻击的攻击门槛非常低&#xff0c;有个代理IP工具&#xff0c;有个cc攻击软件就可以轻易对任何网站发起攻击&#xff0c;那么服务器如何防止cc攻击?请看下面的介绍。 服务器如何防止cc攻击&a…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中&#xff0c;iftop是网络管理的得力助手&#xff0c;能实时监控网络流量、连接情况等&#xff0c;帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子&#xff0c;用于处理异步操作&#xff08;如数据加载&#xff09;中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误&#xff1a;捕获在 loader 或 action 中发生的异步错误替…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

GitHub 趋势日报 (2025年06月08日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

C#学习第29天:表达式树(Expression Trees)

目录 什么是表达式树&#xff1f; 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持&#xff1a; 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...

【JVM】Java虚拟机(二)——垃圾回收

目录 一、如何判断对象可以回收 &#xff08;一&#xff09;引用计数法 &#xff08;二&#xff09;可达性分析算法 二、垃圾回收算法 &#xff08;一&#xff09;标记清除 &#xff08;二&#xff09;标记整理 &#xff08;三&#xff09;复制 &#xff08;四&#xff…...