当前位置: 首页 > news >正文

Pytorch个人学习记录总结 06

目录

神经网络-卷积层 torch.nn.Conv2d

神经网络-最大池化的使用 torch.nn.MaxPool2d


神经网络-卷积层 torch.nn.Conv2d

torch.nn.Conv2d的官方文档地址

CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)


卷积动画的链接:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.mdicon-default.png?t=N6B9https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

注意

  • 默认bias=True,这说明PyTorch中Con2d是默认给卷积操作加了偏置的。
  • 还有一些默认值:stride=1,padding=0等。
  • out_channels输出通道数,相当于就是卷积核的个数
  • dilation:需要使用空洞卷积时再进行设置。
    import torch
    from torch import nn
    from torch.nn import Conv2d
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    from torchvision import datasets
    from torchvision.transforms import transforms# 1. 加载数据
    dataset = datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=False)# 2. 构造模型
    class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1)def forward(self, x):return self.conv1(x)writer = SummaryWriter('./logs/Conv2d')# 3. 实例化一个模型对象,进行卷积
    model = Model()
    step = 0for data in dataloader:imgs, targets = datawriter.add_images('imgs_ch3', imgs, step)# 4. 用tensorboard打开查看图像。但是注意,add_images的输入图像的通道数只能是3
    #    所以如果通道数>3,则可以先采用小土堆的这个不严谨的做法,在tensorboard中查看一下图片outputs = model(imgs)outputs = torch.reshape(outputs, (-1, 3, 30, 30))writer.add_images('imgs_ch6', outputs, step)step += 1writer.close()
    

    神经网络-最大池化的使用 torch.nn.MaxPool2d

池化也可成为下采样(就是缩小输入图像尺寸,但是不会改变输入图像的通道数)。常见的有MaxPool2d、AvgPool2d等。相反有上采样MaxUnPool2d。

MaxPool2d的官方文档地址:MaxPool2d — PyTorch 2.0 documentation 

CLASS torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

注意

  • stride默认=kernel_size
  • ceil_mode默认是False,也就是说事向下取整

pool和conv后的图像尺寸N计算公式是一样的:N = ( W − F + 2 ∗ P ) / S + 1 N=(W-F+2*P)/S+1N=(W−F+2∗P)/S+1,且都是默认N向下取整。
 

  1. 在构造tensor的时候,最好指定元素的数据类型是float,即在最后加上dtype=torch.float32,这样后面有些操作才不会出错。
  2. 池化的作用:保持输入图像的特征,且减小输入量,能加快训练。(就类似于B站视频有10080P的也会有720P的,720P虽然不如1080P那么高清,但是仍然能够看出视频中物体的特征信息,有点像打了马赛克一样)
    import torch
    import torchvision.datasets
    from torch import nn
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriterclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.maxpool1 = nn.MaxPool2d(kernel_size=3)  # 默认:stride=kernel_size,ceil_mode=Falseself.maxpool2 = nn.MaxPool2d(kernel_size=3, ceil_mode=True)def forward(self, x):return self.maxpool1(x), self.maxpool2(x)model = Model()# -------------1.上图例子,查看ceil_mode为True或False的池化结果--------------- #
    input = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]], dtype=torch.float32)input = torch.reshape(input, (-1, 1, 5, 5))
    out1, out2 = model(input)
    print('out1={}\nout2={}'.format(out1, out2))# --------------2.加载数据集,并放入tensorboard查看图片----------------------- #
    dataset = torchvision.datasets.CIFAR10('dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)writer = SummaryWriter('./logs/maxpool')
    step = 0
    for data in dataloader:imgs, targets = datawriter.add_images('imgs', imgs, step)imgs, _ = model(imgs)writer.add_images('imgs_maxpool', imgs, step)step += 1writer.close()
    

相关文章:

Pytorch个人学习记录总结 06

目录 神经网络-卷积层 torch.nn.Conv2d 神经网络-最大池化的使用 torch.nn.MaxPool2d 神经网络-卷积层 torch.nn.Conv2d torch.nn.Conv2d的官方文档地址 CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride1, padding0, dilation1, groups1, biasTrue,…...

Rust之泛型、特性和生命期(四):验证有生存期的引用

开发环境 Windows 10Rust 1.71.0 VS Code 1.80.1 项目工程 这里继续沿用上次工程rust-demo 验证具有生存期的引用 生存期是我们已经在使用的另一种泛型。生存期不是确保一个类型具有我们想要的行为,而是确保引用在我们需要时有效。 我们在第4章“引用和借用”一…...

kubesphere安装中间件

kubesphere安装mysql 创建configMap [client] default-character-setutf8mb4[mysql] default-character-setutf8mb4[mysqld] init_connectSET collation_connection utf8mb4_unicode_ci init_connectSET NAMES utf8mb4 character-set-serverutf8mb4 collation-serverutf8mb4_…...

zookeeper学习(二) 集群模式安装

前置环境 三台centos7服务器 192.168.2.201 192.168.2.202 192.168.2.150三台服务器都需要安装jdk1.8以上zookeeper安装包 安装jdk 在单机模式已经描述过,这里略过,有需要可以去看单机模式中的这部分,注意的是三台服务器都需要安装 安装…...

选择合适的图表,高效展现数据魅力

随着大数据时代的来临,数据的重要性愈发凸显,数据分析和可视化成为了决策和传递信息的重要手段。在数据可视化中,选择合适的图表是至关重要的一环,它能让数据更加生动、直观地呈现,为观众提供更有说服力的信息。本文将…...

springboot自动装配

SPI spi : service provider interface : 是java的一种服务提供机制,spi 允许开发者在不修改代码的情况下,为某个接口提供实现类,来扩展应用程序 将实现类独立到配置文件中,通过配置文件控制导入&#xff…...

python小记-队列

队列(Queue)是一种常见的数据结构,它遵循先进先出(First-In-First-Out,FIFO)的原则。在队列中,新元素(也称为项)总是添加到队列的末尾,而最早添加的元素总是在…...

SpringBoot——持久化技术

简单介绍 在之前我们使用的数据层持久化技术使用的是MyBatis或者是MyBatis-plus,其实都是一样的。在使用之前,我们要导入对应的坐标,然后配置MyBatis特有的配置,比如说Mapper接口,或者XML配置文件,那么除了…...

Kafka 入门到起飞 - 生产者参数详解 ,什么是生产者确认机制? 什么是ISR? 什么是 OSR?

上回书我们讲了,生产者发送消息流程解析传送门 那么这篇我们来看下,生产者发送消息时几个重要的参数详解 ,什么是生产者确认机制? 什么是ISR? 什么是 OSR? 参数: bootstrap.servers : Kafka 集…...

【文献分享】比目前最先进的模型轻30%!高效多机器人SLAM蒸馏描述符!

论文题目:Descriptor Distillation for Efficient Multi-Robot SLAM 中文题目:高效多机器人SLAM蒸馏描述符 作者:Xiyue Guo, Junjie Hu, Hujun Bao and Guofeng Zhang 作者机构:浙江大学CAD&CG国家重点实验室 香港中文大学…...

【数据动态填充到element表格;将带有标签的数据展示为文本格式】

一&#xff1a;数据动态填充到element表格&#xff1b; 二&#xff1a;将带有标签的数据展示为文本格式&#xff1b; 1、 <el-row><el-col :span"24"><el-tabs type"border-card"><el-tab-pane label"返回值"><el-…...

小程序轮播图的两种后台方式(PHP)--【浅入深出系列008】

微信目录集链接在此&#xff1a; 详细解析黑马微信小程序视频–【思维导图知识范围】难度★✰✰✰✰ 不会导入/打开小程序的看这里&#xff1a;参考 让别人的小程序长成自己的样子-更换window上下颜色–【浅入深出系列001】 文章目录 本系列校训学习资源的选择啥是轮播图轮播…...

使用ComPDFKit PDF SDK 构建iOS PDF阅读器

在当今以移动为先的世界中&#xff0c;为企业和开发人员创建一个iOS应用程序是必不可少的。随着对PDF文档处理需求的增加&#xff0c;使用ComPDFKit这个强大的PDF软件开发工具包&#xff08;SDK&#xff09;来构建iOS PDF阅读器和编辑器可以让最终用户轻松查看和编辑PDF文档。 …...

一套流程6个步骤,教你如何正确采购询价

采购询价&#xff08;RFQ&#xff09;是一种竞争性投标文件&#xff0c;用于邀请供应商或承包商就标准化或重复生产的产品或服务提交报价。 询价通常用于大批量/低价值项目&#xff0c;买方必须提供技术规格和商业要求&#xff0c;该文件有时也称为招标书或投标邀请书。询价流…...

git使用

常用命令 git init git库初始化&#xff0c;初始化后会在文件中出现一个.git的隐藏文件 git clone 从远程克隆仓库 git pull 从远程库中拉取 git commit 将暂存提交到本地仓库 git push 提交本地仓库到远程 git branch 查看当前分支 git branch <branchName> 切换分支 …...

SkyWalking链路追踪-搭建-spring-boot-cloud-单机环境 之《10 分钟快速搭建 SkyWalking 服务》

首先了解一下单机环境 第一步&#xff0c;搭建一个 Elasticsearch 服务。第二步&#xff0c;下载 SkyWalking 软件包。第三步&#xff0c;搭建一个 SkyWalking OAP 服务。第四步&#xff0c;启动一个 Spring Boot 应用&#xff0c;并配置 SkyWalking Agent。第五步&#xff0c;…...

Rabbit MQ整合springBoot

一、pom依赖二、消费端2.1、application.properties 配置文件2.2、消费端核心组件 三、生产端3.1、application.properties 配置文件2.2、生产者 MQ消息发送组件四、测试1、生产端控制台2、消费端控制台 一、pom依赖 <dependency><groupId>org.springframework.boo…...

Golang 中的 time 包详解(一):time.Time

在日常开发过程中&#xff0c;会频繁遇到对时间进行操作的场景&#xff0c;使用 Golang 中的 time 包可以很方便地实现对时间的相关操作。接下来的几篇文章会详细讲解 time 包&#xff0c;本文先讲解一下 time 包中的结构体 time.Time。 time.Time time.Time 类型用来表示一个…...

CMU 15-445 -- Database Recovery - 18

CMU 15-445 -- Database Recovery - 18 引言ARIESLog Sequence NumbersNormal ExecutionTransaction CommitTransaction AbortCompensation Log Records Non-fuzzy & fuzzy CheckpointsSlightly Better CheckpointsFuzzy Checkpoints ARIES - Recovery PhasesAnalysis Phas…...

HTTP Header定制,客户端使用Request,服务器端使用Response

在服务器端通过request.getHeaders()是无效的&#xff0c;只能使用response.getHeaders()。 Overridepublic Object beforeBodyWrite(Object body, MethodParameter returnType, MediaType mediaType,Class selectedConverterType, ServerHttpRequest request, ServerHttpRespo…...

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

【git】把本地更改提交远程新分支feature_g

创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...

多种风格导航菜单 HTML 实现(附源码)

下面我将为您展示 6 种不同风格的导航菜单实现&#xff0c;每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...

如何在最短时间内提升打ctf(web)的水平?

刚刚刷完2遍 bugku 的 web 题&#xff0c;前来答题。 每个人对刷题理解是不同&#xff0c;有的人是看了writeup就等于刷了&#xff0c;有的人是收藏了writeup就等于刷了&#xff0c;有的人是跟着writeup做了一遍就等于刷了&#xff0c;还有的人是独立思考做了一遍就等于刷了。…...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

Selenium常用函数介绍

目录 一&#xff0c;元素定位 1.1 cssSeector 1.2 xpath 二&#xff0c;操作测试对象 三&#xff0c;窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四&#xff0c;弹窗 五&#xff0c;等待 六&#xff0c;导航 七&#xff0c;文件上传 …...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机&#xff0c;它可以执行Java字节码。Java虚拟机是Java平台的一部分&#xff0c;Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...