当前位置: 首页 > article >正文

【动手学深度学习】现代卷积神经网络:ALexNet

【动手学深度学习】现代卷积神经网络:ALexNet

  • 1,ALexNet简介
  • 2,AlexNet和LeNet的对比
  • 3, AlexNet模型详细设计
  • 4,AlexNet采用ReLU激活函数
    • 4.1,ReLU激活函数
    • 4.2,sigmoid激活函数
    • 4.3,为什么采用ReLu做激活函数
  • 5,AlexNet的实现
    • 5.1,定义AlexNet模型
    • 5.2,卷积层输出矩阵形状计算
    • 5.3,池化层输出矩阵形状计算
    • 5.4,校验每层形状
    • 5.5,读取Fashion-MNIST数据集
    • 5.6,训练AlexNet模型


1,ALexNet简介

AlexNet是由多伦多大学的 Alex Krizhevsky 等人在2012年首次提出的。AlexNet不仅利用了GPU的强大计算能力,还引入了ReLU激活函数、Dropout正则化和重叠池化等创新技术,显著提升了模型的训练效率和准确性。

在AlexNet之前,尽管已经存在卷积神经网络(如LeNet),但它们的应用范围相对有限,且未能充分展示出深度学习相对于传统机器学习方法的优势。AlexNet的成功展示了深层网络的强大能力,特别是在处理复杂图像识别任务上的优越性能,在2012年的ImageNet大规模视觉识别挑战赛中,AlexNet以压倒性的优势夺冠,错误率远低于其他参赛者。这一胜利标志着深度学习时代的到来,激发了全球对AI研究的新一轮热潮,并催生了一系列基于其架构改进的先进模型。


2,AlexNet和LeNet的对比

如下图所示,AlexNet(右)本质上像是一个更深更大的LeNet(左),AlexNet在LeNet基础上做出的主要改进有:

  • 引入效果更好的ReLu激活函数(LeNet采用Sigmoid激活函数);
  • 池化层采用了最大池化(LeNet采用平均池化);
  • AlexNet采用Dropout正则化减少过拟合;
  • AlexNet的输入是像素更大的三通道彩色图片(LeNet是单通道灰度图片);

在这里插入图片描述


3, AlexNet模型详细设计

通过上图AlexNet的架构可以看到:

  • AlexNet的第一层,卷积核的形状是 11 × 11 11\times11 11×11。由于ImageNet数据集中大多数图像的宽和高比MNIST图像的多10倍以上,因此,需要一个更大的卷积窗口来捕获目标;
  • AlexNet第二层中的卷积核形状被缩减为 5 × 5 5\times5 5×5,然后是 3 × 3 3\times3 3×3。此外,在第一层、第二层和第五层卷积层之后,都加入窗口形状为 3 × 3 3\times3 3×3、步幅为2的最大池化层;
  • 由于AlexNet的输入数据更复杂,需要识别的模式更多,因此AlexNet的卷积输出通道数目远多于LeNet。比如第一个卷积层AleNet的输出通道数是96,LeNet的输出通道数是6;
  • 在最后一个卷积层后有两个全连接的隐藏层,分别有4096个输出;
  • 输出层也是一个全连接层。使用Softmax回归做输出,用于输出每个类别的概率分布。1000个输出对应ImageNet数据集的1000个类别;

4,AlexNet采用ReLU激活函数

AlexNet使用了ReLU激活函数。而不是采用LeNet中的Sigmoid激活函数


4.1,ReLU激活函数

ReLU函数被定义为该元素与0中的最大者。 ReLU函数的数学表达式如下:

ReLU(x) = max(x,0)​

ReLU函数图像为:

在这里插入图片描述


4.2,sigmoid激活函数

sigmoid函数能够将输入的实数值压缩到0和1之间。 sigmoid函数数学表达式为:

σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

sigmoid函数图像为:

在这里插入图片描述


4.3,为什么采用ReLu做激活函数

AlexNet使用ReLU激活函数,原因如下:

  • ReLU激活函数的计算更简单,它不需要如sigmoid激活函数那般复杂的求幂运算;
  • ReLU激活函数使训练模型更加容易。当sigmoid激活函数的输出非常接近于0或1时,这些区域的梯度几乎为0,从而使模型无法得到有效的训练。 因此反向传播无法继续更新一些模型参数。相反,ReLU激活函数在正区间的梯度总是1,可以有效缓解梯度消失问题

5,AlexNet的实现

接下来使用深度学习框架实现AlexNet模型,结合Fashion-MNIST图像分类数据集进行训练和测试。

Fashion-MNIST数据集基本情况如下:

  • 训练集:有60,000张图像,用于模型训练;
  • 测试集:有10,000张图像,用于评估模型性能;
  • 数据集由灰度图像组成,其通道数为1;
  • 每个图像的高度和宽度均为28像素;
  • 调用load_data_fashion_mnist()函数加载数据集;
  • Fashion-MNIST中包含的10个类别;

5.1,定义AlexNet模型

定义AlexNet模型

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(# 第一个卷积层:使用11*11的卷积核,步幅为4(以减少输出的高度和宽度)。# 输入通道为1(Fashion-MNIST数据集为灰度图像),输出通道的数目为96# 采用ReLU激活函数# 注意:padding=1表示上下左右各填充1nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),# 采用步幅为2,形状3×3的最大池化层nn.MaxPool2d(kernel_size=3, stride=2),# 第二个卷积层:使用5×5卷积核,填充为2保证输入与输出的高和宽一致# 输入通道为96,输出通道数增至256# 采用ReLU激活函数nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),# 依然采用步幅为2,形状3×3的最大池化层nn.MaxPool2d(kernel_size=3, stride=2),# 三个连续的卷积层:均使用和3×3的卷积核,采用填充为1 保证输入与输出的高和宽一致# 前两个卷积层的输出通道数为384,最后一个卷积层的输出通道数为256# 采用ReLU激活函数nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),# 三个卷积层之后,依然采用步幅为2,形状3×3的最大池化层nn.MaxPool2d(kernel_size=3, stride=2),# 将多维张量展平成一维向量,以便输入到全连接层。nn.Flatten(),# 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合# 设置p=0.5表示训练过程中,每层神经元有50%的概率被暂时从网络中丢弃(即其输出被置为零)。可强制模型学习更加鲁棒的特征,并减少对特定神经元的依赖nn.Linear(6400, 4096), nn.ReLU(), nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(p=0.5),# 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10nn.Linear(4096, 10))

5.2,卷积层输出矩阵形状计算

假设输入形状为 n h × n w n_h\times n_w nh×nw,卷积核形状为 k h × k w k_h\times k_w kh×kw

无填充、步幅默认为1时,输出矩阵形状为:

( n h − k h + 1 ) × ( n w − k w + 1 ) (n_h-k_h+1) \times (n_w-k_w+1) (nhkh+1)×(nwkw+1)

添加 p h p_h ph行填充(大约一半在顶部,一半在底部)和 p w p_w pw列填充(左侧大约一半,右侧一半),步幅为1时。输出矩阵形状为:

( n h − k h + p h + 1 ) × ( n w − k w + p w + 1 ) 。 (n_h-k_h+p_h+1)\times(n_w-k_w+p_w+1)。 (nhkh+ph+1)×(nwkw+pw+1)

添加 p h p_h ph行、 p w p_w pw列填充,设置垂直步幅 s h s_h sh,水平步幅 s w s_w sw时,输出矩阵形状为:

⌊ n h − k h + p h s h + 1 ⌋ × ⌊ n w − k w + p w s w + 1 ⌋ \lfloor\frac{n_h-k_h+p_h}{s_h}+1\rfloor \times \lfloor\frac{n_w-k_w+p_w}{s_w}+1\rfloor shnhkh+ph+1×swnwkw+pw+1

此时输出形状可换算为:

⌊ ( n h − k h + p h + s h ) / s h ⌋ × ⌊ ( n w − k w + p w + s w ) / s w ⌋ \lfloor(n_h-k_h+p_h+s_h)/s_h\rfloor \times \lfloor(n_w-k_w+p_w+s_w)/s_w\rfloor ⌊(nhkh+ph+sh)/sh×⌊(nwkw+pw+sw)/sw

如果我们设置了 p h = k h − 1 p_h=k_h-1 ph=kh1 p w = k w − 1 p_w=k_w-1 pw=kw1,则输出形状将简化为: ⌊ ( n h + s h − 1 ) / s h ⌋ × ⌊ ( n w + s w − 1 ) / s w ⌋ \lfloor(n_h+s_h-1)/s_h\rfloor \times \lfloor(n_w+s_w-1)/s_w\rfloor ⌊(nh+sh1)/sh×⌊(nw+sw1)/sw

更进一步,如果输入的高度和宽度可以被垂直和水平步幅整除,则输出形状将为: ( n h / s h ) × ( n w / s w ) (n_h/s_h) \times (n_w/s_w) (nh/sh)×(nw/sw)


5.3,池化层输出矩阵形状计算

假设输入形状为 n h × n w n_h\times n_w nh×nw,卷积核形状为 k h × k w k_h\times k_w kh×kw,池化操作一般无填充,池化输出的矩阵形状如下:
输出的高为: ⌊ n h − k h s h + 1 ⌋ 输出的高为:\lfloor\frac{n_h-k_h}{s_h}+1\rfloor 输出的高为:shnhkh+1
输出的宽为: ⌊ n w − k w s w + 1 ⌋ 输出的宽为:\lfloor\frac{n_w-k_w}{s_w}+1\rfloor 输出的宽为:swnwkw+1


5.4,校验每层形状

输出每一层的形状做检查

# 模拟形状为 (1, 1, 224, 224)的输入张量X,X的元素是从标准正态分布(均值为0,标准差为1)中随机抽取的
X = torch.randn(1, 1, 224, 224)
for layer in net:# 对于每一层,使用当前层处理输入张量 XX=layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)

运行结果如下:

在这里插入图片描述


5.5,读取Fashion-MNIST数据集

d2l包内部定义的load_data_fashion_mnist加载数据集

"""
下载Fashion-MNIST数据集,然后将其加载到内存中
参数resize表示调整图片大小
"""
def load_data_fashion_mnist(batch_size, resize=None): # trans是一个用于转换的 *列表*trans = [transforms.ToTensor()]if resize:    # resize不为空,表示需要调整图片大小trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

直接调用API加载数据

batch_size = 128
# 直接调用API获取
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

5.6,训练AlexNet模型

# 设置学习率核训练轮数
lr, num_epochs = 0.01, 10
# 训练模型
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

运行结果如下:

在这里插入图片描述

相关文章:

【动手学深度学习】现代卷积神经网络:ALexNet

【动手学深度学习】现代卷积神经网络:ALexNet 1,ALexNet简介2,AlexNet和LeNet的对比3, AlexNet模型详细设计4,AlexNet采用ReLU激活函数4.1,ReLU激活函数4.2,sigmoid激活函数4.3,为什…...

PyTorch深度学习框架60天进阶学习计划 - 第37天:元学习框架

PyTorch深度学习框架60天进阶学习计划 - 第37天:元学习框架 嘿,朋友们!欢迎来到我们PyTorch进阶之旅的第37天。今天我们将深入探索一个非常有趣且强大的领域——元学习(Meta-Learning),也被称为"学会学习"(Learning to…...

【中检在线-注册安全分析报告】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 1. 暴力破解密码,造成用户信息泄露 2. 短信盗刷的安全问题,影响业务及导致用户投诉 3. 带来经济损失,尤其是后付费客户,风险巨大,造…...

UE5 运行时动态将玩家手部模型设置为相机的子物体

在编辑器里,我们虽然可以手动添加相机,但是无法将网格体设置为相机的子物体,只能将相机设置为网格体的子物体 但是为了使用方便,我们希望将网格体设置为相机的子物体,这样我们直接旋转相机就可以旋转网格体&#xff0…...

EasyExcel-一款好用的excel生成工具

EasyExcel是一款处理excel的工具类,主要特点如下(官方): 特点 高性能读写:FastExcel 专注于性能优化,能够高效处理大规模的 Excel 数据。相比一些传统的 Excel 处理库,它能显著降低内存占用。…...

WEB攻防-Java安全JNDIRMILDAP五大不安全组件RCE执行不出网不回显

目录 1. RCE执行-5大类函数调用 1.1 Runtime方式 1.2 Groovy执行命令 1.3 脚本引擎代码注入 1.4 ProcessImpl 1.5 ProcessBuilder 2. JNDI注入(RCE)-RMI&LDAP&高版本 2.1 RMI服务中的JNDI注入场景 2.2 LDAP服务中的JNDI注入场景 攻击路径示例&#…...

UML组件图

一、UML 组件图 组件图(Component Diagram)主要用于描述系统的物理结构,用于展示可独立部署的软件模块(如微服务、动态链接库、API网关)及其交互关系。组件图中的主要元素包括: 组件(Component…...

DrissionPage移动端自动化:从H5到原生App的跨界测试

一、移动端自动化测试的挑战与机遇 移动端测试面临多维度挑战: 设备碎片化:Android/iOS版本、屏幕分辨率差异 混合应用架构:H5页面与原生组件的深度耦合 交互复杂性:多点触控、手势操作、传感器模拟 性能监控:内存…...

从 Excel 到你的表格应用:条件格式功能的嵌入实践指南

一、引言 在日常工作中,面对海量数据时,如何快速识别关键信息、发现数据趋势或异常值,是每个数据分析师面临的挑战。Excel的条件格式功能通过自动化的视觉标记,帮助用户轻松应对这一难题。 本文将详细介绍条件格式的应用场景&am…...

redis 和 MongoDB都可以存储键值对,并且值可以是复杂json,用完整例子分别展示说明两者在存储json键值对上的使用对比

Redis 存储 JSON 键值对示例 存储操作: // 存储用户信息(键:user:1001,值:JSON对象) SET user:1001 {"name":"Alice", "age":30, "address":"New York&quo…...

SQLI打靶

文章目录 一、DVWA0. Mysql与Mariasql1. 单/双引号 - 十六进制编码绕过**原理:** 2. limit 1的绕过3. 参数化查询绕过一、介绍二、PDO是一种PHP实现参数化查询的机制 三、预编译绕过 之 结构化参数 4. 反自动化手段 之 Anti-CSRF token静态:动态&#xf…...

STM32单片机入门学习——第22节: [7-2] AD单通道AD多通道

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难,但我还是想去做! 本文写于:2025.04.07 STM32开发板学习——第22节: [7-2] AD单通道&AD多通道 前言开发板说明引用解…...

python基础语法1:输入输出

1. 输出 (Output) 1.1 print() 基础 Python 使用 print() 函数向控制台输出内容。 # 输出字符串 print("Hello, World!") # 输出多个值(自动用空格分隔) print("Name:", "Alice", "Age:", 25) # 修改分隔符&…...

对Android中zygote的理解

1. Zygote的作用 Zygote是Android系统的核心进程,核心作用可归纳为以下三点: 核心作用详细说明进程孵化器作为所有应用进程的父进程,通过fork快速创建新进程(避免重复初始化虚拟机)。(system server也由z…...

【Survival Analysis】【机器学习】【1】

前言: 今年在做的一个博士课题项目,主要是利用病人的数据,训练出一个AI模型,做因果分析, 以及个性化治疗。自己一直是做通讯AI方向的,这个系列主要参考卡梅隆大学的教程,以及临床医生的角度 了…...

WebShell详解:原理、分类、攻击与防御

目录 一、WebShell的定义与核心概念 二、WebShell的分类 三、WebShell的攻击原理与常见手法 1. 攻击原理 2. 常见攻击路径 四、WebShell的危害 五、防御与检测策略 六、总结 一、WebShell的定义与核心概念 ​​WebShell​​是一种以ASP、PHP、JSP等网页脚本形式存在的恶…...

JavaScript---原型和原型链

目录 一、引用类型皆为对象 二、原型和原型链是什么 三、__proto__与prototype 总结 四、原型链顶层 五、constructor 六、函数对象的原型链 一、引用类型皆为对象 原型和原型链都是来源于对象而服务于对象: JavaScript中一切引用类型都是对象,…...

离散数学问题集--问题5.9

问题 5.9 综合了计算机组成原理、数字逻辑和离散数学中的关键概念,旨在帮助学生理解二进制算术运算的硬件实现、逻辑门与算术运算的关系,以及如何使用数学方法来验证数字系统的正确性。它强调了从规范到实现再到验证的完整过程。 思想 函数抽象&#xf…...

手游防DDoS攻击SDK接入

在手游中集成防DDoS攻击SDK是抵御流量型和应用层攻击的核心手段之一。以下从​​SDK选型、接入流程、防护策略优化​​三个维度提供完整指南,并附关键代码示例: ​​一、SDK选型与核心能力对比​​ ​​服务商​​​​优势​​​​劣势​​​​适用场景…...

Java—HTML:CSS选择器

今天我要介绍的知识点内容是Java HTML中的CSS选择器; CSS选择器用于定位HTML元素并为其添加样式。它允许我们控制网页的颜色、字体、布局和其他视觉元素。通过分离内容与样式。 下面我将介绍CSS中选择器的使用,并作举例说明; 选择器基本语…...

如何将/dev/ubuntu-vg/lv-data的空间扩展到/dev/ubuntu-vg/ubuntu-lv的空间上

要将 /dev/ubuntu-vg/lv-data 的空间扩展到 /dev/ubuntu-vg/ubuntu-lv 上,实际上是将 lv-data 的空间释放出来,并将其分配给 ubuntu-lv。以下是详细的步骤和操作说明: 已知信息 你有两个逻辑卷: /dev/ubuntu-vg/lv-data/dev/ubun…...

SSM阶段性总结

0 Pojo类 前端给后端:DTO 后端给前端:VO 数据库:PO/VO 业务处理逻辑:BO 统称pojo 1 代理模式 实现静态代理: 1定义接口2实现类3写一个静态代理类4这样在调用时就可以使用这个静态代理类来实现某些功能 实现动态代…...

Qt 5.14.2入门(一)写个Hello Qt!程序

目录 参考链接:一、新建项目二、直接运行三、修改代码增加窗口内容1、Qt 显示一个 QLabel 标签控件窗口2、添加按键 参考链接: Qt5教程(一):Hello World 程序 Qt 编程指南 一、新建项目 1、新建一个项目&#xff08…...

Jmeter分布式测试启动

代理客户端配置 打开jmeter.properties文件,取消注释并设置端口(如server_port1099), 并添加server.rmi.ssl.disabletrue禁用SSL加密。 (Linux系统)修改jmeter-server文件中的RMI_HOST_DEF为代理机实际IP。…...

redis itheima

缓存问题 核心是如何避免大量请求到达数据库 缓存穿透 既不存在于 redis,也不存在于 mysql 的key,被重复请求 public Result queryById(Long id) {String key CACHE_SHOP_KEYid;// 1. redis & mysqlString shopJson stringRedisTemplate.opsFo…...

mysql 执行计划中eq_ref是什么意思?

在 MySQL 的执行计划中,eq_ref 是一种连接类型(type),表示查询优化器在使用**主键(PRIMARY KEY)或唯一索引(UNIQUE INDEX)**进行等值匹配()时,对表…...

QT 调用动态链接库

引入QT提供的动态加载库的类 #include <QLibrary>定义函数指针类型 typedef void (*GetResFunction)(uint8_t*, uint8_t*, int);定义函数指针的主要目的是为了解析和调用动态链接库中的函数。如果你不定义函数指针&#xff0c;就无法直接调用动态链接库中的函数 加载动…...

100天精通Python(爬虫篇)——第122天:基于selenium接管已启动的浏览器(反反爬策略)

文章目录 1、问题描述2、问题推测3、解决方法3.1 selenium自动启动浏览器3.2 selenium接管已启动的浏览器3.3 区别总结 4、代码实战4.1 手动方法&#xff08;手动打开浏览器输入账号密码&#xff09;4.2 自动方法&#xff08;.bat文件启动的浏览器&#xff09; 1、问题描述 使用…...

MPP 架构解析:原理、核心优势与对比指南

一、引言&#xff1a;大数据时代的数据处理挑战 全球数据量正以指数级增长。据 Statista 统计&#xff0c;2010 年全球数据量仅 2ZB&#xff0c;2025 年预计达 175ZB。企业面临的核心挑战已从“如何存储数据”转向“如何快速分析数据”。传统架构在处理海量数据时暴露明显瓶颈…...

GitHub 趋势日报 (2025年04月06日)

GitHub 趋势日报 (2025年04月06日) 本日报由 TrendForge 系统生成 https://trendforge.devlive.org/ &#x1f4c8; 今日整体趋势 Top 10 排名项目名称项目描述今日获星语言1microsoft/markitdownPython tool for converting files and office documents to Markdown.⭐ 548Py…...