当前位置: 首页 > news >正文

[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]

文章目录

    • 概要
    • CoordConv
    • SAConv

概要

CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍:
CoordConv(Coordinate Convolution)

CoordConv 是由Uber AI Labs的研究人员提出的一种卷积操作,用于处理图像中的坐标信息。在传统的卷积操作中,卷积核在图像上滑动并执行卷积操作,但是它们对于图像中的位置信息是不敏感的。CoordConv 的目标是使卷积操作变得位置敏感,它在输入特征图中加入了位置信息作为额外的通道。这个位置信息可以是像素的坐标,也可以是归一化的坐标值,具体取决于应用的场景。

通过将坐标信息与输入特征图拼接在一起,CoordConv 能够帮助神经网络更好地学习到输入数据中的空间关系,从而提高模型的性能。它在需要考虑输入数据的空间位置信息时,特别有用。
SAConv(Spatial Attention Convolution)

SAConv 是一种引入了空间注意力机制的卷积操作。传统的卷积操作在所有位置都应用相同的卷积核,而SAConv 具有可学习的空间注意力权重,这意味着它能够动态地调整不同位置的卷积核权重。

SAConv 的关键思想是,在进行卷积操作之前,先计算每个位置的空间注意力权重。这些权重由神经网络学习得出,然后被用来加权输入特征图的不同位置,从而生成具有位置敏感性的特征表示。这种机制使得神经网络在处理输入数据时能够更加关注重要的区域,从而提高了模型的感知能力和性能。

总的来说,CoordConv 和 SAConv 都是为了增强神经网络对输入数据的空间信息处理能力而提出的方法。CoordConv 引入了位置信息通道,使得网络对位置信息更敏感,而 SAConv 引入了空间注意力机制,使得网络能够动态地调整卷积核的权重,提高了对不同位置信息的关注度。这两种方法在特定的任务和场景下都能够带来性能的提升。

CoordConv

common.py添加如下

class AddCoords(nn.Module):def __init__(self, with_r=False):super().__init__()self.with_r = with_rdef forward(self, input_tensor):"""Args:input_tensor: shape(batch, channel, x_dim, y_dim)"""batch_size, _, x_dim, y_dim = input_tensor.size()xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)xx_channel = xx_channel.float() / (x_dim - 1)yy_channel = yy_channel.float() / (y_dim - 1)xx_channel = xx_channel * 2 - 1yy_channel = yy_channel * 2 - 1xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)ret = torch.cat([input_tensor,xx_channel.type_as(input_tensor),yy_channel.type_as(input_tensor)], dim=1)if self.with_r:rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))ret = torch.cat([ret, rr], dim=1)return retclass CoordConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):super().__init__()self.addcoords = AddCoords(with_r=with_r)in_channels += 2if with_r:in_channels += 1self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)def forward(self, x):x = self.addcoords(x)x = self.conv(x)return x

在yolo.py

在这里插入图片描述

# yolov7 head
head:[[-1, 1, SPPCSPC, [512]], # 51[-1, 1, CoordConv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[37, 1, CoordConv, [256, 1, 1]], # route backbone P4[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 63[-1, 1, CoordConv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[24, 1, CoordConv, [128, 1, 1]], # route backbone P3[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]],[-2, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]], # 75[-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3, 63], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 88[-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]],[[-1, -3, 51], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]],[-2, 1, Conv, [512, 1, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]], # 101[75, 1, CoordConv, [256, 3, 1]],[88, 1, CoordConv, [512, 3, 1]],[101, 1, CoordConv, [1024, 3, 1]],[[102,103,104], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)]

SAConv

在common.py添加

class ConvAWS2d(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))def _get_weight(self, weight):weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)weight = weight - weight_meanstd = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)weight = weight / stdweight = self.weight_gamma * weight + self.weight_betareturn weightdef forward(self, x):weight = self._get_weight(self.weight)return super()._conv_forward(x, weight, None)def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs):self.weight_gamma.data.fill_(-1)super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs)if self.weight_gamma.data.mean() > 0:returnweight = self.weight.dataweight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)self.weight_beta.data.copy_(weight_mean)std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)self.weight_gamma.data.copy_(std)class SAConv2d(ConvAWS2d):def __init__(self,in_channels,out_channels,kernel_size,s=1,p=None,g=1,d=1,act=True,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=s,padding=autopad(kernel_size, p),dilation=d,groups=g,bias=bias)self.switch = torch.nn.Conv2d(self.in_channels,1,kernel_size=1,stride=s,bias=True)self.switch.weight.data.fill_(0)self.switch.bias.data.fill_(1)self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))self.weight_diff.data.zero_()self.pre_context = torch.nn.Conv2d(self.in_channels,self.in_channels,kernel_size=1,bias=True)self.pre_context.weight.data.fill_(0)self.pre_context.bias.data.fill_(0)self.post_context = torch.nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,bias=True)self.post_context.weight.data.fill_(0)self.post_context.bias.data.fill_(0)self.bn = nn.BatchNorm2d(out_channels)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())def forward(self, x):# pre-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)avg_x = self.pre_context(avg_x)avg_x = avg_x.expand_as(x)x = x + avg_x# switchavg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)switch = self.switch(avg_x)# sacweight = self._get_weight(self.weight)out_s = super()._conv_forward(x, weight, None)ori_p = self.paddingori_d = self.dilationself.padding = tuple(3 * p for p in self.padding)self.dilation = tuple(3 * d for d in self.dilation)weight = weight + self.weight_diffout_l = super()._conv_forward(x, weight, None)out = switch * out_s + (1 - switch) * out_lself.padding = ori_pself.dilation = ori_d# post-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)avg_x = self.post_context(avg_x)avg_x = avg_x.expand_as(out)out = out + avg_xreturn self.act(self.bn(out))

然后在yolo.py里面添加
在这里插入图片描述
在这里插入图片描述
和可变形卷积加法一样,但是不建议加太多,也是只替换3x3卷积上面。比普通卷积复杂度高,不建议加太多,推理速度变慢,尽量少用,提高精度。

相关文章:

[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]

文章目录 概要CoordConvSAConv 概要 CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍&#x…...

【万字实操】可视化运维平台openGauss Datakit,带你轻松玩转openGauss 5.0

openGauss Datakit:openGauss社区推出的可视化的运维工具. 特性优势 初级用户学习openGauss门槛高让你望而却步?openGauss Datakit一键化安装企业版集群、监控、日志分析、SQL诊断,让你快速上手,快速部署,从容面对企业环境&#…...

《动手学深度学习 Pytorch版》 10.1 注意力提示

10.1.1 生物学中的注意力提示 “美国心理学之父” 威廉詹姆斯提出的双组件(two-component)框架: 非自主性提示:基于环境中物体的突出性和易见性 自主性提示:受到了认知和意识的控制 10.1.2 查询、键和值 注意力机制…...

C# 写入文件比较

数据长度:128188个long BinaryWriter每次写一个long 耗时14.7828ms StreamWriter每次写一个long 耗时44.0934 ms FileStream每次写一个long 耗时20.5142 ms FileStream固定chunk写入,循环操作数组,耗时13.4126 ms byte[] chunk new byte[d…...

医院设备利用(Use of Hospital Facilities, ACM/ICPC World Finals 1991, UVa212)rust解法

医院里有n(n≤10)个手术室和m(m≤30)个恢复室。每个病人首先会被分配到一个手术室,手术后会被分配到一个恢复室。从任意手术室到任意恢复室的时间均为t1,准备一个手术室和恢复室的时间分别为t2和t3&#xf…...

解决github ping不通的问题(1024程序员节快乐!

1024程序员节快乐!(随便粘贴一个文档,参加活动 解决github ping不通的问题 域名解析(域名->IP):https://www.ipaddress.com/ Ubuntu平台 github经常ping不通或者访问缓慢,方法是更改host…...

QT基础 柱状图

目录 1.QBarSeries 2.QHorizontalBarSeries 3.QPercentBarSeries 4.QHorizontalPercentBarSeries 5.QStackedBarSeries 6.QHorizontalStackedBarSeries 从上图得知柱状的基类是QAbstractBarSeries,派生出来分别是柱状图的水平和垂直类,只是类型…...

微机原理与接口技术-第七章输入输出接口

文章目录 I/O接口概述I/O接口的典型结构基本功能 I/O端口的编址独立编址统一编址 输入输出指令I/O寻址方式I/O数据传输量I/O保护 16位DOS应用程序DOS平台的源程序框架DOS功能调用 无条件传送和查询传送无条件传送三态缓冲器锁存器接口电路 查询传送查询输入端口查询输出端口 中…...

YoloV8改进策略:独家原创,LSKA(大可分离核注意力)改进YoloV8,比Transformer更有效,包括论文翻译和实验结果

文章目录 摘要论文:《LSKA(大可分离核注意力):重新思考CNN大核注意力设计》1、简介2、相关工作3、方法4、实验5、消融研究6、与最先进方法的比较7、ViTs和CNNs的鲁棒性评估基准比较8、结论YoloV8官方结果改进一:测试结果摘要 本文给大家带来一种超大核注意力机制的改进方…...

7天易语言从入门到实战(一)

1.1易语言简介 易语言是一门有着伟大理想的语言。公司用的少,开发者也很少,并不影响国人对他的热情。曾经的多玩LOL,朗读女,都是陪伴再那个国产PC应用匮乏的时代。 2001年1月 吴涛研发了中国自主知识产权的的中文编程语言——易语…...

redis缓存问题

缓存击穿 缓存击穿是指某个热点数据存储在redis中,该数据在高并发的场景下,当该key过期时就会有大量的请求去查询数据库,对数据库的压力非常大,可能会导致数据库压垮。 解决方案 1.不为热点的key设置过期时间。 2.使用分布式锁…...

mysql创建自定义函数报错

mysql创建自定义函数报错:This function has none of DETERMINISTIC, NO SQL, or READS SQL DATA in its declarat… 这是我们开启了bin-log,我们就必须指定我们的函数是否是 1.DETERMINISTIC 不确定的 2.NO SQL没有sql语句,当然也不会修改数…...

Docker 的数据管理与网络通信以及Docker镜像的创建

目录 Docker的数据管理 1、数据卷 2、数据卷容器 3、端口映射 4、容器互联 二、Docker网络 1、Docker网络实现原理 2、Docker的网桥模式 1)Host 2)Container 3)none 4)bridge 5)自定义网络 3、创建自定义…...

linux系统查看bash的history

要输出最近的20条命令,可以使用history命令。在Bash终端中,输入以下命令即可获取最近的20条命令历史记录: history 20这将显示你最近执行的20条命令及其相应的行号。 要将最近的20条命令写入到一个名为 “command.txt” 的文本文件中&#…...

【T+】畅捷通T+增加会计科目提示执行超时已过期。

【问题描述】 在畅捷通T软件中, 增加会计科目的时候提示: 通过DataTable插入ext扩展表出错:执行超时已过期。 完成操作之前已超时或服务器未响应。 操作已被用户取消。 语句已终止。 【解决方法】 【方法一】 注销用户登录,回到软件登录界面…...

0基础学习VR全景平台篇第111篇:全景图拼接和编辑 - PTGui Pro教程

上课!全体起立~ 大家好,欢迎观看蛙色官方系列全景摄影课程! 前情回顾:上节,我们将源图像导入了PTGui,也设置好了各项参数。 下面我们就开始拼接全景图,并且在编辑器里进行一系列检查错位和设…...

Dynamics 365 使用ILMerge 合并CRM开发后的DLL

很久以前写过一篇博文,关于用ILMerge 命令合并DLL,当时时纯敲命令行的,现在有了更简单的方式,只需要在NuGet下载如下两个包 另外插件引用第三方dll的新方案Preview来了,不久的将来就不需要使用ILMerge了...

SpringBoot Web请求响应

目录 前言请求PostmanPostman使用 简单参数原始方式接收普通参数SpringBoot方式接收普通参数参数名不一致问题 实体参数简单实体参数复杂实体对象 数组集合参数数组参数集合参数 日期参数JSON参数路径参数 响应ResponseBody统一响应结果请求响应案例案例需求与准备工作案例实现…...

Jenkins CLI二次开发工具类

使用Jenkins CLI进行二次开发 使用背景 公司自研CI/DI平台,借助JenkinsSonarQube进行代码质量管理。对接版本 Jenkins版本为:Version 2.428 SonarQube版本为:Community EditionVersion 10.2.1 (build 78527)技术选型 Java对接Jenkins有第…...

2. 计算WPL

题目 Huffman编码是通信系统中常用的一种不等长编码,它的特点是:能够使编码之后的电文长度最短。 更多关于Huffman编码的内容参考教材第十章。 输入: 第一行为要编码的符号数量n 第二行~第n1行为每个符号出现的频率 输…...

接口测试中缓存处理策略

在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...

三体问题详解

从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...

NFT模式:数字资产确权与链游经济系统构建

NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南 在数字化营销时代,邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天,我们将深入解析邮件打开率、网站可用性、页面参与时…...

2023赣州旅游投资集团

单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

day36-多路IO复用

一、基本概念 (服务器多客户端模型) 定义:单线程或单进程同时监测若干个文件描述符是否可以执行IO操作的能力 作用:应用程序通常需要处理来自多条事件流中的事件,比如我现在用的电脑,需要同时处理键盘鼠标…...

协议转换利器,profinet转ethercat网关的两大派系,各有千秋

随着工业以太网的发展,其高效、便捷、协议开放、易于冗余等诸多优点,被越来越多的工业现场所采用。西门子SIMATIC S7-1200/1500系列PLC集成有Profinet接口,具有实时性、开放性,使用TCP/IP和IT标准,符合基于工业以太网的…...

9-Oracle 23 ai Vector Search 特性 知识准备

很多小伙伴是不是参加了 免费认证课程(限时至2025/5/15) Oracle AI Vector Search 1Z0-184-25考试,都顺利拿到certified了没。 各行各业的AI 大模型的到来,传统的数据库中的SQL还能不能打,结构化和非结构的话数据如何和…...

pgsql:还原数据库后出现重复序列导致“more than one owned sequence found“报错问题的解决

问题: pgsql数据库通过备份数据库文件进行还原时,如果表中有自增序列,还原后可能会出现重复的序列,此时若向表中插入新行时会出现“more than one owned sequence found”的报错提示。 点击菜单“其它”-》“序列”,…...