当前位置: 首页 > news >正文

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录

  • LeNet
  • 模型训练

LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:
卷积编码器:由两个卷积层组成;
全连接层密集块:由三个全连接层组成。
图片来源https://baike.baidu.com/item/LeNet-5/61427772?fr=ge_ala每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致
将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,可以检查模型。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

模型训练

看看LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

为了使用GPU,还需要一点小改动。 与学习记录10节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 学习记录10中定义的train_ch3。 主要使用高级API实现多层神经网络。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用Xavier随机初始化模型参数。 与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

相关文章:

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录 LeNet模型训练 LeNet 总体来看,LeNet(LeNet-5)由两个部分组成: 卷积编码器:由两个卷积层组成; 全连接层密集块:由三个全连接层组成。 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均…...

log4j 和 java.lang.OutOfMemoryError PermGen space

还是OneCoder在项目中沙箱的问题,用classloader隔离做的沙箱,反复运行用户的任务,出现永生区内存溢出: java.lang.OutOfMemoryError: PermGen space 这个问题在tomcat重复热部署的时候其实比较常见。其道理也和我们沙箱的道理基本…...

2024.9.9营养小题【2】

营养: 1、什么数是丑数? 2、数学数学,丑数的数学意义,哎,数学思维我是忘干净了。 3、可以把while循环换成for循环。由此又想到了一点,三个循环结构各有使用场景。 for(;n%factors[i]0;n/factors[i]){}...

uniapp的barcode组件去掉自动放大功能

autoZoom“false” <barcode id1 class"barcode" autoZoom"false" autostart"false" ref"barcode" background"rgb(0,0,0)" frameColor"#1C86EE"scanbarColor"#1C86EE" :filters"fil" ma…...

H5接入Steam 获取用户数据案例

官方文档地址 1.注册 Steam API Key&#xff1a; 你需要一个 Steam Web API Key&#xff0c;可以在 Steam API Key 页面 获取。https://steamcommunity.com/dev/apikey 2.使用 OpenID 登录&#xff1a; 实现 Steam OpenID 登录&#xff0c;以便用户通过 Steam 账户登录你的…...

《A Few Useful Things to Know about Machine Learning》论文导读

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl机器学习作为人工智能领域的重要分支,近年来得到了广泛的关注和应用。Pedro Domingos的经典论文《A Few Useful Things to Know about Machine Learning》为我们提供了对机器学习深入且全面的理解…...

隔壁老樊2024全国巡回演唱会重磅来袭,首站广州正式官宣!

汹涌人潮将城市填满&#xff0c;斑驳心绪漂浮在时间之隙&#xff0c;当生活的喜悲逐渐演化成歌&#xff0c;天空将自己负载的缄默倾泻&#xff0c;那些或酸涩、或热烈的点滴滑落心海&#xff0c;那层悬挂在「我」与世界分野的无形壁垒&#xff0c;渐也被曙光渗透消融。 提炼生…...

【C++】list(下)

个人主页~ list&#xff08;上&#xff09;~ list 四、模拟实现1、list.h&#xff08;1&#xff09;关于整个list的搭建①节点②迭代器③接口 &#xff08;2&#xff09;自定义类型实例化 2、test.cpp&#xff08;1&#xff09;test1&#xff08;2&#xff09;test2 五、额外小…...

千云物流 -低代码平台MySQL备份数据

windows备份 全量备份 创建备份目录 需要在安装数据库的服务器上创建备份目录,所有如果要做备份至少需要两倍的硬盘空间, mkdir D:\mysql_backup\full_backup准备备份脚本 创建一个windows批处理文件&#xff08;例如 full_backup.bat&#xff09;&#xff0c;用来执行全量…...

MySQL:进阶巩固-视图

目录 一、视图的概述二、视图的基本使用2.1 创建视图2.2 查询视图2.3 修改视图2.4 删除视图 一、视图的概述 视图是一种虚拟存在的表&#xff0c;视图中的数据并不在数据库中实际的存在&#xff0c;行列数据来自于视图中查询的表&#xff0c;并且是在使用视图时动态生成的。 通…...

分布式事务Seata原理及其项目使用

0.Seata官方文档 1.Seata概念及原理 Seata是什么 Seata 是一款开源的分布式事务解决方案&#xff0c;致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式&#xff0c;为用户打造一站式的分布式解决方案。 Seata主要由三个重要组…...

JS_分支结构

if结构 这里的if结构几乎和JAVA中的一样,需要注意的是 if()中的非空字符串会被认为是trueif()中的非零数字会被认为是trueif()中的非空对象会被认为是true <script> if(false){// 非空字符串 if判断为true console.log(true) }else{ console.log(false) } if(){// 长度…...

决策树(Decison Tree)—有监督学习方法、概率模型、生成模型、非线性模型、非参数化模型、批量学习

定义 ID3算法 输入:训练数据集(T= { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } \left\{(x_1,y_1),(x_2,y_2),\cdots,(x_N,y_N)\right\} {(x1​,y1​),(x2​,y2​),⋯,(xN​,yN​)}),特征集A阀值 ε \varepsilon ε 输出:决策树T (1)若D中所有实例属于同一…...

java 自定义注解校验实体类属性

直接上代码 1.是否启用参数校验注解 Target({ElementType.TYPE}) Retention(RetentionPolicy.RUNTIME) Documented public interface EnableArgumentsCheck {/*** 是否启用*/boolean enable() default true;} 2.参数校验自定义注解 /*** 参数校验自定义注解* 属性定义&#…...

光伏并网发电系统中电能质量监测与优化技术探讨

0引言 随着清洁能源技术的持续进步与广泛应用&#xff0c;光伏并网发电系统亦逐步崭露头角。作为一种关键的电力供应方式&#xff0c;其受到了广泛的关注。然而&#xff0c;由于天气等外部条件的影响&#xff0c;光伏发电系统面临若干挑战。电能质量问题&#xff0c;诸如电压波…...

网页解析的那些事

Vue 方面 模板语法理解 熟悉 Vue 的模板语法&#xff0c;包括插值表达式&#xff08;如{{ message }}&#xff09;、指令&#xff08;如v-if、v-for、v-bind等&#xff09;。理解这些语法元素如何将数据与 DOM 元素进行绑定和交互。例如&#xff0c;v-for指令用于循环渲染列表数…...

从文字到世界:2024外语阅读大赛报名开启,赛氪网全程护航

中国外文局CATTI项目管理中心与中国外文界联合宣布&#xff0c;将举办2024年外语阅读大赛&#xff0c;旨在激发外语学习兴趣&#xff0c;选拔并培养优秀的语言应用人才&#xff0c;同时向世界展示和传播具有中国特色的优秀文化。此次大赛旨在激发外语学习兴趣&#xff0c;选拔优…...

微信小程序知识点(二)

1.下拉刷新事件 如果页面需要下拉刷新功能&#xff0c;则在页面对应的json配置文件中&#xff0c;将enablePullDownRefresh配置设置为true&#xff0c;如下 {"usingComponents": {},"enablePullDownRefresh": true } 2.上拉触底事件 在很多时候&#x…...

Springcould -第一个Eureka应用 --- day02

标题 Eureka工作原理Spring Cloud框架下的服务发现Eureka包含两个组件&#xff0c;分别是&#xff1a;Eureka Server与Eureka Client。Eureka Server&#xff1a;Eureka Client&#xff1a; 搭建Eureka Server步骤&#xff1a;步骤1&#xff1a;创建项目&#xff0c;引入依赖步…...

RedissonClient 分布式队列工具类

注意&#xff1a;轻量级队列可以使用工具类&#xff0c;重量级数据量 请使用 MQ 本文章基于redis使用redisson客户端实现轻量级队列&#xff0c;以及代码、执行结果演示 一、常见队列了解 普通队列&#xff1a;先进先出&#xff08;FIFO&#xff09;&#xff0c;只能在一端添…...

深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录

ASP.NET Core 是一个跨平台的开源框架&#xff0c;用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录&#xff0c;以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...

VB.net复制Ntag213卡写入UID

本示例使用的发卡器&#xff1a;https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题&#xff08;可多选&#xff09; 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘&#xff1a;专注于发现数据中…...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式&#xff08;Singleton Pattern&#…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要&#xff1a;设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP&#xff08;Work-in-Progress&#xff09;弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中&#xff0c;设立专门的紧急任务通道尤为重要&#xff0c;这能…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)

要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况&#xff0c;可以通过以下几种方式模拟或触发&#xff1a; 1. 增加CPU负载 运行大量计算密集型任务&#xff0c;例如&#xff1a; 使用多线程循环执行复杂计算&#xff08;如数学运算、加密解密等&#xff09;。运行图…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...