当前位置: 首页 > news >正文

PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []if stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

相关文章:

PyTorch从零开始实现ResNet

文章目录 代码实现参考 代码实现 本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。 代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用…...

企业微信 企业内部开发 学习笔记

官方文档 文档 术语介绍 引入pom <dependency><groupId>com.github.binarywang</groupId><artifactId>wx-java-cp-spring-boot-starter</artifactId><version>4.5.3.B</version></dependency>核心代码 推送消息 final WxCp…...

03 QT基本控件和功能类

一 进度条 、水平滑动条 垂直滑动条 当在QT中,在已知类名的情况下,要了解类的构造函数 常用属性 及 信号和槽 常用api 特征:可以获取当前控件的值和设置它的当值 ---- int ui->progressBar->setValue(value); //给进度条设置一个整型值 ui->progressBar->value…...

epoll数据结构

目录 1.大量的fd 集合。选择什么数据结构&#xff1f;2、Epoll 数据结构Epitem 的定义Eventpoll 的定义 1.大量的fd 集合。选择什么数据结构&#xff1f; 查找频率很高的数据结构 1.红黑树 2.哈希&#xff08;扩容缩容&#xff09; 3. b/btree &#xff08;降低树的高度&#…...

LINUX学习笔记_GIT操作命令

LINUX学习笔记 GIT操作命令 基本命令 git init&#xff1a;初始化仓库git status&#xff1a;查看文件状态git add&#xff1a;添加文件到暂存区&#xff08;index&#xff09;git commit -m “注释”&#xff1a;提交文件到仓库&#xff08;repository&#xff09;git log&a…...

第一百二十九天学习记录:数据结构与算法基础:栈和队列(中)(王卓教学视频)

栈的表示和实现 顺序栈的初始化 ##入栈 链栈的表示...

C语言 — qsort 函数

介绍&#xff1a;qsort是一个库函数&#xff0c;用来对数据进行排序&#xff0c;可以排序任意类型的数据。 void qsort &#xff08;void*base&#xff0c; size_t num, size_t size, int(*compart)(const void*,constvoid*) &#xff09; qsort 具有四个参数&#xff1a; …...

开放式耳机哪个好一点?推荐几款优秀的开放式耳机

在追求更广阔的音场和更真实的音质时&#xff0c;开放式耳机是绝对值得考虑的选择。它们以其通透感和自然的音质而备受推崇&#xff0c;带来更逼真的音乐体验。下面我来推荐几款优秀的开放式耳机&#xff0c;满足你对音质和舒适度的要求&#xff0c;可尽情享受音乐的魅力。 一…...

vue-cli前端工程化——创建vue-cli工程 router版本的创建 目录结构 案例初步

目录 引出创建vue-cli前端工程vue-cli是什么自动构建创建vue-cli项目选择Vue的版本号 手动安装进行选择创建成功 手动创建router版多了一个router 运行测试bug解决 Vue项目结构main.jspackage.jsonvue.config.js Vue项目初步hello案例 总结 引出 1.vue-cli是啥&#xff0c;创建…...

Go和Java实现外观模式

Go和Java实现外观模式 下面我们通过一个构造各种形状的案例来说明外观模式的使用。 1、外观模式 外观模式隐藏系统的复杂性&#xff0c;并向客户端提供了一个客户端可以访问系统的接口。这种类型的设计模式属于结构型 模式&#xff0c;它向现有的系统添加一个接口&#xff…...

人工智能(一)基本概念

人工智能之基本概念 常见问题什么是人工智能&#xff1f;人工智能应用在那些地方&#xff1f;人工智能的三种形态图灵测试是啥&#xff1f;人工智能、机器学习和深度学习之间是什么关系&#xff1f;为什么人工智能计算会用到GPU&#xff1f; 机器学习什么是机器学习&#xff1f…...

〔AI 绘画〕Stable Diffusion 之 解决绘制多人或面部很小的人物时面部崩坏问题 篇

✨ 目录 &#x1f388; 脸部崩坏&#x1f388; 下载脸部修复插件&#x1f388; 启用脸部修复插件&#x1f388; 插件生成效果&#x1f388; 插件功能详解 &#x1f388; 脸部崩坏 相信很多人在画图时候&#xff0c;特别是画 有多个人物 图片或者 人物在图片中很小 的时候&…...

初步认识OSI/TCP/IP一(第三十八课)

1 初始OSI模型 OSI参考模型(Open Systems Interconnection Reference Model)是一个由国际标准化组织(ISO)和国际电报电话咨询委员会(CCITT)联合制定的网络通信协议规范,它将网络通信分为七个不同的层次,每个层次负责不同的功能和任务。 2 网络功能 数据通信、资源共享…...

英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]&#xff08;2&#xff09;——代码分析 ASP整个模块的结果如下&#xff1a; . ├── COPYRIGHT ├── README.md ├── __init__.py ├── asp.py ├── permutation_lib.py ├── permutation_search_kernels…...

FileNotFoundError: [WinError 2] 系统找不到指定的文件。

pyspark demo程序创建spark上下文 完整报错如下&#xff1a; sc SparkContext(“local”, “Partition ID Example”) File “C:\ProgramData\anaconda3\envs\python36\lib\site-packages\pyspark\context.py”, line 133, in init SparkContext._ensure_initialized(self, ga…...

Linux: sysctl:net: IPV4_DEVCONF_ALL ignore_routes_with_linkdown; all vs default

文章目录 简介实例 ignore_routes_with_linkdownlinkdown 的引入dead的引入简介 一般下边这种类型的配置都有三种类型选项:all,default,specific net.ipv6.conf.acc.ignore_routes_with_linkdown = 0 net.ipv6.conf.all.ignore_routes_with_linkdown = 0 net.ipv6.conf.def…...

光耦继电器:实现电气隔离的卓越选择

光耦继电器是一种常用的电子元件&#xff0c;用于实现电气隔离和信号传输。在工业控制、自动化系统和电力电子等领域&#xff0c;光耦继电器具有独特的特点和优势。本文将从可靠性、隔离性、响应速度和适应性等方面对光耦继电器的特点进行概述。 光耦继电器是一种典型的固态继电…...

鸿蒙开发学习笔记2——实现页面之间跳转

鸿蒙开发学习笔记2——实现页面之间跳转 问题背景 上篇文章中&#xff0c;介绍了鸿蒙开发如何新建一个项目跑通hello world&#xff0c;本文将介绍在新建的项目中实现页面跳转的功能。 问题分析 ArkTS工程目录结构&#xff08;FA模型&#xff09; 各目录和路径的介绍如下…...

电子商务类网站需要什么配置的服务器?

随着电子商务的迅猛发展&#xff0c;越来越多的企业和创业者选择在互联网上开设自己的电商网站。为了确保电商网站能够高效运行&#xff0c;给用户提供良好的体验&#xff0c;选择合适的服务器配置至关重要。今天飞飞将和你分享电子商务类网站所需的服务器配置&#xff0c;希望…...

table 根据窗口缩放,自适应

element-plus中&#xff0c;直接应用在页面样式上&#xff0c; ::v-deep .el-table{width: 100%; } ::v-deep .el-table__header-wrapper table,::v-deep .el-table__body-wrapper table{width: 100% !important; } ::v-deep .el-table__body,::v-deep .el-table__footer,::v-d…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

是否存在路径(FIFOBB算法)

题目描述 一个具有 n 个顶点e条边的无向图&#xff0c;该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序&#xff0c;确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数&#xff0c;分别表示n 和 e 的值&#xff08;1…...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...