[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]
文章目录
- 概要
- CoordConv
- SAConv
概要
CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍:
CoordConv(Coordinate Convolution)
CoordConv 是由Uber AI Labs的研究人员提出的一种卷积操作,用于处理图像中的坐标信息。在传统的卷积操作中,卷积核在图像上滑动并执行卷积操作,但是它们对于图像中的位置信息是不敏感的。CoordConv 的目标是使卷积操作变得位置敏感,它在输入特征图中加入了位置信息作为额外的通道。这个位置信息可以是像素的坐标,也可以是归一化的坐标值,具体取决于应用的场景。
通过将坐标信息与输入特征图拼接在一起,CoordConv 能够帮助神经网络更好地学习到输入数据中的空间关系,从而提高模型的性能。它在需要考虑输入数据的空间位置信息时,特别有用。
SAConv(Spatial Attention Convolution)
SAConv 是一种引入了空间注意力机制的卷积操作。传统的卷积操作在所有位置都应用相同的卷积核,而SAConv 具有可学习的空间注意力权重,这意味着它能够动态地调整不同位置的卷积核权重。
SAConv 的关键思想是,在进行卷积操作之前,先计算每个位置的空间注意力权重。这些权重由神经网络学习得出,然后被用来加权输入特征图的不同位置,从而生成具有位置敏感性的特征表示。这种机制使得神经网络在处理输入数据时能够更加关注重要的区域,从而提高了模型的感知能力和性能。
总的来说,CoordConv 和 SAConv 都是为了增强神经网络对输入数据的空间信息处理能力而提出的方法。CoordConv 引入了位置信息通道,使得网络对位置信息更敏感,而 SAConv 引入了空间注意力机制,使得网络能够动态地调整卷积核的权重,提高了对不同位置信息的关注度。这两种方法在特定的任务和场景下都能够带来性能的提升。
CoordConv
common.py添加如下
class AddCoords(nn.Module):def __init__(self, with_r=False):super().__init__()self.with_r = with_rdef forward(self, input_tensor):"""Args:input_tensor: shape(batch, channel, x_dim, y_dim)"""batch_size, _, x_dim, y_dim = input_tensor.size()xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)xx_channel = xx_channel.float() / (x_dim - 1)yy_channel = yy_channel.float() / (y_dim - 1)xx_channel = xx_channel * 2 - 1yy_channel = yy_channel * 2 - 1xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)ret = torch.cat([input_tensor,xx_channel.type_as(input_tensor),yy_channel.type_as(input_tensor)], dim=1)if self.with_r:rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))ret = torch.cat([ret, rr], dim=1)return retclass CoordConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):super().__init__()self.addcoords = AddCoords(with_r=with_r)in_channels += 2if with_r:in_channels += 1self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)def forward(self, x):x = self.addcoords(x)x = self.conv(x)return x
在yolo.py

# yolov7 head
head:[[-1, 1, SPPCSPC, [512]], # 51[-1, 1, CoordConv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[37, 1, CoordConv, [256, 1, 1]], # route backbone P4[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 63[-1, 1, CoordConv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[24, 1, CoordConv, [128, 1, 1]], # route backbone P3[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]],[-2, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]], # 75[-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3, 63], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 88[-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]],[[-1, -3, 51], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]],[-2, 1, Conv, [512, 1, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]], # 101[75, 1, CoordConv, [256, 3, 1]],[88, 1, CoordConv, [512, 3, 1]],[101, 1, CoordConv, [1024, 3, 1]],[[102,103,104], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)]
SAConv
在common.py添加
class ConvAWS2d(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))def _get_weight(self, weight):weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)weight = weight - weight_meanstd = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)weight = weight / stdweight = self.weight_gamma * weight + self.weight_betareturn weightdef forward(self, x):weight = self._get_weight(self.weight)return super()._conv_forward(x, weight, None)def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs):self.weight_gamma.data.fill_(-1)super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs)if self.weight_gamma.data.mean() > 0:returnweight = self.weight.dataweight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)self.weight_beta.data.copy_(weight_mean)std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)self.weight_gamma.data.copy_(std)class SAConv2d(ConvAWS2d):def __init__(self,in_channels,out_channels,kernel_size,s=1,p=None,g=1,d=1,act=True,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=s,padding=autopad(kernel_size, p),dilation=d,groups=g,bias=bias)self.switch = torch.nn.Conv2d(self.in_channels,1,kernel_size=1,stride=s,bias=True)self.switch.weight.data.fill_(0)self.switch.bias.data.fill_(1)self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))self.weight_diff.data.zero_()self.pre_context = torch.nn.Conv2d(self.in_channels,self.in_channels,kernel_size=1,bias=True)self.pre_context.weight.data.fill_(0)self.pre_context.bias.data.fill_(0)self.post_context = torch.nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,bias=True)self.post_context.weight.data.fill_(0)self.post_context.bias.data.fill_(0)self.bn = nn.BatchNorm2d(out_channels)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())def forward(self, x):# pre-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)avg_x = self.pre_context(avg_x)avg_x = avg_x.expand_as(x)x = x + avg_x# switchavg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)switch = self.switch(avg_x)# sacweight = self._get_weight(self.weight)out_s = super()._conv_forward(x, weight, None)ori_p = self.paddingori_d = self.dilationself.padding = tuple(3 * p for p in self.padding)self.dilation = tuple(3 * d for d in self.dilation)weight = weight + self.weight_diffout_l = super()._conv_forward(x, weight, None)out = switch * out_s + (1 - switch) * out_lself.padding = ori_pself.dilation = ori_d# post-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)avg_x = self.post_context(avg_x)avg_x = avg_x.expand_as(out)out = out + avg_xreturn self.act(self.bn(out))
然后在yolo.py里面添加


和可变形卷积加法一样,但是不建议加太多,也是只替换3x3卷积上面。比普通卷积复杂度高,不建议加太多,推理速度变慢,尽量少用,提高精度。
相关文章:
[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]
文章目录 概要CoordConvSAConv 概要 CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍&#x…...
【万字实操】可视化运维平台openGauss Datakit,带你轻松玩转openGauss 5.0
openGauss Datakit:openGauss社区推出的可视化的运维工具. 特性优势 初级用户学习openGauss门槛高让你望而却步?openGauss Datakit一键化安装企业版集群、监控、日志分析、SQL诊断,让你快速上手,快速部署,从容面对企业环境&#…...
《动手学深度学习 Pytorch版》 10.1 注意力提示
10.1.1 生物学中的注意力提示 “美国心理学之父” 威廉詹姆斯提出的双组件(two-component)框架: 非自主性提示:基于环境中物体的突出性和易见性 自主性提示:受到了认知和意识的控制 10.1.2 查询、键和值 注意力机制…...
C# 写入文件比较
数据长度:128188个long BinaryWriter每次写一个long 耗时14.7828ms StreamWriter每次写一个long 耗时44.0934 ms FileStream每次写一个long 耗时20.5142 ms FileStream固定chunk写入,循环操作数组,耗时13.4126 ms byte[] chunk new byte[d…...
医院设备利用(Use of Hospital Facilities, ACM/ICPC World Finals 1991, UVa212)rust解法
医院里有n(n≤10)个手术室和m(m≤30)个恢复室。每个病人首先会被分配到一个手术室,手术后会被分配到一个恢复室。从任意手术室到任意恢复室的时间均为t1,准备一个手术室和恢复室的时间分别为t2和t3…...
解决github ping不通的问题(1024程序员节快乐!
1024程序员节快乐!(随便粘贴一个文档,参加活动 解决github ping不通的问题 域名解析(域名->IP):https://www.ipaddress.com/ Ubuntu平台 github经常ping不通或者访问缓慢,方法是更改host…...
QT基础 柱状图
目录 1.QBarSeries 2.QHorizontalBarSeries 3.QPercentBarSeries 4.QHorizontalPercentBarSeries 5.QStackedBarSeries 6.QHorizontalStackedBarSeries 从上图得知柱状的基类是QAbstractBarSeries,派生出来分别是柱状图的水平和垂直类,只是类型…...
微机原理与接口技术-第七章输入输出接口
文章目录 I/O接口概述I/O接口的典型结构基本功能 I/O端口的编址独立编址统一编址 输入输出指令I/O寻址方式I/O数据传输量I/O保护 16位DOS应用程序DOS平台的源程序框架DOS功能调用 无条件传送和查询传送无条件传送三态缓冲器锁存器接口电路 查询传送查询输入端口查询输出端口 中…...
YoloV8改进策略:独家原创,LSKA(大可分离核注意力)改进YoloV8,比Transformer更有效,包括论文翻译和实验结果
文章目录 摘要论文:《LSKA(大可分离核注意力):重新思考CNN大核注意力设计》1、简介2、相关工作3、方法4、实验5、消融研究6、与最先进方法的比较7、ViTs和CNNs的鲁棒性评估基准比较8、结论YoloV8官方结果改进一:测试结果摘要 本文给大家带来一种超大核注意力机制的改进方…...
7天易语言从入门到实战(一)
1.1易语言简介 易语言是一门有着伟大理想的语言。公司用的少,开发者也很少,并不影响国人对他的热情。曾经的多玩LOL,朗读女,都是陪伴再那个国产PC应用匮乏的时代。 2001年1月 吴涛研发了中国自主知识产权的的中文编程语言——易语…...
redis缓存问题
缓存击穿 缓存击穿是指某个热点数据存储在redis中,该数据在高并发的场景下,当该key过期时就会有大量的请求去查询数据库,对数据库的压力非常大,可能会导致数据库压垮。 解决方案 1.不为热点的key设置过期时间。 2.使用分布式锁…...
mysql创建自定义函数报错
mysql创建自定义函数报错:This function has none of DETERMINISTIC, NO SQL, or READS SQL DATA in its declarat… 这是我们开启了bin-log,我们就必须指定我们的函数是否是 1.DETERMINISTIC 不确定的 2.NO SQL没有sql语句,当然也不会修改数…...
Docker 的数据管理与网络通信以及Docker镜像的创建
目录 Docker的数据管理 1、数据卷 2、数据卷容器 3、端口映射 4、容器互联 二、Docker网络 1、Docker网络实现原理 2、Docker的网桥模式 1)Host 2)Container 3)none 4)bridge 5)自定义网络 3、创建自定义…...
linux系统查看bash的history
要输出最近的20条命令,可以使用history命令。在Bash终端中,输入以下命令即可获取最近的20条命令历史记录: history 20这将显示你最近执行的20条命令及其相应的行号。 要将最近的20条命令写入到一个名为 “command.txt” 的文本文件中&#…...
【T+】畅捷通T+增加会计科目提示执行超时已过期。
【问题描述】 在畅捷通T软件中, 增加会计科目的时候提示: 通过DataTable插入ext扩展表出错:执行超时已过期。 完成操作之前已超时或服务器未响应。 操作已被用户取消。 语句已终止。 【解决方法】 【方法一】 注销用户登录,回到软件登录界面…...
0基础学习VR全景平台篇第111篇:全景图拼接和编辑 - PTGui Pro教程
上课!全体起立~ 大家好,欢迎观看蛙色官方系列全景摄影课程! 前情回顾:上节,我们将源图像导入了PTGui,也设置好了各项参数。 下面我们就开始拼接全景图,并且在编辑器里进行一系列检查错位和设…...
Dynamics 365 使用ILMerge 合并CRM开发后的DLL
很久以前写过一篇博文,关于用ILMerge 命令合并DLL,当时时纯敲命令行的,现在有了更简单的方式,只需要在NuGet下载如下两个包 另外插件引用第三方dll的新方案Preview来了,不久的将来就不需要使用ILMerge了...
SpringBoot Web请求响应
目录 前言请求PostmanPostman使用 简单参数原始方式接收普通参数SpringBoot方式接收普通参数参数名不一致问题 实体参数简单实体参数复杂实体对象 数组集合参数数组参数集合参数 日期参数JSON参数路径参数 响应ResponseBody统一响应结果请求响应案例案例需求与准备工作案例实现…...
Jenkins CLI二次开发工具类
使用Jenkins CLI进行二次开发 使用背景 公司自研CI/DI平台,借助JenkinsSonarQube进行代码质量管理。对接版本 Jenkins版本为:Version 2.428 SonarQube版本为:Community EditionVersion 10.2.1 (build 78527)技术选型 Java对接Jenkins有第…...
2. 计算WPL
题目 Huffman编码是通信系统中常用的一种不等长编码,它的特点是:能够使编码之后的电文长度最短。 更多关于Huffman编码的内容参考教材第十章。 输入: 第一行为要编码的符号数量n 第二行~第n1行为每个符号出现的频率 输…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练
前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...
Java多线程实现之Callable接口深度解析
Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
LLM基础1_语言模型如何处理文本
基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...
《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
华硕a豆14 Air香氛版,美学与科技的馨香融合
在快节奏的现代生活中,我们渴望一个能激发创想、愉悦感官的工作与生活伙伴,它不仅是冰冷的科技工具,更能触动我们内心深处的细腻情感。正是在这样的期许下,华硕a豆14 Air香氛版翩然而至,它以一种前所未有的方式&#x…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...
