当前位置: 首页 > news >正文

PyTorch - Conv2d 和 MaxPool2d

文章目录

    • Conv2d
      • 计算
      • Conv2d 函数解析
      • 代码示例
    • MaxPool2d
      • 计算
      • 函数说明
    • 卷积过程动画
        • Transposed convolution animations
        • Transposed convolution animations


参考视频:土堆说 卷积计算
https://www.bilibili.com/video/BV1hE411t7RN


关于 torch.nn 和 torch.nn.function
torch.nn 是对 torch.nn.function 的封装,前者更方便实用。


Conv2d

卷积过程可见文末动画


计算

卷积层输入特征图(input feature map)的尺寸为:H_i × W_i × C_i

  • H_i :输入特征图的高
  • W_i :输入特征图的宽
  • C_i :输入特征图的通道数
    (如果是第一个卷积层则是输入图像的通道数,如果是中间的卷积层,则是上一层的输出通道数

卷积层的参数如下:

  • P:padding,补零的行数和列数
  • F:正方形卷积核的边长
  • S:stride,步幅
  • K:输出通道数

输出特征图(output feature map)的尺寸为 H_o × W_o × C_o ,其中每一个变量的计算方式如下:

  • H_o = (H_i + 2P − F)/S + 1
  • W_o = (W_i + 2P − F)/S + 1
  • C_o = K

  • 卷积时,超出边界的不计算。

参数量大小的计算,分为weights和biases:

首先,计算weights的参数量:F × F × C_i × K
接着计算biases的参数量:K
所以总参数量为:F × F × C_i × K + K


计算示例

输入卷积核步长padding输出计算
5x52x2104x44 = (5-2)/1 + 1
5x53x3103x33 = (5-3)/1 + 1
5x52x2202x22 = (5-2)/2 + 1
6x62x2203x33 = (6-2)/2 + 1
5x52x2116x64 = (5 + 1*2 - 2)/1 + 1
5x53x3115x53 = (5 + 1*2 - 3)/1 + 1
5x53x3224x43 = (5 + 2*2 - 3)/2 + 1

Conv2d 函数解析

  • torch.nn.functional.conv2d 官方说明
    https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html#torch.nn.functional.conv2d

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

  • in_channels
  • out_channels
  • kernel_size,卷积核大小;可以是一个数(n*n矩阵),也可以是一个元组。这个值在训练过程中,会不断被调整。
  • stride=1
  • padding=0
  • dilation=1,卷积核对应位的距离
  • groups=1,分组卷积;一般为1,很少改动。
  • bias=True,偏置,一般为True
  • padding_mode=‘zeros’,如果设置了 padding,填充模式。默认为 zeros,即填充0。
  • device=None
  • dtype=None)

一般只设置前五个参数



代码示例

import torch
import torch.nn.functional as Ft1 = torch.Tensor([[1, 2, 0, 3, 1], [0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1], ])kernel = torch.Tensor([[1, 2, 1],[0, 1, 0],[2, 1, 0]
])
t1.shape, kernel.shape
# (torch.Size([5, 5]), torch.Size([3, 3]))# channel 和 batch_size 为 1
ip = torch.reshape(t1, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
ip.shape, kernel.shape
# (torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 3, 3]))op = F.conv2d(ip, kernel, stride=1) 
op, op.shape 
'''
(tensor([[[[10., 12., 12.],[18., 16., 16.],[13.,  9.,  3.]]]]),torch.Size([1, 1, 3, 3]))
'''# 不同 stride
op = F.conv2d(ip, kernel, stride=2) 
op, op.shape 
'''
(tensor([[[[10., 12.],[13.,  3.]]]]),torch.Size([1, 1, 2, 2]))
'''# 增加 padding
op = F.conv2d(ip, kernel, stride=2, padding=1) 
op, op.shape 
'''
(tensor([[[[ 1.,  4.,  8.],[ 7., 16.,  8.],[14.,  9.,  4.]]]]),torch.Size([1, 1, 3, 3]))
'''

MaxPool2d

池化的目的是保留特征,减少数据量;
最大池化也被称为 下采样;
另外池化操作是分别应用到每一个深度切片层。输出深度 与 输入的深度 相同。


计算

  • 输入宽高深:H_i,W_i, D_i
  • 滤波器宽高:f_w, f_h
  • S: stride,步长

输出为:
H_o = (H_i - f_h)/S + 1
W_o = (W_i - f_w)/S + 1
D_o = D_i


输入维度是 4x4x5 (HxWxD)
滤波器大小 2x2 (HxW)
stride 的高和宽都是 2 (S)


在这里插入图片描述


函数说明

  • 官方说明
    https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

  • ceil_mode,超出范围时是否计算

代码实现

import torch
import torch.nn as nn# MaxPool2d 函数 input 需要是 4维
ip = torch.reshape(t1, (-1, 1, 5, 5))class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=False)def forward(self, input):output = self.maxpool(input)return outputnet = Net()
ret = net(ip)
ret# tensor([[[[2., 3.], [5., 1.]]]])  # ceil_mode=True
# tensor([[[[2.]]]])  # ceil_mode=False

数据集中调用

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdata_path = '/xxxx/cifar10'
datasets = torchvision.datasets.CIFAR10(data_path, train=False, download=True, transform=torchvision.transforms.ToTensor())data_loader = DataLoader(datasets, batch_size=64)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.maxpool1 = nn.MaxPool2d(kernel_size=3, ceil_mode=False)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=False)def forward(self, input):output = self.maxpool1(input)return outputwriter = SummaryWriter('logs_maxpool1')
step = 0
net = Net()
for data in data_loader:imgs, targets = datawriter.add_images('input', imgs, step)output = net(imgs) writer.add_images('output', output, step)step = step + 1writer.close() 
  • 启动 tensorboard:
tensorboard --logdir=logs_maxpool1 


卷积过程动画

图片来自:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

Transposed convolution animations


No padding, no strides
请添加图片描述


Arbitrary padding, no strides

请添加图片描述


Half padding, no strides
请添加图片描述


Full padding, no strides

请添加图片描述


No padding, strides

请添加图片描述


Padding, strides

请添加图片描述


Padding, strides (odd)

请添加图片描述


Transposed convolution animations


No padding, no strides, transposed

请添加图片描述


Arbitrary padding, no strides, transposed

请添加图片描述


Half padding, no strides, transposed

请添加图片描述


Full padding, no strides, transposed

请添加图片描述


No padding, strides, transposed
请添加图片描述


Padding, strides, transposed
请添加图片描述


Padding, strides, transposed (odd)

请添加图片描述

相关文章:

PyTorch - Conv2d 和 MaxPool2d

文章目录Conv2d计算Conv2d 函数解析代码示例MaxPool2d计算函数说明卷积过程动画Transposed convolution animationsTransposed convolution animations参考视频:土堆说 卷积计算 https://www.bilibili.com/video/BV1hE411t7RN 关于 torch.nn 和 torch.nn.function t…...

leetcode Day2(昨天实习有点bug,心态要崩了)

int carry 0;for(int i a.size() - 1, j b.size() - 1; i > 0 || j > 0 || carry; --i, --j) {int x i < 0 ? 0 : a[i] - 0;int y j < 0 ? 0 : b[j] - 0;int sum (x y carry) % 2;carry (x y carry) / 2;str.insert(0, 1, sum 0);}return str;加一&a…...

另一种思考:为什么不选JPA、MyBatis,而选择JDBCTemplate

以下内容转载自&#xff1a;https://segmentfault.com/a/1190000018472572 作者&#xff1a;scherman 因为项目需要选择数据持久化框架&#xff0c;看了一下主要几个流行的和不流行的框架&#xff0c;对于复杂业务系统&#xff0c;最终的结论是&#xff0c;JOOQ是总体上最好的…...

LeetCode 338. 比特位计数

给你一个整数 n &#xff0c;对于 0 < i < n 中的每个 i &#xff0c;计算其二进制表示中 1 的个数 &#xff0c;返回一个长度为 n 1 的数组 ans 作为答案。 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;[0,1,1] 解释&#xff1a; 0 --> 0 1 --> …...

排序评估指标——NDCG和MAP

在搜索和推荐任务中&#xff0c;系统常返回一个item列表。如何衡量这个返回的列表是否优秀呢&#xff1f; 例如&#xff0c;当我们检索【推荐排序】&#xff0c;网页返回了与推荐排序相关的链接列表。列表可能会是[A,B,C,G,D,E,F],也可能是[C,F,A,E,D]&#xff0c;现在问题来了…...

[Android Studio] Android Studio Virtual Device(AVD)虚拟机的功能试用

&#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea; Android Debug&#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea; Topic 发布安卓学习过程中遇到问题解决过程&#xff0c;希望我的解决方案可以对小伙伴们有帮助。 &#x1f680;write…...

kafka-3-kafka应用的核心要点和内外网访问

kafka实战教程(python操作kafka)&#xff0c;kafka配置文件详解 Kafka内外网访问的设置 1 kafka简介 根据官网的介绍&#xff0c;ApacheKafka是一个分布式流媒体平台&#xff0c;它主要有3种功能&#xff1a; (1)发布和订阅消息流&#xff0c;这个功能类似于消息队列&#x…...

VS2017+OpenCV4.5.5 决策树-评估是否发放贷款

决策树是一种非参数的监督学习方法&#xff0c;主要用于分类和回归。 决策树结构 决策树在逻辑上以树的形式存在&#xff0c;包含根节点、内部结点和叶节点。 根节点&#xff1a;包含数据集中的所有数据的集合内部节点&#xff1a;每个内部节点为一个判断条件&#xff0c;并且…...

Prometheus 记录规则和警报规则

前提环境&#xff1a; Docker环境 涉及参考文档&#xff1a; Prometheus 录制规则Prometheus 警报规则 语法检查规则 promtool check rules /path/to/example.rules.yml一&#xff1a;录制规则语法 groups 语法&#xff1a; groups:[ - <rule_group> ]rule_group…...

(API)接口测试的关键技术

接口测试也就是API测试&#xff0c;从名字上可以知道是面向接口的测试活动。所以在讲API测试之前&#xff0c;我们应该说清楚接口是什么&#xff0c;那么接口就是有特定输入和特定输出的一套逻辑处理单元&#xff0c;而对于接口调用方来说&#xff0c;不用知道自身的内部实现逻…...

快速排序算法原理 Quicksort —— 图解(精讲) JAVA

快速排序是 Java 中 sort 函数主要的排序方法&#xff0c;所以今天要对快速排序法这种重要算法的详细原理进行分析。 思路&#xff1a;首先快速排序之所以高效一部分原因是利用了离散数学中的传递性。 例如 1 < 2 且 2 < 3 所以可以推出 1 < 3。在快速排序的过程中巧…...

linux环境搭建私有gitlab仓库

搭建之前&#xff0c;需要安装相应的依赖包&#xff0c;并且要启动sshd服务(1).安装policycoreutils-python openssh-server openssh-clients [rootVM-0-2-centos ~]# sudo yum install -y curl policycoreutils-python openssh-server openssh-clients [rootVM-0-2-centos ~]…...

SpringSecurity授权

文章目录工具类使用自定义失败处理代码配置跨域其他权限授权hasAnyAuthority自定义权限校验方法基于配置的权限控制工具类 import javax.servlet.http.HttpServletResponse; import java.io.IOException;public class WebUtils {/*** 将字符串渲染到客户端** param response 渲…...

学习 Python 之 Pygame 开发坦克大战(一)

学习 Python 之 Pygame 开发坦克大战&#xff08;一&#xff09;Pygame什么是Pygame?初识pygame1. 使用pygame创建窗口2. 设置窗口背景颜色3. 获取窗口中的事件4. 在窗口中展示图片(1). pygame中的直角坐标系(2). 展示图片(3). 给部分区域设置颜色5. 在窗口中显示文字6. 播放音…...

2.5|iot冯|方元-嵌入式linux系统开发入门|2.13+2.18

一、 Linux 指令操作题&#xff08;共5题&#xff08;共 20 分&#xff0c;每小题 4分&#xff09;与系统工作、系统状态、工作目录、文件、目录、打包压缩与搜索等主题相关。1.文件1.1文件属性1.2文件类型属性字段的第1个字符表示文件类型&#xff0c;后9个字符中&#xff0c;…...

一起Talk Android吧(第四百九十六回:自定义View实例二:环形进度条)

文章目录 知识回顾实现思路实现方法示例代码各位看官们大家好,上一回中咱们说的例子是"如何使用Java版MQTT客户端",这一回中咱们说的例子是"自定义View实例二:环形进度条"。闲话休提,言归正转,让我们一起Talk Android吧! 知识回顾 看官们,我们又回…...

上传图片尺寸校验

使用方法 ● Image ● URL ● onload代码&#xff1a; async validImageSize(file, imgWidth, imgHeight) {const img new Image()img.src URL.createObjectURL(file)const { w, h } await new Promise((resolve, reject) > {img.onload () > {const { width: w, he…...

【Python】缺失值处理和拉格朗日插值法(含源代码实现)

目录&#xff1a;缺失值处理和拉格朗日插值法一、前言二、理论知识三、代码实现一、前言 对于含有缺失值的数据集&#xff0c;如果通过删除小部分记录达到既定的目标&#xff0c;那么删除含有缺失值的记录的方法是最有效的。然而&#xff0c;这种方法也有很多问题&#xff0c;…...

SpringCloudAlibaba-Sentinel

一、介绍官网&#xff1a;https://github.com/alibaba/Sentinel/下载jar包,启动,访问http://localhost:8080/创建module添加如下依赖<!--SpringCloud ailibaba sentinel --><dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring…...

【程序化天空盒】过程记录02:云扰动 边缘光 消散效果

写在前面 写在前面唉&#xff0c;最近筋疲力竭&#xff0c;课题组的东西一堆没做&#xff0c;才刚刚开始带着思考准备练习作品&#xff0c;从去年5月份开始到现在真得学了快一年了&#xff0c;转行学其他的真的好累&#xff0c;&#xff0c;不过还是加油&#xff01; 下面是做…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全&#xff1a;零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言&#xff1a;云原生安全的范式革命 随着云原生技术的普及&#xff0c;安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测&#xff0c;到2025年&#xff0c;零信任架构将成为超…...

JVM 内存结构 详解

内存结构 运行时数据区&#xff1a; Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器&#xff1a; ​ 线程私有&#xff0c;程序控制流的指示器&#xff0c;分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 ​ 每个线程都有一个程序计数…...

使用Spring AI和MCP协议构建图片搜索服务

目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式&#xff08;本地调用&#xff09; SSE模式&#xff08;远程调用&#xff09; 4. 注册工具提…...

快刀集(1): 一刀斩断视频片头广告

一刀流&#xff1a;用一个简单脚本&#xff0c;秒杀视频片头广告&#xff0c;还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农&#xff0c;平时写代码之余看看电影、补补片&#xff0c;是再正常不过的事。 电影嘛&#xff0c;要沉浸&#xff0c;…...

MyBatis中关于缓存的理解

MyBatis缓存 MyBatis系统当中默认定义两级缓存&#xff1a;一级缓存、二级缓存 默认情况下&#xff0c;只有一级缓存开启&#xff08;sqlSession级别的缓存&#xff09;二级缓存需要手动开启配置&#xff0c;需要局域namespace级别的缓存 一级缓存&#xff08;本地缓存&#…...

0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化

是不是受够了安装了oracle database之后sqlplus的简陋&#xff0c;无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话&#xff0c;配置.bahs_profile后也能解决上下翻页这些&#xff0c;但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可&#xff0c…...

VisualXML全新升级 | 新增数据库编辑功能

VisualXML是一个功能强大的网络总线设计工具&#xff0c;专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑&#xff08;如DBC、LDF、ARXML、HEX等&#xff09;&#xff0c;并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...

yaml读取写入常见错误 (‘cannot represent an object‘, 117)

错误一&#xff1a;yaml.representer.RepresenterError: (‘cannot represent an object’, 117) 出现这个问题一直没找到原因&#xff0c;后面把yaml.safe_dump直接替换成yaml.dump&#xff0c;确实能保存&#xff0c;但出现乱码&#xff1a; 放弃yaml.dump&#xff0c;又切…...