【Block总结】WTConv,小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野
论文解读:Wavelet Convolutions for Large Receptive Fields
论文信息
- 标题: Wavelet Convolutions for Large Receptive Fields
- 作者: Shahaf E. Finder, Roy Amoyal, Eran Treister, Oren Freifeld
- 提交日期: 2024年7月8日
- arXiv链接: Wavelet Convolutions for Large Receptive Fields
- Github: https://github.com/BGU-CS-VIL/WTConv
概述
论文《Wavelet Convolutions for Large Receptive Fields》提出了一种新型卷积层,称为WTConv(Wavelet Transform Convolution),旨在通过小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野。该方法能够在不显著增加参数数量的情况下,获得接近全局的感受野,从而提高模型对低频信息的捕捉能力。
主要贡献
-
感受野扩展:传统的卷积神经网络通过增加卷积核的大小来扩展感受野,但这种方法在达到一定程度后会遇到参数过多的问题。WTConv通过小波变换实现了感受野的有效扩展,且参数数量仅以对数方式增长。
-
多频率响应:WTConv能够有效地响应不同频率的输入信号,增强了模型对形状的响应能力,而不仅仅是对纹理的响应。
-
架构兼容性:WTConv可以作为现有架构的替代层,适用于多种网络结构,如ConvNeXt和MobileNetV2,且在图像分类等下游任务中表现出色。
WTConv如何在不增加参数的情况下扩展感受野
WTConv(Wavelet Transform Convolution)是一种新型卷积层,旨在通过小波变换(Wavelet Transform)有效扩展卷积神经网络(CNN)的感受野,而不显著增加模型的参数数量。这一方法的核心在于利用小波变换的特性,使得感受野的扩展与参数的增长呈对数关系。
-
小波变换的优势:小波变换能够将信号分解为不同频率的成分,这使得WTConv能够同时捕捉到低频和高频信息。通过这种方式,WTConv可以在保持较小卷积核的情况下,获得较大的感受野。
-
参数增长控制:传统的卷积层通过增加卷积核的大小来扩展感受野,但这会导致参数数量的急剧增加。WTConv的设计使得对于一个 k × k k \times k k×k 的感受野,所需的可训练参数数量仅以对数方式增长,这样可以有效避免过度参数化的问题[7][8]。
-
架构兼容性:WTConv可以作为现有网络架构的替代层,例如ConvNeXt和MobileNetV2,能够无缝集成到这些模型中,增强其对形状的响应能力,并提高对图像损坏的鲁棒性[5][10]。
实验结果
在多个图像分类任务中,WTConv表现出色,尤其是在处理复杂形状和纹理时,显示出更强的适应性和准确性,在图像分类任务中优于传统卷积层,尤其在处理图像损坏和复杂形状时表现出更强的鲁棒性。
。这表明WTConv不仅在理论上有效,而且在实际应用中也具有良好的性能。
通过这些机制,WTConv实现了感受野的有效扩展,同时保持了模型的参数效率,适应了现代深度学习对计算资源的需求。
代码:
import torch
import torch.nn as nn
import pywt
import pywt.dataimport torch.nn.functional as Fdef create_wavelet_filter(wave, in_size, out_size, type=torch.float):w = pywt.Wavelet(wave)dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)return dec_filters, rec_filtersdef wavelet_transform(x, filters):b, c, h, w = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)x = x.reshape(b, c, 4, h // 2, w // 2)return xdef inverse_wavelet_transform(x, filters):b, c, _, h_half, w_half = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = x.reshape(b, c * 4, h_half, w_half)x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)return xclass WTConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):super(WTConv2d, self).__init__()assert in_channels == out_channelsself.in_channels = in_channelsself.wt_levels = wt_levelsself.stride = strideself.dilation = 1self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels, bias=bias)self.base_scale = _ScaleModule([1, in_channels, 1, 1])self.wavelet_convs = nn.ModuleList([nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)])self.wavelet_scale = nn.ModuleList([_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)])if self.stride > 1:self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride)else:self.do_stride = Nonedef forward(self, x):x_ll_in_levels = []x_h_in_levels = []shapes_in_levels = []curr_x_ll = xfor i in range(self.wt_levels):curr_shape = curr_x_ll.shapeshapes_in_levels.append(curr_shape)if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)curr_x_ll = F.pad(curr_x_ll, curr_pads)curr_x =wavelet_transform(curr_x_ll, self.wt_filter)curr_x_ll = curr_x[:, :, 0, :, :]shape_x = curr_x.shapecurr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))curr_x_tag = curr_x_tag.reshape(shape_x)x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])next_x_ll = 0for i in range(self.wt_levels - 1, -1, -1):curr_x_ll = x_ll_in_levels.pop()curr_x_h = x_h_in_levels.pop()curr_shape = shapes_in_levels.pop()curr_x_ll = curr_x_ll + next_x_llcurr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)next_x_ll = inverse_wavelet_transform(curr_x, self.iwt_filter)next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]x_tag = next_x_llassert len(x_ll_in_levels) == 0x = self.base_scale(self.base_conv(x))x = x + x_tagif self.do_stride is not None:x = self.do_stride(x)return xclass _ScaleModule(nn.Module):def __init__(self, dims, init_scale=1.0, init_bias=0):super(_ScaleModule, self).__init__()self.dims = dimsself.weight = nn.Parameter(torch.ones(*dims) * init_scale)self.bias = Nonedef forward(self, x):return torch.mul(self.weight, x)if __name__ == '__main__':# 创建一个随机输入张量,形状为 (batch_size,height×width,channels)input1 = torch.rand(1, 64,40, 40)# 实例化EFC模块block = WTConv2d(64,64,kernel_size=7)# 前向传播output = block(input1)# 打印输入和输出的形状print(input1.size())print(output.size())
输出结果:
torch.Size([1, 64, 40, 40])
torch.Size([1, 64, 40, 40])
相关文章:

【Block总结】WTConv,小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野
论文解读:Wavelet Convolutions for Large Receptive Fields 论文信息 标题: Wavelet Convolutions for Large Receptive Fields作者: Shahaf E. Finder, Roy Amoyal, Eran Treister, Oren Freifeld提交日期: 2024年7月8日arXiv链接: Wavelet Convolutions for La…...

深入探究分布式日志系统 Graylog:架构、部署与优化
文章目录 一、Graylog简介二、Graylog原理架构三、日志系统对比四、Graylog部署传统部署MongoDB部署OS或者ES部署Garylog部署容器化部署 五、配置详情六、优化网络和 REST APIMongoDB 七、升级八、监控九、常见问题及处理 一、Graylog简介 Graylog是一个简单易用、功能较全面的…...

构建高可用和高防御力的云服务架构第五部分:PolarDB(55)
引言 云计算与数据库服务 云计算作为一种革命性的技术,已经深刻改变了信息技术行业的面貌。它通过提供按需分配的计算资源,使得数据存储、处理和分析变得更加灵活和高效。在云计算的众多服务中,数据库服务扮演着核心角色。数据库服务不仅负…...

【Java 学习】深度剖析Java多态:从向上转型到向下转型,解锁动态绑定的奥秘,让代码更优雅灵活
💬 欢迎讨论:如对文章内容有疑问或见解,欢迎在评论区留言,我需要您的帮助! 👍 点赞、收藏与分享:如果这篇文章对您有所帮助,请不吝点赞、收藏或分享,谢谢您的支持&#x…...

HTTP / 2
序言 在之前的文章中我们介绍过了 HTTP/1.1 协议,现在再来认识一下迭代版本 2。了解比起 1.1 版本,后面的版本改进在哪里,特点在哪里?话不多说,开始吧⭐️! 一、 HTTP / 1.1 存在的问题 很多时候新的版本的…...
【深度学习】利用Java DL4J 训练金融投资组合模型
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c=1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计,Springboot和微服务,熟悉Linux,ESXI虚拟化以及云原生Docker和K8s…...
跨域cookie携带问题总结
背景 我们知道很多场景,都需要前端请求带上cookie,例如用户鉴权、登陆校验等。而有些场景下,我们会发现请求不会带上cookie,这是为什么呢? 概念 cookie是种在域名下的信息。只有请求同域且同站的请求,才…...

Pytorch使用教程(12)-如何进行并行训练?
在使用GPU训练大模型时,往往会面临单卡显存不足的情况。这时,通过多卡并行的形式来扩大显存是一个有效的解决方案。PyTorch主要提供了两个类来实现多卡并行:数据并行torch.nn.DataParallel(DP)和模型并行torch.nn.Dist…...

指针之旅:从基础到进阶的全面讲解
大家好,这里是小编的博客频道 小编的博客:就爱学编程 很高兴在CSDN这个大家庭与大家相识,希望能在这里与大家共同进步,共同收获更好的自己!!! 本文目录 引言正文(1)内置数…...

FPGA与ASIC:深度解析与职业选择
IC(集成电路)行业涵盖广泛,涉及数字、模拟等不同研究方向,以及设计、制造、封测等不同产业环节。其中,FPGA(现场可编程门阵列)和ASIC(专用集成电路)是两种重要的芯片类型…...
PostgreSQL 中进行数据导入和导出
在数据库管理中,数据的导入和导出是非常常见的操作。特别是在 PostgreSQL 中,提供了多种工具和方法来实现数据的有效管理。无论是备份数据,还是将数据迁移到其他数据库,或是进行数据分析,掌握数据导入和导出的技巧都是…...
SDL2基本的绘制流程与步骤
SDL2(Simple DirectMedia Layer 2)是一个跨平台的多媒体库,它为游戏开发和图形应用提供了一个简单的接口,允许程序直接访问音频、键盘、鼠标、硬件加速的渲染等功能。在 SDL2 中,屏幕绘制的流程通常涉及到窗口的创建、渲染目标的设置、图像的绘制、事件的处理等几个步骤。…...
面试-业务逻辑2
应用 给定2个数组a、b,若a[i] b[j],则记(i,j)为一个二元数组,求具体的二元数组及其个数。 实现 a input("请输入数组a的元素个数:") # print(a) a_list list(map(int, input("请输入数组a的元素,…...

HTML之拜年/跨年APP(改进版)
目录: 一:目录 二:效果 三:页面分析/开发逻辑 1.页面详细分析: 2.开发逻辑: 四:完整代码(不多废话) index.html部分 app.json部分 二:效果 三:页面…...
嵌入式硬件篇---ADC模拟-数字转换
文章目录 前言第一部分:STM32 ADC的主要特点1.分辨率2.多通道3.转换模式4.转换速度5.触发源6.数据对齐7.温度传感器和Vrefint通道 第二部分:STM32 ADC的工作流程:1.配置ADC2.启动ADC转换 第三部分:ADC转化1.抽样2.量化3.编码 第四…...

每打开一个chrome页面都会【自动打开F12开发者模式】,原因是 使用HBuilderX会影响谷歌浏览器的浏览模式
打开 HBuilderX,点击 运行 -> 运行到浏览器 -> 设置web服务器 -> 添加chrome浏览器安装路径 chrome谷歌浏览器插件 B站视频下载助手插件: 参考地址:Chrome插件 - B站下载助手(轻松下载bilibili哔哩哔哩视频)…...
Access数据库教案(Excel+VBA+Access数据库SQL Server编程)
文章目录: 一:Access基础知识 1.前言 1.1 基本流程 1.2 基本概念?? 2.使用步骤方法 2.1 表【设计】 2.1.1 表的理论基础 2.1.2 Access建库建表? 2.1.3 表的基本操作 2.2 SQL语句代码【设计】 2.3 窗体【交互】? 2.3.1 多方式创建窗体 2.3.2 窗体常用的控件 …...
09、PT工具用法
目录 1、PT工具原理 2、在线修改表结构 3、使用pt-query-diges分析慢查询 4、使用pt-kill来kill掉一些垃圾SQL 5、pt-table-checksum进行主从一致性排查和修复 6、pt-archiver进行数据归档 7、其他一些pt工具 1、PT工具原理 创建一张与原始表结构相同的临时表 然后对临时…...
华为OD机试E卷 --矩形相交的面积--24年OD统一考试(Java JS Python C C++)
文章目录 题目描述输入描述输出描述用例题目解析JS算法源码Java算法源码python算法源码题目描述 给出3组点坐标(x, y, w, h),-1000<x,y<1000,w,h为正整数。 (x,y, w, h)表示平面直角坐标系中的一个矩形:x, y为矩形左上角坐标点,w, h向右w,向下h。(X, y, w, h)表示x轴…...
C++ 内存分配和管理(八股总结)
C是如何做内存管理的(有哪些内存区域)? (1)堆,使用malloc、free动态分配和释放空间,能分配较大的内存; (2)栈,为函数的局部变量分配内存,能分配…...

Appium+python自动化(十六)- ADB命令
简介 Android 调试桥(adb)是多种用途的工具,该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具,其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利,如安装和调试…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例
使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...

Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

MySQL 8.0 OCP 英文题库解析(十三)
Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

【笔记】WSL 中 Rust 安装与测试完整记录
#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统:Ubuntu 24.04 LTS (WSL2)架构:x86_64 (GNU/Linux)Rust 版本:rustc 1.87.0 (2025-05-09)Cargo 版本:cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...

Linux操作系统共享Windows操作系统的文件
目录 一、共享文件 二、挂载 一、共享文件 点击虚拟机选项-设置 点击选项,设置文件夹共享为总是启用,点击添加,可添加需要共享的文件夹 查询是否共享成功 ls /mnt/hgfs 如果显示Download(这是我共享的文件夹)&…...
计算机系统结构复习-名词解释2
1.定向:在某条指令产生计算结果之前,其他指令并不真正立即需要该计算结果,如果能够将该计算结果从其产生的地方直接送到其他指令中需要它的地方,那么就可以避免停顿。 2.多级存储层次:由若干个采用不同实现技术的存储…...
ABB馈线保护 REJ601 BD446NN1XG
配电网基本量程数字继电器 REJ601是一种专用馈线保护继电器,用于保护一次和二次配电网络中的公用事业和工业电力系统。该继电器在一个单元中提供了保护和监控功能的优化组合,具有同类产品中最佳的性能和可用性。 REJ601是一种专用馈线保护继电器…...