当前位置: 首页 > news >正文

深度学习pytorch之简单方法自定义9种卷积即插即用

本文详细解析了 PyTorch 中 torch.nn.Conv2d 的核心参数,通过代码示例演示了如何利用这一基础函数实现多种卷积操作。涵盖的卷积类型包括:标准卷积、逐点卷积(1x1 卷积)、非对称卷积(长宽不等的卷积核)、空洞卷积(扩大感受野)、深度卷积(逐通道滤波)、组卷积(分组独立处理)、深度可分离卷积(深度+逐点组合)、转置卷积(上采样)和动态卷积(动态生成卷积核),帮助读者理解如何通过调整参数灵活构建卷积层,适应不同任务需求。
深度学习pytorch之22种损失函数数学公式和代码定义
深度学习pytorch之19种优化算法(optimizer)解析
深度学习pytorch之4种归一化方法(Normalization)原理公式解析和参数使用
深度学习pytorch之简单方法自定义多种卷积即插即用

基础函数torch.nn.Conv2d()

torch.nn.Conv2d()定义了一个用于量化的 2D 卷积层 Conv2d,其继承自 _ConvNd 类,专门用于处理量化输入信号。

class Conv2d(_ConvNd):def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True,padding_mode='zeros', device=None, dtype=None):........def forward(self, input):.........

在使用其定义常用卷积前,我们需要理解每一个参数的具体概念。

  • in_channels (int):
    输入数据的通道数。比如在图像处理中,对于彩色图像通常是 3(RGB),对于灰度图像通常是 1。
  • out_channels (int):
    输出数据的通道数,即卷积操作后得到的特征图的深度。通过改变输出通道数,可以改变卷积后的特征图的维度。
  • kernel_size (int or tuple):
    卷积核的大小。可以是一个整数,表示卷积核是正方形的(例如,3 表示一个 3x3 的卷积核);也可以是一个元组,表示非正方形卷积核(例如,(3, 5) 表示卷积核的高度为 3,宽度为 5)。
  • stride (int or tuple, default=1):
    卷积操作的步幅。表示卷积核在输入图像上滑动的步长。可以是一个整数,表示水平和垂直方向的步幅相同;也可以是一个元组,表示水平和垂直方向的步幅不同(例如,(2, 1) 表示水平步幅为 2,垂直步幅为 1)。
  • padding (int or tuple, default=0):
    填充(Padding)大小。填充是指在输入图像的边界加上额外的像素,防止卷积操作减少图像的空间尺寸。可以是一个整数,表示四个边的填充相同;也可以是一个元组,表示每个边的填充不同(例如,(2, 4) 表示上、下边的填充为 2,左、右边的填充为 4)。
  • dilation (int or tuple, default=1):
    卷积核的膨胀因子。膨胀卷积是指在卷积核中插入“空洞”,即卷积核的元素之间的距离增大。可以是一个整数,表示膨胀因子在各个方向上相同;也可以是一个元组,表示膨胀因子在水平和垂直方向上不同。
  • groups (int, default=1):
    卷积层的组数。groups 控制卷积核的分组方式。groups=1 表示标准卷积;groups 大于 1 时,表示分组卷积。
  • bias (bool, default=True):
    是否使用偏置项。如果为 True,卷积层会学习偏置参数;如果为 False,卷积层没有偏置项。
  • padding_mode (str, default=‘zeros’):
    填充模式。该参数控制如何填充输入图像的边界。仅支持 ‘zeros’ 模式,即使用零填充。
  • device (torch.device, optional):
    用于指定模型所在的设备,通常是 ‘cpu’ 或 ‘cuda’。
  • dtype (torch.dtype, optional):
    用于指定模型的张量数据类型。

接下来,我们根据各类卷积的概念,通过torch.nn.Conv2d()实现逐个定义。

1. 标准卷积(Standard Convolution)

标准卷积是最基本的卷积,使用的是常规的卷积层,没有分组和膨胀。

import torch
import torch.nn as nnclass StandardConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(StandardConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)def forward(self, x):return self.conv(x)# 示例:使用标准卷积
model_standard = StandardConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
input_tensor = torch.randn(1, 3, 64, 64)  # 假设输入是 64x64 的 RGB 图像
output = model_standard(input_tensor)
print(output.shape)  # 输出大小

2. 逐点卷积(Pointwise Convolution)

逐点卷积是 1x1 卷积,它的作用是对每个像素进行线性变换,通常用于改变通道数(如通道数的升降)。

class PointwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels):super(PointwiseConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)def forward(self, x):return self.conv(x)# 示例:使用逐点卷积
model_pointwise = PointwiseConv2d(in_channels=3, out_channels=16)
output_pointwise = model_pointwise(input_tensor)
print(output_pointwise.shape)  # 输出大小

3. 非对称卷积(Asymmetric Convolution)

非对称卷积指的是卷积核的长宽比不相等,通常使用较长或较窄的卷积核。例如,3x1 或 1x5 的卷积核。

class AsymmetricConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=(3, 1)):super(AsymmetricConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size[0]//2, kernel_size[1]//2))def forward(self, x):return self.conv(x)# 示例:使用非对称卷积
model_asymmetric = AsymmetricConv2d(in_channels=3, out_channels=16, kernel_size=(3, 1))
output_asymmetric = model_asymmetric(input_tensor)
print(output_asymmetric.shape)  # 输出大小

4. 空洞卷积(Dilated Convolution)

空洞卷积(又叫扩展卷积、膨胀卷积)通过在卷积核之间插入“空洞”来扩大感受野。通过设置 dilation 参数来实现。

class DilatedConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=2):super(DilatedConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation)def forward(self, x):return self.conv(x)# 示例:使用扩展卷积
model_dilated = DilatedConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=2)
output_dilated = model_dilated(input_tensor)
print(output_dilated.shape)  # 输出大小

5. 深度卷积(Depthwise Convolution)

深度卷积是每个输入通道单独进行卷积。它的特点是 groups 参数等于 in_channels。

class DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(DepthwiseConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=in_channels)def forward(self, x):return self.conv(x)# 示例:使用深度卷积
model_depthwise = DepthwiseConv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
output_depthwise = model_depthwise(input_tensor)
print(output_depthwise.shape)  # 输出大小

6. 组卷积(Group Convolution)

组卷积是将输入通道划分为若干组,每组使用不同的卷积核进行卷积。我们可以通过 groups 参数来设置。

class GroupConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=2):super(GroupConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups)def forward(self, x):return self.conv(x)# 示例:使用组卷积
model_group = GroupConv2d(in_channels=6, out_channels=12, kernel_size=3, stride=1, padding=1, groups=3)
input_group = torch.randn(1, 6, 64, 64)  # 假设输入是 6 通道的图像
output_group = model_group(input_group)
print(output_group.shape)  # 输出大小

7. 深度可分离卷积(Depthwise + Pointwise Convolution)

深度可分离卷积由深度卷积(Depthwise Convolution)和逐点卷积(Pointwise Convolution)组合而成。深度卷积负责捕捉空间信息,逐点卷积用于改变通道数。

class SeparableConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(SeparableConv2d, self).__init__()self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return x# 示例:使用空间可分离卷积
model_separable = SeparableConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
output_separable = model_separable(input_tensor)
print(output_separable.shape)  # 输出大小

8. 转置卷积(Transpose Convolution)

转置卷积(也叫反卷积)通常用于上采样。PyTorch 中使用 ConvTranspose2d 实现。

conv = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)

9. 动态卷积(Dynamic Convolution)

动态卷积是根据输入或时间步动态生成卷积核。这意味着卷积核在处理过程中是根据输入数据进行变化的,通常使用与输入张量的大小相关的方式来动态调整卷积核。

class DynamicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3):super(DynamicConv2d, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizedef forward(self, x):batch_size, _, height, width = x.size()# 根据输入动态生成卷积核(例如:使用全连接层生成卷积核)weight = torch.randn(batch_size, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size).to(x.device)return torch.nn.functional.conv2d(x, weight, padding=self.kernel_size//2)# 示例:使用动态卷积
model_dynamic = DynamicConv2d(in_channels=3, out_channels=16, kernel_size=3)
output_dynamic = model_dynamic(input_tensor)
print(output_dynamic.shape)  # 输出大小

相关文章:

深度学习pytorch之简单方法自定义9种卷积即插即用

本文详细解析了 PyTorch 中 torch.nn.Conv2d 的核心参数,通过代码示例演示了如何利用这一基础函数实现多种卷积操作。涵盖的卷积类型包括:标准卷积、逐点卷积(1x1 卷积)、非对称卷积(长宽不等的卷积核)、空…...

TMS320F28P550SJ9学习笔记2:Sysconfig 配置与点亮LED

今日学习使用Sysconfig 对引脚进行配置,并点亮开发板上的LED4 与LED5 我的单片机开发板平台是 LAUNCHXL_F28P55x 我是在上文描述的驱动库C2000ware官方例程example的工程基础之上进行添加功能的 该例程路径如下:D:\C2000Ware_5_04_00_00\driverlib\f28p…...

zRAM内存压缩技术:原理与实践初探

zRAM内存压缩技术:原理与实践指南 1. 技术背景与原理 zRAM是Linux内核中的一项内存压缩技术,于2014年进入Linux 3.14内核主线。它的核心思想是利用CPU压缩算法压缩内存数据,在不增加物理内存的情况下扩展系统有效内存容量。 当系统内存紧张…...

Hive 3.1 在 metastore 运行的 remote threads

Remote threads 是仅当 Hive metastore 作为单独的服务运行是启动,请求需要开启 compactor。 有以下几种: 1. AcidOpenTxnsCounterService 统计当前 open 的事务数 从表 TXNS 中统计状态为 open 的事务。此事务数量可以再 hive metrics 中。 2. Acid…...

大语言模型揭秘:从诞生到智能

引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)无疑是技术领域最耀眼的明星之一。它们不仅能够理解人类的自然语言,还能生成流畅的文本,甚至在对话、翻译、创作等任务中表现出接近人类的智能…...

基于模糊PID控制的供热控制系统设计Simulink仿真

1.模型简介 本仿真模型基于MATLAB/Simulink(版本MATLAB 2017Ra)软件。建议采用matlab2017 Ra及以上版本打开。(若需要其他版本可联系店主代为转换) 换热站干扰因素多导致传统PID控制无法满足控制要求的问题,提出利用…...

宝塔找不到php扩展swoole,服务器编译安装

1. 在php7.4中安装swoole,但找不到这个扩展安装 2. 服务器下载源码解压安装 http://pecl.php.net/package/swoole 下载4.8.0版本 解压到/www/server/php/74/下 3. 发现报错问题; 更新一下依赖 yum update yum -y install gcc gcc-c autoconf libjpe…...

LeetCode 1745.分割回文串 IV:动态规划(用III或II能直接秒)

【LetMeFly】1745.分割回文串 IV:动态规划(用III或II能直接秒) 力扣题目链接:https://leetcode.cn/problems/palindrome-partitioning-iv/ 给你一个字符串 s ,如果可以将它分割成三个 非空 回文子字符串,…...

C++发展

目录 ​编辑C 的发展总结:​编辑 1. C 的早期发展(1979-1985) 2. C 标准化过程(1985-1998) 3. C 标准演化(2003-2011) 4. C11(2011年) 5. C14(2014年&a…...

Python:函数,return返回值与形参实参

函数: 如: def login():print("这是登陆函数") login() #调用几次,函数里面的代码就会运行几次,每次调用的时候函数都会从头开始运行 return返回值:函数执行结束后最后给调用着的一个结果 作用&#xff1a…...

DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)示例2: 分页和排序

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏+关注哦 💕 目录 DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)示例2: 分页和排序📚前言📚页面效果📚指令…...

pandas 文本数据处理

文本数据处理 获取字符串长度: ​ 需要用到函数:str.len() 例: # 求字符串长度 # 引用 pandas import pandas as pd # 定义数据 data {"姓名":["张三","李四","王五","赵六"],"…...

GCC RISCV 后端 -- GCC 后端框架的一些理解

GCC 已经提供了一整套的编译框架,从前端(Frontend / GENERIC-Tree)对编程语言的语法语义处理,到中端(Middle-End / GIMPLE-Tree)的目标机器无关(Target Indepndent)的优化处理&#…...

FastGPT 源码:如何实现 “问题优化“

文章目录 FastGPT 源码:如何实现 "问题优化"一、前言二、源码分析2.1 queryExtension.ts 提示词2.2 queryExtension.ts 核心逻辑2.3 queryExtension 引用位置 三、流程总结 FastGPT 源码:如何实现 “问题优化” 一、前言 问题优化的背景和目…...

CSS—flex布局、过渡transition属性、2D转换transform属性、3D转换transform属性

​ 1.flex布局 也叫弹性布局,是浏览器提倡的布局模型,非常适合结构化布局,提供了强大的空间分布和对齐能力,不会产生浮动布局中脱标现象,布局网页更简单,更灵活。 flex容器属性: 属性描述d…...

Spring Boot Gradle 项目中使用 @Slf4j 注解

Spring Boot Gradle 项目中,如果想使用 Slf4j 注解来启用日志记录,首先需要添加 Lombok 和 SLF4J 的依赖。可以通过以下步骤来添加它们: 1. 添加 Lombok 依赖 在 build.gradle 文件中添加以下 Lombok 依赖: dependencies {impl…...

FreeRTOS系列---程序正常,但任务无法创建

实验环境 stm32F103RCT6核心板 keil5 vscode stm32cubemx 使用stm32cubemx 问题现场 void my_task_init(void) {xTaskCreate(LED1_Task, "LED1_Task", configMINIMAL_STACK_SIZE, NULL, 1, NULL);xTaskCreate(LED2_Task, "LED2_Task", configMINIMA…...

linux应用:errno、perror、open、fopen

errno errno 是一个全局变量,定义在 头文件中。当系统调用(如 open、read、write 等)或库函数执行失败时,会将一个错误码赋值给 errno。不同的错误码代表不同的错误类型,通过检查 errno 的值,可以判断具体…...

物联网中的气象监测设备具备顶级功能

物联网中的气象监测设备具备顶级功能时,通常集成GPS、数据上报和预警系统,以确保精准监测和及时响应。以下是这些功能的详细说明: 1. GPS定位 精准定位:GPS模块提供设备的精确地理位置,确保数据与具体位置关联&#…...

15-YOLOV8OBB损失函数详解

一、YOLO OBB支持的OBB 在Ultralytics YOLO 模型中,OBB 由YOLO OBB 格式中的四个角点表示。这样可以更准确地检测到物体,因为边界框可以旋转以更好地适应物体。其坐标在 0 和 1 之间归一化: class_index x1 y1 x2 y2 x3 y3 x4 y4 YOLO 在内部处理损失和输出是xywhr 格式,x…...

【雷达信号优化】第八章 阵列校准与误差补偿

目录 第八章 阵列校准与误差补偿 8.1 阵列误差模型 8.1.1 幅相误差 8.1.1.1 互耦效应建模 8.1.1.1.1 互耦矩阵的逆矩阵简化 8.2 阵列自校准算法 8.2.1 信号子空间拟合算法 8.2.1.1 交替优化策略 8.2.1.1.1 信源方向与误差参数的迭代更新 8.2.2 辅助源校准 8.2.2.1 单…...

告别彻夜等待:SteamShutdown让游戏下载完成后自动关机的智能解决方案

告别彻夜等待:SteamShutdown让游戏下载完成后自动关机的智能解决方案 【免费下载链接】SteamShutdown Automatic shutdown after Steam download(s) has finished. 项目地址: https://gitcode.com/gh_mirrors/st/SteamShutdown 你是否也曾经历过这样的困扰&a…...

终极指南:如何轻松解包Godot PCK文件并提取游戏资源

终极指南:如何轻松解包Godot PCK文件并提取游戏资源 【免费下载链接】godot-unpacker godot .pck unpacker 项目地址: https://gitcode.com/gh_mirrors/go/godot-unpacker 还在为Godot游戏的PCK文件无法解包而烦恼吗?无论你是游戏开发者想要复用资…...

Windows 11/10扩展属性冲突:输入法与UAC的隐藏关联

1. Windows扩展属性冲突的典型表现 最近在帮同事调试一个自动化脚本时,遇到了一个奇怪的问题。每次运行那个bat文件,系统就会弹出"扩展属性不一致"的错误提示。这个bat脚本本身很简单,就是用来启动一个内部工具的可执行文件。但无…...

基于DAMOYOLO-S与计算机网络技术:构建分布式视频分析集群

基于DAMOYOLO-S与计算机网络技术:构建分布式视频分析集群 想象一下,一个大型物流园区,上百个摄像头日夜不停地运转,管理者需要实时知道:哪条通道拥堵了?哪个区域有异常人员闯入?传统的监控方式…...

FastAPI状态管理:FastAPI 全局状态管理的 3 种最佳实践

更多内容请见: 《Python Web项目集锦》 - 专栏介绍和目录 在构建生产级FastAPI应用时,全局状态管理是确保资源高效利用和系统稳定性的关键。不当的状态管理可能导致资源泄漏、线程安全问题和不可预测的行为。本文将深入分析FastAPI中实现全局状态的三种最佳实践,揭示其底层机…...

CSS动画播放状态控制终极指南:掌握交互式动画实现技巧

CSS动画播放状态控制终极指南:掌握交互式动画实现技巧 【免费下载链接】css-reference CSS Reference: a free visual guide to the most popular CSS properties 项目地址: https://gitcode.com/gh_mirrors/cs/css-reference CSS动画播放状态控制是网页交互…...

番茄小说下载器:一站式离线阅读与听书解决方案

番茄小说下载器:一站式离线阅读与听书解决方案 【免费下载链接】Tomato-Novel-Downloader 番茄小说下载器不精简版 项目地址: https://gitcode.com/gh_mirrors/to/Tomato-Novel-Downloader 还在为网络不稳定而无法畅快阅读番茄小说烦恼吗?想要在通…...

OpenClaw我的龙虾怎么识别不了图片

问题现象 图片发送给龙虾,要么一直说没收到图片,要么提示不支持,要么提示安装OCR工具,要么就是识别出来的完全牛头不对马嘴。 解决方案 这里面涉及三个因素: 模型是否支撑图片识别配置中的input是否配置了image聊天渠道…...

构建语音驱动的智能Agent:集成SenseVoice-Small与AI决策框架

构建语音驱动的智能Agent:集成SenseVoice-Small与AI决策框架 你有没有想过,对着电脑说句话,它就能帮你写代码、查资料、甚至控制智能家居?这听起来像是科幻电影里的场景,但现在,通过将强大的语音识别模型与…...