基于pytorch的深度学习基础3——模型创建与nn.Module
三 模型创建与nn.Module
3.1 nn.Module

模型构建两要素:
- 构建子模块——__init()__
- 拼接子模块——forward()

一个module可以有多个module;
一个module相当于一个运算,都必须实现forward函数;
每一个module有8个字典管理属性。
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
3.2 网络容器

nn.Sequential()
是nn.Module()的一个容器,用于按照顺序包装一组网络层;
顺序性:网络层之间严格按照顺序构建;
自带forward():
各网络层之间严格按顺序执行,常用于block构建
class LeNetSequential(nn.Module):
def __init__(self, classes):
super(LeNetSequential, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes),)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
nn.ModuleList()
是nn.Module的容器,用于包装网络层,以迭代方式调用网络层。
主要方法:
append():在ModuleList后面添加网络层;
extend():拼接两个ModuleList.
Insert():指定在ModuleList中插入网络层。
nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])
def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x
nn.ModuleDict()
以索引方式调用网络层
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除
n.ModuleDict:索引性,常用于可选择的网络层
class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
3.3卷积层
nn.ConV2d()
nn.Conv2d(in_channels, out_channels,kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
in_channels:输入通道数,比如RGB图像是3,而后续的网络层的输入通道数为前一卷积层的输出通道数;
out_channels:输出通道数,等价于卷积核个数
kernel_size:卷积核尺寸
stride:步
padding:填充个数
dilation:空洞卷积大小
groups:分组卷积设置
bias:偏置
conv_layer = nn.Conv2d(3, 1, 3) # input:(i, o, size) weights:(o, i , h, w)
nn.init.xavier_normal_(conv_layer.weight.data)
# calculation
img_conv = conv_layer(img_tensor)
这里使用 input*channel 为 3,output_channel 为 1 ,卷积核大小为 3×3 的卷积核nn.Conv2d(3, 1, 3),使用nn.init.xavier_normal*()方法初始化网络的权值。
我们通过`conv_layer.weight.shape`查看卷积核的 shape 是`(1, 3, 3, 3)`,对应是`(output_channel, input_channel, kernel_size, kernel_size)`。所以第一个维度对应的是卷积核的个数,每个卷积核都是`(3,3,3)`。虽然每个卷积核都是 3 维的,执行的却是 2 维卷积。
转置卷积nn.ConvTranspose2d
转置卷积又称为反卷积(Deconvolution)和部分跨越卷积(Fractionally-stridedConvolution) ,用于对图像进行上采样(UpSample)
为什么称为转置卷积?
假设图像尺寸为4*4,卷积核为3*3,padding=0,stride=1
正常卷积:
![]()
转置卷积:
假设图像尺寸为2*2,卷积核为3*3,padding=0,stride=1

nn.ConvTranspose2d(in_channels, out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1, padding_mode='zeros')
输出尺寸计算:
![]()
![]()
# flag = 1
flag = 0
if flag:
conv_layer = nn.ConvTranspose2d(3, 1, 3, stride=2) # input:(i, o, size)
nn.init.xavier_normal_(conv_layer.weight.data)
# calculation
img_conv = conv_layer(img_tensor)
print("卷积前尺寸:{}\n卷积后尺寸:{}".format(img_tensor.shape, img_conv.shape))
img_conv = transform_invert(img_conv[0, 0:1, ...], img_transform)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
plt.subplot(122).imshow(img_conv, cmap='gray')
plt.subplot(121).imshow(img_raw)
plt.show()
3.4池化层nn.MaxPool2d && nn.AvgPool2d
池化运算:对信号进行 “收集”并 “总结”,类似水池收集水资源,因而
得名池化层
“收集”:多变少
“总结”:最大值/平均值
nn.MaxPool2d
nn.MaxPool2d(kernel_size, stride=None,
padding=0, dilation=1,
return_indices=False,
ceil_mode=False)
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数
• dilation:池化核间隔大小
• ceil_mode:尺寸向上取整
• return_indices:记录池化像素索引
# flag = 1
flag = 0
if flag:
maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2)) # input:(i, o, size) weights:(o, i , h, w)
img_pool = maxpool_layer(img_tensor)
nn.AvgPool2d
nn.AvgPool2d(kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None)
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数
• ceil_mode:尺寸向上取整
• count_include_pad:填充值用于计算
• divisor_override :除法因子
avgpoollayer = nn.AvgPool2d((2, 2), stride=(2, 2)) # input:(i, o, size) weights:(o, i , h, w)
img_pool = avgpoollayer(img_tensor)
img_tensor = torch.ones((1, 1, 4, 4))
avgpool_layer = nn.AvgPool2d((2, 2), stride=(2, 2), divisor_override=3)
img_pool = avgpool_layer(img_tensor)
print("raw_img:\n{}\npooling_img:\n{}".format(img_tensor, img_pool))
nn.MaxUnpool2d
功能:对二维信号(图像)进行最大值池化
上采样
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数
# pooling
img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float)
maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
img_pool, indices = maxpool_layer(img_tensor)
# unpooling
img_reconstruct = torch.randn_like(img_pool, dtype=torch.float)
maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2))
img_unpool = maxunpool_layer(img_reconstruct, indices)
print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))
print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))
3.5线性层
nn.Linear(in_features, out_features, bias=True)
功能:对一维信号(向量)进行线性组合
主要参数:
• in_features:输入结点数
• out_features:输出结点数
• bias :是否需要偏置
计算公式:y = 𝒙𝑾𝑻 + 𝒃𝒊𝒂s
inputs = torch.tensor([[1., 2, 3]])
linear_layer = nn.Linear(3, 4)
linear_layer.weight.data = torch.tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]])
linear_layer.bias.data.fill_(0.5)
output = linear_layer(inputs)
print(inputs, inputs.shape)
print(linear_layer.weight.data, linear_layer.weight.data.shape)
print(output, output.shape)
3.6 激活函数层
nn.Sigmoid
nn.tanh:
nn.ReLU
nn.LeakyReLU
negative_slope: 负半轴斜率
nn.PReLU
init: 可学习斜率
nn.RReLU
lower: 均匀分布下限
upper:均匀分布上限

参考资料
深度之眼课程
相关文章:
基于pytorch的深度学习基础3——模型创建与nn.Module
三 模型创建与nn.Module 3.1 nn.Module 模型构建两要素: 构建子模块——__init()__拼接子模块——forward() 一个module可以有多个module; 一个module相当于一个运算,都必须实现forward函数; 每一个mod…...
Debian-linux运维-docker安装和配置
腾讯云搭建docker官方文档:https://cloud.tencent.com/document/product/213/46000 阿里云安装Docker官方文档:https://help.aliyun.com/zh/ecs/use-cases/install-and-use-docker-on-a-linux-ecs-instance 天翼云常见docker源配置指导:htt…...
Docker完整技术汇总
Docker 背景引入 在实际开发过程中有三个环境,分别是:开发环境、测试环境以及生产环境,假设开发环境中开发人员用的是jdk8,而在测试环境中测试人员用的时jdk7,这就导致程序员开发完系统后将其打成jar包发给测试人员后…...
在JavaScript文件中定义方法和数据(不是在对象里定以数据和方法,不要搞错了)
在对象里定以数据和方法看这一篇 对象字面量内定义属性和方法(什么使用const等关键字,什么时候用键值对)-CSDN博客https://blog.csdn.net/m0_62961212/article/details/144788665 下是在JavaScript文件中定义方法和数据的基本方式ÿ…...
python爬虫爬抖音小店商品数据+数据可视化
爬虫代码 爬虫代码是我调用的数据接口,可能会过一段时间用不了,欢迎大家留言评论,我会不定时更新 import requests import time cookies {token: 5549EB98B15E411DA0BD05935C0F225F,tfstk: g1vopsc0sQ5SwD8TyEWSTmONZ3cA2u6CReedJ9QEgZ7byz…...
关于 覆铜与导线之间间距较小需要增加间距 的解决方法
若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/144776995 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…...
uniapp中Nvue白屏问题 ReferenceError: require is not defined
uniapp控制台输出如下 exception function:createInstanceContext, exception:white screen cause create instanceContext failed,check js stack ->Uncaught ReferenceError: require is not defined 或者 exception function:createInstanceContext, exception:white s…...
在 Windows 上,如果忘记了 MySQL 密码 重置密码
在 Windows 上,如果忘记了 MySQL 密码,可以通过以下方法重置密码: 方法 1:以跳过权限验证模式启动 MySQL 并重置密码 停止 MySQL 服务: 打开 命令提示符 或 PowerShell,输入以下命令停止 MySQL 服务&#…...
《PyTorch:从基础概念到实战应用》
《PyTorch:从基础概念到实战应用》 一、PyTorch 初印象二、PyTorch 之历史溯源三、PyTorch 核心优势尽显(一)简洁高效,契合思维(二)易于上手,调试便捷(三)社区繁荣&#…...
前端:改变鼠标点击物体的颜色
需求: 需要改变图片中某一物体的颜色,该物体是纯色; 鼠标点击哪个物体,哪个物体的颜色变为指定的颜色,利用canvas实现。 演示案例 代码Demo <!DOCTYPE html> <html lang"en"><head>&l…...
Java-33 深入浅出 Spring - FactoryBean 和 BeanFactory BeanPostProcessor
点一下关注吧!!!非常感谢!!持续更新!!! 大数据篇正在更新!https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了: MyBatisÿ…...
HTML4笔记
尚硅谷 一、前序知识 1.认识两位先驱 2.计算机基础知识 3.C/S架构与B/S架构 4.浏览器相关知识 5.网页相关概念 二、HTML简介 1.什么是HTML? 2.相关国际组织(了解) 3.HTML发展历史(了解)** 三、准备工作 1.常用电脑设置 2.安装Chrome浏览器 四、HTML入门 1.HTML初体验 2.H…...
python报错ModuleNotFoundError: No module named ‘visdom‘
在用虚拟环境跑深度学习代码时,新建的环境一般会缺少一些库,而一般解决的方法就是直接conda install,但是我在conda install visdom之后,安装是没有任何报错的,conda list里面也有visdom的信息,但是再运行代…...
linux-21 目录管理(一)mkdir命令,创建空目录
对linux而言,对一个系统管理来讲,最关键的还是文件管理。那所以我们接下来就来看看如何实现文件管理。当然,在文件管理之前,我们说过,文件通常都放在目录下,对吧?所以先了解目录,可能…...
总结-常见缓存替换算法
缓存替换算法 1. 总结 1. 总结 常见的缓存替换算法除了FIFO、LRU和LFU还有下面几种: 算法优点缺点适用场景FIFO简单实现可能移除重要数据嵌入式系统,简单场景LRU局部性原理良好维护成本高,占用更多存储空间内存管理,浏览器缓存L…...
【Vue】如何在 Vue 3 中使用组合式 API 与 Vuex 进行状态管理的详细教程
如何在 Vue 3 中使用组合式 API 与 Vuex 进行状态管理的详细教程。 安装 Vuex 首先,在你的 Vue 3 项目中安装 Vuex。可以使用 npm 或 yarn: npm install vuexnext --save # or yarn add vuexnext创建 Store 在 Vue 3 中,你可以使用 creat…...
VSCode 插件开发实战(十五):如何支持多语言
前言 在软件开发中,多语言支持(i18n)是一个非常重要的功能。无论是桌面应用、移动应用,还是浏览器插件,都需要考虑如何支持不同国家和地区的用户,软件应用的多语言支持(i18n)已经成…...
面试241228
面试可参考 1、cas的概念 2、AQS的概念 3、redis的数据结构 使用场景 不熟 4、redis list 扩容流程 5、dubbo 怎么进行服务注册和调用,6、dubbo 预热 7如何解决cos上传的安全问题kafka的高并发高吞吐的原因ES倒排索引的原理 spring的 bean的 二级缓存和三级缓存 spr…...
Python数据序列化模块pickle使用
pickle 是 Python 的一个标准库模块,它实现了基本的对象序列化和反序列化。序列化是指将对象转换为字节流的过程,这样对象就可以被保存到文件中或通过网络传输。反序列化是指从字节流中恢复对象的过程。 以下是 pickle 模块的基本使用方法: …...
Spring Boot对访问密钥加解密——HMAC-SHA256
HMAC-SHA256 简介 HMAC-SHA256 是一种基于 哈希函数 的消息认证码(Message Authentication Code, MAC),它结合了哈希算法(如 SHA-256)和一个密钥,用于验证消息的完整性和真实性。 HMAC 是 “Hash-based M…...
wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...
UE5 学习系列(二)用户操作界面及介绍
这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...
CTF show Web 红包题第六弹
提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...
Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例
使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...
Debian系统简介
目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版ÿ…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
Python如何给视频添加音频和字幕
在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
