当前位置: 首页 > news >正文

PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []if stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

相关文章:

PyTorch从零开始实现ResNet

文章目录 代码实现参考 代码实现 本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。 代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用…...

企业微信 企业内部开发 学习笔记

官方文档 文档 术语介绍 引入pom <dependency><groupId>com.github.binarywang</groupId><artifactId>wx-java-cp-spring-boot-starter</artifactId><version>4.5.3.B</version></dependency>核心代码 推送消息 final WxCp…...

03 QT基本控件和功能类

一 进度条 、水平滑动条 垂直滑动条 当在QT中,在已知类名的情况下,要了解类的构造函数 常用属性 及 信号和槽 常用api 特征:可以获取当前控件的值和设置它的当值 ---- int ui->progressBar->setValue(value); //给进度条设置一个整型值 ui->progressBar->value…...

epoll数据结构

目录 1.大量的fd 集合。选择什么数据结构&#xff1f;2、Epoll 数据结构Epitem 的定义Eventpoll 的定义 1.大量的fd 集合。选择什么数据结构&#xff1f; 查找频率很高的数据结构 1.红黑树 2.哈希&#xff08;扩容缩容&#xff09; 3. b/btree &#xff08;降低树的高度&#…...

LINUX学习笔记_GIT操作命令

LINUX学习笔记 GIT操作命令 基本命令 git init&#xff1a;初始化仓库git status&#xff1a;查看文件状态git add&#xff1a;添加文件到暂存区&#xff08;index&#xff09;git commit -m “注释”&#xff1a;提交文件到仓库&#xff08;repository&#xff09;git log&a…...

第一百二十九天学习记录:数据结构与算法基础:栈和队列(中)(王卓教学视频)

栈的表示和实现 顺序栈的初始化 ##入栈 链栈的表示...

C语言 — qsort 函数

介绍&#xff1a;qsort是一个库函数&#xff0c;用来对数据进行排序&#xff0c;可以排序任意类型的数据。 void qsort &#xff08;void*base&#xff0c; size_t num, size_t size, int(*compart)(const void*,constvoid*) &#xff09; qsort 具有四个参数&#xff1a; …...

开放式耳机哪个好一点?推荐几款优秀的开放式耳机

在追求更广阔的音场和更真实的音质时&#xff0c;开放式耳机是绝对值得考虑的选择。它们以其通透感和自然的音质而备受推崇&#xff0c;带来更逼真的音乐体验。下面我来推荐几款优秀的开放式耳机&#xff0c;满足你对音质和舒适度的要求&#xff0c;可尽情享受音乐的魅力。 一…...

vue-cli前端工程化——创建vue-cli工程 router版本的创建 目录结构 案例初步

目录 引出创建vue-cli前端工程vue-cli是什么自动构建创建vue-cli项目选择Vue的版本号 手动安装进行选择创建成功 手动创建router版多了一个router 运行测试bug解决 Vue项目结构main.jspackage.jsonvue.config.js Vue项目初步hello案例 总结 引出 1.vue-cli是啥&#xff0c;创建…...

Go和Java实现外观模式

Go和Java实现外观模式 下面我们通过一个构造各种形状的案例来说明外观模式的使用。 1、外观模式 外观模式隐藏系统的复杂性&#xff0c;并向客户端提供了一个客户端可以访问系统的接口。这种类型的设计模式属于结构型 模式&#xff0c;它向现有的系统添加一个接口&#xff…...

人工智能(一)基本概念

人工智能之基本概念 常见问题什么是人工智能&#xff1f;人工智能应用在那些地方&#xff1f;人工智能的三种形态图灵测试是啥&#xff1f;人工智能、机器学习和深度学习之间是什么关系&#xff1f;为什么人工智能计算会用到GPU&#xff1f; 机器学习什么是机器学习&#xff1f…...

〔AI 绘画〕Stable Diffusion 之 解决绘制多人或面部很小的人物时面部崩坏问题 篇

✨ 目录 &#x1f388; 脸部崩坏&#x1f388; 下载脸部修复插件&#x1f388; 启用脸部修复插件&#x1f388; 插件生成效果&#x1f388; 插件功能详解 &#x1f388; 脸部崩坏 相信很多人在画图时候&#xff0c;特别是画 有多个人物 图片或者 人物在图片中很小 的时候&…...

初步认识OSI/TCP/IP一(第三十八课)

1 初始OSI模型 OSI参考模型(Open Systems Interconnection Reference Model)是一个由国际标准化组织(ISO)和国际电报电话咨询委员会(CCITT)联合制定的网络通信协议规范,它将网络通信分为七个不同的层次,每个层次负责不同的功能和任务。 2 网络功能 数据通信、资源共享…...

英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]&#xff08;2&#xff09;——代码分析 ASP整个模块的结果如下&#xff1a; . ├── COPYRIGHT ├── README.md ├── __init__.py ├── asp.py ├── permutation_lib.py ├── permutation_search_kernels…...

FileNotFoundError: [WinError 2] 系统找不到指定的文件。

pyspark demo程序创建spark上下文 完整报错如下&#xff1a; sc SparkContext(“local”, “Partition ID Example”) File “C:\ProgramData\anaconda3\envs\python36\lib\site-packages\pyspark\context.py”, line 133, in init SparkContext._ensure_initialized(self, ga…...

Linux: sysctl:net: IPV4_DEVCONF_ALL ignore_routes_with_linkdown; all vs default

文章目录 简介实例 ignore_routes_with_linkdownlinkdown 的引入dead的引入简介 一般下边这种类型的配置都有三种类型选项:all,default,specific net.ipv6.conf.acc.ignore_routes_with_linkdown = 0 net.ipv6.conf.all.ignore_routes_with_linkdown = 0 net.ipv6.conf.def…...

光耦继电器:实现电气隔离的卓越选择

光耦继电器是一种常用的电子元件&#xff0c;用于实现电气隔离和信号传输。在工业控制、自动化系统和电力电子等领域&#xff0c;光耦继电器具有独特的特点和优势。本文将从可靠性、隔离性、响应速度和适应性等方面对光耦继电器的特点进行概述。 光耦继电器是一种典型的固态继电…...

鸿蒙开发学习笔记2——实现页面之间跳转

鸿蒙开发学习笔记2——实现页面之间跳转 问题背景 上篇文章中&#xff0c;介绍了鸿蒙开发如何新建一个项目跑通hello world&#xff0c;本文将介绍在新建的项目中实现页面跳转的功能。 问题分析 ArkTS工程目录结构&#xff08;FA模型&#xff09; 各目录和路径的介绍如下…...

电子商务类网站需要什么配置的服务器?

随着电子商务的迅猛发展&#xff0c;越来越多的企业和创业者选择在互联网上开设自己的电商网站。为了确保电商网站能够高效运行&#xff0c;给用户提供良好的体验&#xff0c;选择合适的服务器配置至关重要。今天飞飞将和你分享电子商务类网站所需的服务器配置&#xff0c;希望…...

table 根据窗口缩放,自适应

element-plus中&#xff0c;直接应用在页面样式上&#xff0c; ::v-deep .el-table{width: 100%; } ::v-deep .el-table__header-wrapper table,::v-deep .el-table__body-wrapper table{width: 100% !important; } ::v-deep .el-table__body,::v-deep .el-table__footer,::v-d…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

dify打造数据可视化图表

一、概述 在日常工作和学习中&#xff0c;我们经常需要和数据打交道。无论是分析报告、项目展示&#xff0c;还是简单的数据洞察&#xff0c;一个清晰直观的图表&#xff0c;往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server&#xff0c;由蚂蚁集团 AntV 团队…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...

LLMs 系列实操科普(1)

写在前面&#xff1a; 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容&#xff0c;原视频时长 ~130 分钟&#xff0c;以实操演示主流的一些 LLMs 的使用&#xff0c;由于涉及到实操&#xff0c;实际上并不适合以文字整理&#xff0c;但还是决定尽量整理一份笔…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

C++ 设计模式 《小明的奶茶加料风波》

&#x1f468;‍&#x1f393; 模式名称&#xff1a;装饰器模式&#xff08;Decorator Pattern&#xff09; &#x1f466; 小明最近上线了校园奶茶配送功能&#xff0c;业务火爆&#xff0c;大家都在加料&#xff1a; 有的同学要加波霸 &#x1f7e4;&#xff0c;有的要加椰果…...

在树莓派上添加音频输入设备的几种方法

在树莓派上添加音频输入设备可以通过以下步骤完成&#xff0c;具体方法取决于设备类型&#xff08;如USB麦克风、3.5mm接口麦克风或HDMI音频输入&#xff09;。以下是详细指南&#xff1a; 1. 连接音频输入设备 USB麦克风/声卡&#xff1a;直接插入树莓派的USB接口。3.5mm麦克…...