当前位置: 首页 > news >正文

ResNet 原理剖析以及代码复现

原理

ResNet 解决了什么问题?

一言以蔽之:解决了深度的神经网络难以训练的问题。
具体的说,理论上神经网络的深度越深,其训练效果应该越好,但实际上并非如此,层数越深会导致越差的结果并且容易产生梯度爆炸或梯度消失等问题。

ResNet 怎么解决的?

提出了一个残差学习网络的框架,该框架解决了上述问题。

残差网络的架构

在这里插入图片描述

整个架构如上图所示。

首先我们要学习的东西是 H(x),假设现在已经有了一个浅的网络,然后我们要在上面新加一些层,让网络变得更深,如果按传统的做法那么新加的层就继续跟之前一样进行学习就行了。但是现在在新加的层中我们不直接去学 H(x),而是应该去学 H(x) - x。x 就是之前比较浅的网络已经学到的那个东西,也就是在新加的层中不去重新学个东西而只是学学到的东西和真实的东西二者之间的残差 H(x) - x,然后该层最后的输出结果 F(x) 再加上原始数据 x 就是最终结果也就是 F(x) + x ,此时优化的目标就不再是原始的 H(x),而是 H(x) - x 这个东西。

这就是 ResNet 的核心思想。

我感觉有一篇文章讲的很好,可以参考一下:ResNet网络详细讲解

下面是论文原文的描述:

在本文中,我们通过引入一个深度残差学习框架来解决退化问题。我们不希望每几个堆叠层直接拟合一个期望的底层映射,而是明确地让这些层拟合一个残差映射。在形式上,我们将期望的底层映射表示为H ( x ),并让堆叠的非线性层拟合F ( x )的另一个映射:= H ( x ) - x。原始映射被重铸成F ( x ) + x。我们假设优化残差映射比优化原始的、未引用的映射更容易。在极端情况下,如果一个恒等映射是最优的,那么将残差推到零比用一堆非线性层拟合一个恒等映射更容易。

F ( x ) + x的表达式可以通过具有"捷径连接"的前馈神经网络来实现(图2 )。快捷方式连接[ 2、33、48]是那些跳过一个或多个层的连接。在我们的例子中,快捷连接只是执行身份映射,它们的输出被添加到堆叠层的输出中(图2 )。身份捷径连接既不增加额外的参数,也不增加计算复杂度。整个网络仍然可以通过反向传播的SGD进行端到端的训练,并且可以很容易地使用公共库(例如, Caffe )实现,无需修改求解器。

我们在ImageNet [ 35 ]上进行了全面的实验来展示退化问题并评估我们的方法。研究表明:

1 )我们的深度残差网络易于优化,但对应的"普通"网络(简单地堆叠层)在深度增加时表现出更高的训练误差;
2 )我们的深度残差网络可以很容易地从大幅增加的深度中获得精度增益,产生的结果明显优于以前的网络。

在ImageNet分类数据集上[ 35 ],我们通过极深的残差网络获得了优异的结果。我们的152层残差网络是ImageNet上有史以来最深层的网络,但仍比VGG网络具有更低的复杂度[ 40 ]。我们的集成在ImageNet测试集上有3.57 %的top - 5误差,并在ILSVRC 2015分类竞赛中获得第一名。在其他识别任务上也具有出色的泛化性能,并引领我们在ILSVRC & COCO 2015竞赛中进一步获得第1名:ImageNet检测、ImageNet定位、COCO检测和COCO分割。这有力的证据表明,残差学习原理具有一般性,我们预期它在其他视觉和非视觉问题中也适用。

代码复现

这里给出我自己的模型代码:

import torch
from torch import nn# 基本残差块
class BasicBlock(nn.Module):expansion = 1"""参数解释:in_ch:输入通道数block_ch:输出通道数stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:基本残差块:输入输出通道数相同瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加另外二者的结构不同,可以通过论文看到"""def __init__(self, in_ch, block_ch, stride=1, downSample=None):super().__init__()self.downSample = downSample# 从网络结构图中可以看到,先进行第一层卷积self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=3, stride=stride, padding=1, bias=False)# 在网络模型中添加一个二维批归一化(Batch Normalization)层。# 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理self.bn1 = nn.BatchNorm2d(block_ch)# 激活函数self.relu1 = nn.ReLU()# 第二层卷积self.conv2 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(block_ch * self.expansion)self.relu2 = nn.ReLU()def forward(self, x):identity = x# 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)if self.downSample is not None:# 升维,让输入输出的通道数对齐identity = self.downSample(x)out = self.relu1(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# 这里就是论文中的输出与原始输入进对位相加的步骤out += identity# 对位相加结束后再进行 relu 函数的激活,然后输出结果return self.relu2(out)# 瓶颈残差块
class Bottleneck(nn.Module):# 从论文的网络结构图中不难发现,瓶颈残差块在第三层卷积时通道数会放大四倍# 因此定义一个 expansion 变量expansion = 4"""参数解释:in_ch:输入通道数block_ch:输出通道数stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:基本残差块:输入输出通道数相同瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加另外二者的结构不同,可以通过论文看到"""def __init__(self, in_ch, block_ch, stride=1, downSample=None):super().__init__()self.downSample = downSample# 从网络结构图中可以看到,先进行第一层卷积self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=1, stride=stride, bias=False)# 在网络模型中添加一个二维批归一化(Batch Normalization)层。# 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理self.bn1 = nn.BatchNorm2d(block_ch)# 激活函数self.relu1 = nn.ReLU()# 第二层卷积self.conv2 = nn.Conv2d(block_ch, block_ch, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(block_ch)self.relu2 = nn.ReLU()# 第三层卷积self.conv3 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(block_ch * self.expansion)self.relu3 = nn.ReLU()def forward(self, x):identity = x# 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)if self.downSample is not None:# 升维,让输入输出的通道数对齐identity = self.downSample(x)out = self.relu1(self.bn1(self.conv1(x)))out = self.relu2(self.bn2(self.conv2(out)))out = self.bn3(self.conv3(out))# 这里就是论文中的输出与原始输入进对位相加的步骤out += identity# 对位相加结束后再进行 relu 函数的激活,然后输出结果return self.relu3(out)# 残差网络
class ResNet(nn.Module):"""in_ch: 默认为3,因为残差网络就是用来图片分类的,所以输入通道数默认为 3num_classes:分类的数量,默认设置为100,即 100 种分类block:用来区别是 基本残差块 还是 瓶颈残差块block_num:每个残差块所需要堆叠的次数(也是论文中提供的有)"""def __init__(self, in_ch=3, num_classes=100, block=Bottleneck, block_num=[3, 4, 6, 3]):super().__init__()# 因为在各层之间通道数会发生变化,因此要进行跟踪self.in_ch = in_ch# 对于残差网络来说,不管是什么类型其一开始都要进行 7x7 的卷积和 3x3 的池化# 因此我们直接照搬即可(论文中已经有了)self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.in_ch = 64# 将残差块堆叠起来,形成一个一个的残差层,进而构建成ResNetself.layer1 = self._make_layer(block, 64, block_num[0], stride=1)self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)# 最后是全连接层,做预测的self.fc_layer = nn.Sequential(nn.Linear(512*block.expansion*7*7, num_classes),nn.Softmax(dim=-1))def _make_layer(self, block, block_ch, block_num, stride=2):layers = []downSample = nn.Conv2d(self.in_ch, block_ch * block.expansion, kernel_size=1, stride=stride)layers += [block(self.in_ch, block_ch, stride=stride, downSample=downSample)]self.in_ch = block_ch * block.expansionfor _ in range(1, block_num):layers += [block(self.in_ch, block_ch)]return nn.Sequential(*layers)def forward(self, x):out = self.maxpool1(self.bn1(self.conv1(x))) #(1, 3, 224, 224) -> (1, 64, 56, 56)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = out.reshape(out.shape[0], -1)out = self.fc_layer(out)return outif __name__ == '__main__':# 接下来进行测试# 这行代码创建了一个形状为 (1, 3, 224, 224) 的四维张量 x,# 其中包含了一个大小为 1 的批次中的一个 224x224 像素的 RGB 图像。x = torch.randn(1, 3, 224, 224)resnet = ResNet(in_ch=3, num_classes=100, block=Bottleneck, block_num=[2, 2, 2, 2])y = resnet(x)print(y.shape)

相关文章:

ResNet 原理剖析以及代码复现

原理 ResNet 解决了什么问题? 一言以蔽之:解决了深度的神经网络难以训练的问题。 具体的说,理论上神经网络的深度越深,其训练效果应该越好,但实际上并非如此,层数越深会导致越差的结果并且容易产生梯度爆炸…...

数据结构(十)图

文章目录 图的简介图的定义图的结构图的分类无向图有向图带权图(Wighted Graph) 图的存储邻接矩阵(Adjacency Matrix)邻接表代码实现 图的遍历深度优先搜索(DFS,Depth Fisrt Search)遍历抖索过程…...

四数之和-力扣

本题在三数之和的基础上&#xff0c;再增加一重循环进行解答 首先注意的点是&#xff0c;一级剪枝处理&#xff0c;target > 0 && nums[i] > target 此处只有整数才可剪枝处理&#xff0c;如果target为负数&#xff0c;nums[i] < target&#xff0c;也不能代…...

JS 中怎么删除数组元素?有哪几种方法?

正文开始之前推荐一位宝藏博主免费分享的学习教程,学起来! 编号学习链接1Cesium: 保姆级教程+源码示例2openlayers: 保姆级教程+源码示例3Leaflet: 保姆级教程+源码示例4MapboxGL: 保姆级教程+源码示例splice() JavaScript中的splice()方法是一个内置的数组对象函数, 用于…...

Git如何将pre-commit也提交到仓库

我一开始准备将pre-commit提交到仓库进行备份的&#xff0c;但是却发现提交不了&#xff0c;即使我使用强制提交都不行。 (main) $ git add ./.git/hooks/pre-commit(main) $ git status On branch main nothing to commit, working tree clean# 强制提交(main) $ git add -f .…...

vmware中Ubuntu虚拟机和本地电脑Win10互相ping通

初始状态 使用vmware17版本安装的Ubuntu的20版本&#xff0c;安装之后什么配置都要不懂&#xff0c;然后进行下述配置。 初始的时候是NAT&#xff0c;没动的. 设置 点击右键编辑“属性” 常规选择“启用”&#xff1a; 高级选择全部&#xff1a; 打开网络配置&#xff0c;右键属…...

比较含退格的字符串-力扣

做这道题时出现了许多问题 第一次做题思路是使用双指针去解决&#xff0c;快慢指针遇到字母则前进&#xff0c;遇到 # 则慢指针退1&#xff0c;最开始并未考虑到 slowindex < 0 ,从而导致越界。第二个问题在于&#xff0c;在最后判断两个字符串是否相同时&#xff0c;最初使…...

NSSCTF-Web题目4

[SWPUCTF 2021 新生赛]hardrce 1、题目 2、知识点 rce&#xff1a;远程代码执行、url取反编码 3、解题思路 打开题目 出现一段代码&#xff0c;审计源代码 题目需要我们通过get方式输入变量wllm的值 但是变量的值被过滤了&#xff0c;不能输入字母和\t、\n等值 所以我们需…...

7. CSS 网格布局

CSS3引入了强大的网格布局&#xff08;Grid Layout&#xff09;&#xff0c;它提供了一种二维的布局方式&#xff0c;使得创建复杂的网页布局变得更加简单和直观。通过定义行和列&#xff0c;我们可以精确控制网页元素的排列和对齐。本章将详细介绍网格布局的基本概念和属性&am…...

如何配置才能连接远程服务器上的 redis server ?

文章目录 Intro修改点 Intro 以阿里云服为例。 首先&#xff0c;我在我买的阿里云服务器中以下载源码、手动编译的方式安装了 redis-server&#xff0c;操作流程见&#xff1a;Ubuntu redis 下载解压配置使用及密码管理 && 包管理工具联网安装。 接着&#xff0c;我…...

MindSpore实践图神经网络之环境篇

MindSpore在Windows11系统下的环境配置。 MindSpore环境配置大概分为三步&#xff1a;&#xff08;1&#xff09;安装Python环境&#xff0c;&#xff08;2&#xff09;安装MindSpore&#xff0c;&#xff08;3&#xff09;验证是否成功 如果是GPU环境还需安装CUDA等环境&…...

MVS net笔记和理解

文章目录 传统的方法有什么缺陷吗&#xff1f;MVSnet深度的预估 传统的方法有什么缺陷吗&#xff1f; 传统的mvs算法它对图像的光照要求相对较高&#xff0c;但是在实际中要保证照片的光照效果很好是很难的。所以传统算法对镜面反射&#xff0c;白墙这种的重建效果就比较差。 …...

Linux 编译屏障之 ACCESS_ONCE()

文章目录 1. 前言2. 背景3. 为什么要有 ACCESS_ONCE() &#xff1f;4. ACCESS_ONCE() 代码实现5. ACCESS_ONCE() 实例分析6. ACCESS() 的演进7. 结语8. 参考资料 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做…...

Discuz!X3.4论坛网站公安备案号怎样放到网站底部?

Discuz&#xff01;网站的工信部备案号都知道在后台——全局——站点信息——网站备案信息代码填写&#xff0c;那公安备案号要添加在哪里呢&#xff1f;并没有看到公安备案号填写栏&#xff0c;今天驰网飞飞和你分享 1&#xff09;工信部备案号和公安备案号统一填写到网站备案…...

LPDDR6带宽预计将翻倍增长:应对低功耗挑战与AI时代能源需求激增

在当前科技发展的背景下&#xff0c;低能耗问题成为了业界关注的焦点。国际能源署(IEA)近期报告显示&#xff0c;日常的数字活动对电力消耗产生显著影响——每次Google搜索平均消耗0.3瓦时&#xff08;Wh&#xff09;&#xff0c;而向OpenAI的ChatGPT提出的每一次请求则消耗2.9…...

云原生架构内涵_3.主要架构模式

云原生架构有非常多的架构模式&#xff0c;这里列举一些对应用收益更大的主要架构模式&#xff0c;如服务化架构模式、Mesh化架构模式、Serverless模式、存储计算分离模式、分布式事务模式、可观测架构、事件驱动架构等。 1.服务化架构模式 服务化架构是云时代构建云原生应用的…...

宏基因组分析流程(Metagenomic workflow)202405|持续更新

Logs 增加R包pctax内的一些帮助上游分析的小脚本&#xff08;2024.03.03&#xff09;增加Mmseqs2用于去冗余&#xff0c;基因聚类的速度非常快&#xff0c;且随序列量线性增长&#xff08;2024.03.12&#xff09;更新全文细节&#xff08;2024.05.29&#xff09; 注意&#x…...

一千题,No.0037(组个最小数)

给定数字 0-9 各若干个。你可以以任意顺序排列这些数字&#xff0c;但必须全部使用。目标是使得最后得到的数尽可能小&#xff08;注意 0 不能做首位&#xff09;。例如&#xff1a;给定两个 0&#xff0c;两个 1&#xff0c;三个 5&#xff0c;一个 8&#xff0c;我们得到的最…...

PV PVC

默写 1 如何将pod创建在指定的Node节点上 node亲和、pod亲和、pod反亲和: 调度策略 匹配标签 操作符 nodeAffinity 主机 In,NotIn,Exists,DoesNotExist&#xff0c;Gt&#xff0c;Lt podAffinity …...

深入理解Nginx配置文件:全面指南

Nginx 是一个高性能的 HTTP 服务器和反向代理服务器&#xff0c;也是一个电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器。由于其高效性和灵活性&#xff0c;Nginx 被广泛应用于各种 web 服务中。本文将详细介绍 Nginx 配置文件的结构和主要配置项&#xff0c;帮助你深入…...

Vim 调用外部命令学习笔记

Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中&#xff0c;Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染&#xff08;即CPU被阻塞&#xff09;&#xff0c;这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案&#xff1a; 对惹&#xff0c;这里有一个游戏开发交流小组&…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明&#xff1a;server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

Windows安装Miniconda

一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

【Linux】自动化构建-Make/Makefile

前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具&#xff1a;make/makfile 1.背景 在一个工程中源文件不计其数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…...

API网关Kong的鉴权与限流:高并发场景下的核心实践

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 引言 在微服务架构中&#xff0c;API网关承担着流量调度、安全防护和协议转换的核心职责。作为云原生时代的代表性网关&#xff0c;Kong凭借其插件化架构…...

保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!

目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...