当前位置: 首页 > news >正文

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录

  • LeNet
  • 模型训练

LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:
卷积编码器:由两个卷积层组成;
全连接层密集块:由三个全连接层组成。
图片来源https://baike.baidu.com/item/LeNet-5/61427772?fr=ge_ala每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致
将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,可以检查模型。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

模型训练

看看LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

为了使用GPU,还需要一点小改动。 与学习记录10节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 学习记录10中定义的train_ch3。 主要使用高级API实现多层神经网络。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用Xavier随机初始化模型参数。 与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

相关文章:

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录 LeNet模型训练 LeNet 总体来看,LeNet(LeNet-5)由两个部分组成: 卷积编码器:由两个卷积层组成; 全连接层密集块:由三个全连接层组成。 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均…...

log4j 和 java.lang.OutOfMemoryError PermGen space

还是OneCoder在项目中沙箱的问题,用classloader隔离做的沙箱,反复运行用户的任务,出现永生区内存溢出: java.lang.OutOfMemoryError: PermGen space 这个问题在tomcat重复热部署的时候其实比较常见。其道理也和我们沙箱的道理基本…...

2024.9.9营养小题【2】

营养: 1、什么数是丑数? 2、数学数学,丑数的数学意义,哎,数学思维我是忘干净了。 3、可以把while循环换成for循环。由此又想到了一点,三个循环结构各有使用场景。 for(;n%factors[i]0;n/factors[i]){}...

uniapp的barcode组件去掉自动放大功能

autoZoom“false” <barcode id1 class"barcode" autoZoom"false" autostart"false" ref"barcode" background"rgb(0,0,0)" frameColor"#1C86EE"scanbarColor"#1C86EE" :filters"fil" ma…...

H5接入Steam 获取用户数据案例

官方文档地址 1.注册 Steam API Key&#xff1a; 你需要一个 Steam Web API Key&#xff0c;可以在 Steam API Key 页面 获取。https://steamcommunity.com/dev/apikey 2.使用 OpenID 登录&#xff1a; 实现 Steam OpenID 登录&#xff0c;以便用户通过 Steam 账户登录你的…...

《A Few Useful Things to Know about Machine Learning》论文导读

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl机器学习作为人工智能领域的重要分支,近年来得到了广泛的关注和应用。Pedro Domingos的经典论文《A Few Useful Things to Know about Machine Learning》为我们提供了对机器学习深入且全面的理解…...

隔壁老樊2024全国巡回演唱会重磅来袭,首站广州正式官宣!

汹涌人潮将城市填满&#xff0c;斑驳心绪漂浮在时间之隙&#xff0c;当生活的喜悲逐渐演化成歌&#xff0c;天空将自己负载的缄默倾泻&#xff0c;那些或酸涩、或热烈的点滴滑落心海&#xff0c;那层悬挂在「我」与世界分野的无形壁垒&#xff0c;渐也被曙光渗透消融。 提炼生…...

【C++】list(下)

个人主页~ list&#xff08;上&#xff09;~ list 四、模拟实现1、list.h&#xff08;1&#xff09;关于整个list的搭建①节点②迭代器③接口 &#xff08;2&#xff09;自定义类型实例化 2、test.cpp&#xff08;1&#xff09;test1&#xff08;2&#xff09;test2 五、额外小…...

千云物流 -低代码平台MySQL备份数据

windows备份 全量备份 创建备份目录 需要在安装数据库的服务器上创建备份目录,所有如果要做备份至少需要两倍的硬盘空间, mkdir D:\mysql_backup\full_backup准备备份脚本 创建一个windows批处理文件&#xff08;例如 full_backup.bat&#xff09;&#xff0c;用来执行全量…...

MySQL:进阶巩固-视图

目录 一、视图的概述二、视图的基本使用2.1 创建视图2.2 查询视图2.3 修改视图2.4 删除视图 一、视图的概述 视图是一种虚拟存在的表&#xff0c;视图中的数据并不在数据库中实际的存在&#xff0c;行列数据来自于视图中查询的表&#xff0c;并且是在使用视图时动态生成的。 通…...

分布式事务Seata原理及其项目使用

0.Seata官方文档 1.Seata概念及原理 Seata是什么 Seata 是一款开源的分布式事务解决方案&#xff0c;致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式&#xff0c;为用户打造一站式的分布式解决方案。 Seata主要由三个重要组…...

JS_分支结构

if结构 这里的if结构几乎和JAVA中的一样,需要注意的是 if()中的非空字符串会被认为是trueif()中的非零数字会被认为是trueif()中的非空对象会被认为是true <script> if(false){// 非空字符串 if判断为true console.log(true) }else{ console.log(false) } if(){// 长度…...

决策树(Decison Tree)—有监督学习方法、概率模型、生成模型、非线性模型、非参数化模型、批量学习

定义 ID3算法 输入:训练数据集(T= { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } \left\{(x_1,y_1),(x_2,y_2),\cdots,(x_N,y_N)\right\} {(x1​,y1​),(x2​,y2​),⋯,(xN​,yN​)}),特征集A阀值 ε \varepsilon ε 输出:决策树T (1)若D中所有实例属于同一…...

java 自定义注解校验实体类属性

直接上代码 1.是否启用参数校验注解 Target({ElementType.TYPE}) Retention(RetentionPolicy.RUNTIME) Documented public interface EnableArgumentsCheck {/*** 是否启用*/boolean enable() default true;} 2.参数校验自定义注解 /*** 参数校验自定义注解* 属性定义&#…...

光伏并网发电系统中电能质量监测与优化技术探讨

0引言 随着清洁能源技术的持续进步与广泛应用&#xff0c;光伏并网发电系统亦逐步崭露头角。作为一种关键的电力供应方式&#xff0c;其受到了广泛的关注。然而&#xff0c;由于天气等外部条件的影响&#xff0c;光伏发电系统面临若干挑战。电能质量问题&#xff0c;诸如电压波…...

网页解析的那些事

Vue 方面 模板语法理解 熟悉 Vue 的模板语法&#xff0c;包括插值表达式&#xff08;如{{ message }}&#xff09;、指令&#xff08;如v-if、v-for、v-bind等&#xff09;。理解这些语法元素如何将数据与 DOM 元素进行绑定和交互。例如&#xff0c;v-for指令用于循环渲染列表数…...

从文字到世界:2024外语阅读大赛报名开启,赛氪网全程护航

中国外文局CATTI项目管理中心与中国外文界联合宣布&#xff0c;将举办2024年外语阅读大赛&#xff0c;旨在激发外语学习兴趣&#xff0c;选拔并培养优秀的语言应用人才&#xff0c;同时向世界展示和传播具有中国特色的优秀文化。此次大赛旨在激发外语学习兴趣&#xff0c;选拔优…...

微信小程序知识点(二)

1.下拉刷新事件 如果页面需要下拉刷新功能&#xff0c;则在页面对应的json配置文件中&#xff0c;将enablePullDownRefresh配置设置为true&#xff0c;如下 {"usingComponents": {},"enablePullDownRefresh": true } 2.上拉触底事件 在很多时候&#x…...

Springcould -第一个Eureka应用 --- day02

标题 Eureka工作原理Spring Cloud框架下的服务发现Eureka包含两个组件&#xff0c;分别是&#xff1a;Eureka Server与Eureka Client。Eureka Server&#xff1a;Eureka Client&#xff1a; 搭建Eureka Server步骤&#xff1a;步骤1&#xff1a;创建项目&#xff0c;引入依赖步…...

RedissonClient 分布式队列工具类

注意&#xff1a;轻量级队列可以使用工具类&#xff0c;重量级数据量 请使用 MQ 本文章基于redis使用redisson客户端实现轻量级队列&#xff0c;以及代码、执行结果演示 一、常见队列了解 普通队列&#xff1a;先进先出&#xff08;FIFO&#xff09;&#xff0c;只能在一端添…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题&#xff08;可多选&#xff09; 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘&#xff1a;专注于发现数据中…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 &#xff08;FL&#xff09; 支持跨分布式客户端进行协作模型训练&#xff0c;而无需共享原始数据&#xff0c;这使其成为在互联和自动驾驶汽车 &#xff08;CAV&#xff09; 等领域保护隐私的机器学习的一种很有前途的方法。然而&#xff0c;最近的研究表明&…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版&#xff0c;柱状图PPT模版&#xff0c;线状图PPT模版&#xff0c;折线图PPT模版&#xff0c;饼状图PPT模版&#xff0c;雷达图PPT模版&#xff0c;树状图PPT模版 图表类系列各种样式PPT模版分享&#xff1a;图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

使用 SymPy 进行向量和矩阵的高级操作

在科学计算和工程领域&#xff0c;向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能&#xff0c;能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作&#xff0c;并通过具体…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...