当前位置: 首页 > news >正文

PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []if stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

相关文章:

PyTorch从零开始实现ResNet

文章目录 代码实现参考 代码实现 本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。 代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用…...

企业微信 企业内部开发 学习笔记

官方文档 文档 术语介绍 引入pom <dependency><groupId>com.github.binarywang</groupId><artifactId>wx-java-cp-spring-boot-starter</artifactId><version>4.5.3.B</version></dependency>核心代码 推送消息 final WxCp…...

03 QT基本控件和功能类

一 进度条 、水平滑动条 垂直滑动条 当在QT中,在已知类名的情况下,要了解类的构造函数 常用属性 及 信号和槽 常用api 特征:可以获取当前控件的值和设置它的当值 ---- int ui->progressBar->setValue(value); //给进度条设置一个整型值 ui->progressBar->value…...

epoll数据结构

目录 1.大量的fd 集合。选择什么数据结构&#xff1f;2、Epoll 数据结构Epitem 的定义Eventpoll 的定义 1.大量的fd 集合。选择什么数据结构&#xff1f; 查找频率很高的数据结构 1.红黑树 2.哈希&#xff08;扩容缩容&#xff09; 3. b/btree &#xff08;降低树的高度&#…...

LINUX学习笔记_GIT操作命令

LINUX学习笔记 GIT操作命令 基本命令 git init&#xff1a;初始化仓库git status&#xff1a;查看文件状态git add&#xff1a;添加文件到暂存区&#xff08;index&#xff09;git commit -m “注释”&#xff1a;提交文件到仓库&#xff08;repository&#xff09;git log&a…...

第一百二十九天学习记录:数据结构与算法基础:栈和队列(中)(王卓教学视频)

栈的表示和实现 顺序栈的初始化 ##入栈 链栈的表示...

C语言 — qsort 函数

介绍&#xff1a;qsort是一个库函数&#xff0c;用来对数据进行排序&#xff0c;可以排序任意类型的数据。 void qsort &#xff08;void*base&#xff0c; size_t num, size_t size, int(*compart)(const void*,constvoid*) &#xff09; qsort 具有四个参数&#xff1a; …...

开放式耳机哪个好一点?推荐几款优秀的开放式耳机

在追求更广阔的音场和更真实的音质时&#xff0c;开放式耳机是绝对值得考虑的选择。它们以其通透感和自然的音质而备受推崇&#xff0c;带来更逼真的音乐体验。下面我来推荐几款优秀的开放式耳机&#xff0c;满足你对音质和舒适度的要求&#xff0c;可尽情享受音乐的魅力。 一…...

vue-cli前端工程化——创建vue-cli工程 router版本的创建 目录结构 案例初步

目录 引出创建vue-cli前端工程vue-cli是什么自动构建创建vue-cli项目选择Vue的版本号 手动安装进行选择创建成功 手动创建router版多了一个router 运行测试bug解决 Vue项目结构main.jspackage.jsonvue.config.js Vue项目初步hello案例 总结 引出 1.vue-cli是啥&#xff0c;创建…...

Go和Java实现外观模式

Go和Java实现外观模式 下面我们通过一个构造各种形状的案例来说明外观模式的使用。 1、外观模式 外观模式隐藏系统的复杂性&#xff0c;并向客户端提供了一个客户端可以访问系统的接口。这种类型的设计模式属于结构型 模式&#xff0c;它向现有的系统添加一个接口&#xff…...

人工智能(一)基本概念

人工智能之基本概念 常见问题什么是人工智能&#xff1f;人工智能应用在那些地方&#xff1f;人工智能的三种形态图灵测试是啥&#xff1f;人工智能、机器学习和深度学习之间是什么关系&#xff1f;为什么人工智能计算会用到GPU&#xff1f; 机器学习什么是机器学习&#xff1f…...

〔AI 绘画〕Stable Diffusion 之 解决绘制多人或面部很小的人物时面部崩坏问题 篇

✨ 目录 &#x1f388; 脸部崩坏&#x1f388; 下载脸部修复插件&#x1f388; 启用脸部修复插件&#x1f388; 插件生成效果&#x1f388; 插件功能详解 &#x1f388; 脸部崩坏 相信很多人在画图时候&#xff0c;特别是画 有多个人物 图片或者 人物在图片中很小 的时候&…...

初步认识OSI/TCP/IP一(第三十八课)

1 初始OSI模型 OSI参考模型(Open Systems Interconnection Reference Model)是一个由国际标准化组织(ISO)和国际电报电话咨询委员会(CCITT)联合制定的网络通信协议规范,它将网络通信分为七个不同的层次,每个层次负责不同的功能和任务。 2 网络功能 数据通信、资源共享…...

英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]&#xff08;2&#xff09;——代码分析 ASP整个模块的结果如下&#xff1a; . ├── COPYRIGHT ├── README.md ├── __init__.py ├── asp.py ├── permutation_lib.py ├── permutation_search_kernels…...

FileNotFoundError: [WinError 2] 系统找不到指定的文件。

pyspark demo程序创建spark上下文 完整报错如下&#xff1a; sc SparkContext(“local”, “Partition ID Example”) File “C:\ProgramData\anaconda3\envs\python36\lib\site-packages\pyspark\context.py”, line 133, in init SparkContext._ensure_initialized(self, ga…...

Linux: sysctl:net: IPV4_DEVCONF_ALL ignore_routes_with_linkdown; all vs default

文章目录 简介实例 ignore_routes_with_linkdownlinkdown 的引入dead的引入简介 一般下边这种类型的配置都有三种类型选项:all,default,specific net.ipv6.conf.acc.ignore_routes_with_linkdown = 0 net.ipv6.conf.all.ignore_routes_with_linkdown = 0 net.ipv6.conf.def…...

光耦继电器:实现电气隔离的卓越选择

光耦继电器是一种常用的电子元件&#xff0c;用于实现电气隔离和信号传输。在工业控制、自动化系统和电力电子等领域&#xff0c;光耦继电器具有独特的特点和优势。本文将从可靠性、隔离性、响应速度和适应性等方面对光耦继电器的特点进行概述。 光耦继电器是一种典型的固态继电…...

鸿蒙开发学习笔记2——实现页面之间跳转

鸿蒙开发学习笔记2——实现页面之间跳转 问题背景 上篇文章中&#xff0c;介绍了鸿蒙开发如何新建一个项目跑通hello world&#xff0c;本文将介绍在新建的项目中实现页面跳转的功能。 问题分析 ArkTS工程目录结构&#xff08;FA模型&#xff09; 各目录和路径的介绍如下…...

电子商务类网站需要什么配置的服务器?

随着电子商务的迅猛发展&#xff0c;越来越多的企业和创业者选择在互联网上开设自己的电商网站。为了确保电商网站能够高效运行&#xff0c;给用户提供良好的体验&#xff0c;选择合适的服务器配置至关重要。今天飞飞将和你分享电子商务类网站所需的服务器配置&#xff0c;希望…...

table 根据窗口缩放,自适应

element-plus中&#xff0c;直接应用在页面样式上&#xff0c; ::v-deep .el-table{width: 100%; } ::v-deep .el-table__header-wrapper table,::v-deep .el-table__body-wrapper table{width: 100% !important; } ::v-deep .el-table__body,::v-deep .el-table__footer,::v-d…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

如何为服务器生成TLS证书

TLS&#xff08;Transport Layer Security&#xff09;证书是确保网络通信安全的重要手段&#xff0c;它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书&#xff0c;可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN&#xff0c;根据VPN原理&#xff0c;打通两个内网必然需要借助一个公共中继节点&#xff0c;ktconnect工具巧妙的利用k8s原生的portforward能力&#xff0c;简化了建立连接的过程&#xff0c;apiserver间接起到了中继节…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...

华为OD机考-机房布局

import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...