当前位置: 首页 > news >正文

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录

  • LeNet
  • 模型训练

LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:
卷积编码器:由两个卷积层组成;
全连接层密集块:由三个全连接层组成。
图片来源https://baike.baidu.com/item/LeNet-5/61427772?fr=ge_ala每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致
将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,可以检查模型。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

模型训练

看看LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

为了使用GPU,还需要一点小改动。 与学习记录10节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 学习记录10中定义的train_ch3。 主要使用高级API实现多层神经网络。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用Xavier随机初始化模型参数。 与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

相关文章:

动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

目录 LeNet模型训练 LeNet 总体来看,LeNet(LeNet-5)由两个部分组成: 卷积编码器:由两个卷积层组成; 全连接层密集块:由三个全连接层组成。 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均…...

log4j 和 java.lang.OutOfMemoryError PermGen space

还是OneCoder在项目中沙箱的问题,用classloader隔离做的沙箱,反复运行用户的任务,出现永生区内存溢出: java.lang.OutOfMemoryError: PermGen space 这个问题在tomcat重复热部署的时候其实比较常见。其道理也和我们沙箱的道理基本…...

2024.9.9营养小题【2】

营养: 1、什么数是丑数? 2、数学数学,丑数的数学意义,哎,数学思维我是忘干净了。 3、可以把while循环换成for循环。由此又想到了一点,三个循环结构各有使用场景。 for(;n%factors[i]0;n/factors[i]){}...

uniapp的barcode组件去掉自动放大功能

autoZoom“false” <barcode id1 class"barcode" autoZoom"false" autostart"false" ref"barcode" background"rgb(0,0,0)" frameColor"#1C86EE"scanbarColor"#1C86EE" :filters"fil" ma…...

H5接入Steam 获取用户数据案例

官方文档地址 1.注册 Steam API Key&#xff1a; 你需要一个 Steam Web API Key&#xff0c;可以在 Steam API Key 页面 获取。https://steamcommunity.com/dev/apikey 2.使用 OpenID 登录&#xff1a; 实现 Steam OpenID 登录&#xff0c;以便用户通过 Steam 账户登录你的…...

《A Few Useful Things to Know about Machine Learning》论文导读

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl机器学习作为人工智能领域的重要分支,近年来得到了广泛的关注和应用。Pedro Domingos的经典论文《A Few Useful Things to Know about Machine Learning》为我们提供了对机器学习深入且全面的理解…...

隔壁老樊2024全国巡回演唱会重磅来袭,首站广州正式官宣!

汹涌人潮将城市填满&#xff0c;斑驳心绪漂浮在时间之隙&#xff0c;当生活的喜悲逐渐演化成歌&#xff0c;天空将自己负载的缄默倾泻&#xff0c;那些或酸涩、或热烈的点滴滑落心海&#xff0c;那层悬挂在「我」与世界分野的无形壁垒&#xff0c;渐也被曙光渗透消融。 提炼生…...

【C++】list(下)

个人主页~ list&#xff08;上&#xff09;~ list 四、模拟实现1、list.h&#xff08;1&#xff09;关于整个list的搭建①节点②迭代器③接口 &#xff08;2&#xff09;自定义类型实例化 2、test.cpp&#xff08;1&#xff09;test1&#xff08;2&#xff09;test2 五、额外小…...

千云物流 -低代码平台MySQL备份数据

windows备份 全量备份 创建备份目录 需要在安装数据库的服务器上创建备份目录,所有如果要做备份至少需要两倍的硬盘空间, mkdir D:\mysql_backup\full_backup准备备份脚本 创建一个windows批处理文件&#xff08;例如 full_backup.bat&#xff09;&#xff0c;用来执行全量…...

MySQL:进阶巩固-视图

目录 一、视图的概述二、视图的基本使用2.1 创建视图2.2 查询视图2.3 修改视图2.4 删除视图 一、视图的概述 视图是一种虚拟存在的表&#xff0c;视图中的数据并不在数据库中实际的存在&#xff0c;行列数据来自于视图中查询的表&#xff0c;并且是在使用视图时动态生成的。 通…...

分布式事务Seata原理及其项目使用

0.Seata官方文档 1.Seata概念及原理 Seata是什么 Seata 是一款开源的分布式事务解决方案&#xff0c;致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式&#xff0c;为用户打造一站式的分布式解决方案。 Seata主要由三个重要组…...

JS_分支结构

if结构 这里的if结构几乎和JAVA中的一样,需要注意的是 if()中的非空字符串会被认为是trueif()中的非零数字会被认为是trueif()中的非空对象会被认为是true <script> if(false){// 非空字符串 if判断为true console.log(true) }else{ console.log(false) } if(){// 长度…...

决策树(Decison Tree)—有监督学习方法、概率模型、生成模型、非线性模型、非参数化模型、批量学习

定义 ID3算法 输入:训练数据集(T= { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } \left\{(x_1,y_1),(x_2,y_2),\cdots,(x_N,y_N)\right\} {(x1​,y1​),(x2​,y2​),⋯,(xN​,yN​)}),特征集A阀值 ε \varepsilon ε 输出:决策树T (1)若D中所有实例属于同一…...

java 自定义注解校验实体类属性

直接上代码 1.是否启用参数校验注解 Target({ElementType.TYPE}) Retention(RetentionPolicy.RUNTIME) Documented public interface EnableArgumentsCheck {/*** 是否启用*/boolean enable() default true;} 2.参数校验自定义注解 /*** 参数校验自定义注解* 属性定义&#…...

光伏并网发电系统中电能质量监测与优化技术探讨

0引言 随着清洁能源技术的持续进步与广泛应用&#xff0c;光伏并网发电系统亦逐步崭露头角。作为一种关键的电力供应方式&#xff0c;其受到了广泛的关注。然而&#xff0c;由于天气等外部条件的影响&#xff0c;光伏发电系统面临若干挑战。电能质量问题&#xff0c;诸如电压波…...

网页解析的那些事

Vue 方面 模板语法理解 熟悉 Vue 的模板语法&#xff0c;包括插值表达式&#xff08;如{{ message }}&#xff09;、指令&#xff08;如v-if、v-for、v-bind等&#xff09;。理解这些语法元素如何将数据与 DOM 元素进行绑定和交互。例如&#xff0c;v-for指令用于循环渲染列表数…...

从文字到世界:2024外语阅读大赛报名开启,赛氪网全程护航

中国外文局CATTI项目管理中心与中国外文界联合宣布&#xff0c;将举办2024年外语阅读大赛&#xff0c;旨在激发外语学习兴趣&#xff0c;选拔并培养优秀的语言应用人才&#xff0c;同时向世界展示和传播具有中国特色的优秀文化。此次大赛旨在激发外语学习兴趣&#xff0c;选拔优…...

微信小程序知识点(二)

1.下拉刷新事件 如果页面需要下拉刷新功能&#xff0c;则在页面对应的json配置文件中&#xff0c;将enablePullDownRefresh配置设置为true&#xff0c;如下 {"usingComponents": {},"enablePullDownRefresh": true } 2.上拉触底事件 在很多时候&#x…...

Springcould -第一个Eureka应用 --- day02

标题 Eureka工作原理Spring Cloud框架下的服务发现Eureka包含两个组件&#xff0c;分别是&#xff1a;Eureka Server与Eureka Client。Eureka Server&#xff1a;Eureka Client&#xff1a; 搭建Eureka Server步骤&#xff1a;步骤1&#xff1a;创建项目&#xff0c;引入依赖步…...

RedissonClient 分布式队列工具类

注意&#xff1a;轻量级队列可以使用工具类&#xff0c;重量级数据量 请使用 MQ 本文章基于redis使用redisson客户端实现轻量级队列&#xff0c;以及代码、执行结果演示 一、常见队列了解 普通队列&#xff1a;先进先出&#xff08;FIFO&#xff09;&#xff0c;只能在一端添…...

protobuf使用

我下载的是 protobuf-27.4 以下使用vs2022 根据readme&#xff0c;执行如下命令 "C:\Program Files\CMake\bin\cmake.exe" -G "Visual Studio 17 2022" -DCMAKE_INSTALL_PREFIXC:\Users\x\Downloads\install C:\Users\x\Downloads\protobuf-27.4 -D…...

【微处理器系统原理与应用设计第十二讲】通用定时器设计二之PWM波实现呼吸灯的功能

一、基础知识 1、寄存器的配置 &#xff08;1&#xff09;GPIOX_AFRL&#xff1a;GPIO复用功能低位寄存器 GPIOX_AFRH&#xff1a;GPIO复用功能高位寄存器 &#xff08;2&#xff09;配置PA5 GPIOA->MODER&#xff08;端口模式寄存器&#xff09;&#xff0c;10为复用功…...

2025秋招NLP算法面试真题(十九)-大模型分布式训练题目

目录: 理论篇 1.1 训练大语言模型存在问题? 1.2 什么是点对点通信? 1.3 什么是集体通信? 1.4 什么是数据并行? 1.5 数据并行如何提升效率? 1.6 什么是流水线并行? 1.7 什么是张量并行 (intra-layer)? 1.8 数据并行 vs 张量并行 vs 流水线并行? 1.9 什么是3D并行? 1.1…...

线程池的应用

1.线程的执行机制 线程分为用户线程 和 内核线程 内核线程就是系统级别的线程&#xff0c;与cpu逻辑处理器数量对应的 用户线程就是使用java代码创建的Thread对象 用户线程必须与内核线程关联&#xff08;映射&#xff09;&#xff0c;才能执行任务 当用户线程多于内核线程时…...

OPenCV结构分析与形状描述符(5)查找图像中的连通组件的函数connectedComponents()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 connectedComponents 函数计算布尔图像的连通组件标签图像。 该函数接受一个具有4或8连通性的二值图像&#xff0c;并返回 N&#xff0c;即标签…...

HCIA--实验十三:VLAN间通信子接口实验/双单臂路由实验

一、实验内容 1.需求/要求&#xff1a; 将两个单臂路由通过两台交换机连接起来&#xff0c;成为双臂路由&#xff0c;并探讨这么做的原因。实现全网通&#xff0c;让任何一台主机之间都可以通信。 二、实验过程 1.拓扑图&#xff1a; 2.步骤&#xff1a; 1.给PC配置ip地址…...

AIStarter市场指南:项目分享与框架优化【AI绘画、写作、对话、办公、设计】

随着人工智能技术的飞速发展&#xff0c;越来越多的开发者和爱好者希望能够将自己的创意和项目分享给更多人。AIStarter作为一个专注于AI领域的平台&#xff0c;正致力于打造一个开放的应用市场&#xff0c;让创作者能够轻松分享他们的项目&#xff0c;同时也方便其他用户下载和…...

机器学习第8章 集成学习

目录 个体与集成BoostingBagging与随机森林Bagging随机森林 结合策略平均法投票法学习法 个体与集成 定义&#xff1a;集成学习&#xff0c;也叫多分类器系统、基于委员会的学习等&#xff0c;它是一种通过结合多个学习器来构建一个更强大的学习器的技术。如下图所示 在这里&a…...

京东鸿蒙上线前瞻——使用 Taro 打造高性能原生应用

背景 2024 年 1 月&#xff0c;京东正式启动鸿蒙原生应用开发&#xff0c;基于 HarmonyOS NEXT 的全场景、原生智能、原生安全等优势特性&#xff0c;为消费者打造更流畅、更智能、更安全的购物体验。同年 6 月&#xff0c;京东鸿蒙原生应用尝鲜版上架华为应用市场&#xff0c…...

day2 QT

作业 2> 在登录界面的登录取消按钮进行以下设置&#xff1a; 使用手动连接&#xff0c;将登录框中的取消按钮使用第2种方式的连接到自定义的槽函数中&#xff0c;在自定义的槽函数中调用关闭函数 将登录按钮使用qt4版本的连接到自定义的槽函数中&#xff0c;在槽函数中判断…...