当前位置: 首页 > news >正文

pytorch复现4_Resnet

ResNet在《Deep Residual Learning for Image Recognition》论文中提出,是在CVPR 2016发表的一种影响深远的网络模型,由何凯明大神团队提出来,在ImageNet的分类比赛上将网络深度直接提高到了152层,前一年夺冠的VGG只有19层。ImageNet的目标检测以碾压的优势成功夺得了当年识别和目标检测的冠军,COCO数据集的目标检测和图像分割比赛上同样碾压夺冠,可以说ResNet的出现对深度神经网络来说具有重大的历史意义。

在这里插入图片描述
在resnet出现之前,网络层数的增加会导致梯度消失或者梯度爆炸
在ResNet网络中有如下几个亮点:
(1)提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)
(2)使用Batch Normalization加速训练(丢弃dropout)

残差结构(residual)

下图是论文中给出的两种残差结构。左边的残差结构是针对层数较少网络,例如ResNet18层和ResNet34层网络
右边是针对网络层数较多的网络,例如ResNet101,ResNet152等。
为什么深层网络要使用右侧的残差结构呢。因为,右侧的残差结构能够减少网络参数与运算量。同样输入、输出一个channel为256的特征矩阵,如果使用左侧的残差结构需要大约1170648个参数,但如果使用右侧的残差结构只需要69632个参数。明显搭建深层网络时,使用右侧的残差结构更合适。

在这里插入图片描述
代码:

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return out

class Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

完整代码:

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

相关文章:

pytorch复现4_Resnet

ResNet在《Deep Residual Learning for Image Recognition》论文中提出,是在CVPR 2016发表的一种影响深远的网络模型,由何凯明大神团队提出来,在ImageNet的分类比赛上将网络深度直接提高到了152层,前一年夺冠的VGG只有19层。Image…...

【数据库】形式化关系查询语言(一):关系代数Relational Algebra:基本运算、附加关系代数、扩展的关系代数

目录 一、关系代数Relational Algebra 1. 基本运算 a. 选择运算(Select Operation) b. 投影运算(Project Operation) 组合 c. 并运算(Union Operation) d. 集合差运算(Set Difference Op…...

【计算机网络】计算机网络和因特网

一.基本术语介绍 端系统通过通信链路(communication link)和分组交换机(packet switch)连接到一起,连接这些端系统和分组交换机的物理媒体包括:同轴电缆,铜线,光纤和无线电频谱。而…...

JAVA面经整理(9)

一)什么是Spring?它有什么优点? spring是一款顶级的开源框架,他是包含了众多工具方法的IOC容器,Spring中包含了很多模块,比如说Spring-core,Spring-context,Spring-aop,Spring-web,…...

IPD(集成产品开发)模式下的产品研发流程

IPD(集成产品开发)涵盖了产品从创意提出到研发、生产、运营等,包含了产品开发到营销运营的整个过程。围绕产品(或项目)生命周期的过程的管理模式,是一套生产流程,更是时下国际先进的管理体系。I…...

Flutter GetX的使用

比较强大的状态管理框架 引入库: dependencies:get: ^4.6.6一.实现一个简单的demo 实现一个计数器功能 代码如下: import package:flutter/material.dart; import package:get/get.dart;void main() > runApp(const GetMaterialApp(home: Home()…...

【Amazon】AWS实战 | 快速发布安全传输的静态页面

文章目录 一、实验架构图二、实验涉及的AWS服务三、实验操作步骤1. 创建S3存储桶,存放网站网页2. 使用ACM建立域名证书3. 设置Cloudfront,连接S3存储桶✴️4. 设置Route53,解析域名服务5. 通过CLI工具上传网页更新内容【可选】 四、实验总结 …...

前后端登录的密码加密和解密

在一个典型的前后端应用中,前端对密码进行加密后传给后端,后端再进行解密或验证。这通常涉及前端加密、后端解密或验证的相互配合。下面是一个基本的流程: 前端加密: 前端可以使用各种加密库或算法对密码进行加密。常见的是使用哈…...

使用 Curl 和 DomCrawler 下载抖音视频链接并存储到指定文件夹

项目需求 假设我们需要从抖音平台上下载一些特定的视频,以便进行分析、编辑或其他用途。为了实现这个目标,我们需要编写一个爬虫程序来获取抖音视频的链接,并将其保存到本地文件夹中。 目标分析 在开始编写爬虫之前,我们需要了…...

取消Excel打开密码的两种方法

Excel设置了打开密码,想要取消打开密码是由两种方法的,今天分享这两种方法给大家。 想要取消密码是需要直到正确密码的,因为只有打开文件才能进行取消密码的操作 方法一: 是大家常见的取消方法,打开excel文件之后&a…...

多测师肖sir_高级金牌讲师_jmeter 反向代理录制脚本

jemeter自带的录制脚本功能,是利用代理服务器来进行录制的 1,新建一个线程组 2,新建一个代理服务器 右击工作台-添加-非测试元件-http代理服务器 3, 配置http代理服务器 端口: 默认为8888,可修改。但…...

网络取证-Tomcat-简单

题干: 我们的 SOC 团队在公司内部网的一台 Web 服务器上检测到可疑活动。为了更深入地了解情况,团队捕获了网络流量进行分析。此 pcap 文件可能包含一系列恶意活动,这些活动已导致 Apache Tomcat Web 服务器遭到破坏。我们需要进一步调查这一…...

3.Linux常用操作(传输、crontab定时、匹配日期删除文件等)

1. 服务器之间传输文件 1.1 传输文件到本服务器 scp -P 19622 -C dockeruser192.168.100.96:/home/dockeruser/lgr/lgr.dmp /home/dockeruser/lgr描述: 用dockeruser账号登录端口号为19622的192.168.100.96服务器,将此服务器的/home/dockeruser/lgr/l…...

ChatGPT对未来发展的影响?一般什么时候用到GPT

ChatGPT以其强大的自然语言处理能力对未来的发展具有重要影响。以下是ChatGPT的潜在影响和一般使用情况: 改善自然语言理解和生成:ChatGPT和类似的模型可以改善机器对人类语言的理解和生成。这将有助于改进各种应用领域,包括智能助手、聊天机…...

在Win10系统进行MySQL的安装、连接、卸载

在Win10系统进行MySQL的安装、连接、卸载 MySQL的安装 本教程在Win10系统下安装部署MySQL-8.0.32版。 MySQL安装参考地址 MySQL安装包地址 提取码: rnbc。 选择下载mysql-installer-community-8.0.32.0安装包。 连接数据库 方式一: 安装后,可以在开始…...

Windows下pm2调用npm和nuxt的办法

pm2调用npm pm2 start C:\Users\xiao\AppData\Roaming\npm\node_modules\npm\index.js --name test -- run start 其中index.js的路径就是npm全局安装的路径,可通过以下命令获取 npm root -g require全局npm模块的一种方法 新建文件pm2npm.js const root req…...

本地仓库转为git仓库推送到gitee

通常有两种获取 Git 项目仓库的方式: 方式一:将尚未进行版本控制的本地目录转换为 Git 仓库; 方式二:从其它服务器 克隆 一个已存在的 Git 仓库。 两种方式都会在你的本地机器上得到一个工作就绪的 Git 仓库。 方式一&#xff1a…...

CSS以及JavaScript

目录 一.CSS 1.overflow溢出属性 2.定位 二.JavaScript基础 1.JavaScript引入方式 2.JavaScript数据类型 常用方法: 字符串常用方法: 在js里,什么是真,什么是假 数组的常用方法 运算符 (1)算数运…...

JVM——类的生命周期(加载阶段,连接阶段,初始化阶段)

目录 1.加载阶段2.连接阶段1.验证2.准备3.解析 3.初始化阶段4.总结 类的生命周期 1.加载阶段 ⚫ 1、加载(Loading)阶段第一步是类加载器根据类的全限定名通过不同的渠道以二进制流的方式获取字节码信息。 程序员可以使用Java代码拓展的不同的渠道。 ⚫ 2、类加载器在加载完类…...

CSS中实现元素居中的几种方法总结

一、使用 text-align: center 居中 使用 text-align: center; 可以在CSS中实现内联元素的水平居中。这个技术利用了CSS的 text-align 属性&#xff0c;通过对元素的文本对齐方式进行调整来实现居中效果。注&#xff1a;只展示主要代码。 <div class"container"&…...

保护听力戴什么耳机比较好?开放式耳机能保护听力吗?

如果想要在保护听力的前提下戴耳机&#xff0c;那么我是推荐戴骨传导耳机的&#xff01;&#xff01;&#xff01; 所谓骨传导即是一种声音传递的方式&#xff0c;跟普通耳机不同的是传统耳机是通过空气将声音通过耳膜以此完成传递&#xff0c;而骨传导耳机的原理是将声音以不同…...

【JVM】垃圾回收机制

【JVM】垃圾回收机制 文章目录 【JVM】垃圾回收机制1. 方法区的回收2. 堆的回收2.1 引用计数法2.2 可达性分析算法 3. 对象引用3.1 强引用3.2 软引用3.3 弱引用3.4 虚引用和终结器引用 4. 垃圾回收算法4.1 标记清除算法4.2 复制算法4.3 标记整理算法4.4 分代垃圾回收算法 5. 垃…...

MySQL数据库入门到精通——运维篇(2)

MySQL数据库入门到精通——运维篇&#xff08;2&#xff09; 1. 分库分表1.1 分库分表介绍1.1.1 现在的问题1.1.2 拆分策略1.1.2.1 垂直拆分策略1.1.2.2 水平拆分策略 1.2 Mycat概述1.3 Mycat入门1.4 Mycat配置1.4.1 Schema标签1.4.2 Datanode标签1.4.3 Datahost标签1.4.4 rule…...

投资者如何保障个人利益?行业律师与欧科云链专家给出建议

香港作为全球加速拥抱Web3变革的引领之地&#xff0c;规定自今年6月起在香港经营虚拟资产服务业务需申领牌照。蜂拥而至的Web3创业公司&#xff0c;伺机而动的加密货币交易所&#xff0c;以及跃跃欲试的行业从业者&#xff0c;都让这座金融之都热闹非凡。但近期伴随JPEX诈骗案等…...

【办公软件】C#调用NPOI实现Excel文件的加载、导出功能

文章目录 1. 引言2. 环境准备3. 示例代码4. 结果5. 总结 1. 引言 本文将介绍如何使用C#和NPOI库实现Excel文件的读写操作&#xff0c;并通过加载文件和导出文件的按钮进行封装。NPOI是一个强大的.NET库&#xff0c;可以轻松处理Excel文件。我们将学习如何使用NPOI打开现有的Ex…...

UVA 11990 “Dynamic‘‘ Inversion 区域树 + 树状数组

一、题目大意 我们有 1 2 3 ... n 这些数字组成的一个排列数组 a &#xff0c;需要从这个排列中取出m个数字&#xff0c;要求计算出出每次取出数字之前&#xff0c;数组中的逆序数&#xff08;逆序数就是 i < j&#xff0c;但是 ai > aj的数&#xff09; 二、解题思路 …...

邮件钓鱼分析

三大协议 SPF Sender Policy Framework 的缩写&#xff0c;一种以IP地址认证电子邮件发件人身份的技术。 注&#xff1a;收信人怀疑币是假的&#xff0c;查看这个送信包裹里面记录的发出地是不是央行&#xff0c;如果是黑市有可能是黑钱 DKIM 加密签名和域名关联。 注&am…...

Android 小技巧

1. Android Studio下载地址 Android 开发者 | Android Developers (google.cn) 2.Android Aosp 在线查看地址&#xff1a; AOSPXRef 3.Android 官方文档地址&#xff1a; Android 开源项目 | Android Open Source Project (google.cn)...

Centos MySQL --skip-grant-tables详解

跳过权限验证&#xff0c;导出数据备份 主机系统&#xff1a;Centos7 64位 数据库版本&#xff1a;MySQL5.7.40 使用–skip-grant-tables场景 1、忘记管理员密码 2、修改管理员密码 mysql -uroot -p显示错误内容如下&#xff1a; ERROR 1045 (28000): Access denied for …...

Linux:进程控制的概念和理解

文章目录 进程的创建fork函数写时拷贝的原理fork函数的用法和失败原因 进程终止进程的退出进程异常的问题 进程终止进程退出 进程等待什么是进程等待&#xff1f;为什么要进行进程等待&#xff1f;如何进行进程等待&#xff1f;父进程如何知道子进程的退出信息&#xff1f; wai…...