当前位置: 首页 > news >正文

【Block总结】WTConv,小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野

论文解读:Wavelet Convolutions for Large Receptive Fields

论文信息

  • 标题: Wavelet Convolutions for Large Receptive Fields
  • 作者: Shahaf E. Finder, Roy Amoyal, Eran Treister, Oren Freifeld
  • 提交日期: 2024年7月8日
  • arXiv链接: Wavelet Convolutions for Large Receptive Fields
  • Github: https://github.com/BGU-CS-VIL/WTConv

概述

论文《Wavelet Convolutions for Large Receptive Fields》提出了一种新型卷积层,称为WTConv(Wavelet Transform Convolution),旨在通过小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野。该方法能够在不显著增加参数数量的情况下,获得接近全局的感受野,从而提高模型对低频信息的捕捉能力。
在这里插入图片描述

主要贡献

  1. 感受野扩展:传统的卷积神经网络通过增加卷积核的大小来扩展感受野,但这种方法在达到一定程度后会遇到参数过多的问题。WTConv通过小波变换实现了感受野的有效扩展,且参数数量仅以对数方式增长。

  2. 多频率响应:WTConv能够有效地响应不同频率的输入信号,增强了模型对形状的响应能力,而不仅仅是对纹理的响应。

  3. 架构兼容性:WTConv可以作为现有架构的替代层,适用于多种网络结构,如ConvNeXt和MobileNetV2,且在图像分类等下游任务中表现出色。
    在这里插入图片描述

WTConv如何在不增加参数的情况下扩展感受野

WTConv(Wavelet Transform Convolution)是一种新型卷积层,旨在通过小波变换(Wavelet Transform)有效扩展卷积神经网络(CNN)的感受野,而不显著增加模型的参数数量。这一方法的核心在于利用小波变换的特性,使得感受野的扩展与参数的增长呈对数关系。

  1. 小波变换的优势:小波变换能够将信号分解为不同频率的成分,这使得WTConv能够同时捕捉到低频和高频信息。通过这种方式,WTConv可以在保持较小卷积核的情况下,获得较大的感受野。

  2. 参数增长控制:传统的卷积层通过增加卷积核的大小来扩展感受野,但这会导致参数数量的急剧增加。WTConv的设计使得对于一个 k × k k \times k k×k 的感受野,所需的可训练参数数量仅以对数方式增长,这样可以有效避免过度参数化的问题[7][8]。

  3. 架构兼容性:WTConv可以作为现有网络架构的替代层,例如ConvNeXt和MobileNetV2,能够无缝集成到这些模型中,增强其对形状的响应能力,并提高对图像损坏的鲁棒性[5][10]。

实验结果

在多个图像分类任务中,WTConv表现出色,尤其是在处理复杂形状和纹理时,显示出更强的适应性和准确性,在图像分类任务中优于传统卷积层,尤其在处理图像损坏和复杂形状时表现出更强的鲁棒性。
。这表明WTConv不仅在理论上有效,而且在实际应用中也具有良好的性能。

通过这些机制,WTConv实现了感受野的有效扩展,同时保持了模型的参数效率,适应了现代深度学习对计算资源的需求。

代码:

import torch
import torch.nn as nn
import pywt
import pywt.dataimport torch.nn.functional as Fdef create_wavelet_filter(wave, in_size, out_size, type=torch.float):w = pywt.Wavelet(wave)dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)return dec_filters, rec_filtersdef wavelet_transform(x, filters):b, c, h, w = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)x = x.reshape(b, c, 4, h // 2, w // 2)return xdef inverse_wavelet_transform(x, filters):b, c, _, h_half, w_half = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = x.reshape(b, c * 4, h_half, w_half)x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)return xclass WTConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):super(WTConv2d, self).__init__()assert in_channels == out_channelsself.in_channels = in_channelsself.wt_levels = wt_levelsself.stride = strideself.dilation = 1self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels, bias=bias)self.base_scale = _ScaleModule([1, in_channels, 1, 1])self.wavelet_convs = nn.ModuleList([nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)])self.wavelet_scale = nn.ModuleList([_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)])if self.stride > 1:self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride)else:self.do_stride = Nonedef forward(self, x):x_ll_in_levels = []x_h_in_levels = []shapes_in_levels = []curr_x_ll = xfor i in range(self.wt_levels):curr_shape = curr_x_ll.shapeshapes_in_levels.append(curr_shape)if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)curr_x_ll = F.pad(curr_x_ll, curr_pads)curr_x =wavelet_transform(curr_x_ll, self.wt_filter)curr_x_ll = curr_x[:, :, 0, :, :]shape_x = curr_x.shapecurr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))curr_x_tag = curr_x_tag.reshape(shape_x)x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])next_x_ll = 0for i in range(self.wt_levels - 1, -1, -1):curr_x_ll = x_ll_in_levels.pop()curr_x_h = x_h_in_levels.pop()curr_shape = shapes_in_levels.pop()curr_x_ll = curr_x_ll + next_x_llcurr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)next_x_ll = inverse_wavelet_transform(curr_x, self.iwt_filter)next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]x_tag = next_x_llassert len(x_ll_in_levels) == 0x = self.base_scale(self.base_conv(x))x = x + x_tagif self.do_stride is not None:x = self.do_stride(x)return xclass _ScaleModule(nn.Module):def __init__(self, dims, init_scale=1.0, init_bias=0):super(_ScaleModule, self).__init__()self.dims = dimsself.weight = nn.Parameter(torch.ones(*dims) * init_scale)self.bias = Nonedef forward(self, x):return torch.mul(self.weight, x)if __name__ == '__main__':# 创建一个随机输入张量,形状为 (batch_size,height×width,channels)input1 = torch.rand(1, 64,40, 40)# 实例化EFC模块block = WTConv2d(64,64,kernel_size=7)# 前向传播output = block(input1)# 打印输入和输出的形状print(input1.size())print(output.size())

输出结果:

torch.Size([1, 64, 40, 40])
torch.Size([1, 64, 40, 40])

相关文章:

【Block总结】WTConv,小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野

论文解读:Wavelet Convolutions for Large Receptive Fields 论文信息 标题: Wavelet Convolutions for Large Receptive Fields作者: Shahaf E. Finder, Roy Amoyal, Eran Treister, Oren Freifeld提交日期: 2024年7月8日arXiv链接: Wavelet Convolutions for La…...

深入探究分布式日志系统 Graylog:架构、部署与优化

文章目录 一、Graylog简介二、Graylog原理架构三、日志系统对比四、Graylog部署传统部署MongoDB部署OS或者ES部署Garylog部署容器化部署 五、配置详情六、优化网络和 REST APIMongoDB 七、升级八、监控九、常见问题及处理 一、Graylog简介 Graylog是一个简单易用、功能较全面的…...

构建高可用和高防御力的云服务架构第五部分:PolarDB(55)

引言 云计算与数据库服务 云计算作为一种革命性的技术,已经深刻改变了信息技术行业的面貌。它通过提供按需分配的计算资源,使得数据存储、处理和分析变得更加灵活和高效。在云计算的众多服务中,数据库服务扮演着核心角色。数据库服务不仅负…...

【Java 学习】深度剖析Java多态:从向上转型到向下转型,解锁动态绑定的奥秘,让代码更优雅灵活

💬 欢迎讨论:如对文章内容有疑问或见解,欢迎在评论区留言,我需要您的帮助! 👍 点赞、收藏与分享:如果这篇文章对您有所帮助,请不吝点赞、收藏或分享,谢谢您的支持&#x…...

HTTP / 2

序言 在之前的文章中我们介绍过了 HTTP/1.1 协议,现在再来认识一下迭代版本 2。了解比起 1.1 版本,后面的版本改进在哪里,特点在哪里?话不多说,开始吧⭐️! 一、 HTTP / 1.1 存在的问题 很多时候新的版本的…...

【深度学习】利用Java DL4J 训练金融投资组合模型

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c=1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计,Springboot和微服务,熟悉Linux,ESXI虚拟化以及云原生Docker和K8s…...

跨域cookie携带问题总结

背景 我们知道很多场景,都需要前端请求带上cookie,例如用户鉴权、登陆校验等。而有些场景下,我们会发现请求不会带上cookie,这是为什么呢? 概念 cookie是种在域名下的信息。只有请求同域且同站的请求,才…...

Pytorch使用教程(12)-如何进行并行训练?

在使用GPU训练大模型时,往往会面临单卡显存不足的情况。这时,通过多卡并行的形式来扩大显存是一个有效的解决方案。PyTorch主要提供了两个类来实现多卡并行:数据并行torch.nn.DataParallel(DP)和模型并行torch.nn.Dist…...

指针之旅:从基础到进阶的全面讲解

大家好,这里是小编的博客频道 小编的博客:就爱学编程 很高兴在CSDN这个大家庭与大家相识,希望能在这里与大家共同进步,共同收获更好的自己!!! 本文目录 引言正文(1)内置数…...

FPGA与ASIC:深度解析与职业选择

IC(集成电路)行业涵盖广泛,涉及数字、模拟等不同研究方向,以及设计、制造、封测等不同产业环节。其中,FPGA(现场可编程门阵列)和ASIC(专用集成电路)是两种重要的芯片类型…...

PostgreSQL 中进行数据导入和导出

在数据库管理中,数据的导入和导出是非常常见的操作。特别是在 PostgreSQL 中,提供了多种工具和方法来实现数据的有效管理。无论是备份数据,还是将数据迁移到其他数据库,或是进行数据分析,掌握数据导入和导出的技巧都是…...

SDL2基本的绘制流程与步骤

SDL2(Simple DirectMedia Layer 2)是一个跨平台的多媒体库,它为游戏开发和图形应用提供了一个简单的接口,允许程序直接访问音频、键盘、鼠标、硬件加速的渲染等功能。在 SDL2 中,屏幕绘制的流程通常涉及到窗口的创建、渲染目标的设置、图像的绘制、事件的处理等几个步骤。…...

面试-业务逻辑2

应用 给定2个数组a、b,若a[i] b[j],则记(i,j)为一个二元数组,求具体的二元数组及其个数。 实现 a input("请输入数组a的元素个数:") # print(a) a_list list(map(int, input("请输入数组a的元素,…...

HTML之拜年/跨年APP(改进版)

目录: 一:目录 二:效果 三:页面分析/开发逻辑 1.页面详细分析: 2.开发逻辑: 四:完整代码(不多废话) index.html部分 app.json部分 二:效果 三:页面…...

嵌入式硬件篇---ADC模拟-数字转换

文章目录 前言第一部分:STM32 ADC的主要特点1.分辨率2.多通道3.转换模式4.转换速度5.触发源6.数据对齐7.温度传感器和Vrefint通道 第二部分:STM32 ADC的工作流程:1.配置ADC2.启动ADC转换 第三部分:ADC转化1.抽样2.量化3.编码 第四…...

每打开一个chrome页面都会【自动打开F12开发者模式】,原因是 使用HBuilderX会影响谷歌浏览器的浏览模式

打开 HBuilderX,点击 运行 -> 运行到浏览器 -> 设置web服务器 -> 添加chrome浏览器安装路径 chrome谷歌浏览器插件 B站视频下载助手插件: 参考地址:Chrome插件 - B站下载助手(轻松下载bilibili哔哩哔哩视频&#xff09…...

Access数据库教案(Excel+VBA+Access数据库SQL Server编程)

文章目录: 一:Access基础知识 1.前言 1.1 基本流程 1.2 基本概念?? 2.使用步骤方法 2.1 表【设计】 2.1.1 表的理论基础 2.1.2 Access建库建表? 2.1.3 表的基本操作 2.2 SQL语句代码【设计】 2.3 窗体【交互】? 2.3.1 多方式创建窗体 2.3.2 窗体常用的控件 …...

09、PT工具用法

目录 1、PT工具原理 2、在线修改表结构 3、使用pt-query-diges分析慢查询 4、使用pt-kill来kill掉一些垃圾SQL 5、pt-table-checksum进行主从一致性排查和修复 6、pt-archiver进行数据归档 7、其他一些pt工具 1、PT工具原理 创建一张与原始表结构相同的临时表 然后对临时…...

华为OD机试E卷 --矩形相交的面积--24年OD统一考试(Java JS Python C C++)

文章目录 题目描述输入描述输出描述用例题目解析JS算法源码Java算法源码python算法源码题目描述 给出3组点坐标(x, y, w, h),-1000<x,y<1000,w,h为正整数。 (x,y, w, h)表示平面直角坐标系中的一个矩形:x, y为矩形左上角坐标点,w, h向右w,向下h。(X, y, w, h)表示x轴…...

C++ 内存分配和管理(八股总结)

C是如何做内存管理的&#xff08;有哪些内存区域&#xff09;? &#xff08;1&#xff09;堆&#xff0c;使用malloc、free动态分配和释放空间&#xff0c;能分配较大的内存&#xff1b; &#xff08;2&#xff09;栈&#xff0c;为函数的局部变量分配内存&#xff0c;能分配…...

Linux 进程管理学习指南:架构、计划与关键问题全解

Linux 进程管理学习指南&#xff1a;架构、计划与关键问题全解 本文面向初学者&#xff0c;旨在帮助你从架构视角理解 Linux 进程管理子系统&#xff0c;构建系统化学习路径&#xff0c;并通过结构化笔记方法与典型问题总结&#xff0c;夯实基础、明确方向&#xff0c;逐步掌握…...

基于开源AI大模型AI智能名片S2B2C商城小程序源码的中等平台型社交电商运营模式研究

摘要&#xff1a;本文聚焦中等平台型社交电商&#xff0c;探讨其与传统微商及大型社交电商平台的差异&#xff0c;尤其关注产品品类管理对代理运营的影响。通过引入开源AI大模型、AI智能名片与S2B2C商城小程序源码技术&#xff0c;构建智能化运营体系。研究结果表明&#xff0c…...

青少年编程与数学 01-011 系统软件简介 08 Windows操作系统

青少年编程与数学 01-011 系统软件简介 08 Windows操作系统 1. Windows操作系统的起源与发展1.1 早期版本&#xff08;1985-1995&#xff09;1.2 Windows 9x系列&#xff08;1995-2000&#xff09;1.3 Windows NT系列&#xff08;1993-2001&#xff09;1.4 Windows XP及以后版…...

互斥锁与消息队列的架构哲学

更多精彩内容请访问&#xff1a;通义灵码2.5——基于编程智能体开发Wiki多功能搜索引擎更多精彩内容请访问&#xff1a;更多精彩内容请访问&#xff1a;通义灵码2.5——基于编程智能体开发Wiki多功能搜索引擎 一、资源争用的现实镜像 当多个ATM机共用一个现金库时&#xff0c;…...

AT_abc409_e [ABC409E] Pair Annihilation

AT_abc409_e [ABC409E] Pair Annihilation 赛时没开longlong挂了。 思路 首先我们可以把这棵树转化为一颗有根树&#xff0c;且所有电子的都朝根节点移动。 那么接下来我们就需要选择一个最优的树根。 考虑换根dp。 但是可以发现换根时答案其实是没有变化的。 我们设 f…...

第四讲:类和对象(下)

1. 再探构造函数 • 之前我们实现构造函数时&#xff0c;初始化成员变量主要使⽤函数体内赋值&#xff0c;构造函数初始化还有⼀种⽅ 式&#xff0c;就是初始化列表&#xff0c;初始化列表的使⽤⽅式是以⼀个冒号开始&#xff0c;接着是⼀个以逗号分隔的数据成 员列表&#xff…...

【Elasticsearch】映射:Join 类型、Flattened 类型、多表关联设计

映射&#xff1a;Join 类型、Flattened 类型、多表关联设计 1.Join 类型1.1 主要应用场景1.1.1 一对多关系建模1.1.2 多层级关系建模1.1.3 需要独立更新子文档的场景1.1.4 文档分离但需要关联查询 1.2 使用注意事项1.3 与 Nested 类型的区别 2.Flattened 类型2.1 实际运用场景和…...

鸿蒙仓颉语言开发实战教程:商城应用个人中心页面

又到了高考的日子&#xff0c;幽蓝君在这里祝各位考生朋友冷静答题&#xff0c;超常发挥。 今天要分享的内容是仓颉语言商城应用的个人中心页面&#xff0c;先看效果图&#xff1a; 下面介绍下这个页面的实现过程。 我们可以先分析下整个页面的布局结构。可以看出它是纵向的布…...

附加模块--Qt OpenGL模块功能及架构

一、模块功能&#xff1a; 主要变化 Qt OpenGL 模块的分离&#xff1a; 在 Qt 6 中&#xff0c;原来的 Qt OpenGL 功能被拆分为多个模块 传统的 Qt OpenGL 模块 (QGL*) 已被标记为废弃 新的图形架构&#xff1a; Qt 6 引入了基于 QRhi (Qt Rendering Hardware Interface) 的…...

数学建模期末速成 主成分分析的基本步骤

设有 n n n个研究对象&#xff0c; m m m个指标变量 x 1 , x 2 , ⋯ , x m x_1,x_2,\cdots,x_m x1​,x2​,⋯,xm​&#xff0c;第 i i i个对象关于第 j j j个指标取值为 a i j a_{ij} aij​,构造数据矩阵 A ( a i j ) n m A\left(\begin{array}{c}a_{ij}\end{array}\right)_{…...