当前位置: 首页 > news >正文

基于pytorch的深度学习基础3——模型创建与nn.Module

三 模型创建与nn.Module

3.1 nn.Module

模型构建两要素:

  1. 构建子模块——__init()__
  2. 拼接子模块——forward()

一个module可以有多个module;

一个module相当于一个运算,都必须实现forward函数;

每一个module有8个字典管理属性。

self._parameters = OrderedDict()

self._buffers = OrderedDict()

self._backward_hooks = OrderedDict()

self._forward_hooks = OrderedDict()

self._forward_pre_hooks = OrderedDict()

self._state_dict_hooks = OrderedDict()

self._load_state_dict_pre_hooks = OrderedDict()

self._modules = OrderedDict()

3.2 网络容器

nn.Sequential()

是nn.Module()的一个容器,用于按照顺序包装一组网络层;

顺序性:网络层之间严格按照顺序构建;

自带forward():

各网络层之间严格按顺序执行,常用于block构建

class LeNetSequential(nn.Module):

    def __init__(self, classes):

        super(LeNetSequential, self).__init__()

        self.features = nn.Sequential(

            nn.Conv2d(3, 6, 5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(6, 16, 5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(

            nn.Linear(16*5*5, 120),

            nn.ReLU(),

            nn.Linear(120, 84),

            nn.ReLU(),

            nn.Linear(84, classes),)

    def forward(self, x):

        x = self.features(x)

        x = x.view(x.size()[0], -1)

        x = self.classifier(x)

        return x

nn.ModuleList()

是nn.Module的容器,用于包装网络层,以迭代方式调用网络层。

主要方法:

append():在ModuleList后面添加网络层;

extend():拼接两个ModuleList.

Insert():指定在ModuleList中插入网络层。

nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建

class ModuleList(nn.Module):

    def __init__(self):

        super(ModuleList, self).__init__()

        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):

        for i, linear in enumerate(self.linears):

            x = linear(x)

        return x

nn.ModuleDict()

以索引方式调用网络层

主要方法:

• clear():清空ModuleDict

• items():返回可迭代的键值对(key-value pairs)

• keys():返回字典的键(key)

• values():返回字典的值(value)

• pop():返回一对键值,并从字典中删除

n.ModuleDict:索引性,常用于可选择的网络层

class ModuleDict(nn.Module):

    def __init__(self):

        super(ModuleDict, self).__init__()

        self.choices = nn.ModuleDict({

            'conv': nn.Conv2d(10, 10, 3),

            'pool': nn.MaxPool2d(3)

        })

        self.activations = nn.ModuleDict({

            'relu': nn.ReLU(),

            'prelu': nn.PReLU()

        })

    def forward(self, x, choice, act):

        x = self.choices[choice](x)

        x = self.activations[act](x)

        return x

3.3卷积层

nn.ConV2d()

nn.Conv2d(in_channels, out_channels,kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

in_channels:输入通道数,比如RGB图像是3,而后续的网络层的输入通道数为前一卷积层的输出通道数;

out_channels:输出通道数,等价于卷积核个数

kernel_size:卷积核尺寸

stride:步

padding:填充个数

dilation:空洞卷积大小

groups:分组卷积设置

bias:偏置

    conv_layer = nn.Conv2d(3, 1, 3)   # input:(i, o, size) weights:(o, i , h, w)

    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation

    img_conv = conv_layer(img_tensor)

这里使用 input*channel 为 3,output_channel 为 1 ,卷积核大小为 3×3 的卷积核nn.Conv2d(3, 1, 3),使用nn.init.xavier_normal*()方法初始化网络的权值。

我们通过`conv_layer.weight.shape`查看卷积核的 shape 是`(1, 3, 3, 3)`,对应是`(output_channel, input_channel, kernel_size, kernel_size)`。所以第一个维度对应的是卷积核的个数,每个卷积核都是`(3,3,3)`。虽然每个卷积核都是 3 维的,执行的却是 2 维卷积。

转置卷积nn.ConvTranspose2d

转置卷积又称为反卷积(Deconvolution)和部分跨越卷积(Fractionally-stridedConvolution) ,用于对图像进行上采样(UpSample)

为什么称为转置卷积?

假设图像尺寸为4*4,卷积核为3*3,padding=0,stride=1

正常卷积:

转置卷积:

假设图像尺寸为2*2,卷积核为3*3,padding=0,stride=1

nn.ConvTranspose2d(in_channels, out_channels,

kernel_size,

stride=1,

padding=0,

output_padding=0,

groups=1,

bias=True,

dilation=1, padding_mode='zeros')

输出尺寸计算:

# flag = 1

flag = 0

if flag:

    conv_layer = nn.ConvTranspose2d(3, 1, 3, stride=2)   # input:(i, o, size)

    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation

    img_conv = conv_layer(img_tensor)

print("卷积前尺寸:{}\n卷积后尺寸:{}".format(img_tensor.shape, img_conv.shape))

img_conv = transform_invert(img_conv[0, 0:1, ...], img_transform)

img_raw = transform_invert(img_tensor.squeeze(), img_transform)

plt.subplot(122).imshow(img_conv, cmap='gray')

plt.subplot(121).imshow(img_raw)

plt.show()

3.4池化层nn.MaxPool2d && nn.AvgPool2d

池化运算:对信号进行 “收集”并 “总结”,类似水池收集水资源,因而

得名池化层

“收集”:多变少

“总结”:最大值/平均值

nn.MaxPool2d

nn.MaxPool2d(kernel_size, stride=None,

padding=0, dilation=1,

return_indices=False,

ceil_mode=False)

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

• dilation:池化核间隔大小

• ceil_mode:尺寸向上取整

• return_indices:记录池化像素索引

# flag = 1

flag = 0

if flag:

    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2))   # input:(i, o, size) weights:(o, i , h, w)

    img_pool = maxpool_layer(img_tensor)

nn.AvgPool2d

nn.AvgPool2d(kernel_size,

stride=None,

padding=0,

ceil_mode=False,

count_include_pad=True,

divisor_override=None)

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

• ceil_mode:尺寸向上取整

• count_include_pad:填充值用于计算

• divisor_override :除法因子

    avgpoollayer = nn.AvgPool2d((2, 2), stride=(2, 2))   # input:(i, o, size) weights:(o, i , h, w)

    img_pool = avgpoollayer(img_tensor)

    img_tensor = torch.ones((1, 1, 4, 4))

    avgpool_layer = nn.AvgPool2d((2, 2), stride=(2, 2), divisor_override=3)

    img_pool = avgpool_layer(img_tensor)

    print("raw_img:\n{}\npooling_img:\n{}".format(img_tensor, img_pool))

nn.MaxUnpool2d

功能:对二维信号(图像)进行最大值池化

上采样

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

    # pooling

    img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float)

    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)

    img_pool, indices = maxpool_layer(img_tensor)

    # unpooling

    img_reconstruct = torch.randn_like(img_pool, dtype=torch.float)

    maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2))

    img_unpool = maxunpool_layer(img_reconstruct, indices)

    print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))

    print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))

3.5线性层

nn.Linear(in_features, out_features, bias=True)

功能:对一维信号(向量)进行线性组合

主要参数:

• in_features:输入结点数

• out_features:输出结点数

• bias :是否需要偏置

计算公式:y = 𝒙𝑾𝑻 + 𝒃𝒊𝒂s

    inputs = torch.tensor([[1., 2, 3]])

    linear_layer = nn.Linear(3, 4)

    linear_layer.weight.data = torch.tensor([[1., 1., 1.],

                                             [2., 2., 2.],

                                             [3., 3., 3.],

                                             [4., 4., 4.]])

    linear_layer.bias.data.fill_(0.5)

    output = linear_layer(inputs)

    print(inputs, inputs.shape)

    print(linear_layer.weight.data, linear_layer.weight.data.shape)

    print(output, output.shape)

3.6 激活函数层

nn.Sigmoid

nn.tanh:

nn.ReLU

nn.LeakyReLU

negative_slope: 负半轴斜率

nn.PReLU

init: 可学习斜率

nn.RReLU

lower: 均匀分布下限

upper:均匀分布上限

参考资料

深度之眼课程

相关文章:

基于pytorch的深度学习基础3——模型创建与nn.Module

三 模型创建与nn.Module 3.1 nn.Module 模型构建两要素: 构建子模块——__init()__拼接子模块——forward() 一个module可以有多个module; 一个module相当于一个运算,都必须实现forward函数; 每一个mod…...

Debian-linux运维-docker安装和配置

腾讯云搭建docker官方文档:https://cloud.tencent.com/document/product/213/46000 阿里云安装Docker官方文档:https://help.aliyun.com/zh/ecs/use-cases/install-and-use-docker-on-a-linux-ecs-instance 天翼云常见docker源配置指导:htt…...

Docker完整技术汇总

Docker 背景引入 在实际开发过程中有三个环境,分别是:开发环境、测试环境以及生产环境,假设开发环境中开发人员用的是jdk8,而在测试环境中测试人员用的时jdk7,这就导致程序员开发完系统后将其打成jar包发给测试人员后…...

在JavaScript文件中定义方法和数据(不是在对象里定以数据和方法,不要搞错了)

在对象里定以数据和方法看这一篇 对象字面量内定义属性和方法(什么使用const等关键字,什么时候用键值对)-CSDN博客https://blog.csdn.net/m0_62961212/article/details/144788665 下是在JavaScript文件中定义方法和数据的基本方式&#xff…...

python爬虫爬抖音小店商品数据+数据可视化

爬虫代码 爬虫代码是我调用的数据接口,可能会过一段时间用不了,欢迎大家留言评论,我会不定时更新 import requests import time cookies {token: 5549EB98B15E411DA0BD05935C0F225F,tfstk: g1vopsc0sQ5SwD8TyEWSTmONZ3cA2u6CReedJ9QEgZ7byz…...

关于 覆铜与导线之间间距较小需要增加间距 的解决方法

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/144776995 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…...

uniapp中Nvue白屏问题 ReferenceError: require is not defined

uniapp控制台输出如下 exception function:createInstanceContext, exception:white screen cause create instanceContext failed,check js stack ->Uncaught ReferenceError: require is not defined 或者 exception function:createInstanceContext, exception:white s…...

在 Windows 上,如果忘记了 MySQL 密码 重置密码

在 Windows 上,如果忘记了 MySQL 密码,可以通过以下方法重置密码: 方法 1:以跳过权限验证模式启动 MySQL 并重置密码 停止 MySQL 服务: 打开 命令提示符 或 PowerShell,输入以下命令停止 MySQL 服务&#…...

《PyTorch:从基础概念到实战应用》

《PyTorch:从基础概念到实战应用》 一、PyTorch 初印象二、PyTorch 之历史溯源三、PyTorch 核心优势尽显(一)简洁高效,契合思维(二)易于上手,调试便捷(三)社区繁荣&#…...

前端:改变鼠标点击物体的颜色

需求&#xff1a; 需要改变图片中某一物体的颜色&#xff0c;该物体是纯色&#xff1b; 鼠标点击哪个物体&#xff0c;哪个物体的颜色变为指定的颜色&#xff0c;利用canvas实现。 演示案例 代码Demo <!DOCTYPE html> <html lang"en"><head>&l…...

Java-33 深入浅出 Spring - FactoryBean 和 BeanFactory BeanPostProcessor

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 大数据篇正在更新&#xff01;https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了&#xff1a; MyBatis&#xff…...

HTML4笔记

尚硅谷 一、前序知识 1.认识两位先驱 2.计算机基础知识 3.C/S架构与B/S架构 4.浏览器相关知识 5.网页相关概念 二、HTML简介 1.什么是HTML? 2.相关国际组织(了解) 3.HTML发展历史(了解)** 三、准备工作 1.常用电脑设置 2.安装Chrome浏览器 四、HTML入门 1.HTML初体验 2.H…...

python报错ModuleNotFoundError: No module named ‘visdom‘

在用虚拟环境跑深度学习代码时&#xff0c;新建的环境一般会缺少一些库&#xff0c;而一般解决的方法就是直接conda install&#xff0c;但是我在conda install visdom之后&#xff0c;安装是没有任何报错的&#xff0c;conda list里面也有visdom的信息&#xff0c;但是再运行代…...

linux-21 目录管理(一)mkdir命令,创建空目录

对linux而言&#xff0c;对一个系统管理来讲&#xff0c;最关键的还是文件管理。那所以我们接下来就来看看如何实现文件管理。当然&#xff0c;在文件管理之前&#xff0c;我们说过&#xff0c;文件通常都放在目录下&#xff0c;对吧&#xff1f;所以先了解目录&#xff0c;可能…...

总结-常见缓存替换算法

缓存替换算法 1. 总结 1. 总结 常见的缓存替换算法除了FIFO、LRU和LFU还有下面几种&#xff1a; 算法优点缺点适用场景FIFO简单实现可能移除重要数据嵌入式系统&#xff0c;简单场景LRU局部性原理良好维护成本高&#xff0c;占用更多存储空间内存管理&#xff0c;浏览器缓存L…...

【Vue】如何在 Vue 3 中使用组合式 API 与 Vuex 进行状态管理的详细教程

如何在 Vue 3 中使用组合式 API 与 Vuex 进行状态管理的详细教程。 安装 Vuex 首先&#xff0c;在你的 Vue 3 项目中安装 Vuex。可以使用 npm 或 yarn&#xff1a; npm install vuexnext --save # or yarn add vuexnext创建 Store 在 Vue 3 中&#xff0c;你可以使用 creat…...

VSCode 插件开发实战(十五):如何支持多语言

前言 在软件开发中&#xff0c;多语言支持&#xff08;i18n&#xff09;是一个非常重要的功能。无论是桌面应用、移动应用&#xff0c;还是浏览器插件&#xff0c;都需要考虑如何支持不同国家和地区的用户&#xff0c;软件应用的多语言支持&#xff08;i18n&#xff09;已经成…...

面试241228

面试可参考 1、cas的概念 2、AQS的概念 3、redis的数据结构 使用场景 不熟 4、redis list 扩容流程 5、dubbo 怎么进行服务注册和调用&#xff0c;6、dubbo 预热 7如何解决cos上传的安全问题kafka的高并发高吞吐的原因ES倒排索引的原理 spring的 bean的 二级缓存和三级缓存 spr…...

​Python数据序列化模块pickle使用

pickle 是 Python 的一个标准库模块&#xff0c;它实现了基本的对象序列化和反序列化。序列化是指将对象转换为字节流的过程&#xff0c;这样对象就可以被保存到文件中或通过网络传输。反序列化是指从字节流中恢复对象的过程。 以下是 pickle 模块的基本使用方法&#xff1a; …...

Spring Boot对访问密钥加解密——HMAC-SHA256

HMAC-SHA256 简介 HMAC-SHA256 是一种基于 哈希函数 的消息认证码&#xff08;Message Authentication Code, MAC&#xff09;&#xff0c;它结合了哈希算法&#xff08;如 SHA-256&#xff09;和一个密钥&#xff0c;用于验证消息的完整性和真实性。 HMAC 是 “Hash-based M…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖

在前面的练习中&#xff0c;每个页面需要使用ref&#xff0c;onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入&#xff0c;需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

视频字幕质量评估的大规模细粒度基准

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 摘要 视频字幕在文本到视频生成任务中起着至关重要的作用&#xff0c;因为它们的质量直接影响所生成视频的语义连贯性和视觉保真度。尽管大型视觉-语言模型&#xff08;VLMs&#xff09;在字幕生成方面…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...