当前位置: 首页 > news >正文

【论文笔记】Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

论文地址:Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks

代码地址:https://github.com/jierunchen/fasternet

该论文主要提出了PConv,通过优化FLOPS提出了快速推理模型FasterNet。

在设计神经网络结构的时候,大部分注意力都会放在降低FLOPs( floating-point opera-
tions)上,有的时候FLOPs降低了,并不意味了推理速度加快了,这主要是因为没考虑到FLOPS(floating-point operations per second)。针对该问题,作者提出了PConv( partial convolution),通过提高FLOPS来加快推理速度。

一、引言

      非常多的实时推理模型都将重点放在降低FLOPs上,比如:MobileNet,ShuffleNet,GhostNet等等。虽然这些网络都降低了FLOPs,但是他们没有考虑到FLOPS,所以推理速度仍有优化空间,推理的延时计算公式如下:

由上式可以看出,要想加快推理速度,不仅可以从FLOPs入手,也可以优化FLOPS。作者在多个模型上做了实验,发现很多模型的FLOPS低于ResNet50。于是作者提出了PConv,通过提高FLOPS来加快推理速度。

二、PConv

为了提高FLOPS,作者提出了PConv,其结构如下图:

部分通道数经过卷积运算,其他通道不进行运算。再看了几眼。。。。这个和GhostConv好像呀。。。。

网络整体结构如下:

三、模型性能

FasterNet在ImageNet-1K上的表现如下:

在coco数据集上的表现如下:

四、代码

给出PConv的代码,也是非常简单:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from typing import List
from torch import Tensor
import copy
import ostry:from mmdet.models.builder import BACKBONES as det_BACKBONESfrom mmdet.utils import get_root_loggerfrom mmcv.runner import _load_checkpointhas_mmdet = True
except ImportError:print("If for detection, please install mmdetection first")has_mmdet = Falseclass Partial_conv3(nn.Module):def __init__(self, dim, n_div, forward):super().__init__()self.dim_conv3 = dim // n_divself.dim_untouched = dim - self.dim_conv3self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)if forward == 'slicing':self.forward = self.forward_slicingelif forward == 'split_cat':self.forward = self.forward_split_catelse:raise NotImplementedErrordef forward_slicing(self, x: Tensor) -> Tensor:# only for inferencex = x.clone()   # !!! Keep the original input intact for the residual connection laterx[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])return xdef forward_split_cat(self, x: Tensor) -> Tensor:# for training/inferencex1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)x1 = self.partial_conv3(x1)x = torch.cat((x1, x2), 1)return xclass MLPBlock(nn.Module):def __init__(self,dim,n_div,mlp_ratio,drop_path,layer_scale_init_value,act_layer,norm_layer,pconv_fw_type):super().__init__()self.dim = dimself.mlp_ratio = mlp_ratioself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.n_div = n_divmlp_hidden_dim = int(dim * mlp_ratio)mlp_layer: List[nn.Module] = [nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),norm_layer(mlp_hidden_dim),act_layer(),nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)]self.mlp = nn.Sequential(*mlp_layer)self.spatial_mixing = Partial_conv3(dim,n_div,pconv_fw_type)if layer_scale_init_value > 0:self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.forward = self.forward_layer_scaleelse:self.forward = self.forwarddef forward(self, x: Tensor) -> Tensor:shortcut = xx = self.spatial_mixing(x)x = shortcut + self.drop_path(self.mlp(x))return xdef forward_layer_scale(self, x: Tensor) -> Tensor:shortcut = xx = self.spatial_mixing(x)x = shortcut + self.drop_path(self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))return xclass BasicStage(nn.Module):def __init__(self,dim,depth,n_div,mlp_ratio,drop_path,layer_scale_init_value,norm_layer,act_layer,pconv_fw_type):super().__init__()blocks_list = [MLPBlock(dim=dim,n_div=n_div,mlp_ratio=mlp_ratio,drop_path=drop_path[i],layer_scale_init_value=layer_scale_init_value,norm_layer=norm_layer,act_layer=act_layer,pconv_fw_type=pconv_fw_type)for i in range(depth)]self.blocks = nn.Sequential(*blocks_list)def forward(self, x: Tensor) -> Tensor:x = self.blocks(x)return xclass PatchEmbed(nn.Module):def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = nn.Identity()def forward(self, x: Tensor) -> Tensor:x = self.norm(self.proj(x))return xclass PatchMerging(nn.Module):def __init__(self, patch_size2, patch_stride2, dim, norm_layer):super().__init__()self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)if norm_layer is not None:self.norm = norm_layer(2 * dim)else:self.norm = nn.Identity()def forward(self, x: Tensor) -> Tensor:x = self.norm(self.reduction(x))return xclass FasterNet(nn.Module):def __init__(self,in_chans=3,num_classes=1000,embed_dim=96,depths=(1, 2, 8, 2),mlp_ratio=2.,n_div=4,patch_size=4,patch_stride=4,patch_size2=2,  # for subsequent layerspatch_stride2=2,patch_norm=True,feature_dim=1280,drop_path_rate=0.1,layer_scale_init_value=0,norm_layer='BN',act_layer='RELU',fork_feat=False,init_cfg=None,pretrained=None,pconv_fw_type='split_cat',**kwargs):super().__init__()if norm_layer == 'BN':norm_layer = nn.BatchNorm2delse:raise NotImplementedErrorif act_layer == 'GELU':act_layer = nn.GELUelif act_layer == 'RELU':act_layer = partial(nn.ReLU, inplace=True)else:raise NotImplementedErrorif not fork_feat:self.num_classes = num_classesself.num_stages = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_normself.num_features = int(embed_dim * 2 ** (self.num_stages - 1))self.mlp_ratio = mlp_ratioself.depths = depths# split image into non-overlapping patchesself.patch_embed = PatchEmbed(patch_size=patch_size,patch_stride=patch_stride,in_chans=in_chans,embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)# stochastic depth decay ruledpr = [x.item()for x in torch.linspace(0, drop_path_rate, sum(depths))]# build layersstages_list = []for i_stage in range(self.num_stages):stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),n_div=n_div,depth=depths[i_stage],mlp_ratio=self.mlp_ratio,drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],layer_scale_init_value=layer_scale_init_value,norm_layer=norm_layer,act_layer=act_layer,pconv_fw_type=pconv_fw_type)stages_list.append(stage)# patch merging layerif i_stage < self.num_stages - 1:stages_list.append(PatchMerging(patch_size2=patch_size2,patch_stride2=patch_stride2,dim=int(embed_dim * 2 ** i_stage),norm_layer=norm_layer))self.stages = nn.Sequential(*stages_list)self.fork_feat = fork_featif self.fork_feat:self.forward = self.forward_det# add a norm layer for each outputself.out_indices = [0, 2, 4, 6]for i_emb, i_layer in enumerate(self.out_indices):if i_emb == 0 and os.environ.get('FORK_LAST3', None):raise NotImplementedErrorelse:layer = norm_layer(int(embed_dim * 2 ** i_emb))layer_name = f'norm{i_layer}'self.add_module(layer_name, layer)else:self.forward = self.forward_cls# Classifier headself.avgpool_pre_head = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(self.num_features, feature_dim, 1, bias=False),act_layer())self.head = nn.Linear(feature_dim, num_classes) \if num_classes > 0 else nn.Identity()self.apply(self.cls_init_weights)self.init_cfg = copy.deepcopy(init_cfg)if self.fork_feat and (self.init_cfg is not None or pretrained is not None):self.init_weights()def cls_init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, (nn.Conv1d, nn.Conv2d)):trunc_normal_(m.weight, std=.02)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)# init for mmdetection by loading imagenet pre-trained weightsdef init_weights(self, pretrained=None):logger = get_root_logger()if self.init_cfg is None and pretrained is None:logger.warn(f'No pre-trained weights for 'f'{self.__class__.__name__}, 'f'training start from scratch')passelse:assert 'checkpoint' in self.init_cfg, f'Only support ' \f'specify `Pretrained` in ' \f'`init_cfg` in ' \f'{self.__class__.__name__} 'if self.init_cfg is not None:ckpt_path = self.init_cfg['checkpoint']elif pretrained is not None:ckpt_path = pretrainedckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu')if 'state_dict' in ckpt:_state_dict = ckpt['state_dict']elif 'model' in ckpt:_state_dict = ckpt['model']else:_state_dict = ckptstate_dict = _state_dictmissing_keys, unexpected_keys = \self.load_state_dict(state_dict, False)# show for debugprint('missing_keys: ', missing_keys)print('unexpected_keys: ', unexpected_keys)def forward_cls(self, x):# output only the features of last layer for image classificationx = self.patch_embed(x)x = self.stages(x)x = self.avgpool_pre_head(x)  # B C 1 1x = torch.flatten(x, 1)x = self.head(x)return xdef forward_det(self, x: Tensor) -> Tensor:# output the features of four stages for dense predictionx = self.patch_embed(x)outs = []for idx, stage in enumerate(self.stages):x = stage(x)if self.fork_feat and idx in self.out_indices:norm_layer = getattr(self, f'norm{idx}')x_out = norm_layer(x)outs.append(x_out)return outs

相关文章:

【论文笔记】Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

论文地址&#xff1a;Run, Dont Walk: Chasing Higher FLOPS for Faster Neural Networks 代码地址&#xff1a;https://github.com/jierunchen/fasternet 该论文主要提出了PConv&#xff0c;通过优化FLOPS提出了快速推理模型FasterNet。 在设计神经网络结构的时候&#xff…...

python常用函数汇总

python常用函数汇总 对准蓝字按下左键可以跳转哦 类型函数数值相关函数abs() divmod() max() min() pow() round() sum()类型转换函数ascii() bin() hex() oct() bool() bytearray() bytes() chr() complex() float() int() 迭代和循环函数iter() next() e…...

阶段十-物业项目

可能遇到的错误&#xff1a; 解决jdk17javax.xml.bind.DatatypeConverter错误 <!--解决jdk17javax.xml.bind.DatatypeConverter错误--><dependency><groupId>javax.xml.bind</groupId><artifactId>jaxb-api</artifactId><version>…...

使用 Jekyll 构建你的网站 - 初入门

文章目录 一、Jekyll介绍二、Jekyll安装和启动2.1 配置Ruby环境1&#xff09;Windows2&#xff09;macOS 2.2 安装 Jekyll2.3 构建Jekyll项目2.4 启动 Jekyll 服务 三、Jekyll常用命令四、目录结构4.1 主要目录4.2 其他的约定目录 五、使用GitLink构建Jekyll博客5.1 生成Jekyll…...

【数据库】postgressql设置数据库执行超时时间

在这篇文章中&#xff0c;我们将深入探讨PostgreSQL数据库中的一个关键设置&#xff1a;SET statement_timeout。这个设置对于管理数据库性能和优化查询执行时间非常重要。让我们一起来了解它的工作原理以及如何有效地使用它。 什么是statement_timeout&#xff1f; statemen…...

SQL语言之DDL

目录结构 SQL语言之DDLDDL操作数据库查询数据库创建数据库删除数据库使用某个数据库案例 DDL操作表创建表查看表结构查询表修改表添加字段删除字段修改字段的类型修改字段名和字段类型 修改表名删除表案例 SQL语言之DDL ​ DDL&#xff1a;数据定义语言&#xff0c;用来定义数…...

hive高级查询(2)

-- 分组查询 SELECT sex,SUM(mark) sum_mark FROM score GROUP BY sex HAVING sum_mark > 555; SELECT sex,sum_mark FROM( SELECT sex,SUM(mark) sum_mark FROM score GROUP BY sex ) t WHERE sum_mark > 555; SELECT AVG(gid),SUM(gid)/COUNT(gid) FROM …...

golang的jwt学习笔记

文章目录 初始化项目加密一步一步编写程序另一个参数--加密方式关于StandardClaims 解密解析出来的怎么用关于`MapClaims`上面使用结构体的全代码实战项目关于验证这个项目的前端初始化项目 自然第一步是暗转jwt-go的依赖啦 #go get github.com/golang-jwt/jwt/v5 go get githu…...

第十五节TypeScript 接口

1、简介 接口是一系列抽象方法的声明&#xff0c;是一些方法特征的集合&#xff0c;这些方法都应该是抽象的&#xff0c;需要有由具体的类去实现&#xff0c;然后第三方就可以通过这组抽象方法调用&#xff0c;让具体的类执行具体的方法。 2、接口的定义 interface interface_…...

【hadoop】解决浏览器不能访问Hadoop的50070、8088等端口?!

【hadoop】解决浏览器不能访问Hadoop的50070、8088等端口&#xff1f;&#xff01;&#x1f60e; 前言&#x1f64c;【hadoop】解决浏览器不能访问Hadoop的50070、8088等端口&#xff1f;&#xff01;查看自己的配置文件&#xff1a;最终成功访问如图所示&#xff1a; 总结撒花…...

14.bash shell中的for/while/until循环

文章目录 shell循环语句for命令**读取列表中的值****读取列表中的复杂值****从变量读取列表**迭代数组**从命令读取值****用通配符读取目录**C语言风格的shell for循环 shell循环while命令shell 循环的until命令shell循环跳出的break/continue命令break命令continue命令trick 欢…...

RPC(6):RMI实现RPC

1RMI简介 RMI(Remote Method Invocation) 远程方法调用。 RMI是从JDK1.2推出的功能&#xff0c;它可以实现在一个Java应用中可以像调用本地方法一样调用另一个服务器中Java应用&#xff08;JVM&#xff09;中的内容。 RMI 是Java语言的远程调用&#xff0c;无法实现跨语言。…...

strlen和sizeof的初步理解

大家好我是Beilef&#xff0c;一个美好的下我接触到编程并且逐渐喜欢。我虽然不是科班出身但是我会更加努力地去学&#xff0c;有啥不对的地方请斧正 文章目录 目录 文章目录 前言 想必大家对sizeof肯定很了解&#xff0c;那对strlen又了解多少。其实这个问题应该让不少人困扰。…...

纯CSS的华为充电动画,它来了

&#x1f4e2; 鸿蒙专栏&#xff1a;想学鸿蒙的&#xff0c;冲 &#x1f4e2; C语言专栏&#xff1a;想学C语言的&#xff0c;冲 &#x1f4e2; VUE专栏&#xff1a;想学VUE的&#xff0c;冲这里 &#x1f4e2; Krpano专栏&#xff1a;想学Krpano的&#xff0c;冲 &#x1f514…...

在架构设计中,前后端分离有什么好处?

前后端分离是一种架构设计模式&#xff0c;将前端和后端的开发分别独立进行&#xff0c;它带来了多方面的好处&#xff1a; 1、独立开发和维护&#xff1a; 前后端分离允许前端和后端开发团队独立进行工作。这意味着两个团队可以并行开发&#xff0c;提高了整体的开发效率。前…...

C语言中的结构体和联合体:异同及应用

文章目录 C语言中的结构体和联合体&#xff1a;异同及应用1. 结构体&#xff08;Struct&#xff09;的概述代码示例&#xff1a; 2. 联合体&#xff08;Union&#xff09;的概述代码示例&#xff1a; 3. 结构体与联合体的异同点相同点&#xff1a;不同点&#xff1a;代码说明 结…...

文件夹共享(普通共享和高级共享的区别)防火墙设置(包括了jdk安装和Tomcat)

文章目录 一、共享文件1.1为什么需要配置文件夹共享功能&#xff1f;1.2配置文件共享功能1.3高级共享和普通共享的区别&#xff1a; 二、防火墙设置2.1先要在虚拟机上安装JDK和Tomcat供外部访问。2.2设置防火墙&#xff1a; 一、共享文件 1.1为什么需要配置文件夹共享功能&…...

❀My排序算法学习之冒泡排序❀

目录 冒泡排序(Bubble Sort):) 一、定义 二、算法原理 三、算法分析 时间复杂度 算法稳定性 算法描述 C语言 C++ 算法比较 插入排序 选择排序 快速排序 归并排序 冒泡排序(Bubble Sort):) 一、定义 冒泡排序(Bubble Sort),是一种计算机科学领域的较简单…...

服务器数据恢复-raid6离线磁盘强制上线后分区打不开的数据恢复案例

服务器数据恢复环境&#xff1a; 服务器上有一组由12块硬盘组建的raid6磁盘阵列&#xff0c;raid6阵列上层有一个lun&#xff0c;映射到WINDOWS系统上使用&#xff0c;WINDOWS系统划分了一个GPT分区。 服务器故障&分析&#xff1a; 服务器在运行过程中突然无法访问。对服务…...

Zookeeper在分布式命名服务中的实践

Java学习面试指南&#xff1a;https://javaxiaobear.cn 命名服务是为系统中的资源提供标识能力。ZooKeeper的命名服务主要是利用ZooKeeper节点的树形分层结构和子节点的顺序维护能力&#xff0c;来为分布式系统中的资源命名。 哪些应用场景需要用到分布式命名服务呢&#xff1…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0&#xff1a;开发环境同步测试 cookie 至 localhost&#xff0c;便于本地请求服务携带 cookie 参考地址&#xff1a;https://juejin.cn/post/7139354571712757767 里面有源码下载下来&#xff0c;加在到扩展即可使用FeHelp…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题&#xff1a; 下面创建一个简单的Flask RESTful API示例。首先&#xff0c;我们需要创建环境&#xff0c;安装必要的依赖&#xff0c;然后…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

蓝桥杯 2024 15届国赛 A组 儿童节快乐

P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡&#xff0c;轻快的音乐在耳边持续回荡&#xff0c;小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下&#xff0c;六一来了。 今天是六一儿童节&#xff0c;小蓝老师为了让大家在节…...

Module Federation 和 Native Federation 的比较

前言 Module Federation 是 Webpack 5 引入的微前端架构方案&#xff0c;允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...