ResNet残差神经网络的模型结构定义(pytorch实现)
ResNet残差神经网络的模型结构定义(pytorch实现)
ResNet‑34
ResNet‑34的实现思路。核心在于:
- 定义残差块(BasicBlock)
- 用
_make_layer
方法堆叠多个残差块 - 按照 ResNet‑34 的通道和层数配置来搭建网络
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1 # 对于 BasicBlock,输出通道 = base_channels * expansiondef __init__(self, in_channels, out_channels, stride=1):super().__init__()# 第一个 3×3 卷积self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个 3×3 卷积self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 如果输入输出通道或下采样不一致,则用 1×1 卷积做一下“shortcut”self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# 残差连接out += self.shortcut(x)return F.relu(out)class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):"""block: 残差块类型(BasicBlock 或 Bottleneck)layers: 每个 stage 包含多少个 block,例如 [3, 4, 6, 3] 对应 ResNet‑34num_classes: 最后分类数"""super().__init__()self.in_channels = 64# Stem:7×7 conv + maxpoolself.conv1 = nn.Conv2d(3, 64, kernel_size=7,stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个 stage,通道分别是 [64,128,256,512]self.layer1 = self._make_layer(block, 64, layers[0], stride=1)self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 全局平均池化 + 全连接self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):"""构造一个 stage,由 num_blocks 个 block 组成。第一个 block 可能带 stride 下采样,其余 block stride=1。"""strides = [stride] + [1] * (num_blocks - 1)layers = []for s in strides:layers.append(block(self.in_channels, out_channels, stride=s))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.pool1(x)x = self.layer1(x) # output size /4x = self.layer2(x) # output size /8x = self.layer3(x) # output size /16x = self.layer4(x) # output size /32x = self.avgpool(x) # [B, C, 1, 1]x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000):"""返回一个 ResNet-34 实例"""return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
关键点解析
-
BasicBlock
- 两个连续的 3×3 卷积,均附带 BatchNorm 和 ReLU
- 当通道数或步幅不匹配时,用 1×1 卷积对输入做一下线性变换,才能做元素相加
-
_make_layer
- 每个 stage 第一个残差块如果要做下采样,则 stride=2;其余都保持 stride=1
layers
参数[3,4,6,3]
精确对应了图中红、粉、灰、蓝四部分每层 block 的数量
-
整体流程
- 7×7、stride=2 下采样 → 最大池化 →
- 四个 stage(通道 64→128→256→512,每段下采样一次)→
- 全局平均池化 → 全连接分类
这样就完整复现了图中右侧那张“34-layer residual”结构。你可以直接调用 resnet34()
,并像下面这样测试一下输出形状:
if __name__ == "__main__":model = resnet34(num_classes=1000)x = torch.randn(8, 3, 224, 224)y = model(x)print(y.shape) # torch.Size([8, 1000])
ResNet‑50
PyTorch 实现 ResNet‑50 。它与 ResNet‑34 唯一不同之处在于使用了 Bottleneck 模块,并且每个 stage 的 block 数量依次为 [3, 4, 6, 3]
(同 ResNet‑34),但每个 block 内部由三个卷积层组成,expansion 值为 4。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Bottleneck(nn.Module):# 输出通道 = base_channels * expansionexpansion = 4def __init__(self, in_channels, base_channels, stride=1):super().__init__()# 1×1 降维self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1,bias=False)self.bn1 = nn.BatchNorm2d(base_channels)# 3×3 卷积(可能下采样)self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=3,stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(base_channels)# 1×1 升维self.conv3 = nn.Conv2d(base_channels, base_channels * Bottleneck.expansion,kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(base_channels * Bottleneck.expansion)# shortcut 分支self.shortcut = nn.Sequential()if stride != 1 or in_channels != base_channels * Bottleneck.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, base_channels * Bottleneck.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(base_channels * Bottleneck.expansion))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.relu(self.bn2(self.conv2(out)))out = self.bn3(self.conv3(out))out += self.shortcut(x)return F.relu(out)class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):super().__init__()self.in_channels = 64# Stem:7×7 conv + maxpoolself.conv1 = nn.Conv2d(3, 64, kernel_size=7,stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个 stageself.layer1 = self._make_layer(block, 64, layers[0], stride=1)self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 池化 + 全连接self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, base_channels, num_blocks, stride):"""构造一个 stage,由 num_blocks 个 block 组成。第一个 block 可能下采样(stride>1),其余保持 stride=1。"""strides = [stride] + [1] * (num_blocks - 1)layers = []for s in strides:layers.append(block(self.in_channels, base_channels, stride=s))self.in_channels = base_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.pool1(x)x = self.layer1(x) # /4x = self.layer2(x) # /8x = self.layer3(x) # /16x = self.layer4(x) # /32x = self.avgpool(x) # [B, C, 1, 1]x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet50(num_classes=1000):"""返回一个 ResNet-50 实例"""return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)# 简单测试
if __name__ == "__main__":model = resnet50(num_classes=1000)x = torch.randn(4, 3, 224, 224)y = model(x)print(y.shape) # -> torch.Size([4, 1000])
说明
- Bottleneck 模块:三个卷积层依次为 1×1 → 3×3 → 1×1,最后一个 1×1 用来恢复维度(乘以
expansion=4
)。 - shortcut 分支:当需下采样(stride=2)或输入输出维度不一致时,使用 1×1 卷积对齐后相加。
- layers 参数
[3,4,6,3]
:分别对应四个 stage 中 Bottleneck block 的个数。
这样就完成了 ResNet‑50 的全结构定义。你可以直接调用 resnet50()
并将其与预训练权重或自己的数据集一起使用。
参考:Kaiming He 等人,Deep Residual Learning for Image Recognition (CVPR 2016).
相关文章:

ResNet残差神经网络的模型结构定义(pytorch实现)
ResNet残差神经网络的模型结构定义(pytorch实现) ResNet‑34 ResNet‑34的实现思路。核心在于: 定义残差块(BasicBlock)用 _make_layer 方法堆叠多个残差块按照 ResNet‑34 的通道和层数配置来搭建网络 import torch…...

uniapp|商品列表加入购物车实现抛物线动画效果、上下左右抛入、多端兼容(H5、APP、微信小程序)
以uniapp框架为基础,详细解析商品列表加入购物车抛物线动画的实现方案。通过动态获取商品点击位置与购物车坐标,结合CSS过渡动画模拟抛物线轨迹,实现从商品图到购物车图标的动态效果。 目录 核心实现原理坐标动态计算抛物线轨迹模拟动画元素控制代码实现详解模板层设计脚本…...

谈AI/OT 的融合
过去的十几年间,工业界讨论最多的话题之一就是IT/OT 融合,现在,我们不仅要实现IT/OT 的融合,更要面向AI/OT 的融合。看起来不太靠谱,却留给我们无限的想象空间。OT 领域的专家们不要再当“九斤老太”,指责这…...

USB传输模式
USB有四种传输模式: 控制传输, 中断传输, 同步传输, 批量传输 1. 中断传输 中断传输一般用于小批量, 非连续的传输. 对实时性要求较高. 常见的使用此传输模式的设备有: 鼠标, 键盘等. 要注意的是, 这里的 “中断” 和我们常见的中断概念有差异. Linux中的中断是设备主动发起的…...
Tomcat的`context.xml`配置详解!
全文目录: 开篇语前言一、context.xml 文件的基本结构二、常见的 context.xml 配置项1. **数据源(DataSource)配置**示例: 2. **日志配置**示例: 3. **设置环境变量(Environment Variables)**示…...
MapReduce 的工作原理
MapReduce 是一种分布式计算框架,用于处理和生成大规模数据集。它将任务分为两个主要阶段:Map 阶段和 Reduce 阶段。开发人员可以使用存储在 HDFS 中的数据,编写 Hadoop 的 MapReduce 任务,从而实现并行处理1。 MapReduce 的工作…...

.NET10 - 尝试一下Open Api的一些新特性
1.简单介绍 .NET9中Open Api有了很大的变化,在默认的Asp.NET Core Web Api项目中,已经移除了Swashbuckle.AspNetCore package,同时progrom中也变更为 builder.Servers.AddOpenApi() builder.Services.MapOpenApi() 2025年微软将发布…...

RabbitMQ 工作模式
RabbitMQ 一共有 7 中工作模式,可以先去官网上了解一下(一下截图均来自官网):RabbitMQ 官网 Simple P:生产者,要发送消息的程序;C:消费者,消息的接受者;hell…...

基于C++的多线程网络爬虫设计与实现(CURL + 线程池)
在当今大数据时代,网络爬虫作为数据采集的重要工具,其性能直接决定了数据获取的效率。传统的单线程爬虫在面对海量网页时往往力不从心,而多线程技术可以充分利用现代多核CPU的计算能力,显著提升爬取效率。本文将详细介绍如何使用C…...
Android11.0 framework第三方无源码APP读写断电后数据丢失问题解决
1.前言 在11.0中rom定制化开发中,在某些产品开发中,在某些情况下在App用FileOutputStream读写完毕后,突然断电 会出现写完的数据丢失的问题,接下来就需要分析下关于使用FileOutputStream读写数据的相关流程,来实现相关 功能 2.framework第三方无源码APP读写断电后数据丢…...
国产大模型「五强争霸」:决战AGI,谁主沉浮?
引言 中国AI大模型市场正经历一场史无前例的洗牌!曾经“百模混战”的局面已落幕,字节、阿里、阶跃星辰、智谱和DeepSeek五大巨头强势崛起,形成“基模五强”新格局。这场竞争不仅是技术实力的较量,更是资源、人才与生态的全面博弈。…...
【Python 基础语法】
Python 基础语法是编程的基石,以下从核心要素到实用技巧进行系统梳理: 一、代码结构规范 缩进规则 使用4个空格缩进(PEP 8标准)缩进定义代码块(如函数、循环、条件语句) def greet(name):if name: # 正确缩…...

【日撸 Java 三百行】Day 11(顺序表(一))
目录 Day 11:顺序表(一) 一、关于顺序表 二、关于面向对象 三、代码模块分析 1. 顺序表的属性 2. 顺序表的方法 四、代码及测试 拓展: 小结 Day 11:顺序表(一) Task: 在《数…...
path环境变量满了如何处理,分割 PATH 到 Path1 和 Path2
要正确设置 Path1 的值,你需要将现有的 PATH 环境变量 中的部分路径复制到 Path1 和 Path2 中。以下是详细步骤: 步骤 1:获取当前 PATH 的值 打开环境变量窗口: 按 Win R,输入 sysdm.cpl,点击 确定。在 系…...

软考 系统架构设计师系列知识点之杂项集萃(55)
接前一篇文章:软考 系统架构设计师系列知识点之杂项集萃(54) 第89题 某软件公司欲开发一个Windows平台上的公告板系统。在明确用户需求后,该公司的架构师决定采用Command模式实现该系统的界面显示部分,并设计UML类图如…...

保持Word中插入图片的清晰度
大家有没有遇到这个问题,原本绘制的高清晰度图片,插入word后就变模糊了。先说原因,word默认启动了自动压缩图片功能,分享一下如何关闭这项功能,保持Word中插入图片的清晰度。 ①在Word文档中,点击左上角的…...
Web应用开发指南
一、引言 随着互联网的迅猛发展,Web应用已深度融入日常生活的各个方面。为满足用户对性能、交互与可维护性的日益增长的需求,开发者需要一整套高效、系统化的解决方案。在此背景下,前端框架应运而生。不同于仅提供UI组件的工具库,…...
贝叶斯算法
贝叶斯算法是一类基于贝叶斯定理的机器学习算法,它们在分类任务中表现出色,尤其在处理具有不确定性和 probabilistic 关系的数据时具有独特优势。本文将深入探讨贝叶斯算法的核心原理、主要类型以及实际应用案例,带你领略贝叶斯算法在概率推理…...

Linux复习笔记(三) 网络服务配置(web)
遇到的问题,都有解决方案,希望我的博客能为你提供一点帮助。 二、网络服务配置 2.3 web服务配置 2.3.1通信基础:HTTP协议与C/S架构(了解) HTTP协议的核心作用 Web服务基于HTTP/HTTPS协议实现客户端ÿ…...

springboot旅游小程序-计算机毕业设计源码76696
目 录 摘要 1 绪论 1.1研究背景与意义 1.2研究现状 1.3论文结构与章节安排 2 基于微信小程序旅游网站系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统…...

uniapp自定义导航栏搭配插槽
<uni-nav-bar dark :fixed"true" shadow background-color"#007AFF" left-icon"left" left-text"返回" clickLeft"back"><view class"nav-bar-title">{{ navBarTitle }}</view><block v-slo…...

MFC listctrl修改背景颜色
在 MFC 中修改 ListCtrl 控件的行背景颜色,需要通过自绘(Owner-Draw)机制实现。以下是详细的实现方法: 方法一:通过自绘(Owner-Draw)实现 步骤 1:启用自绘属性 在对话框设计器中选…...
Kotlin跨平台Compose Multiplatform实战指南
Kotlin Multiplatform(KMP)结合 Compose Multiplatform 正在成为跨平台开发的热门选择,它允许开发者用一套代码构建 Android、iOS、桌面(Windows/macOS/Linux)和 Web 应用。以下是一个实战指南,涵盖核心概念…...

SpringBoot+Dubbo+Zookeeper实现分布式系统步骤
SpringBootDubboZookeeper实现分布式系统 一、分布式系统通俗解释二、环境准备(详细版)1. 软件版本2. 安装Zookeeper(单机模式) 三、完整项目结构(带详细注释)四、手把手代码实现步骤1:创建父工…...
一个极简单的 VUE3 + Element-Plus 查询表单展开收起功能组件
在管理系统页面开发时,会遇到一个简单又令人头痛的问题,那就是:搜索页面太多,搜索表单项内容太多。对于过多的内容,往往采取折叠的形式,仅展示部分内容,需要时展开查看全部。 如果在程序设计时…...
es 里的Filesystem Cache 理解
文章目录 背景问题1,Filesystem Cache 里放的是啥问题2,哪些查询它们会受益于文件系统缓存问题3 查询分析 背景 对于es 优化来说常常看到会有一条结论给,给 JVM Heap 最多不超过物理内存的 50%,且不要超过 31GB(避免压…...

Linux进程10-有名管道概述、创建、读写操作、两个管道进程间通信、读写规律(只读、只写、读写区别)、设置阻塞/非阻塞
目录 1.有名管道 1.1概述 1.2与无名管道的差异 2.有名管道的创建 2.1 直接用shell命令创建有名管道 2.2使用mkfifo函数创建有名管道 3.有名管道读写操作 3.1单次读写 3.2多次读写 4.有名管道进程间通信 4.1回合制通信 4.2父子进程通信 5.有名管道读写规律ÿ…...

精品可编辑PPT | 全面风险管理信息系统项目建设风控一体化标准方案
这份文档是一份全面风险管理信息系统项目建设风控一体化标准方案,涵盖了业务架构、功能方案、系统技术架构设计、项目实施及服务等多个方面的详细内容。方案旨在通过信息化手段提升企业全面风险管理工作水平,促进风险管理落地和内部控制规范化࿰…...
YOLOv8网络结构
YOLOv8的网络结构由输入端(Input)、骨干网络(Backbone)、颈部网络(Neck)和检测头(Head)四部分组成。 YOLOv8的网络结构如下图所示: 在整个系统架构中,图像首先进入输入处理模块,该模块承担着图像预处理与数据增强的双重任务。接着,…...
数组对象 按照对象中的某个字段排序
在JavaScript中,可以使用数组的sort()方法按照对象中的某个字段对数组进行排序。 按照对象中的某个字段对数组进行排序: 基本排序方法 升序排序 const array [{ name: John, age: 25 },{ name: Jane, age: 21 },{ name: Bob, age: 30 } ];// 按照age字…...