Haar小波下采样模块
论文原址:Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect
原文代码:HWD/HWD.py at main · apple1986/HWD (github.com)
介绍
深度卷积神经网络 (DCNN) 通常采用标准的下采样操作,例如最大池化、平均池化和跨步卷积,这可能会导致信息丢失。丢失的信息,如边界和纹理,对于语义分割可能是必不可少的。为了缓解这个问题,一般有下面四种方法:
- 通过跳过连接到解码器子网(如U-Net、LCU-Net、CENet、LinkNet和RefineNet )。
- 提取具有空间金字塔池化或扩展卷积的多尺度特征图到融合模块中(如DeepLab、PSPNet、PCPLP-Net、BiSenet和ICNet)。
- 向编码器提供多模态图像(如DiSegNet、MMADT、CANet和CCFFNet)。
- 增加先验信息。轮廓增强关注模块,旨在从CT图像中提取边界和形状线索,以细化分割区域。
这些方法的主要目的是通过基于多尺度、先验指导、多模态等各种策略提供更多的学习信息或特征,帮助下采样特征与分割标签之间建立良好的关系。
因此,是否可以设计一个保留信息的下采样模块,使DCNNs中尽可能多地保留信息进行语义分割?这就是作者的想法。
下采样模块
最大池化与平均池化
池化过程类似于卷积过程。在这个示意图中,我们看到对一个 4x4 的特征图邻域进行操作,使用了一个 2x2 的滤波器,步长为2进行扫描。这个过程被称为最大池化(Max Pooling),其中选择邻域内的最大值并输出到下一层。
常用的 max pooling 参数是 S=2、f=2,其效果是将特征图的高度和宽度减半,而通道数保持不变。
如上图所示,描述的是对一个 4x4 的特征图邻域内的数值进行操作。使用了一个 2x2 的滤波器,步长为2进行扫描,计算邻域内数值的平均值并将其输出到下一层。这种操作被称为平均池化(Mean Pooling)。
"""
Copyright (c) 2023, Auorui.
All rights reserved.The Torch implementation of average pooling and maximum pooling has been compared with the official Torch implementation
"""
import torch
import torch.nn as nn__all__ = ["MaxPool2d", "AvgPool2d"]class MaxPool2d(nn.Module):"""池化层计算公式:output_size = [(input_size−kernel_size) // stride + 1]"""def __init__(self, kernel_size, stride):super(MaxPool2d, self).__init__()self.kernel_size = kernel_sizeself.stride = stridedef max_pool2d(self, input_tensor, kernel_size, stride):batch_size, channels, height, width = input_tensor.size()output_height = (height - kernel_size) // stride + 1output_width = (width - kernel_size) // stride + 1output_tensor = torch.zeros(batch_size, channels, output_height, output_width)for i in range(output_height):for j in range(output_width):# 获取输入张量中与池化窗口对应的部分window = input_tensor[:, :,i * stride: i * stride + kernel_size, j * stride: j * stride + kernel_size]output_tensor[:, :, i, j] = torch.max(window.reshape(batch_size, channels, -1), dim=2)[0]return output_tensordef forward(self, input_tensor):return self.max_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)class AvgPool2d(nn.Module):"""池化层计算公式:output_size = [(input_size−kernel_size) // stride + 1]"""def __init__(self, kernel_size, stride):super(AvgPool2d, self).__init__()self.kernel_size = kernel_sizeself.stride = stridedef avg_pool2d(self, input_tensor, kernel_size, stride):batch_size, channels, height, width = input_tensor.size()output_height = (height - kernel_size) // stride + 1output_width = (width - kernel_size) // stride + 1output_tensor = torch.zeros(batch_size, channels, output_height, output_width)for i in range(output_height):for j in range(output_width):# 获取输入张量中与池化窗口对应的部分window = input_tensor[:, :,i * stride: i * stride + kernel_size, j * stride:j * stride + kernel_size]output_tensor[:, :, i, j] = torch.mean(window.reshape(batch_size, channels, -1), dim=2)return output_tensordef forward(self, input_tensor):return self.avg_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)if __name__=="__main__":# input_data = torch.rand((1, 3, 3, 3))input_data = torch.Tensor([[[[0.3939, 0.8964, 0.3681],[0.5134, 0.3780, 0.0047],[0.0681, 0.0989, 0.5962]],[[0.7954, 0.4811, 0.3329],[0.8804, 0.3986, 0.3561],[0.2797, 0.3672, 0.6508]],[[0.6309, 0.1340, 0.0564],[0.3101, 0.9927, 0.5554],[0.0947, 0.2305, 0.8299]]]])print(input_data.shape)kernel_size = 3stride = 1MaxPool2d1 = nn.MaxPool2d(kernel_size, stride)output_data_with_torch_max = MaxPool2d1(input_data)AvgPool2d1 = nn.AvgPool2d(kernel_size, stride)output_data_with_torch_avg = AvgPool2d1(input_data)AvgPool2d2 = AvgPool2d(kernel_size, stride)output_data_with_torch_Avg = AvgPool2d2(input_data)MaxPool2d2 = MaxPool2d(kernel_size, stride)output_data_with_torch_Max = MaxPool2d2(input_data)# output_data_with_max = max_pool2d(input_data, kernel_size, stride)# output_data_with_avg = avg_pool2d(input_data, kernel_size, stride)print("\ntorch.nn pooling Output:")print(output_data_with_torch_max,"\n",output_data_with_torch_max.size())print(output_data_with_torch_avg,"\n",output_data_with_torch_avg.size())print("\npooling Output:")print(output_data_with_torch_Max,"\n",output_data_with_torch_Max.size())print(output_data_with_torch_Avg,"\n",output_data_with_torch_Avg.size())# 直接使用bool方法判断会因为浮点数的原因出现偏差print(torch.allclose(output_data_with_torch_max,output_data_with_torch_Max))print(torch.allclose(output_data_with_torch_avg,output_data_with_torch_Avg))# tensor([[[[0.8964]], # output_data_with_max# [[0.8804]],# [[0.9927]]]])# tensor([[[[0.3686]], # output_data_with_avg# [[0.5047]],# [[0.4261]]]])
在这里,简单地与PyTorch官方的实现进行了比对,成功的进行复现。
跨步卷积
import torch
import torch.nn as nnclass StridedConvolution(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):super(StridedConvolution, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)self.relu = nn.ReLU(inplace=True)self.is_relu = is_reludef forward(self, x):x = self.conv(x)if self.is_relu:x = self.relu(x)return xif __name__ == '__main__':input_data = torch.rand((1, 3, 64, 64))strided_conv = StridedConvolution(3, 64)output_data = strided_conv(input_data)print("Input shape:", input_data.shape)print("Output shape:", output_data.shape)
对输入进行跨步卷积,并根据 is_relu 参数选择是否添加ReLU激活函数。在构建卷积神经网络时经常被用于下采样步骤,以减小特征图的尺寸。
Haar小波下采样
这一部分就直接参考的作者的代码,与池化不同的是,这里它是要指定输入输出几个通道。
"""
Haar Wavelet-based Downsampling (HWD)Original address of the paper: https://www.sciencedirect.com/science/article/abs/pii/S0031320323005174
Code reference: https://github.com/apple1986/HWD/tree/main
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForwardclass HWDownsampling(nn.Module):def __init__(self, in_channel, out_channel):super(HWDownsampling, self).__init__()self.wt = DWTForward(J=1, wave='haar', mode='zero')self.conv_bn_relu = nn.Sequential(nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True),)def forward(self, x):yL, yH = self.wt(x)y_HL = yH[0][:, :, 0, ::]y_LH = yH[0][:, :, 1, ::]y_HH = yH[0][:, :, 2, ::]x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)x = self.conv_bn_relu(x)return xif __name__ == '__main__':downsampling_layer = HWDownsampling(3, 64)input_data = torch.rand((1, 3, 64, 64))output_data = downsampling_layer(input_data)print("Input shape:", input_data.shape)print("Output shape:", output_data.shape)
Haar小波变换是一种基于小波的信号处理方法,它将信号分解成低频和细节高频两个部分。在图像处理中,Haar小波通常用于图像压缩和特征提取,代码中使用的DWTForward模块中离散小波变换,通过选择 yH 中的不同方向上的高频分量,构建了新的特征图。将原始低频分量 yL 与新构建的高频分量拼接在一起。最后通过一个包含卷积、批归一化和ReLU激活函数的序列处理最终的特征图。
实验验证
这是作者论文中做的实验,这样看起来,似乎HWD在细节上确实是比池化和跨步卷积效果要好。
这里因为我也用我自己的数据进行了实验:
最大池化效果
平均池化效果
跨步卷积效果
HDW效果
从肉眼上来看,HDW的效果确实要比其他的效果要好一些。
下面是我做实验的代码,感兴趣的可以在自己的数据上面进行实验,我觉得用于交通和医学上应该会有比较好的效果。
import cv2
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from pytorch_wavelets import DWTForwardclass StridedConvolution(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):super(StridedConvolution, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)self.relu = nn.ReLU(inplace=True)self.is_relu = is_reludef forward(self, x):x = self.conv(x)if self.is_relu:x = self.relu(x)return xclass HWDownsampling(nn.Module):def __init__(self, in_channel, out_channel):super(HWDownsampling, self).__init__()self.wt = DWTForward(J=1, wave='haar', mode='zero')self.conv_bn_relu = nn.Sequential(nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True),)def forward(self, x):yL, yH = self.wt(x)y_HL = yH[0][:, :, 0, ::]y_LH = yH[0][:, :, 1, ::]y_HH = yH[0][:, :, 2, ::]x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)x = self.conv_bn_relu(x)return xclass DeeperCNN(nn.Module):def __init__(self):super(DeeperCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.batch_norm1 = nn.BatchNorm2d(16)self.relu = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)# self.pool1 = HWDownsampling(16, 16)self.pool1 = StridedConvolution(16, 16, is_relu=True)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.batch_norm2 = nn.BatchNorm2d(32)# self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)# self.pool2 = HWDownsampling(32, 32)self.pool2 = StridedConvolution(32, 32, is_relu=True)self.conv6 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.pool1(self.relu(self.batch_norm1(self.conv1(x))))print(x.shape)x = self.pool2(self.relu(self.batch_norm2(self.conv2(x))))print(x.shape)x = self.conv6(x)return ximage_path = r'D:\PythonProject\Crack_classification_training_script\data\base\val\crack\2416.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)
import numpy as np
model = DeeperCNN()
output = model(input_image)
print("Output shape:", output.shape)input_image = input_image.squeeze(0).permute(1, 2, 0).numpy()
output_image = output.squeeze(0).permute(1, 2, 0).detach().numpy()
output_image = output_image / output_image.max()
output_image = np.clip(output_image, 0, 1)plt.subplot(1, 2, 1)
plt.imshow(input_image)
plt.title('Input Image')plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Output Image')plt.show()
总结
在论文当中,作者也做了大量的消融实验去证实这个下采样模块的有效性,建议大家去看看原著作,或许会有更多的收获。
相关文章:

Haar小波下采样模块
论文原址:Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect 原文代码:HWD/HWD.py at main apple1986/HWD (github.com) 介绍 深度卷积神经网络 (DCNN) 通…...

k8s的包管理工具helm
Helm是什么? 之前的这篇文章介绍了一开始接触k8s的时候接触到的几个命令工具 kubectl&kubelet&rancher&helm&kubeadm这几个命令行工具是什么关系?-CSDN博客 Helm 是一个用于管理和部署 Kubernetes 应用程序的包管理工具。它允许用户定义、安装和…...

《WebKit 技术内幕》学习之八(1):硬件加速机制
《WebKit 技术内幕》之八(1):硬件加速机制 1 硬件加速基础 1.1 概念 这里说的硬件加速技术是指使用GPU的硬件能力来帮助渲染网页,因为GPU的作用主要是用来绘制3D图形并且性能特别好,这是它的专长所在,它…...

【Linux对磁盘进行清理、重建、配置文件系统和挂载,进行系统存储管理调整存储结构】
Linux 调整存储结构 前言一、查看磁盘和分区列表二、创建 ext4 文件系统,即:格式化分区为ext4文件系统。1.使用命令 mkfs.ext4 (make file system)报错如下:解决办法1:(经测试,不采用)X解决办法…...
RT-DETR算法优化改进:DCNv4更快收敛、更高速度、更高性能,效果秒杀DCNv3、DCNv2等 ,助力检测
💡💡💡本文独家改进:DCNv4更快收敛、更高速度、更高性能,完美和RT-DETR结合,助力涨点 DCNv4优势:(1) 去除空间聚合中的softmax归一化,以增强其动态性和表达能力;(2) 优化存储器访问以最小化冗余操作以加速。这些改进显著加快了收敛速度,并大幅提高了处理速度,DC…...
Docker基础使用
Docker基础使用 1.查看容器挂载文件夹一定要放开权限,否则后面启动nexus时会无法启动1.查询远程镜像重启docker服务容器自启动关闭容器自启动查看docker容器是否挂载容器挂载解释保存和加载本地镜像创建mysql容器容器转换为镜像创建dockerfile容器相互通讯查看容器的…...

数据库中的经纬度数据如何在QGIS中显示
思路:必须先将经纬度数据转换成POINT,MULTILINESTRING等格式才能在QGIS中展示 步骤 1、首先在postgresql数据中建一张包括经纬度数据的表 **注意:**如果是新建数据库,一定要执行如下代码,否则后面的函数ST_GeomFrom…...
制作linux运行包
从源码制作 syslinux:https://mirrors.edge.kernel.org/pub/linux/utils/boot/syslinux/syslinux-6.03.tar.gz busybox:https://busybox.net/downloads/busybox-1.26.0.tar.bz2 kernel:https://mirrors.edge.kernel.org/pub/linux/kernel/v6.x/linux-6.5.7.tar.gz 遇到问题&…...
一些 AI 机构
文章目录 OpenAITHUDMMetaAITIIStability AINousResearch OpenAI hf : https://huggingface.co/openai 官网:https://openai.com THUDM 清华大学 KEG 和 THUDM 团队 Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University hf : https://h…...

AP5191 降压恒流 双灯 12V5A 一切一LED车灯汽车大灯驱动方案
AP5191是一款PWM工作模式,高效率、外围简 单、内置功率MOS管,适用于4.5-150V输入的高 精度降压LED恒流驱动芯片。输出功率150W, 电流6A。 AP5191可实现线性调光和PWM调光,线性调 光脚有效电压范围0.55-2.6V. AP5191 工作频率可以通过RT 外部…...
淘宝/天猫获取卖出的商品订单列表 API(taobao.seller_order_list)
淘宝和天猫平台提供了一个API接口(taobao.seller_order_list),用于获取卖家出售的商品订单列表。以下是使用该API的基本步骤: 获取API密钥:首先,您需要在淘宝开放平台(Open Platform)…...
Linux常规操作指南
1. 文件系统操作 (1)查看当前目录内容 ls或查看详细信息: ls -l(2)切换工作目录 cd /path/to/directory(3)创建新目录 mkdir directory_name(4)删除空目录 rmdir d…...

原生微信小程AR序实现模型动画播放只播放一次,且停留在最后一秒
1.效果展示 0868d9b9f56517a9a07dfc180cddecb2 2.微信小程序AR是2023年初发布,还有很多问提(比如glb模型不能直接播放最后一帧;AR识别不了金属、玻璃材质的模型等…有问题解决了的小伙伴记得告诉我一声) 微信官方文档地址 3.代码…...

【Docker】在centos中安装nginx
🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是平顶山大师,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的博客专栏《【Docker】安装nginx》。🎯&#…...

leetcode:最接近的三数之和---(双指针,排序,数组)
题目: 给你一个长度为 n 的整数数组 nums 和 一个目标值 target。请你从 nums 中选出三个整数,使它们的和与 target 最接近。 返回这三个数的和。 假定每组输入只存在恰好一个解。 示例: 示例 1: 输入:nums [-1…...

dpdk网络转发环境的搭建
文章目录 前言ip命令的使用配置dpdk-basicfwd需要的网络结构测试dpdk-basicfwddpdk-basicfwd代码分析附录basicfwd在tcp转发时的失败抓包信息DPDK的相关设置 前言 上手dpdk有两难。其一为环境搭建。被绑定之后的网卡没有IP,我如何给它发送数据呢?当然&a…...

【MYSQL】存储引擎MyISAM和InnoDB
MYSQL 存储引擎 查看MySQL提供所有的存储引擎 mysql> show engines; mysql常用引擎包括:MYISAM、Innodb、Memory、MERGE 1、MYISAM:全表锁,拥有较高的执行速度,不支持事务,不支持外键,并发性能差&#x…...

什么是DOM?(JavaScript DOM是什么?)
1、DOM简洁 DOM是js中最重要的一部分,没有DOM就不会通过js实现和用户之间的交互。 window是最大的浏览器对象,在它的下面还有很多子对象,我们要学习的DOM就是window对象下面的document对象 DOM(Document Object Model)…...
UIElement编辑器扩展 组件 Inspector
UIElement编辑器扩展 组件 Inspector https://docs.unity.cn/cn/2021.3/Manual/UIE-create-a-binding-uxml-inspector.html 简单开始 声明序列化VisualTreeAsset [SerializeField] VisualTreeAsset visualTree; 声明完,直接在脚本的Inspector面板,把你…...

Flask 3.x log全域配置(包含pytest)
最近使用到flask3.x,配置了全域的log,这边记录下 首先需要创建logging的配置文件,我是放在项目根目录的, Logging 配置 logging.json {"version": 1, # 配置文件版本号"formatters": {"default&qu…...

简易版抽奖活动的设计技术方案
1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
Robots.txt 文件
什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...

前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...

Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...

10-Oracle 23 ai Vector Search 概述和参数
一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI,使用客户端或是内部自己搭建集成大模型的终端,加速与大型语言模型(LLM)的结合,同时使用检索增强生成(Retrieval Augmented Generation &#…...

九天毕昇深度学习平台 | 如何安装库?
pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子: 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)
前言: 在Java编程中,类的生命周期是指类从被加载到内存中开始,到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期,让读者对此有深刻印象。 目录 …...

免费数学几何作图web平台
光锐软件免费数学工具,maths,数学制图,数学作图,几何作图,几何,AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...