当前位置: 首页 > news >正文

ResNet 原理剖析以及代码复现

原理

ResNet 解决了什么问题?

一言以蔽之:解决了深度的神经网络难以训练的问题。
具体的说,理论上神经网络的深度越深,其训练效果应该越好,但实际上并非如此,层数越深会导致越差的结果并且容易产生梯度爆炸或梯度消失等问题。

ResNet 怎么解决的?

提出了一个残差学习网络的框架,该框架解决了上述问题。

残差网络的架构

在这里插入图片描述

整个架构如上图所示。

首先我们要学习的东西是 H(x),假设现在已经有了一个浅的网络,然后我们要在上面新加一些层,让网络变得更深,如果按传统的做法那么新加的层就继续跟之前一样进行学习就行了。但是现在在新加的层中我们不直接去学 H(x),而是应该去学 H(x) - x。x 就是之前比较浅的网络已经学到的那个东西,也就是在新加的层中不去重新学个东西而只是学学到的东西和真实的东西二者之间的残差 H(x) - x,然后该层最后的输出结果 F(x) 再加上原始数据 x 就是最终结果也就是 F(x) + x ,此时优化的目标就不再是原始的 H(x),而是 H(x) - x 这个东西。

这就是 ResNet 的核心思想。

我感觉有一篇文章讲的很好,可以参考一下:ResNet网络详细讲解

下面是论文原文的描述:

在本文中,我们通过引入一个深度残差学习框架来解决退化问题。我们不希望每几个堆叠层直接拟合一个期望的底层映射,而是明确地让这些层拟合一个残差映射。在形式上,我们将期望的底层映射表示为H ( x ),并让堆叠的非线性层拟合F ( x )的另一个映射:= H ( x ) - x。原始映射被重铸成F ( x ) + x。我们假设优化残差映射比优化原始的、未引用的映射更容易。在极端情况下,如果一个恒等映射是最优的,那么将残差推到零比用一堆非线性层拟合一个恒等映射更容易。

F ( x ) + x的表达式可以通过具有"捷径连接"的前馈神经网络来实现(图2 )。快捷方式连接[ 2、33、48]是那些跳过一个或多个层的连接。在我们的例子中,快捷连接只是执行身份映射,它们的输出被添加到堆叠层的输出中(图2 )。身份捷径连接既不增加额外的参数,也不增加计算复杂度。整个网络仍然可以通过反向传播的SGD进行端到端的训练,并且可以很容易地使用公共库(例如, Caffe )实现,无需修改求解器。

我们在ImageNet [ 35 ]上进行了全面的实验来展示退化问题并评估我们的方法。研究表明:

1 )我们的深度残差网络易于优化,但对应的"普通"网络(简单地堆叠层)在深度增加时表现出更高的训练误差;
2 )我们的深度残差网络可以很容易地从大幅增加的深度中获得精度增益,产生的结果明显优于以前的网络。

在ImageNet分类数据集上[ 35 ],我们通过极深的残差网络获得了优异的结果。我们的152层残差网络是ImageNet上有史以来最深层的网络,但仍比VGG网络具有更低的复杂度[ 40 ]。我们的集成在ImageNet测试集上有3.57 %的top - 5误差,并在ILSVRC 2015分类竞赛中获得第一名。在其他识别任务上也具有出色的泛化性能,并引领我们在ILSVRC & COCO 2015竞赛中进一步获得第1名:ImageNet检测、ImageNet定位、COCO检测和COCO分割。这有力的证据表明,残差学习原理具有一般性,我们预期它在其他视觉和非视觉问题中也适用。

代码复现

这里给出我自己的模型代码:

import torch
from torch import nn# 基本残差块
class BasicBlock(nn.Module):expansion = 1"""参数解释:in_ch:输入通道数block_ch:输出通道数stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:基本残差块:输入输出通道数相同瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加另外二者的结构不同,可以通过论文看到"""def __init__(self, in_ch, block_ch, stride=1, downSample=None):super().__init__()self.downSample = downSample# 从网络结构图中可以看到,先进行第一层卷积self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=3, stride=stride, padding=1, bias=False)# 在网络模型中添加一个二维批归一化(Batch Normalization)层。# 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理self.bn1 = nn.BatchNorm2d(block_ch)# 激活函数self.relu1 = nn.ReLU()# 第二层卷积self.conv2 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(block_ch * self.expansion)self.relu2 = nn.ReLU()def forward(self, x):identity = x# 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)if self.downSample is not None:# 升维,让输入输出的通道数对齐identity = self.downSample(x)out = self.relu1(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# 这里就是论文中的输出与原始输入进对位相加的步骤out += identity# 对位相加结束后再进行 relu 函数的激活,然后输出结果return self.relu2(out)# 瓶颈残差块
class Bottleneck(nn.Module):# 从论文的网络结构图中不难发现,瓶颈残差块在第三层卷积时通道数会放大四倍# 因此定义一个 expansion 变量expansion = 4"""参数解释:in_ch:输入通道数block_ch:输出通道数stride:步长,通过该参数我们就可以实现网络结构中特征图Size减半、通道数增加一倍的效果downSample:其本身也是一个网络,用来实现残差网络中的跳跃连接(也就是论文中虚线和实线)同时跳跃连接也是用来区别基本残差块和瓶颈残差块的,二者区别如下:基本残差块:输入输出通道数相同瓶颈残差块:输入输出通道数不同,需要进行升维操作才能对位相加另外二者的结构不同,可以通过论文看到"""def __init__(self, in_ch, block_ch, stride=1, downSample=None):super().__init__()self.downSample = downSample# 从网络结构图中可以看到,先进行第一层卷积self.conv1 = nn.Conv2d(in_ch, block_ch, kernel_size=1, stride=stride, bias=False)# 在网络模型中添加一个二维批归一化(Batch Normalization)层。# 批归一化是一种用于加速神经网络训练并提高其性能的技术,类似于将上面所输出的数据进行了统一整理self.bn1 = nn.BatchNorm2d(block_ch)# 激活函数self.relu1 = nn.ReLU()# 第二层卷积self.conv2 = nn.Conv2d(block_ch, block_ch, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(block_ch)self.relu2 = nn.ReLU()# 第三层卷积self.conv3 = nn.Conv2d(block_ch, block_ch * self.expansion, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(block_ch * self.expansion)self.relu3 = nn.ReLU()def forward(self, x):identity = x# 如果downSample参数不为空,说明其需要升维(也就是论文中虚线的样子)if self.downSample is not None:# 升维,让输入输出的通道数对齐identity = self.downSample(x)out = self.relu1(self.bn1(self.conv1(x)))out = self.relu2(self.bn2(self.conv2(out)))out = self.bn3(self.conv3(out))# 这里就是论文中的输出与原始输入进对位相加的步骤out += identity# 对位相加结束后再进行 relu 函数的激活,然后输出结果return self.relu3(out)# 残差网络
class ResNet(nn.Module):"""in_ch: 默认为3,因为残差网络就是用来图片分类的,所以输入通道数默认为 3num_classes:分类的数量,默认设置为100,即 100 种分类block:用来区别是 基本残差块 还是 瓶颈残差块block_num:每个残差块所需要堆叠的次数(也是论文中提供的有)"""def __init__(self, in_ch=3, num_classes=100, block=Bottleneck, block_num=[3, 4, 6, 3]):super().__init__()# 因为在各层之间通道数会发生变化,因此要进行跟踪self.in_ch = in_ch# 对于残差网络来说,不管是什么类型其一开始都要进行 7x7 的卷积和 3x3 的池化# 因此我们直接照搬即可(论文中已经有了)self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.in_ch = 64# 将残差块堆叠起来,形成一个一个的残差层,进而构建成ResNetself.layer1 = self._make_layer(block, 64, block_num[0], stride=1)self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)# 最后是全连接层,做预测的self.fc_layer = nn.Sequential(nn.Linear(512*block.expansion*7*7, num_classes),nn.Softmax(dim=-1))def _make_layer(self, block, block_ch, block_num, stride=2):layers = []downSample = nn.Conv2d(self.in_ch, block_ch * block.expansion, kernel_size=1, stride=stride)layers += [block(self.in_ch, block_ch, stride=stride, downSample=downSample)]self.in_ch = block_ch * block.expansionfor _ in range(1, block_num):layers += [block(self.in_ch, block_ch)]return nn.Sequential(*layers)def forward(self, x):out = self.maxpool1(self.bn1(self.conv1(x))) #(1, 3, 224, 224) -> (1, 64, 56, 56)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = out.reshape(out.shape[0], -1)out = self.fc_layer(out)return outif __name__ == '__main__':# 接下来进行测试# 这行代码创建了一个形状为 (1, 3, 224, 224) 的四维张量 x,# 其中包含了一个大小为 1 的批次中的一个 224x224 像素的 RGB 图像。x = torch.randn(1, 3, 224, 224)resnet = ResNet(in_ch=3, num_classes=100, block=Bottleneck, block_num=[2, 2, 2, 2])y = resnet(x)print(y.shape)

相关文章:

ResNet 原理剖析以及代码复现

原理 ResNet 解决了什么问题? 一言以蔽之:解决了深度的神经网络难以训练的问题。 具体的说,理论上神经网络的深度越深,其训练效果应该越好,但实际上并非如此,层数越深会导致越差的结果并且容易产生梯度爆炸…...

数据结构(十)图

文章目录 图的简介图的定义图的结构图的分类无向图有向图带权图(Wighted Graph) 图的存储邻接矩阵(Adjacency Matrix)邻接表代码实现 图的遍历深度优先搜索(DFS,Depth Fisrt Search)遍历抖索过程…...

四数之和-力扣

本题在三数之和的基础上&#xff0c;再增加一重循环进行解答 首先注意的点是&#xff0c;一级剪枝处理&#xff0c;target > 0 && nums[i] > target 此处只有整数才可剪枝处理&#xff0c;如果target为负数&#xff0c;nums[i] < target&#xff0c;也不能代…...

JS 中怎么删除数组元素?有哪几种方法?

正文开始之前推荐一位宝藏博主免费分享的学习教程,学起来! 编号学习链接1Cesium: 保姆级教程+源码示例2openlayers: 保姆级教程+源码示例3Leaflet: 保姆级教程+源码示例4MapboxGL: 保姆级教程+源码示例splice() JavaScript中的splice()方法是一个内置的数组对象函数, 用于…...

Git如何将pre-commit也提交到仓库

我一开始准备将pre-commit提交到仓库进行备份的&#xff0c;但是却发现提交不了&#xff0c;即使我使用强制提交都不行。 (main) $ git add ./.git/hooks/pre-commit(main) $ git status On branch main nothing to commit, working tree clean# 强制提交(main) $ git add -f .…...

vmware中Ubuntu虚拟机和本地电脑Win10互相ping通

初始状态 使用vmware17版本安装的Ubuntu的20版本&#xff0c;安装之后什么配置都要不懂&#xff0c;然后进行下述配置。 初始的时候是NAT&#xff0c;没动的. 设置 点击右键编辑“属性” 常规选择“启用”&#xff1a; 高级选择全部&#xff1a; 打开网络配置&#xff0c;右键属…...

比较含退格的字符串-力扣

做这道题时出现了许多问题 第一次做题思路是使用双指针去解决&#xff0c;快慢指针遇到字母则前进&#xff0c;遇到 # 则慢指针退1&#xff0c;最开始并未考虑到 slowindex < 0 ,从而导致越界。第二个问题在于&#xff0c;在最后判断两个字符串是否相同时&#xff0c;最初使…...

NSSCTF-Web题目4

[SWPUCTF 2021 新生赛]hardrce 1、题目 2、知识点 rce&#xff1a;远程代码执行、url取反编码 3、解题思路 打开题目 出现一段代码&#xff0c;审计源代码 题目需要我们通过get方式输入变量wllm的值 但是变量的值被过滤了&#xff0c;不能输入字母和\t、\n等值 所以我们需…...

7. CSS 网格布局

CSS3引入了强大的网格布局&#xff08;Grid Layout&#xff09;&#xff0c;它提供了一种二维的布局方式&#xff0c;使得创建复杂的网页布局变得更加简单和直观。通过定义行和列&#xff0c;我们可以精确控制网页元素的排列和对齐。本章将详细介绍网格布局的基本概念和属性&am…...

如何配置才能连接远程服务器上的 redis server ?

文章目录 Intro修改点 Intro 以阿里云服为例。 首先&#xff0c;我在我买的阿里云服务器中以下载源码、手动编译的方式安装了 redis-server&#xff0c;操作流程见&#xff1a;Ubuntu redis 下载解压配置使用及密码管理 && 包管理工具联网安装。 接着&#xff0c;我…...

MindSpore实践图神经网络之环境篇

MindSpore在Windows11系统下的环境配置。 MindSpore环境配置大概分为三步&#xff1a;&#xff08;1&#xff09;安装Python环境&#xff0c;&#xff08;2&#xff09;安装MindSpore&#xff0c;&#xff08;3&#xff09;验证是否成功 如果是GPU环境还需安装CUDA等环境&…...

MVS net笔记和理解

文章目录 传统的方法有什么缺陷吗&#xff1f;MVSnet深度的预估 传统的方法有什么缺陷吗&#xff1f; 传统的mvs算法它对图像的光照要求相对较高&#xff0c;但是在实际中要保证照片的光照效果很好是很难的。所以传统算法对镜面反射&#xff0c;白墙这种的重建效果就比较差。 …...

Linux 编译屏障之 ACCESS_ONCE()

文章目录 1. 前言2. 背景3. 为什么要有 ACCESS_ONCE() &#xff1f;4. ACCESS_ONCE() 代码实现5. ACCESS_ONCE() 实例分析6. ACCESS() 的演进7. 结语8. 参考资料 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做…...

Discuz!X3.4论坛网站公安备案号怎样放到网站底部?

Discuz&#xff01;网站的工信部备案号都知道在后台——全局——站点信息——网站备案信息代码填写&#xff0c;那公安备案号要添加在哪里呢&#xff1f;并没有看到公安备案号填写栏&#xff0c;今天驰网飞飞和你分享 1&#xff09;工信部备案号和公安备案号统一填写到网站备案…...

LPDDR6带宽预计将翻倍增长:应对低功耗挑战与AI时代能源需求激增

在当前科技发展的背景下&#xff0c;低能耗问题成为了业界关注的焦点。国际能源署(IEA)近期报告显示&#xff0c;日常的数字活动对电力消耗产生显著影响——每次Google搜索平均消耗0.3瓦时&#xff08;Wh&#xff09;&#xff0c;而向OpenAI的ChatGPT提出的每一次请求则消耗2.9…...

云原生架构内涵_3.主要架构模式

云原生架构有非常多的架构模式&#xff0c;这里列举一些对应用收益更大的主要架构模式&#xff0c;如服务化架构模式、Mesh化架构模式、Serverless模式、存储计算分离模式、分布式事务模式、可观测架构、事件驱动架构等。 1.服务化架构模式 服务化架构是云时代构建云原生应用的…...

宏基因组分析流程(Metagenomic workflow)202405|持续更新

Logs 增加R包pctax内的一些帮助上游分析的小脚本&#xff08;2024.03.03&#xff09;增加Mmseqs2用于去冗余&#xff0c;基因聚类的速度非常快&#xff0c;且随序列量线性增长&#xff08;2024.03.12&#xff09;更新全文细节&#xff08;2024.05.29&#xff09; 注意&#x…...

一千题,No.0037(组个最小数)

给定数字 0-9 各若干个。你可以以任意顺序排列这些数字&#xff0c;但必须全部使用。目标是使得最后得到的数尽可能小&#xff08;注意 0 不能做首位&#xff09;。例如&#xff1a;给定两个 0&#xff0c;两个 1&#xff0c;三个 5&#xff0c;一个 8&#xff0c;我们得到的最…...

PV PVC

默写 1 如何将pod创建在指定的Node节点上 node亲和、pod亲和、pod反亲和: 调度策略 匹配标签 操作符 nodeAffinity 主机 In,NotIn,Exists,DoesNotExist&#xff0c;Gt&#xff0c;Lt podAffinity …...

深入理解Nginx配置文件:全面指南

Nginx 是一个高性能的 HTTP 服务器和反向代理服务器&#xff0c;也是一个电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器。由于其高效性和灵活性&#xff0c;Nginx 被广泛应用于各种 web 服务中。本文将详细介绍 Nginx 配置文件的结构和主要配置项&#xff0c;帮助你深入…...

Admin.Net中的消息通信SignalR解释

定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

BLEU评分:机器翻译质量评估的黄金标准

BLEU评分&#xff1a;机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域&#xff0c;衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标&#xff0c;自2002年由IBM的Kishore Papineni等人提出以来&#xff0c;…...

DAY 26 函数专题1

函数定义与参数知识点回顾&#xff1a;1. 函数的定义2. 变量作用域&#xff1a;局部变量和全局变量3. 函数的参数类型&#xff1a;位置参数、默认参数、不定参数4. 传递参数的手段&#xff1a;关键词参数5 题目1&#xff1a;计算圆的面积 任务&#xff1a; 编写一…...

数据结构:泰勒展开式:霍纳法则(Horner‘s Rule)

目录 &#x1f50d; 若用递归计算每一项&#xff0c;会发生什么&#xff1f; Horners Rule&#xff08;霍纳法则&#xff09; 第一步&#xff1a;我们从最原始的泰勒公式出发 第二步&#xff1a;从形式上重新观察展开式 &#x1f31f; 第三步&#xff1a;引出霍纳法则&…...

结构化文件管理实战:实现目录自动创建与归类

手动操作容易因疲劳或疏忽导致命名错误、路径混乱等问题&#xff0c;进而引发后续程序异常。使用工具进行标准化操作&#xff0c;能有效降低出错概率。 需要快速整理大量文件的技术用户而言&#xff0c;这款工具提供了一种轻便高效的解决方案。程序体积仅有 156KB&#xff0c;…...