当前位置: 首页 > news >正文

深度学习(4):torch.nn.Module

文章目录

  • 一、是什么
  • 二、`nn.Module` 的核心功能
  • 三、`nn.Module` 的基本用法
    • 1. 定义自定义模型
    • 2. 初始化模型
      • 3. 模型的使用
  • 四、`nn.Module` 的关键特性
    • 1. 自动注册子模块和参数
    • 2. `forward` 方法
    • 3. 不需要定义反向传播
  • 五、常用的内置模块
  • 六、示例:创建一个简单的神经网络
    • 1. 问题描述
    • 2. 模型定义
    • 3. 训练过程
  • 七、深入理解 `nn.Module` 的一些重要概念
    • 1. 参数访问
    • 2. 模块访问
    • 3. 保存和加载模型
    • 4. 自定义层和模块
  • 八、`nn.Module` 的实践技巧
    • 1. 使用 `Sequential` 快速构建模型
    • 2. 模型的嵌套
  • 九、总结
    • 十、参考示例:完整的训练脚本

一、是什么

torch.nn.Module 是 PyTorch 中所有神经网络模块的基类,是构建神经网络模型的核心组件。

二、nn.Module 的核心功能

  1. 参数管理:自动管理模型的可训练参数(parameters),方便参数的访问和更新。

  2. 子模块管理:支持将模型分解为多个子模块,便于组织复杂的网络结构。

  3. 前向计算(forward):定义模型的前向传播逻辑。


三、nn.Module 的基本用法

1. 定义自定义模型

要创建自定义的神经网络模型,需要继承 nn.Module,并实现以下内容:

  • 构造函数 __init__:在这里定义网络的层和子模块。
  • 前向方法 forward:定义数据如何经过网络进行前向传播。
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义网络层self.layer1 = nn.Linear(10, 20)self.relu = nn.ReLU()self.layer2 = nn.Linear(20, 1)def forward(self, x):# 定义前向传播过程out = self.layer1(x)out = self.relu(out)out = self.layer2(out)return out

2. 初始化模型

model = MyModel()

3. 模型的使用

  • 前向传播

    output = model(input_data)
    
  • 获取模型参数

    for name, param in model.named_parameters():print(name, param.size())
    

四、nn.Module 的关键特性

1. 自动注册子模块和参数

__init__ 方法中,当你将 nn.Module 的实例(如 nn.Linearnn.Conv2d 等)赋值给模型的属性时,nn.Module 会自动将这些子模块注册到模型中。这意味着:

  • 参数管理:模型的所有参数都会被自动收集,存储在 model.parameters() 中。
  • 子模块管理:可以通过 model.children()model.modules() 访问子模块。
class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.fc = nn.Linear(10, 5)self.conv = nn.Conv2d(3, 16, kernel_size=3)model = MyModule()
print(list(model.parameters()))  # 自动包含了 fc 和 conv 的参数

2. forward 方法

forward 方法定义了模型的前向传播逻辑。在调用模型实例时,会自动调用 forward 方法。

output = model(input_data)  # 等价于 output = model.forward(input_data)

3. 不需要定义反向传播

在大多数情况下,不需要手动实现反向传播函数。PyTorch 的自动求导机制(autograd)会根据前向传播中的操作,自动计算梯度。

五、常用的内置模块

PyTorch 提供了大量的内置模块,继承自 nn.Module,可以直接使用:

  • 线性层nn.Linear
  • 卷积层nn.Conv1dnn.Conv2dnn.Conv3d
  • 循环神经网络nn.RNNnn.LSTMnn.GRU
  • 归一化层nn.BatchNorm1dnn.BatchNorm2d
  • 激活函数nn.ReLUnn.Sigmoidnn.Softmax
  • 损失函数nn.MSELossnn.CrossEntropyLoss

六、示例:创建一个简单的神经网络

1. 问题描述

创建一个多层感知机(MLP),用于对 MNIST 手写数字进行分类。

2. 模型定义

class MNISTClassifier(nn.Module):def __init__(self):super(MNISTClassifier, self).__init__()self.flatten = nn.Flatten()  # 将输入展开为一维self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(64, 10)  # 输出10个类别的分数def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x

3. 训练过程

import torch.optim as optim# 初始化模型、损失函数和优化器
model = MNISTClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 假设有数据加载器 data_loader
for epoch in range(num_epochs):for images, labels in data_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

七、深入理解 nn.Module 的一些重要概念

1. 参数访问

  • parameters():返回一个生成器,包含模型所有可训练的参数。
  • named_parameters():返回一个生成器,生成 (name, parameter) 对,方便查看参数名称和形状。
for name, param in model.named_parameters():print(f'Parameter {name}: shape {param.shape}')

2. 模块访问

  • children():返回直接子模块的迭代器。
  • modules():返回自身及所有子模块的迭代器。
for child in model.children():print(child)for module in model.modules():print(module)

3. 保存和加载模型

  • 保存模型状态

    torch.save(model.state_dict(), 'model.pth')
    
  • 加载模型状态

    model = MNISTClassifier()
    model.load_state_dict(torch.load('model.pth'))
    

4. 自定义层和模块

通过继承 nn.Module,可以创建自定义的层或模块。

class CustomLayer(nn.Module):def __init__(self, in_features, out_features):super(CustomLayer, self).__init__()self.weight = nn.Parameter(torch.randn(in_features, out_features))self.bias = nn.Parameter(torch.zeros(out_features))def forward(self, x):return torch.matmul(x, self.weight) + self.bias

八、nn.Module 的实践技巧

1. 使用 Sequential 快速构建模型

对于简单的模型,可以使用 nn.Sequential 将多个层按顺序组合。

model = nn.Sequential(nn.Flatten(),nn.Linear(28 * 28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10)
)

2. 模型的嵌套

可以将模块嵌套使用,构建复杂的网络结构。

class ComplexModel(nn.Module):def __init__(self):super(ComplexModel, self).__init__()self.block1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU())self.block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),nn.ReLU())self.fc = nn.Linear(64 * 24 * 24, 10)def forward(self, x):x = self.block1(x)x = self.block2(x)x = x.view(x.size(0), -1)  # 展平x = self.fc(x)return x

九、总结

  • nn.Module 是 PyTorch 构建神经网络的基础,提供了参数管理、子模块管理和前向传播等功能。
  • 通过继承 nn.Module,可以方便地创建自定义模型或层,满足各种复杂的需求。
  • 在使用 nn.Module 时,注意正确地定义 __init__forward 方法,并确保在 forward 方法中定义前向计算逻辑。
  • PyTorch 提供了大量的内置模块,可以直接使用或作为自定义模块的基石。
  • 善于利用 nn.Module 的特性和工具,可以大大提高模型开发的效率和代码的可读性。

十、参考示例:完整的训练脚本

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据集和数据加载器
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# 定义模型
class MNISTClassifier(nn.Module):def __init__(self):super(MNISTClassifier, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(64, 10)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x# 初始化模型、损失函数和优化器
model = MNISTClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for images, labels in train_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'mnist_classifier.pth')

相关文章:

深度学习(4):torch.nn.Module

文章目录 一、是什么二、nn.Module 的核心功能三、nn.Module 的基本用法1. 定义自定义模型2. 初始化模型3. 模型的使用 四、nn.Module 的关键特性1. 自动注册子模块和参数2. forward 方法3. 不需要定义反向传播 五、常用的内置模块六、示例:创建一个简单的神经网络1…...

(14)关于docker如何通过防火墙做策略限制

关于docker如何通过防火墙做策略限制 1、iptables相关问题 在Iptables防火墙中包含四种常见的表,分别是filter、nat、mangle、raw。 filter:负责过滤数据包。 filter表可以管理INPUT、OUTPUT、FORWARD链。 nat:用于网络地址转换。 nat表…...

新React开发人员应该如何思考

React是一个用于构建用户界面的流行JavaScript库,通过使开发人员能够创建可重用组件并有效管理复杂的UI,彻底改变了前端开发。然而,采用正确的心态对于新开发人员驾驭React独特的范式至关重要。让我们来探索塑造“React思维模式”的基本原则和…...

解密.bixi、.baxia勒索病毒:如何安全恢复被加密数据

导言 在数字化时代,数据安全已成为个人和企业面临的重大挑战之一。随着网络攻击手段的不断演进,勒索病毒的出现尤为引人关注。其中,.bixi、.baxia勒索病毒是一种新型的恶意软件,它通过加密用户的重要文件,迫使受害者支…...

开源 AI 智能名片与 S2B2C 商城小程序:嫁接权威实现信任与增长

摘要:本文探讨了嫁接权威在产品营销中的重要性,并结合开源 AI 智能名片与 S2B2C 商城小程序,阐述了如何通过与权威关联来建立客户信任,提升产品竞争力。强调了在当今商业环境中,巧妙运用嫁接权威的方法,能够…...

S-Clustr-Simple 飞机大战:骇入现实的建筑灯光游戏

项目地址:https://github.com/MartinxMax/S-Clustr/releases Video https://www.youtube.com/watch?vr3JIZY1olro 飞机大战 按键操作: ←:向左移动 →:向右移动 Space:发射子弹 这是一个影子集群的游戏插件,可以将游戏画面映射到现实的设备,允许恶…...

MySQL:存储引擎简介和库的基本操作

目录 一、存储引擎 1、什么是存储引擎? 2、存储引擎的分类 关系型数据库存储引擎: 非关系型数据库存储引擎: 分布式数据库存储引擎: 3、常用的存储引擎及优缺点 1、InnoDb存储引擎 2、MyISAM存储引擎 3、MEMORY存储引擎 …...

JavaScript类型判断(总结)

1. 使用typeof操作符 typeof操作符可以返回一个值的类型的字符串表示。例如: typeof 42; // "number" typeof "Hello"; // "string" typeof true; // "boolean" typeof undefined; // "undefined" typeof null…...

SpringBoot之登录校验关于JWT、Filter、interceptor、异常处理的使用

什么是登录校验? 所谓登录校验,指的是我们在服务器端接收到浏览器发送过来的请求之后,首先我们要对请求进行校验。先要校验一下用户登录了没有,如果用户已经登录了,就直接执行对应的业务操作就可以了;如果用…...

我的AI工具箱Tauri版-FunAsr音频转文本

本教程基于自研的AI工具箱Tauri版进行FunAsr音频转文本服务。 FunAsr音频转文本服务 是自研AI工具箱Tauri版中的一个高效模块,专为将音频或视频中的语音内容自动转化为文本或字幕而设计。用户只需简单配置输入、输出路径,即可通过FunAsr工具快速批量处理…...

C++:模版初阶

目录 一、泛型编程 二、函数模版 概念 格式 原理 实例化 模版参数的匹配原则 三、类模版 定义格式 实例化 一、泛型编程 如何实现一个通用的交换函数呢? void Swap(int& left, int& right) {int temp left;left right;right temp; } void Swa…...

Python Web 与区块链集成的最佳实践:智能合约、DApp与安全

Python Web 与区块链集成的最佳实践:智能合约、DApp与安全 📚 目录 🏗 区块链基础 区块链的基础概念与应用场景使用 Web3.py 与 Python Web 应用集成区块链网络在 Web 应用中实现加密货币支付与转账功能 🔑 智能合约与 DApp 编写…...

使用工具将截图公式转换为word公式

引言: 公式越复杂,心情越凌乱,手写都会觉得很麻烦,何况敲到电脑里面呢,特别是在写论文时,word有专属的公式格式,十分繁杂,如果照着mathTYPE软件敲,那么会耗费很长的时间…...

深度学习(6):Dataset 和 DataLoader

文章目录 Dataset 类DataLoader 类 Dataset 类 概念: Dataset 是一个抽象类,用于表示数据集。它定义了如何获取数据集中的单个样本和标签。 作用: 为数据集提供统一的接口,便于数据的读取、预处理和管理。 关键方法&#xff…...

Qt窗口——QToolBar

文章目录 工具栏创建工具栏设置toolTip工具栏配合菜单栏工具栏浮动状态 工具栏 QToolBar工具栏是应用程序中集成各种功能实现快捷键使用的一个区域。 可以有多个,也可以没有。 创建工具栏 #include "mainwindow.h" #include "ui_mainwindow.h&qu…...

MySQL—存储过程详解

基本介绍 存储过程和函数是数据库中预先编译并存储的一组SQL语句集合。它们的主要目的是提高代码的复用性、减少数据传输、简化业务逻辑处理,并且一旦编译成功,可以永久有效。 存储过程和函数的好处 提高代码的复用性:存储过程和函数可以在…...

2024ICPC网络赛2记录:CK

这一次网络赛我们过8题,排名71,算是发挥的非常好的了。这一把我们三个人手感都很好,前六题都是一遍过,然后我又切掉了非签到的E和C,最后时间不是很多,K只想到大概字典树的思路,细节不是很懂就直…...

PerparedStatement概述

PreparedStatement 是 Java 中的一个接口,用于预编译 SQL 语句并执行数据库操作。 一、主要作用 提高性能: 数据库在首次执行预编译语句时会进行语法分析、优化等操作,并将其存储在缓存中。后续执行相同的预编译语句时,数据库可…...

联影医疗嵌入式面试题及参考答案(3万字长文)

假如你要做机器人控制,你会遵循怎样的开发流程? 首先,需求分析阶段。明确机器人的功能需求,例如是用于工业生产中的物料搬运、还是家庭服务中的清洁打扫等。了解工作环境的特点,包括空间大小、障碍物分布、温度湿度等因素。同时,确定机器人的性能指标,如运动速度、精度、…...

Rust的作用?

在Linux中,Rust可以开发命令行工具,如FD、SD、Ripgep、Bat、EXA、SKIM等。虽然Rust是面向系统编程,但也不妨碍使用Rust写命令行工具,因为Rust具备现代语言特性、无依赖、生成的目标文件小。 在云计算和区块链区域,Rus…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI

前一阵子在百度 AI 开发者大会上&#xff0c;看到基于小智 AI DIY 玩具的演示&#xff0c;感觉有点意思&#xff0c;想着自己也来试试。 如果只是想烧录现成的固件&#xff0c;乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外&#xff0c;还提供了基于网页版的 ESP LA…...

LLM基础1_语言模型如何处理文本

基于GitHub项目&#xff1a;https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken&#xff1a;OpenAI开发的专业"分词器" torch&#xff1a;Facebook开发的强力计算引擎&#xff0c;相当于超级计算器 理解词嵌入&#xff1a;给词语画"…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

AI,如何重构理解、匹配与决策?

AI 时代&#xff0c;我们如何理解消费&#xff1f; 作者&#xff5c;王彬 封面&#xff5c;Unplash 人们通过信息理解世界。 曾几何时&#xff0c;PC 与移动互联网重塑了人们的购物路径&#xff1a;信息变得唾手可得&#xff0c;商品决策变得高度依赖内容。 但 AI 时代的来…...

Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?

Redis 的发布订阅&#xff08;Pub/Sub&#xff09;模式与专业的 MQ&#xff08;Message Queue&#xff09;如 Kafka、RabbitMQ 进行比较&#xff0c;核心的权衡点在于&#xff1a;简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

Docker 本地安装 mysql 数据库

Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker &#xff1b;并安装。 基础操作不再赘述。 打开 macOS 终端&#xff0c;开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...