当前位置: 首页 > article >正文

pytorch小记(十三):pytorch中`nn.ModuleList` 详解

pytorch小记(十三):pytorch中`nn.ModuleList` 详解

  • PyTorch 中的 `nn.ModuleList` 详解
    • 1. 什么是 `nn.ModuleList`?
    • 2. 为什么不直接使用普通的 Python 列表?
    • 3. `nn.ModuleList` 的基本用法
      • 示例:构建一个包含两层全连接网络的模型
    • 4. 使用 `nn.ModuleList` 计算参数总数(与普通列表对比)
      • 示例代码
    • 5. `nn.ModuleList` 的其他应用
      • 示例:构建动态 MLP 模型
      • Transformers中的多头注意力机制
    • 6. 总结


PyTorch 中的 nn.ModuleList 详解

在构建深度学习模型时,经常需要管理多个网络层(例如多个 nn.Linearnn.Conv2d 等)。在 PyTorch 中,nn.ModuleList 是一个非常有用的容器,可以帮助我们存储多个子模块,并自动注册它们的参数。这对于确保所有参数能够参与训练非常重要。本文将详细介绍 nn.ModuleList 的作用、使用方法及与普通 Python 列表的区别,并给出清晰的代码示例。


1. 什么是 nn.ModuleList

nn.ModuleList 是一个类似于 Python 列表的容器,但专门用来存储 PyTorch 的子模块(也就是继承自 nn.Module 的对象)。其主要特点是:

  • 自动注册子模块:将 nn.Module 存储在 ModuleList 中后,这些模块的参数会自动被添加到父模块的参数列表中。这意味着当你调用 model.parameters() 时,这些子模块的参数也会被包含进去,从而参与梯度计算和优化。

  • 灵活管理:它可以像普通列表一样进行索引、迭代和切片操作,方便构建动态网络结构。

注意nn.ModuleList 不会像 nn.Sequential 那样自动定义前向传播(forward)流程。你需要在模型的 forward() 方法中手动遍历 ModuleList 并调用各个子模块。


2. 为什么不直接使用普通的 Python 列表?

虽然可以将 nn.Module 对象存储在普通列表中,但这样做有一个主要问题:
普通列表中的模块不会自动注册为父模块的子模块
这会导致:

  • 调用 model.parameters() 时无法获取这些模块的参数;
  • 优化器无法更新这些参数,从而影响模型训练。

而使用 nn.ModuleList 可以避免这个问题,因为它会自动将内部所有的模块注册到父模块中。


3. nn.ModuleList 的基本用法

下面通过一个简单的示例来说明如何使用 nn.ModuleList 构建一个简单的神经网络模型。

示例:构建一个包含两层全连接网络的模型

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 创建一个 ModuleList 来存储各层self.layers = nn.ModuleList([nn.Linear(10, 20),  # 第 1 层:输入 10 个特征,输出 20 个特征nn.ReLU(),          # 激活层nn.Linear(20, 5)    # 第 2 层:输入 20 个特征,输出 5 个特征])def forward(self, x):# 手动遍历 ModuleList 中的每个模块,并依次调用 forwardfor layer in self.layers:x = layer(x)return x# 创建模型实例
model = MyModel()# 打印模型结构
print("模型结构:")
print(model)# 生成一组示例输入
input_tensor = torch.randn(3, 10)  # 3 个样本,每个样本 10 个特征# 得到模型输出
output = model(input_tensor)
print("\n模型输出:")
print(output)
模型结构:
MyModel((layers): ModuleList((0): Linear(in_features=10, out_features=20, bias=True)(1): ReLU()(2): Linear(in_features=20, out_features=5, bias=True))
)模型输出:
tensor([[ 0.3741,  0.0883,  0.3550, -0.3930,  0.5173],[ 0.2171, -0.0978, -0.0585, -0.4568,  0.3331],[ 0.1232, -0.1491,  0.2026, -0.0978,  0.5478]],grad_fn=<AddmmBackward0>)

说明

  • __init__() 方法中,我们将各个层放在了 nn.ModuleList 中。
  • forward() 方法中,我们使用了一个简单的 for 循环,依次调用 self.layers 中的每个子模块。

4. 使用 nn.ModuleList 计算参数总数(与普通列表对比)

为了进一步说明 nn.ModuleList 与普通列表的区别,我们分别计算一下两种方式下模型的参数总数。

示例代码

import torch.nn as nn# 使用 ModuleList 存储模型层
layers_ml = nn.ModuleList([nn.Linear(10, 20),nn.Linear(20, 5)
])# 计算 ModuleList 中的参数总数
ml_params = 0
for p in layers_ml.parameters():ml_params += p.numel()# 使用普通 Python 列表存储模型层
layers_list = [nn.Linear(10, 20),nn.Linear(20, 5)
]# 计算普通列表中的参数总数
list_params = 0
# 先遍历列表中的每个层
for layer in layers_list:# 再遍历每个层的参数for p in layer.parameters():list_params += p.numel()print("ModuleList 参数总数:", ml_params)
print("普通列表参数总数:", list_params)
ModuleList 参数总数: 325
普通列表参数总数: 325

说明

  • 第一个 for 循环遍历 layers_ml.parameters(),直接累加所有参数的元素数。
  • 第二部分中,我们先遍历普通列表中的每个 layer,再单独遍历每个层的参数。这样做使每一步都清晰易懂。

5. nn.ModuleList 的其他应用

示例:构建动态 MLP 模型

当网络结构比较复杂或层数不固定时,可以利用列表生成器动态构建 ModuleList

class DynamicMLP(nn.Module):def __init__(self, layer_sizes):super(DynamicMLP, self).__init__()# 使用 for 循环构造每一层,存储在 ModuleList 中layers = []  # 先用普通列表保存层for i in range(len(layer_sizes) - 1):linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])layers.append(linear_layer)# 将普通列表转换为 ModuleListself.layers = nn.ModuleList(layers)def forward(self, x):# 遍历每一层(没有嵌套循环,逐个执行)for layer in self.layers:x = torch.relu(layer(x))return x# 创建一个动态 MLP:输入 10,隐藏层 20, 30,输出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("动态 MLP 模型:")
print(dynamic_model)# 测试模型
input_tensor = torch.randn(4, 10)  # 4 个样本,每个样本 10 个特征
output = dynamic_model(input_tensor)
print("\n动态 MLP 模型输出:")
print(output)

说明

  • __init__() 方法中,我们使用一个普通列表 layers 存储每个 nn.Linear 层,然后再将它转换为 nn.ModuleList
  • forward() 方法中,使用单独的 for 循环逐个调用每一层,并对输出应用 ReLU 激活函数。
  • 这种写法适用于层数动态变化的网络(例如 MLP、RNN、Transformer 中部分模块)。

Transformers中的多头注意力机制

class SingleHeadAttention(nn.Module):def __init__(self, embed_dim, head_dim):super().__init__()self.query = nn.Linear(embed_dim, head_dim)self.key = nn.Linear(embed_dim, head_dim)self.value = nn.Linear(embed_dim, head_dim)def forward(self, x):# 实现注意力计算逻辑...return attended_valuesclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_heads# 显式创建每个注意力头self.head1 = SingleHeadAttention(embed_dim, self.head_dim)self.head2 = SingleHeadAttention(embed_dim, self.head_dim)self.head3 = SingleHeadAttention(embed_dim, self.head_dim)# 使用ModuleList管理多个头self.heads = nn.ModuleList([self.head1,self.head2,self.head3])self.output_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 分别处理每个头head1_out = self.head1(x)head2_out = self.head2(x) head3_out = self.head3(x)# 拼接结果combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)return self.output_proj(combined)

关键点解析:

  • 显式声明每个注意力头(避免循环)

  • 使用ModuleList统一管理注意力头

  • 在forward中分别调用每个头

  • 保持各头独立性,便于后续调试


6. 总结

  • nn.ModuleList 是专门用于存储多个子模块的容器,它会自动注册子模块,确保所有参数能参与训练。
  • 与普通 Python 列表相比,ModuleList 可以直接通过 model.parameters() 获取其中所有参数,从而方便地进行优化。
  • 使用 ModuleList 时,前向传播需要手动遍历其中的模块,这提供了更大的灵活性,但也要求开发者理解循环过程。

相关文章:

pytorch小记(十三):pytorch中`nn.ModuleList` 详解

pytorch小记&#xff08;十三&#xff09;&#xff1a;pytorch中nn.ModuleList 详解 PyTorch 中的 nn.ModuleList 详解1. 什么是 nn.ModuleList&#xff1f;2. 为什么不直接使用普通的 Python 列表&#xff1f;3. nn.ModuleList 的基本用法示例&#xff1a;构建一个包含两层全连…...

SpringData Redis:RedisTemplate配置与数据操作

文章目录 引言一、Redis概述与环境准备二、RedisTemplate基础配置三、连接属性配置四、操作String类型数据五、操作Hash类型数据六、操作List类型数据七、操作Set类型数据八、操作ZSet类型数据九、事务与管道操作总结 引言 Redis作为高性能的NoSQL数据库&#xff0c;在分布式系…...

Qt按钮控件常用的API

1.创建按钮 QPushButton *btnnew QPushButton; 以顶层方式弹出窗口控件 代码&#xff1a; #include "widget.h" #include "ui_widget.h" #include"QPushButton"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui-&…...

如何检查CMS建站系统的插件是否安全?

检查好CMS建站系统的插件安全是确保网站安全的重要环节&#xff0c;对于常见的安全检查&#xff0c;大家可以利用以下几种有效的方法和工具&#xff0c;来帮你评估插件的安全性。 1. 检查插件来源和开发者信誉 选择可信来源&#xff1a;仅从官方插件库或可信的第三方开发者处…...

修改HuggingFace模型默认缓存路径

huggingface模型的默认缓存路径是~/.cache/huggingface/hub/ 通常修改为自己的路径会更为方便。 方式一&#xff1a;cache_dir 参数 可以通过from_pretrained函数中的 cache_dir 参数来指定&#xff0c;缺点&#xff0c;每次都需要手动指定&#xff0c;比较麻烦。 如&#x…...

ORA-12541: TNS:no listener

问题描述&#xff1a;使用 PL/SQL Developer 连接数据库时出现ORA-12541错误。 1.查询监听状态 [oraclelocalhost admin]$ lsnrctl statusLSNRCTL for Linux: Version 19.0.0.0.0 - Production on 15-MAR-2025 10:47:57Copyright (c) 1991, 2021, Oracle. All rights reserv…...

【Matlab GUI】封装matlab GUI为exe文件

注&#xff1a;封装后的exe还是需要有matlab环境才能运行 &#xff08;1&#xff09;安装MCRinstaller.exe文件&#xff0c;在matlab安装目录下的toolbox/compiler/deploy/win64文件夹里 &#xff08;2&#xff09;安装完MCRinstaller.exe&#xff0c;字命令窗口输入&#x…...

【eNSP实战】(续)一个AC多个VAP的实现—将隧道转发改成直接转发

在 一个AC多个VAP的实现—CAPWAP隧道转发 此篇文章配置的基础上&#xff0c;将隧道转发改成直接转发 一、改成直接转发需要改动的配置 &#xff08;一&#xff09;将连接AP的接口改成trunk口&#xff0c;并允许vlan100、101、102通过 [AC1]interface GigabitEthernet 0/0/8 …...

Redis能否替代MySQL作为主数据库?深入解析两者的持久化差异与适用边界——基于AOF持久化与关系型数据库的对比

一、Redis的持久化机制与可靠性分析 ​AOF持久化原理与策略 Redis的AOF&#xff08;Append Only File&#xff09;通过记录所有写操作命令实现持久化&#xff0c;支持三种策略&#xff1a; ​**always模式**&#xff1a;每条命令执行后立即同步到磁盘&#xff0c;理论上数据丢失…...

机器人ROS学习:Ubuntu22.04安装ROS2和Moveit2实现运动规划

通过本篇文章学习&#xff0c;你可以收获以下内容&#xff1a; 学会在 Ubuntu22.04 上安装 Moveit2学会下载编译运行 Moveit2 样例程序学会使用样例程序进行运动规划等 版本平台 系统版本&#xff1a;ubuntu22.04ROS2 版本&#xff1a;humbleMoveit 版本&#xff1a;moveit2…...

GIT标签(Tag)操作

在Git中&#xff0c;标签&#xff08;Tag&#xff09;用于标记特定的提交点&#xff0c;通常用于发布版本。 切换到需要打标签的分支&#xff1a; git checkout <branch-name>创建标签 git tag v1.0.0 git tag -a v1.0.0 -m "Release version 1.0.0"查看所…...

生成式AI红队测试:如何有效评估大语言模型

OWASP最新指南为组建生成式AI红队或调整现有红队以适应新技术提供了详细的指导。 红队测试是一种经过时间检验的网络安全系统测试和加固方法&#xff0c;但它需要不断适应技术的演变。近年来&#xff0c;生成式AI和大语言模型&#xff08;LLM&#xff09;的爆发&#xff0c;是…...

技术路线图ppt模板_流程图ppt图表_PPT架构图

技术路线图ppt模板 / 学术ppt模板 - 院士增选、国家科技奖、杰青、长江学者特聘教授、校企联聘教授、重点研发、优青、青长、青拔.. / 学术ppt案例 WordinPPT / 持续为双一流高校、科研院所、企业等提供PPT制作系统服务。 - 科学技术奖ppt&#xff1a;自然科学奖 | 技术…...

Leetcode-131.Palindrome Partitioning [C++][Java]

目录 一、题目描述 二、解题思路 【C】 【Java】 Leetcode-131.Palindrome Partitioninghttps://leetcode.com/problems/palindrome-partitioning/description/131. 分割回文串 - 力扣&#xff08;LeetCode&#xff09;131. 分割回文串 - 给你一个字符串 s&#xff0c;请你…...

LeetCode 解题思路 20(Hot 100)

解题思路&#xff1a; 递归定义对称性&#xff1a; 若两棵子树镜像对称&#xff0c;需满足&#xff1a; 当前节点值相等&#xff1b;左子树的左节点与右子树的右节点对称&#xff1b;左子树的右节点与右子树的左节点对称。 终止条件&#xff1a; 两个节点均为空 → 对称&am…...

挖矿------获取以太坊测试币

文章目录 挖矿------获取以太坊测试币通过水龙头获取以太坊测试币了解Sepolia是什么&#xff1f;水龙头&#xff08;Faucet&#xff09;是什么&#xff1f;Gitcoin Passport是什么&#xff1f; 操作1.MetaMask钱包2.将MetaMask切换到Sepolia测试网络3.用MetaMask连接Gitcoin Pa…...

每天五分钟深度学习框架pytorch:基于pytorch搭建循环神经网络RNN

本文重点 我们前面介绍了循环神经网络RNN,主要分析了它的维度信息,其实它的维度信息是最重要的,一旦我们把维度弄清楚了,一起就很简单了,本文我们正式的来学习一下,如何使用pytorch搭建循环神经网络RNN。 RNN的搭建 在pytorch中我们使用nn.RNN()就可以创建出RNN神经网络…...

XEasyWork:面向AI应用的可视化工作流开发平台

文章目录 前言 一、平台核心价值 1.1产品定位 1.2 技术优势 二、技术架构解析 2.1战略级整合 自主开发模块 2.2集成开源项目 三、体验地址 三、未来规划 总结 前言 在人工智能技术快速落地的今天&#xff0c;开发者在构建AI应用时仍面临两大挑战&#xff1a;技术栈复杂带来的高…...

C#进阶(多线程相关)

1。进程&#xff1f; 进程&#xff08;Process&#xff09;是计算机中的程序关于某数据集合上的一次运行活动&#xff0c;【是系统进行资源分配的基本单位】&#xff0c;是操作系统结构的基础。在早期面向进程设计的计算机结构中&#xff0c;进程是程序的基本执行实体&#xf…...

【C++】:C++11详解 —— 右值引用

目录 左值和右值 左值的概念 右值的概念 左值 vs 右值 左值引用 和 右值引用 左值引用 右值引用 左值引用 vs 右值引用 使用场景 左值引用的使用场景 左值引用的短板 右值引用的使用场景 1. 实现移动语义&#xff08;资源高效转移&#xff09; 2. 优化容器操作&a…...

【css酷炫效果】纯CSS实现虫洞穿越效果

【css酷炫效果】纯CSS实现穿越效果 缘创作背景html结构css样式完整代码基础版进阶版&#xff08;虫洞穿越&#xff09; 效果图 想直接拿走的老板&#xff0c;链接放在这里&#xff1a;https://download.csdn.net/download/u011561335/90491973 缘 创作随缘&#xff0c;不定时…...

Linux IP 配置

Linux IP 配置 1 环境介绍2 网卡信息配置3 使用nmtui工具配置4 更多Linux命令学习使用列表 1 环境介绍 虚拟机&#xff0c;服务器安装系统完成后&#xff0c;先要配置ip 地址&#xff0c;这样可以方便远程若是物理服务器一般会有4个网卡信息麒麟v10&#xff0c;CentOS7&#x…...

设计模式之装饰器模式:原理、实现与应用

引言 装饰器模式&#xff08;Decorator Pattern&#xff09;是一种结构型设计模式&#xff0c;它允许你通过将对象放入包含行为的特殊封装对象中来为原对象动态添加新的行为。装饰器模式提供了一种灵活的替代方案&#xff0c;避免了通过继承扩展功能的局限性。本文将深入探讨装…...

c语言笔记 结构体基础

目录 基础知识 结构体定义 基础知识 在c语言中变量是有类型的&#xff0c;比如整型&#xff0c;char型&#xff0c;浮点型等&#xff0c;这些都是单一的类型&#xff0c;那么如果说我要定义一个学生的信息&#xff0c;那么这些单一的类型是不足以表达一个学生的全部信息&#…...

Vue 登录 记住密码,设置存储时间

Vue 登录 记住密码&#xff0c;设置存储时间 一、手动存储login.vue 二、使用vue-cookies插件main.jslogin.vue 一、手动存储 login.vue 提示&#xff1a; // 设置cookie方法 setCookie(loginName, password, days) {let text encryptDes(password, des123)//使用des方法加…...

C++ 学习笔记(三)—— 入门+类和对象

1、内联函数&#xff08;inline&#xff09; 内联函数主要是解决C语言的宏的缺陷提出来的&#xff1b; 宏的缺陷&#xff1a; 1&#xff09;容易出错&#xff0c;语法坑很多&#xff1b; 2&#xff09;不能调试&#xff1b; 3&#xff09;没有类型安全的检查&#xff1b; 宏的…...

KVM安全模块生产环境配置与优化指南

KVM安全模块生产环境配置与优化指南 一、引言 在当今复杂多变的网络安全环境下&#xff0c;生产环境中KVM&#xff08;Kernel-based Virtual Machine&#xff09;的安全配置显得尤为重要。本指南旨在详细阐述KVM安全模块的配置方法&#xff0c;结合强制访问控制&#xff08;M…...

基于 SSE 和 WebSocket 的在线文本实时传输工具

简介 在线文本实时传输工具支持 SSE&#xff08;Server-Sent Events&#xff09; 和 WebSocket&#xff0c;可在不同设备间快速共享和同步文本&#xff0c;适用于跨设备协作、远程办公和即时通讯。 核心功能 实时同步&#xff1a;文本输入后&#xff0c;另一端用户可立即看到…...

数图亮相第三届全国生鲜创新峰会,赋能生鲜零售数字化转型

2025年3月15-18日&#xff0c;第三届全国生鲜创新峰会在湖北宜昌召开&#xff0c;主题为“生鲜破局&#xff0c;重塑价值”。峰会汇聚行业专家、企业领袖及精英&#xff0c;探讨生鲜零售新机遇与挑战。作为领先的“智慧零售”服务商&#xff0c;数图信息科技受邀出席&#xff0…...

go 安装swagger

1、依赖安装&#xff1a; # 安装 swag 命令行工具 go install github.com/swaggo/swag/cmd/swaglatest# 安装 gin-swagger 和 swagger 文件的依赖 go get -u github.com/swaggo/gin-swagger go get -u github.com/swaggo/files 2、测试 cmd中输入&#xff1a; swag -v 3、…...