当前位置: 首页 > news >正文

深度学习pytorch之简单方法自定义9种卷积即插即用

本文详细解析了 PyTorch 中 torch.nn.Conv2d 的核心参数,通过代码示例演示了如何利用这一基础函数实现多种卷积操作。涵盖的卷积类型包括:标准卷积、逐点卷积(1x1 卷积)、非对称卷积(长宽不等的卷积核)、空洞卷积(扩大感受野)、深度卷积(逐通道滤波)、组卷积(分组独立处理)、深度可分离卷积(深度+逐点组合)、转置卷积(上采样)和动态卷积(动态生成卷积核),帮助读者理解如何通过调整参数灵活构建卷积层,适应不同任务需求。
深度学习pytorch之22种损失函数数学公式和代码定义
深度学习pytorch之19种优化算法(optimizer)解析
深度学习pytorch之4种归一化方法(Normalization)原理公式解析和参数使用
深度学习pytorch之简单方法自定义多种卷积即插即用

基础函数torch.nn.Conv2d()

torch.nn.Conv2d()定义了一个用于量化的 2D 卷积层 Conv2d,其继承自 _ConvNd 类,专门用于处理量化输入信号。

class Conv2d(_ConvNd):def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True,padding_mode='zeros', device=None, dtype=None):........def forward(self, input):.........

在使用其定义常用卷积前,我们需要理解每一个参数的具体概念。

  • in_channels (int):
    输入数据的通道数。比如在图像处理中,对于彩色图像通常是 3(RGB),对于灰度图像通常是 1。
  • out_channels (int):
    输出数据的通道数,即卷积操作后得到的特征图的深度。通过改变输出通道数,可以改变卷积后的特征图的维度。
  • kernel_size (int or tuple):
    卷积核的大小。可以是一个整数,表示卷积核是正方形的(例如,3 表示一个 3x3 的卷积核);也可以是一个元组,表示非正方形卷积核(例如,(3, 5) 表示卷积核的高度为 3,宽度为 5)。
  • stride (int or tuple, default=1):
    卷积操作的步幅。表示卷积核在输入图像上滑动的步长。可以是一个整数,表示水平和垂直方向的步幅相同;也可以是一个元组,表示水平和垂直方向的步幅不同(例如,(2, 1) 表示水平步幅为 2,垂直步幅为 1)。
  • padding (int or tuple, default=0):
    填充(Padding)大小。填充是指在输入图像的边界加上额外的像素,防止卷积操作减少图像的空间尺寸。可以是一个整数,表示四个边的填充相同;也可以是一个元组,表示每个边的填充不同(例如,(2, 4) 表示上、下边的填充为 2,左、右边的填充为 4)。
  • dilation (int or tuple, default=1):
    卷积核的膨胀因子。膨胀卷积是指在卷积核中插入“空洞”,即卷积核的元素之间的距离增大。可以是一个整数,表示膨胀因子在各个方向上相同;也可以是一个元组,表示膨胀因子在水平和垂直方向上不同。
  • groups (int, default=1):
    卷积层的组数。groups 控制卷积核的分组方式。groups=1 表示标准卷积;groups 大于 1 时,表示分组卷积。
  • bias (bool, default=True):
    是否使用偏置项。如果为 True,卷积层会学习偏置参数;如果为 False,卷积层没有偏置项。
  • padding_mode (str, default=‘zeros’):
    填充模式。该参数控制如何填充输入图像的边界。仅支持 ‘zeros’ 模式,即使用零填充。
  • device (torch.device, optional):
    用于指定模型所在的设备,通常是 ‘cpu’ 或 ‘cuda’。
  • dtype (torch.dtype, optional):
    用于指定模型的张量数据类型。

接下来,我们根据各类卷积的概念,通过torch.nn.Conv2d()实现逐个定义。

1. 标准卷积(Standard Convolution)

标准卷积是最基本的卷积,使用的是常规的卷积层,没有分组和膨胀。

import torch
import torch.nn as nnclass StandardConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(StandardConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)def forward(self, x):return self.conv(x)# 示例:使用标准卷积
model_standard = StandardConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
input_tensor = torch.randn(1, 3, 64, 64)  # 假设输入是 64x64 的 RGB 图像
output = model_standard(input_tensor)
print(output.shape)  # 输出大小

2. 逐点卷积(Pointwise Convolution)

逐点卷积是 1x1 卷积,它的作用是对每个像素进行线性变换,通常用于改变通道数(如通道数的升降)。

class PointwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels):super(PointwiseConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)def forward(self, x):return self.conv(x)# 示例:使用逐点卷积
model_pointwise = PointwiseConv2d(in_channels=3, out_channels=16)
output_pointwise = model_pointwise(input_tensor)
print(output_pointwise.shape)  # 输出大小

3. 非对称卷积(Asymmetric Convolution)

非对称卷积指的是卷积核的长宽比不相等,通常使用较长或较窄的卷积核。例如,3x1 或 1x5 的卷积核。

class AsymmetricConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=(3, 1)):super(AsymmetricConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size[0]//2, kernel_size[1]//2))def forward(self, x):return self.conv(x)# 示例:使用非对称卷积
model_asymmetric = AsymmetricConv2d(in_channels=3, out_channels=16, kernel_size=(3, 1))
output_asymmetric = model_asymmetric(input_tensor)
print(output_asymmetric.shape)  # 输出大小

4. 空洞卷积(Dilated Convolution)

空洞卷积(又叫扩展卷积、膨胀卷积)通过在卷积核之间插入“空洞”来扩大感受野。通过设置 dilation 参数来实现。

class DilatedConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=2):super(DilatedConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation)def forward(self, x):return self.conv(x)# 示例:使用扩展卷积
model_dilated = DilatedConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=2)
output_dilated = model_dilated(input_tensor)
print(output_dilated.shape)  # 输出大小

5. 深度卷积(Depthwise Convolution)

深度卷积是每个输入通道单独进行卷积。它的特点是 groups 参数等于 in_channels。

class DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(DepthwiseConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=in_channels)def forward(self, x):return self.conv(x)# 示例:使用深度卷积
model_depthwise = DepthwiseConv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
output_depthwise = model_depthwise(input_tensor)
print(output_depthwise.shape)  # 输出大小

6. 组卷积(Group Convolution)

组卷积是将输入通道划分为若干组,每组使用不同的卷积核进行卷积。我们可以通过 groups 参数来设置。

class GroupConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=2):super(GroupConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups)def forward(self, x):return self.conv(x)# 示例:使用组卷积
model_group = GroupConv2d(in_channels=6, out_channels=12, kernel_size=3, stride=1, padding=1, groups=3)
input_group = torch.randn(1, 6, 64, 64)  # 假设输入是 6 通道的图像
output_group = model_group(input_group)
print(output_group.shape)  # 输出大小

7. 深度可分离卷积(Depthwise + Pointwise Convolution)

深度可分离卷积由深度卷积(Depthwise Convolution)和逐点卷积(Pointwise Convolution)组合而成。深度卷积负责捕捉空间信息,逐点卷积用于改变通道数。

class SeparableConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(SeparableConv2d, self).__init__()self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return x# 示例:使用空间可分离卷积
model_separable = SeparableConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
output_separable = model_separable(input_tensor)
print(output_separable.shape)  # 输出大小

8. 转置卷积(Transpose Convolution)

转置卷积(也叫反卷积)通常用于上采样。PyTorch 中使用 ConvTranspose2d 实现。

conv = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)

9. 动态卷积(Dynamic Convolution)

动态卷积是根据输入或时间步动态生成卷积核。这意味着卷积核在处理过程中是根据输入数据进行变化的,通常使用与输入张量的大小相关的方式来动态调整卷积核。

class DynamicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3):super(DynamicConv2d, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizedef forward(self, x):batch_size, _, height, width = x.size()# 根据输入动态生成卷积核(例如:使用全连接层生成卷积核)weight = torch.randn(batch_size, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size).to(x.device)return torch.nn.functional.conv2d(x, weight, padding=self.kernel_size//2)# 示例:使用动态卷积
model_dynamic = DynamicConv2d(in_channels=3, out_channels=16, kernel_size=3)
output_dynamic = model_dynamic(input_tensor)
print(output_dynamic.shape)  # 输出大小

相关文章:

深度学习pytorch之简单方法自定义9种卷积即插即用

本文详细解析了 PyTorch 中 torch.nn.Conv2d 的核心参数,通过代码示例演示了如何利用这一基础函数实现多种卷积操作。涵盖的卷积类型包括:标准卷积、逐点卷积(1x1 卷积)、非对称卷积(长宽不等的卷积核)、空…...

TMS320F28P550SJ9学习笔记2:Sysconfig 配置与点亮LED

今日学习使用Sysconfig 对引脚进行配置,并点亮开发板上的LED4 与LED5 我的单片机开发板平台是 LAUNCHXL_F28P55x 我是在上文描述的驱动库C2000ware官方例程example的工程基础之上进行添加功能的 该例程路径如下:D:\C2000Ware_5_04_00_00\driverlib\f28p…...

zRAM内存压缩技术:原理与实践初探

zRAM内存压缩技术:原理与实践指南 1. 技术背景与原理 zRAM是Linux内核中的一项内存压缩技术,于2014年进入Linux 3.14内核主线。它的核心思想是利用CPU压缩算法压缩内存数据,在不增加物理内存的情况下扩展系统有效内存容量。 当系统内存紧张…...

Hive 3.1 在 metastore 运行的 remote threads

Remote threads 是仅当 Hive metastore 作为单独的服务运行是启动,请求需要开启 compactor。 有以下几种: 1. AcidOpenTxnsCounterService 统计当前 open 的事务数 从表 TXNS 中统计状态为 open 的事务。此事务数量可以再 hive metrics 中。 2. Acid…...

大语言模型揭秘:从诞生到智能

引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)无疑是技术领域最耀眼的明星之一。它们不仅能够理解人类的自然语言,还能生成流畅的文本,甚至在对话、翻译、创作等任务中表现出接近人类的智能…...

基于模糊PID控制的供热控制系统设计Simulink仿真

1.模型简介 本仿真模型基于MATLAB/Simulink(版本MATLAB 2017Ra)软件。建议采用matlab2017 Ra及以上版本打开。(若需要其他版本可联系店主代为转换) 换热站干扰因素多导致传统PID控制无法满足控制要求的问题,提出利用…...

宝塔找不到php扩展swoole,服务器编译安装

1. 在php7.4中安装swoole,但找不到这个扩展安装 2. 服务器下载源码解压安装 http://pecl.php.net/package/swoole 下载4.8.0版本 解压到/www/server/php/74/下 3. 发现报错问题; 更新一下依赖 yum update yum -y install gcc gcc-c autoconf libjpe…...

LeetCode 1745.分割回文串 IV:动态规划(用III或II能直接秒)

【LetMeFly】1745.分割回文串 IV:动态规划(用III或II能直接秒) 力扣题目链接:https://leetcode.cn/problems/palindrome-partitioning-iv/ 给你一个字符串 s ,如果可以将它分割成三个 非空 回文子字符串,…...

C++发展

目录 ​编辑C 的发展总结:​编辑 1. C 的早期发展(1979-1985) 2. C 标准化过程(1985-1998) 3. C 标准演化(2003-2011) 4. C11(2011年) 5. C14(2014年&a…...

Python:函数,return返回值与形参实参

函数: 如: def login():print("这是登陆函数") login() #调用几次,函数里面的代码就会运行几次,每次调用的时候函数都会从头开始运行 return返回值:函数执行结束后最后给调用着的一个结果 作用&#xff1a…...

DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)示例2: 分页和排序

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏+关注哦 💕 目录 DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)示例2: 分页和排序📚前言📚页面效果📚指令…...

pandas 文本数据处理

文本数据处理 获取字符串长度: ​ 需要用到函数:str.len() 例: # 求字符串长度 # 引用 pandas import pandas as pd # 定义数据 data {"姓名":["张三","李四","王五","赵六"],"…...

GCC RISCV 后端 -- GCC 后端框架的一些理解

GCC 已经提供了一整套的编译框架,从前端(Frontend / GENERIC-Tree)对编程语言的语法语义处理,到中端(Middle-End / GIMPLE-Tree)的目标机器无关(Target Indepndent)的优化处理&#…...

FastGPT 源码:如何实现 “问题优化“

文章目录 FastGPT 源码:如何实现 "问题优化"一、前言二、源码分析2.1 queryExtension.ts 提示词2.2 queryExtension.ts 核心逻辑2.3 queryExtension 引用位置 三、流程总结 FastGPT 源码:如何实现 “问题优化” 一、前言 问题优化的背景和目…...

CSS—flex布局、过渡transition属性、2D转换transform属性、3D转换transform属性

​ 1.flex布局 也叫弹性布局,是浏览器提倡的布局模型,非常适合结构化布局,提供了强大的空间分布和对齐能力,不会产生浮动布局中脱标现象,布局网页更简单,更灵活。 flex容器属性: 属性描述d…...

Spring Boot Gradle 项目中使用 @Slf4j 注解

Spring Boot Gradle 项目中,如果想使用 Slf4j 注解来启用日志记录,首先需要添加 Lombok 和 SLF4J 的依赖。可以通过以下步骤来添加它们: 1. 添加 Lombok 依赖 在 build.gradle 文件中添加以下 Lombok 依赖: dependencies {impl…...

FreeRTOS系列---程序正常,但任务无法创建

实验环境 stm32F103RCT6核心板 keil5 vscode stm32cubemx 使用stm32cubemx 问题现场 void my_task_init(void) {xTaskCreate(LED1_Task, "LED1_Task", configMINIMAL_STACK_SIZE, NULL, 1, NULL);xTaskCreate(LED2_Task, "LED2_Task", configMINIMA…...

linux应用:errno、perror、open、fopen

errno errno 是一个全局变量,定义在 头文件中。当系统调用(如 open、read、write 等)或库函数执行失败时,会将一个错误码赋值给 errno。不同的错误码代表不同的错误类型,通过检查 errno 的值,可以判断具体…...

物联网中的气象监测设备具备顶级功能

物联网中的气象监测设备具备顶级功能时,通常集成GPS、数据上报和预警系统,以确保精准监测和及时响应。以下是这些功能的详细说明: 1. GPS定位 精准定位:GPS模块提供设备的精确地理位置,确保数据与具体位置关联&#…...

15-YOLOV8OBB损失函数详解

一、YOLO OBB支持的OBB 在Ultralytics YOLO 模型中,OBB 由YOLO OBB 格式中的四个角点表示。这样可以更准确地检测到物体,因为边界框可以旋转以更好地适应物体。其坐标在 0 和 1 之间归一化: class_index x1 y1 x2 y2 x3 y3 x4 y4 YOLO 在内部处理损失和输出是xywhr 格式,x…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中,结构体可以嵌套使用,形成更复杂的数据结构。例如,可以通过嵌套结构体描述多层级数据关系: struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

遍历 Map 类型集合的方法汇总

1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述,后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作,其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

蓝桥杯 2024 15届国赛 A组 儿童节快乐

P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异&#xff…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

Linux --进程控制

本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...