当前位置: 首页 > news >正文

PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []if stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

相关文章:

PyTorch从零开始实现ResNet

文章目录 代码实现参考 代码实现 本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。 代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用…...

企业微信 企业内部开发 学习笔记

官方文档 文档 术语介绍 引入pom <dependency><groupId>com.github.binarywang</groupId><artifactId>wx-java-cp-spring-boot-starter</artifactId><version>4.5.3.B</version></dependency>核心代码 推送消息 final WxCp…...

03 QT基本控件和功能类

一 进度条 、水平滑动条 垂直滑动条 当在QT中,在已知类名的情况下,要了解类的构造函数 常用属性 及 信号和槽 常用api 特征:可以获取当前控件的值和设置它的当值 ---- int ui->progressBar->setValue(value); //给进度条设置一个整型值 ui->progressBar->value…...

epoll数据结构

目录 1.大量的fd 集合。选择什么数据结构&#xff1f;2、Epoll 数据结构Epitem 的定义Eventpoll 的定义 1.大量的fd 集合。选择什么数据结构&#xff1f; 查找频率很高的数据结构 1.红黑树 2.哈希&#xff08;扩容缩容&#xff09; 3. b/btree &#xff08;降低树的高度&#…...

LINUX学习笔记_GIT操作命令

LINUX学习笔记 GIT操作命令 基本命令 git init&#xff1a;初始化仓库git status&#xff1a;查看文件状态git add&#xff1a;添加文件到暂存区&#xff08;index&#xff09;git commit -m “注释”&#xff1a;提交文件到仓库&#xff08;repository&#xff09;git log&a…...

第一百二十九天学习记录:数据结构与算法基础:栈和队列(中)(王卓教学视频)

栈的表示和实现 顺序栈的初始化 ##入栈 链栈的表示...

C语言 — qsort 函数

介绍&#xff1a;qsort是一个库函数&#xff0c;用来对数据进行排序&#xff0c;可以排序任意类型的数据。 void qsort &#xff08;void*base&#xff0c; size_t num, size_t size, int(*compart)(const void*,constvoid*) &#xff09; qsort 具有四个参数&#xff1a; …...

开放式耳机哪个好一点?推荐几款优秀的开放式耳机

在追求更广阔的音场和更真实的音质时&#xff0c;开放式耳机是绝对值得考虑的选择。它们以其通透感和自然的音质而备受推崇&#xff0c;带来更逼真的音乐体验。下面我来推荐几款优秀的开放式耳机&#xff0c;满足你对音质和舒适度的要求&#xff0c;可尽情享受音乐的魅力。 一…...

vue-cli前端工程化——创建vue-cli工程 router版本的创建 目录结构 案例初步

目录 引出创建vue-cli前端工程vue-cli是什么自动构建创建vue-cli项目选择Vue的版本号 手动安装进行选择创建成功 手动创建router版多了一个router 运行测试bug解决 Vue项目结构main.jspackage.jsonvue.config.js Vue项目初步hello案例 总结 引出 1.vue-cli是啥&#xff0c;创建…...

Go和Java实现外观模式

Go和Java实现外观模式 下面我们通过一个构造各种形状的案例来说明外观模式的使用。 1、外观模式 外观模式隐藏系统的复杂性&#xff0c;并向客户端提供了一个客户端可以访问系统的接口。这种类型的设计模式属于结构型 模式&#xff0c;它向现有的系统添加一个接口&#xff…...

人工智能(一)基本概念

人工智能之基本概念 常见问题什么是人工智能&#xff1f;人工智能应用在那些地方&#xff1f;人工智能的三种形态图灵测试是啥&#xff1f;人工智能、机器学习和深度学习之间是什么关系&#xff1f;为什么人工智能计算会用到GPU&#xff1f; 机器学习什么是机器学习&#xff1f…...

〔AI 绘画〕Stable Diffusion 之 解决绘制多人或面部很小的人物时面部崩坏问题 篇

✨ 目录 &#x1f388; 脸部崩坏&#x1f388; 下载脸部修复插件&#x1f388; 启用脸部修复插件&#x1f388; 插件生成效果&#x1f388; 插件功能详解 &#x1f388; 脸部崩坏 相信很多人在画图时候&#xff0c;特别是画 有多个人物 图片或者 人物在图片中很小 的时候&…...

初步认识OSI/TCP/IP一(第三十八课)

1 初始OSI模型 OSI参考模型(Open Systems Interconnection Reference Model)是一个由国际标准化组织(ISO)和国际电报电话咨询委员会(CCITT)联合制定的网络通信协议规范,它将网络通信分为七个不同的层次,每个层次负责不同的功能和任务。 2 网络功能 数据通信、资源共享…...

英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]&#xff08;2&#xff09;——代码分析 ASP整个模块的结果如下&#xff1a; . ├── COPYRIGHT ├── README.md ├── __init__.py ├── asp.py ├── permutation_lib.py ├── permutation_search_kernels…...

FileNotFoundError: [WinError 2] 系统找不到指定的文件。

pyspark demo程序创建spark上下文 完整报错如下&#xff1a; sc SparkContext(“local”, “Partition ID Example”) File “C:\ProgramData\anaconda3\envs\python36\lib\site-packages\pyspark\context.py”, line 133, in init SparkContext._ensure_initialized(self, ga…...

Linux: sysctl:net: IPV4_DEVCONF_ALL ignore_routes_with_linkdown; all vs default

文章目录 简介实例 ignore_routes_with_linkdownlinkdown 的引入dead的引入简介 一般下边这种类型的配置都有三种类型选项:all,default,specific net.ipv6.conf.acc.ignore_routes_with_linkdown = 0 net.ipv6.conf.all.ignore_routes_with_linkdown = 0 net.ipv6.conf.def…...

光耦继电器:实现电气隔离的卓越选择

光耦继电器是一种常用的电子元件&#xff0c;用于实现电气隔离和信号传输。在工业控制、自动化系统和电力电子等领域&#xff0c;光耦继电器具有独特的特点和优势。本文将从可靠性、隔离性、响应速度和适应性等方面对光耦继电器的特点进行概述。 光耦继电器是一种典型的固态继电…...

鸿蒙开发学习笔记2——实现页面之间跳转

鸿蒙开发学习笔记2——实现页面之间跳转 问题背景 上篇文章中&#xff0c;介绍了鸿蒙开发如何新建一个项目跑通hello world&#xff0c;本文将介绍在新建的项目中实现页面跳转的功能。 问题分析 ArkTS工程目录结构&#xff08;FA模型&#xff09; 各目录和路径的介绍如下…...

电子商务类网站需要什么配置的服务器?

随着电子商务的迅猛发展&#xff0c;越来越多的企业和创业者选择在互联网上开设自己的电商网站。为了确保电商网站能够高效运行&#xff0c;给用户提供良好的体验&#xff0c;选择合适的服务器配置至关重要。今天飞飞将和你分享电子商务类网站所需的服务器配置&#xff0c;希望…...

table 根据窗口缩放,自适应

element-plus中&#xff0c;直接应用在页面样式上&#xff0c; ::v-deep .el-table{width: 100%; } ::v-deep .el-table__header-wrapper table,::v-deep .el-table__body-wrapper table{width: 100% !important; } ::v-deep .el-table__body,::v-deep .el-table__footer,::v-d…...

变量 varablie 声明- Rust 变量 let mut 声明与 C/C++ 变量声明对比分析

一、变量声明设计&#xff1a;let 与 mut 的哲学解析 Rust 采用 let 声明变量并通过 mut 显式标记可变性&#xff0c;这种设计体现了语言的核心哲学。以下是深度解析&#xff1a; 1.1 设计理念剖析 安全优先原则&#xff1a;默认不可变强制开发者明确声明意图 let x 5; …...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

优选算法第十二讲:队列 + 宽搜 优先级队列

优选算法第十二讲&#xff1a;队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

ip子接口配置及删除

配置永久生效的子接口&#xff0c;2个IP 都可以登录你这一台服务器。重启不失效。 永久的 [应用] vi /etc/sysconfig/network-scripts/ifcfg-eth0修改文件内内容 TYPE"Ethernet" BOOTPROTO"none" NAME"eth0" DEVICE"eth0" ONBOOT&q…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局&#xff1a;刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断"&#xff0c;医生需通过显微镜观察组织切片&#xff0c;在细胞迷宫中捕捉癌变信号。某省病理质控报告显示&#xff0c;基层医院误诊率达12%-15%&#xff0c;专家会诊…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...