当前位置: 首页 > news >正文

动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l# 定义模型net
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

相关文章:

动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet LeNet由两个部分组成: 卷积编码器:由两个卷积层组成;全连接层密集块:由三个全连接层组成。 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;每个卷积层使用55卷积核和一个sigmoid激…...

腾讯面经总结

最近在准备面试,看了很多大厂的面经,抽空将腾讯面试的题目整理了一下,希望对大家有所帮助~ 一面 1、mysql索引结构? 2、redis持久化策略? 3、zookeeper节点类型说一下; 4、zookeeper选举机制&#xff…...

matlab机器人工具箱基础使用

资料:https://blog.csdn.net/huangjunsheng123/article/details/110630665 用vscode直接看工具箱api代码比较方便,代码说明很多 一、模型设置 1、基础效果 %采用机器人工具箱进行正逆运动学验证 a[0,-0.3,-0.3,0,0,0];%DH参数 d[0.05,0,0,0.06,0.05,…...

利用WonderLeak进行内存泄露检测【一】

1、下载地址: WonderLeak - Visual Studio Marketplace https://www.relyze.com/ 2、WonderLeak支持vs2017 2019扩展,或者单独启动 3、https://www.relyze.com/docs/wonderleak/help/w/overview/msvc_extension1.png 4、对于二进制程序来说支持以下…...

二刷LeetCode--155. 最小栈(C++版本),思维题

思路:本题需要使用两个栈,一个就是正常栈,执行出入操作,另一个栈只负责将对应的最小值进行保存即可.每次入栈的时候,最小值栈的栈顶也需要入栈元素,不过这个元素是最小值,那么就需要进行比较,因此在getmin()的时候只需要将最小值栈的栈顶元素弹出即可.初始化的时候只需要将最小…...

进程的状态与转换

进程在其生命周期内,由于系统中各进程之间的相互制约及系统的运行环境的变化,使得进程的状态也在不断地发生变化。通常进程有以下5种状态,前三种是基础讷航的基本状态 1)运行态。进程正在处理机上运行。在单处理机机中&#xff0…...

用MariaDB创建数据库,SQL练习,MarialDB安装和使用

前言:MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可 MariaDB的目的是完全兼容MySQL,包括API和命令行,使之能轻松成为MySQL的代替品。在存储引擎方面,使用XtraDB来代替MySQ…...

【Docker】 使用Docker-Compose 搭建基于 WordPress 的博客网站

引 本文将使用流行的博客搭建工具 WordPress 搭建一个私人博客站点。部署过程中使用到了 Docker 、MySQL 。站点搭建完成后经行了发布文章的体验。 WordPress WordPress 是一个广泛使用的开源内容管理系统(CMS),用于构建和管理网站、博客和…...

Hlang社区-前端社区宣传首页实现

文章目录 前言页面结构固定钉头部轮播JS特效完整代码总结前言 这里的话,博主其实也是今年参与考研的大军之一,所以的话,是抽空去完成这个项目的,当然这个项目的肯定是可以在较短的时间内完成的。 那么废话不多说,昨天也是干到1点多,把这个首页写出来了。先看看看效果吧:…...

【LeetCode-Medium】833. 字符串中的查找与替换

题目链接 833. 字符串中的查找与替换 标签 字符串 步骤 Step1. 初始化 ans[]&#xff1a; for (int i 0; i < s.length(); i) { // 初始化ansans[i] s[i]; }Step2. 根据 index, source, target 查找&#xff1b;如果找到&#xff0c;那么将 ans[i] 更改为 target&am…...

数据结构中公式前中后缀表达式-二叉树应用

目录 数据结构中公式前中后缀表达式-二叉树应用 数据结构中公式前中后缀表达式-二叉树应用 什么是前缀表达式、中缀表达式、后缀表达式 前缀表达式、中缀表达式、后缀表达式&#xff0c;是通过树来存储和计算表达式的三种不同方式 以如下公式为例 通过树来存储该公式&#x…...

Visual Studio 2022连接远程系统进行C/C++开发

Visual Studio被称为是宇宙最强IDE&#xff0c;以前开发Linux C/C服务器程序&#xff0c;基本上都是在Windows上使用VS编写跨平台的C/C代码&#xff0c;然后先在VS中编译、链接、调试&#xff0c;然后在Linux下编译、链接&#xff0c;再针对Linux下的特定代码进行调试。后面Vis…...

TiDB数据库从入门到精通系列之二:TiDB数据库的简介

TiDB数据库从入门到精通系列之二&#xff1a;TiDB数据库的简介 一、TiDB数据库的简介二、五大核心特性三、四大核心应用场景四、TiDB数据库与MySQL数据库的兼容性 一、TiDB数据库的简介 TiDB是开源分布式关系型数据库&#xff0c;是一款同时支持在线事务处理与在线分析处理 (H…...

opencv视频截取每一帧并保存为图片python代码CV2实现练习

当涉及到视频处理时&#xff0c;Python中的OpenCV库提供了强大的功能&#xff0c;可以方便地从视频中截取每一帧并将其保存为图片。这是一个很有趣的练习&#xff0c;可以让你更深入地了解图像处理和多媒体操作。 使用OpenCV库&#xff0c;你可以轻松地读取视频文件&#xff0…...

虹科方案 | 汽车总线协议转换解决方案(二)

上期说到&#xff0c;虹科的PCAN-LIN网关在CAN、LIN总线转换方面有显著的作用&#xff0c;尤其是为BMS电池通信的测试提供了优秀的解决方案。假如您感兴趣&#xff0c;可以点击文末相关链接进行回顾&#xff01; 而今天&#xff0c;虹科将继续给大家带来Router系列在各个领域的…...

[Android] 通过JNI 让 JAVA 调用 android native 接口

前言&#xff1a; JNI (java native interface) 是一个库&#xff0c;可以让 java 代码和其他语言互动&#xff0c;比如 java 通过 JNI 调用融合了 jni库的 c/c 代码&#xff0c;注意&#xff0c;这里要求 c/c代码中必须通过链接 jni 库并按照 JNI 规范定义一套可供 JAVA 调用…...

MySQL高可用MHA

目录 前言 一、概述 二、配置免密、组从复制 三、MHA配置 四、测试 总结 前言 MySQL高可用管理工具&#xff08;MHA&#xff0c;Master High Availability&#xff09;是一个用于自动管理MySQL主从复制的工具&#xff0c;它可以提供高可用性和自动故障转移。MHA由原版的MHA工具…...

DoIP学习笔记系列:(五)“安全认证”的.dll从何而来?

文章目录 1. “安全认证”的.dll从何而来?1.1 .dll文件base1.2 增加客户需求算法传送门 DoIP学习笔记系列:导航篇 1. “安全认证”的.dll从何而来? 无论是用CANoe还是VFlash,亦或是编辑cdd文件,都需要加载一个与$27服务相关的.dll(Windows的动态库文件),这个文件是从哪…...

205、仿真-51单片机直流数字电流表多档位切换Proteus仿真设计(程序+Proteus仿真+原理图+流程图+元器件清单+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括&#xff1a; 方案选择 单片机的选择 方案一&#xff1a;STM32系列单片机控制&#xff0c;该型号单片机为LQFP44封装&#xff0c;内部资源…...

服务器如何防止cc攻击

对于搭载网站运行的服务器来说&#xff0c;cc攻击应该并不陌生&#xff0c;特别是cc攻击的攻击门槛非常低&#xff0c;有个代理IP工具&#xff0c;有个cc攻击软件就可以轻易对任何网站发起攻击&#xff0c;那么服务器如何防止cc攻击?请看下面的介绍。 服务器如何防止cc攻击&a…...

Vue作用域插槽

下面,我们来系统的梳理关于 **Vue 作用域插槽 ** 的基本知识点: 一、作用域插槽核心概念 1.1 什么是作用域插槽? 作用域插槽是 Vue 中一种反向数据流机制,允许子组件将数据传递给父组件中的插槽内容。这种模式解决了传统插槽中父组件无法访问子组件内部状态的限制。 1.2…...

【Elasticsearch】映射:fielddata 详解

映射&#xff1a;fielddata 详解 1.fielddata 是什么2.fielddata 的工作原理3.主要用法3.1 启用 fielddata&#xff08;通常在 text 字段上&#xff09;3.2 监控 fielddata 使用情况3.3 清除 fielddata 缓存 4.使用场景示例示例 1&#xff1a;对 text 字段进行聚合示例 2&#…...

SQL-事务(2025.6.6-2025.6.7学习篇)

1、简介 事务是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向系统提交或撤销操作请求&#xff0c;即这些操作要么同时成功&#xff0c;要么同时失败。 默认MySQL的事务是自动提交的&#xff0c;也就是说&#xff0…...

论文MR-SVD

每个像素 7 个 FLOPs意思&#xff1a; FLOPs&#xff08;浮点运算次数&#xff09;&#xff1a;衡量算法计算复杂度的指标&#xff0c;数值越小表示运算越高效。含义&#xff1a;对图像中每个像素进行处理时&#xff0c;仅需执行7 次浮点运算&#xff08;如加减乘除等&#xf…...

『React』Fragment的用法及简写形式

在 React 渲染组件时&#xff0c;每个组件只能返回一个根节点&#xff08;root element&#xff09;。传统上&#xff0c;如果我们需要渲染多条并列的元素&#xff0c;通常会使用一个多余的 <div> 或者其他容器标签将它们包裹起来。但是&#xff0c;这样会在最终的 HTML …...

Ubuntu 下开机自动执行命令的方法

Ubuntu 下开机自动执行命令的方法&#xff08;使用 crontab&#xff09; 在日常使用 Ubuntu 或其他 Linux 系统时&#xff0c;我们常常需要让某些程序或脚本在系统启动后自动运行。例如&#xff1a;启动 Clash 代理、初始化服务、定时同步数据等。 本文将介绍一种简单且常用的…...

Faiss向量数据库全面解析:从原理到实战

Faiss向量数据库全面解析&#xff1a;从原理到实战 引言&#xff1a;向量搜索的时代需求 在AI技术爆发的今天&#xff0c;向量数据已成为表示文本、图像、音视频等内容的核心形式。Facebook AI研究院开源的Faiss&#xff08;Facebook AI Similarity Search&#xff09;作为高…...

hadoop集群datanode启动显示init failed,不能解析hostname

三个datanode集群&#xff0c;有一个总是起不起来。去查看log显示 Initialization failed for Block pool BP-1920852191-192.168.115.154-1749093939738 (Datanode Uuid 89d9df36-1c01-4f22-9905-517fee205a8e) service to node154/192.168.115.154:8020 Datanode denied com…...

ASP.NET Core使用Quartz部署到IIS资源自动被回收解决方案

iis自动回收的原因 回收机制默认配置&#xff0c;间隔时间是1740分钟&#xff0c;意思是&#xff1a;默认情况下每1740分钟(29小时)回收一次&#xff0c;定期检查应用程序池中的工作进程&#xff0c;并终止那些已经存在很长时间或已经使用了太多资源的工作进程 进程模型默认配…...

AWS App Mesh实战:构建可观测、安全的微服务通信解决方案

摘要&#xff1a;本文详解如何利用AWS App Mesh统一管理微服务间通信&#xff0c;实现精细化流量控制、端到端可观测性与安全通信&#xff0c;提升云原生应用稳定性。 一、什么是AWS App Mesh&#xff1f; AWS App Mesh 是一种服务网格&#xff08;Service Mesh&#xff09;解…...