当前位置: 首页 > news >正文

动手学深度学习——残差网络ResNet(原理解释+代码详解)

残差网络ResNet

      • 1. 函数类
      • 2. 残差块
      • 3. ResNet模型
      • 4. 训练模型

ResNet为了解决“新添加的层如何提升神经网络的性能”,它在2015年的ImageNet图像识别挑战赛夺魁
在这里插入图片描述

它深刻影响了后来的深度神经网络的设计,ResNet的被引用量更是达到了19万+。
在这里插入图片描述

1. 函数类

假设有一类特定的神经网络架构F,它包括学习速率和其他超参数设置。对于所有f∈F,存在一些参数集(例如权重和偏置),这些参数可以通过在合适的数据集上进行训练而获得。

现在假设 f* 是我们真正想要找到的函数,如果是 f*∈F,那可以轻而易举的训练得到它。

给定一个具有X特性和y标签的数据集
在这里插入图片描述
在这里插入图片描述是我们要找的函数,为了使其更近似真正的 f* ,则需要更强的架构F’。

在这里插入图片描述
对于非嵌套函数类,较复杂的函数类并不总是向“真”函数 f* 靠拢(复杂度由F1向F6递增)。虽然F3比F1更接近 f*,但却离F6的更远了。

而右侧的嵌套函数可以避免上述问题。

2. 残差块

残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。
在这里插入图片描述
F(x) + x包含了原始元素。

ResNet沿用了VGG完整的3x3卷积层设计。

  • 残差块里首先有2个有相同输出通道数的3x3卷积层。
  • 每个卷积层后接一个批量规范化层和ReLU激活函数。
  • 然后通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module): #@save# use_ixiconv:残差连接是直接连接还是通过卷积层连接def __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += X return F.relu(Y)

此代码生成两种类型的网络:

  • 一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。
  • 另一种是当use_1x1conv=True时,添加通过1x1卷积调整通道和分辨率。
    在这里插入图片描述查看输入和输出形状一致的情况
# 输入、输出的情况
blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape

在这里插入图片描述
在增加输出通道数的同时,减半输出的高和宽

# 增加输出通道的同时,减半输出的高度和宽度
blk = Residual(3, 6, use_1x1conv=True, strides=2)
blk(X).shape

在这里插入图片描述

3. ResNet模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7x7卷积层后,接步幅为2的3x3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。

# ResNet在每个卷积层后增加了批量规范层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

ResNet使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
第一个模块的通道数同输入通道数一致。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

# 实现残差连接模块:由4个残差连接块组成
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):# 定义空网络结构blk = []for i in range(num_residuals):# 第2,3,4个Inception块的第一个残差模块连接1x1卷积层if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk

接着在ResNet加入所有残差块,这里每个模块使用2个残差块。

# 在ResNet加入所有残差块,每个模版使用2个残差块
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

最后在ResNet中加入全局平均汇聚层,以及全连接层输出。

# 在ResNet加入全局平均汇聚层以及全连接层输出
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))

每个模块有4个卷积层,加上第一个7x7卷积层和最后一个全连接层,共有18层,这种模型通常被称为ResNet-18。
在这里插入图片描述
观察一下ResNet中不同模块的输入形状是如何变化

# 每个模块有4个卷积层(不包括恒等映射的1x1卷积层)。 
# 加上第一个7x7卷积层和最后一个全连接层,共有18层。 
# 因此,这种模型通常被称为ResNet-18。
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output_shape:\t', X.shape)

在这里插入图片描述

4. 训练模型

在Fashion-MNIST数据集上训练ResNet
定义精度评估函数

"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

定义GPU训练函数

"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

训练模型

# 训练模型
lr, num_epochs, batch_size = 0.05, 10 ,256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

相关文章:

动手学深度学习——残差网络ResNet(原理解释+代码详解)

残差网络ResNet 1. 函数类2. 残差块3. ResNet模型4. 训练模型 ResNet为了解决“新添加的层如何提升神经网络的性能”,它在2015年的ImageNet图像识别挑战赛夺魁 它深刻影响了后来的深度神经网络的设计,ResNet的被引用量更是达到了19万。 1. 函数类 假…...

MYSQL 8.0 配置CDC(binlog)

CDC(Change Data Capture)即数据变更抓取,通过源端数据源开启CDC,ROMA Connect 可实现数据源的实时数据同步以及物理表的物理删除同步。这里介绍通过开启Binlog模式CDC功能。 注意:1、使用MYSQL8.0及以上版本。 2、不…...

软件测试/测试开发丨ChatGPT能否成为PPT最佳伴侣

点此获取更多相关资料 简介 PPT 已经渗透到我们的日常工作中,无论是工作汇报、商务报告、学术演讲、培训材料都常常要求编写一个正式的 PPT,协助完成一次汇报或一次演讲。PPT相比于传统文本的就是有布局、图片、动画效果等,可以给到观众更好…...

java对象的创建过程

一.类的加载与检查 当我们new了一个对象的时候,首先会去检查一下这个指令是否在常量池中存在符号引用,并且检查这个符号引用代表的对象是否被加载,解析初始化过,如果没有就要先去进行类加载过程 二.分配内存 我们通过第一步的检…...

Salesforce创建一个页面,能够配置各种提示语,而不需要修改代码

在Salesforce中创建一个页面,并使其能够配置各种提示语,可以使用自定义设置、自定义对象或自定义标签等方法来实现。以下是一种常见的方法: 自定义对象或自定义设置:您可以创建一个自定义对象或自定义设置来存储各种提示语的信息。…...

轻松管理MySQL权限:Python脚本带你飞

数据库管理是 IT 专家和开发者日常工作中的重要组成部分。一个合适的用户权限管理系统不仅确保了数据的安全性,还能确保数据能够按照预期的方式被正确地访问和修改。在本文中,我们将探讨如何使用 Python 脚本来管理和查询 MySQL 数据库中的用户权限。 用户权限管理:创建或修…...

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略 目录 transformers_stream_generator的简介 1、Web Demo T1、original T2、stream transformers_stream_generator的安装 transformers_stream_generator的…...

2023年Zotero最新同步教程-使用TeraCloud的25G免费空间实时跨设备同步文献

文章目录 1. 前言2.1. 注册账号2.1.1. 填写注册信息2.1.2. 创建账号成功2.1.3. 注意2.2. 扩容空间2.3. 打开WebDAV 3. Zotero配置WebDAV同步3.1. 设置网址3.2. 验证服务器3.3. 文件同步成功 4. 结语 1. 前言 Zotero免费版的存储空间是300m,一个图文PDF动辄两三M&am…...

面试题:用宏定义写出swap(x,y),即交换两数。

鼠标选中查看答案↓: #define swap(x,y) do{(x)(x)(y);(y)(x)-(y);(x)(x)-(y);}while(0) 这个题考查宏定义的语法,尤其是多行代码的宏定义,加上do{}while(),,可以保证这些语句只执行一次。...

微服务框架SpringcloudAlibaba+Nacos集成RabbitMQ

目前公司使用jeepluscloud版本,这个版本没有集成消息队列,这里记录一下,集成的过程;这个框架跟ruoyi的那个微服务版本结构一模一样,所以也可以快速上手。 1.项目结构图: 配置类的东西做成一个公共的模块 …...

低代码开发,一场深度的IT效率革命

目录 一、前言 二、低代码迅速流行的理由 三、稳定性和生产率的最佳实践 四、程序员用低代码开发应用有哪些益处? 1、提升开发价值 2、利于团队升级 五、总结 一、前言 尽管IT技术支撑了全球的信息化浪潮,然而困扰行业已久的软件开发效率并未像摩尔定律那…...

虚拟串口软件使用介绍

对于上位机开发来说(特别是串口通信应用),上机位软件的调试尤为重要,但是上机位软件的调试并不关心硬件,只需要关注验证发送的数据帧的接收情况,为了便于调试,可以将上机位软件与串口软件互通,实现数据的交互,但由于互通需要串口,可以借助串口虚拟软件(VSPD),虚拟出…...

如何编写一份完整的软件测试报告?(进阶版)百分之90不知道

背景 作为测试从业者,编写测试用例,测试计划,测试报告都是必经之路,最近完成了年终述职以及版本准出,感觉测试报告或者各类报告真是职场人不可或缺的一项技能,趁着热乎劲🔥,写下一些…...

python企业微信小程序发送信息

python企业微信小程序发送信息 在使用下面代码之前先配置webhook 教程如下: https://www.bilibili.com/video/BV1oH4y1S7pN/?vd_sourcebee29ac3f59b719e046019f637738769 然后使用如下代码就可以发消息了: 代码如下: #codinggbk import r…...

Java入门篇 之 逻辑控制(练习题篇)

博主碎碎念: 练习题是需要大家自己打的请在自己尝试后再看答案哦; 个人认为,只要自己努力在将来的某一天一定会看到回报,在看这篇博客的你,不就是在努力吗,所以啊,不要放弃,路上必定坎坷&#x…...

Android Google登录并获取token(亲测有效)

背景: Android 需要用到Google的登录授权,用去token给到服务器,服务器再通过token去获取用户信息,实现第三方登录。 我们通过登录之后的email来获取token,不需要server_clientId;如果用server_clientId还…...

npm ERR! code ELIFECYCLE

问题: 一个老项目,现在想运行下,打不开了 npm install 也出错 尝试1 、使用cnpm npm install -g cnpm --registryhttps://registry.npm.taobao.org cnpm install 还是不行 尝试2、 package.json 文件,去掉 那个插件 chorm…...

Mgeo:multi-modalgeographic language model pre-training

文章目录 question5.1 Geographic Encoder5.1.1 Encoding5.1.2 5.2 multi-modal pre-training 7 conclusionGeo-Encoder: A Chunk-Argument Bi-Encoder Framework for Chinese Geographic Re-Rankingabs ERNIE-GeoL: A Geography-and-Language Pre-trained Model and its Appli…...

[激光原理与应用-75]:西门子PLC系列选型

目录 一、西门子PLC PLC系列 二、西门子PLC S7 1200系列 2.1 概述 2.2 12xx系列比较 三、西门子 PLC 1212C系列 四、主要类别比较 4.1 AC/DC/RLY的含义 4.2 AC/DC/RLY与DC/DC/DC 4.3 直流输入与交流输入比较 4.4 继电器输出与DC输出的区别 一、西门子PLC PLC系列 …...

Linux上编译sqlite3库出现undefined reference to `sqlite3_column_table_name‘

作者:朱金灿 来源:clever101的专栏 为什么大多数人学不会人工智能编程?>>> 在Ubuntu 18上编译sqlite3库后在运行程序时出现undefined reference to sqlite3_column_table_name’的错误。网上的说法是说缺少SQLITE_ENABLE_COLUMN_M…...

【Java学习笔记】Arrays类

Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...

在Ubuntu24上采用Wine打开SourceInsight

1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

人机融合智能 | “人智交互”跨学科新领域

本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...

CSS | transition 和 transform的用处和区别

省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...

深入浅出Diffusion模型:从原理到实践的全方位教程

I. 引言:生成式AI的黎明 – Diffusion模型是什么? 近年来,生成式人工智能(Generative AI)领域取得了爆炸性的进展,模型能够根据简单的文本提示创作出逼真的图像、连贯的文本,乃至更多令人惊叹的…...

MySQL的pymysql操作

本章是MySQL的最后一章,MySQL到此完结,下一站Hadoop!!! 这章很简单,完整代码在最后,详细讲解之前python课程里面也有,感兴趣的可以往前找一下 一、查询操作 我们需要打开pycharm …...

《Docker》架构

文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器,docker,镜像,k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...