PyTorch之nn.Module与nn.functional用法区别
文章目录
- 1. nn.Module
- 2. nn.functional
- 2.1 基本用法
- 2.2 常用函数
- 3. nn.Module 与 nn.functional
- 3.1 主要区别
- 3.2 具体样例:nn.ReLU() 与 F.relu()
- 参考资料
1. nn.Module
在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。
关于 nn.Module 的更多介绍可以参考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解
这里,我们基于nn.Module创建一个简单的神经网络模型,实现代码如下:
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(MyModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size)self.layer2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.layer1(x))x = self.layer2(x)return x
2. nn.functional
nn.functional 是PyTorch中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。这些函数通常用于执行各种非线性操作、损失函数、激活函数等。
2.1 基本用法
如何在神经网络中使用nn.functional?
在PyTorch中,你可以轻松地在神经网络中使用 nn.functional 函数。通常,你只需将输入数据传递给这些函数,并将它们作为网络的一部分。
以下是一个简单的示例,演示如何在一个全连接神经网络中使用ReLU激活函数:
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(64, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x
在上述示例中,我们首先导入nn.functional 模块,然后在网络的forward 方法中使用F.relu 函数作为激活函数。
nn.functional的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。
2.2 常用函数
(1)激活函数
激活函数是神经网络中的关键组件,它们引入非线性性,使网络能够拟合复杂的数据。以下是一些常见的激活函数:
- ReLU(Rectified Linear Unit)
ReLU是一种简单而有效的激活函数,它将输入值小于零的部分设为零,大于零的部分保持不变。它的数学表达式如下:
output = F.relu(input)
- Sigmoid
Sigmoid函数将输入值映射到0和1之间,常用于二分类问题的输出层。它的数学表达式如下:
output = F.sigmoid(input)
- Tanh(双曲正切)
Tanh函数将输入值映射到-1和1之间,它具有零中心化的特性,通常在循环神经网络中使用。它的数学表达式如下:
output = F.tanh(input)
(2)损失函数
损失函数用于度量模型的预测与真实标签之间的差距。PyTorch的nn.functional 模块包含了各种常用的损失函数,例如:
- 交叉熵损失(Cross-Entropy Loss)
交叉熵损失通常用于多分类问题,计算模型的预测分布与真实分布之间的差异。它的数学表达式如下:
loss = F.cross_entropy(input, target)
- 均方误差损失(Mean Squared Error Loss)
均方误差损失通常用于回归问题,度量模型的预测值与真实值之间的平方差。它的数学表达式如下:
loss = F.mse_loss(input, target)
- L1 损失
L1损失度量预测值与真实值之间的绝对差距,通常用于稀疏性正则化。它的数学表达式如下:
loss = F.l1_loss(input, target)
(3)非线性操作
nn.functional 模块还包含了许多非线性操作,如池化、归一化等。
- 最大池化(Max Pooling)
最大池化是一种用于减小特征图尺寸的操作,通常用于卷积神经网络中。它的数学表达式如下:
output = F.max_pool2d(input, kernel_size)
- 批量归一化(Batch Normalization)
批量归一化是一种用于提高训练稳定性和加速收敛的技术。它的数学表达式如下:
output = F.batch_norm(input, mean, std, weight, bias)
3. nn.Module 与 nn.functional
3.1 主要区别
nn.Module 与 nn.functional 的主要区别在于:
- nn.Module实现的layers是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数;
- nn.functional中的函数更像是纯函数,由def function(input)定义。
注意:
- 如果模型有可学习的参数时,最好使用nn.Module。
- 激活函数(ReLU、sigmoid、Tanh)、池化(MaxPool)等层没有可学习的参数,可以使用对应的functional函数。
- 卷积、全连接等有可学习参数的网络建议使用nn.Module。
- dropout没有可学习参数,但建议使用nn.Dropout而不是nn.functional.dropout。
3.2 具体样例:nn.ReLU() 与 F.relu()
nn.ReLU() :
import torch.nn as nn
'''
nn.ReLU()
F.relu():
import torch.nn.functional as F
'''
out = F.relu(input)
其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。
当用print(net)输出时,nn.ReLU()会有对应的层,而F.ReLU()是没有输出的。
import torch.nn as nn
import torch.nn.functional as Fclass NET1(nn.Module):def __init__(self):super(NET1, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)self.relu = nn.ReLU() # 模块的激活函数def forward(self, x):out = self.conv(x)x = self.bn(x)out = self.relu()return outclass NET2(nn.Module):def __init__(self):super(NET2, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)def forward(self, x):x = self.conv(x)x = self.bn(x)out = F.relu(x) # 函数的激活函数return outnet1 = NET1()
net2 = NET2()
print(net1)
print(net2)

参考资料
- PyTorch的nn.Module类的详细介绍
- PyTorch
nn.functional模块详解:探索神经网络的魔法工具箱 - pytorch:F.relu() 与 nn.ReLU() 的区别
相关文章:
PyTorch之nn.Module与nn.functional用法区别
文章目录 1. nn.Module2. nn.functional2.1 基本用法2.2 常用函数 3. nn.Module 与 nn.functional3.1 主要区别3.2 具体样例:nn.ReLU() 与 F.relu() 参考资料 1. nn.Module 在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网…...
2024.06.24 校招 实习 内推 面经
绿*泡*泡VX: neituijunsir 交流*裙 ,内推/实习/校招汇总表格 1、校招 | 昂瑞微2025届校园招聘正式启动 校招 | 昂瑞微2025届校园招聘正式启动 2、实习 | 东风公司研发总院暑期实习生火爆招募中 实习 | 东风公司研发总院暑期实习生火爆招募中 3、实习…...
【C++】using namespace std 到底什么意思
📢博客主页:https://blog.csdn.net/2301_779549673 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! 📢本文作为 JohnKi 的学习笔记,引用了部分大佬的案例 📢未来很长&a…...
基于ESP32 IDF的WebServer实现以及OTA固件升级实现记录(三)
经过前面两篇的前序铺垫,对webserver以及restful api架构有了大体了解后本篇描述下最终的ota实现的代码以及调试中遇到的诡异bug。 eps32的实际ota实现过程其实esp32官方都已经基本实现好了,我们要做到无非就是把要升级的固件搬运到对应ota flash分区里面…...
116-基于5VLX110T FPGA FMC接口功能验证6U CPCI平台
一、板卡概述 本板卡是Xilinx公司芯片V5系列芯片设计信号处理板卡。由一片Xilinx公司的XC5VLX110T-1FF1136 / XC5VSX95T-1FF1136 / XC5VFX70T-1FF1136芯片组成。FPGA接1片DDR2内存条 2GB,32MB Nor flash存储器,用于存储程序。外扩 SATA、PCI、PCI expres…...
Android - Json/Gson
Json数据解析 json对象:花括号开头和结尾,中间是键值对形式————”属性”:属性值”” json数组:中括号里放置 json 数组,里面是多个json对象或者数字等 JSONObject 利用 JSONObject 解析 1.创建 JSONObject 对象,传…...
盲信号处理的发展现状
盲源分离技术最早在上个世纪中期提出,在1991年Herault和Jutten提出基于反馈神经网络的盲源分离方法,但该方法缺乏理论基础,后来Tong和Liu分析了盲源分离问题的可辨识性和不确定性,Cardoso于1993年提出了基于高阶统计的联合对角化盲…...
二轴机器人装箱机:重塑物流效率,精准灵活,引领未来装箱新潮流
在现代化物流领域,高效、精准与灵活性无疑是各大企业追求的核心目标。而在这个日益追求自动化的时代,二轴机器人装箱机凭借其较佳的性能和出色的表现,正逐渐成为装箱作业的得力助手,引领着未来装箱新潮流。 一、高效:重…...
使用python做飞机大战
代码地址: 点击跳转...
Python面向对象编程:派生
本套课在线学习视频(网盘地址,保存到网盘即可免费观看): https://pan.quark.cn/s/69d1cc25d4ba 面向对象编程(OOP)是一种编程范式,它通过将数据和操作数据的方法封装在一起࿰…...
华为仓颉编程语言
目录 一、引言 二、仓颉编程语言概述 三、技术特征 四、应用场景 五、社区支持 六、结论与展望 一、引言 随着信息技术的快速发展,编程语言作为软件开发的核心工具,其重要性日益凸显。近年来,华为公司投入大量研发资源,成功…...
【微信小程序开发实战项目】——如何制作一个属于自己的花店微信小程序(2)
👨💻个人主页:开发者-曼亿点 👨💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨💻 本文由 曼亿点 原创 👨💻 收录于专栏:…...
解锁数据资产的无限潜能:深入探索创新的数据分析技术,挖掘其在实际应用场景中的广阔价值,助力企业发掘数据背后的深层信息,实现业务的持续增长与创新
目录 一、引言 二、创新数据分析技术的发展 1、大数据分析技术 2、人工智能与机器学习 3、可视化分析技术 三、创新数据分析技术在实际应用场景中的价值 1、市场洞察与竞争分析 2、客户细分与个性化营销 3、业务流程优化与风险管理 4、产品创新与研发 四、案例分析 …...
Bridging nonnull in Objective-C to Swift: Is It Safe?
Bridging nonnull in Objective-C to Swift: Is It Safe? In the world of iOS development, bridging between Objective-C and Swift is a common practice, especially for legacy codebases (遗留代码库) or when integrating (集成) third-party libraries. One importa…...
算法训练 | 图论Part1 | 98.所有可达路径
目录 98.所有可达路径 深度搜索法 98.所有可达路径 题目链接:98. 所有可达路径 文章讲解:代码随想录 深度搜索法 代码一:邻接矩阵写法 #include <iostream> #include <vector> using namespace std; vector<vector<…...
【JVM基础篇】垃圾回收
文章目录 垃圾回收常见内存管理方式手动回收:C内存管理自动回收(GC):Java内存管理自动、手动回收优缺点 应用场景垃圾回收器需要对哪些部分内存进行回收?不需要垃圾回收器回收需要垃圾回收器回收 方法区的回收代码测试手动调用垃圾回收方法Sy…...
Spark join数据倾斜调优
Spark中常见的两种数据倾斜现象如下 stage部分task执行特别慢 一般情况下是某个task处理的数据量远大于其他task处理的数据量,当然也不排除是程序代码没有冗余,异常数据导致程序运行异常。 作业重试多次某几个task总会失败 常见的退出码143、53、137…...
YOLOv5初学者问题——用自己的模型预测图片不画框
如题,我在用自己的数据集训练权重模型的时候,在训练完成输出的yolov5-v5.0\runs\train\exp2目录下可以看到,在训练测试的时候是有输出描框的。 但是当我引用训练好的best.fangpt去进行预测的时候, 程序输出的图片并没有描框。根据…...
【linux学习---1】点亮一个LED---驱动一个GPIO
文章目录 1、原理图找对应引脚2、IO复用3、IO配置4、GPIO配置5、GPIO时钟使能6、总结 1、原理图找对应引脚 从上图 可以看出, 蜂鸣器 接到了 BEEP 上, BEEP 就是 GPIO5_IO05 2、IO复用 查找IMX6UL参考手册 和 STM32一样,如果某个 IO 要作为…...
Redis分布式锁代码实现详解
引言 在分布式系统中,资源竞争和数据一致性问题常常需要通过锁机制来解决。Redis作为一个高性能的键值存储系统,因其提供的原子操作、丰富的数据结构以及网络延迟低等特点,成为了实现分布式锁的理想选择。本文将详细介绍如何使用Redis来实现…...
Pixel Fashion Atelier保姆级教程:从INSERT COIN按钮到像素粒子物理引擎解析
Pixel Fashion Atelier保姆级教程:从INSERT COIN按钮到像素粒子物理引擎解析 1. 像素时装锻造坊简介 像素时装锻造坊是一款融合了复古游戏美学与现代AI技术的图像生成工具。它基于Stable Diffusion和Anything-v5模型构建,专为时尚设计和像素艺术创作而…...
Sodaq_RN2483库详解:LoRaWAN Class A终端嵌入式实现
1. Sodaq_RN2483库深度解析:面向Class A LoRaWAN终端的嵌入式通信实现 1.1 库定位与工程价值 Sodaq_RN2483是一个专为Microchip RN2483 LoRaWAN模块设计的Arduino兼容C库,其核心目标是为资源受限的嵌入式系统提供稳定、可复用、符合LoRaWAN协议规范的无…...
RWKV7-1.5B-g1a镜像部署教程:CSDN平台一键拉起Web服务,7860端口直连体验
RWKV7-1.5B-g1a镜像部署教程:CSDN平台一键拉起Web服务,7860端口直连体验 1. 模型简介 rwkv7-1.5B-g1a 是基于新一代 RWKV-7 架构的多语言文本生成模型,特别适合中文场景下的轻量级应用。这个1.5B参数的版本在保持较高生成质量的同时&#x…...
OneAPI 百度文心一言ERNIE-Bot接入:千帆平台Key对接指南
OneAPI 百度文心一言ERNIE-Bot接入:千帆平台Key对接指南 安全提示:使用 root 用户初次登录系统后,务必修改默认密码 123456! 1. 引言:为什么需要统一的API管理平台 在当今AI技术快速发展的时代,企业和开发…...
Windows下OpenClaw安装指南:对接ollama GLM-4.7-Flash模型
Windows下OpenClaw安装指南:对接ollama GLM-4.7-Flash模型 1. 为什么选择OpenClaw GLM-4.7-Flash组合 作为一个长期在Windows环境下折腾AI工具的开发者,我一直在寻找一个既能保持本地数据隐私,又能灵活对接各类开源模型的自动化框架。Open…...
ArcMap地图数字化实战:从加载地形图到保存成果的完整流程(附常见问题解决)
ArcMap地图数字化实战:从加载地形图到保存成果的完整流程(附常见问题解决) 在GIS领域,地图数字化是将纸质地图或图像转换为计算机可识别和处理的数字格式的基础工作。这项技能不仅是GIS专业学生的必修课,也是城市规划、…...
【AI工程化硬核考点】:FastAPI 2.0 + async/await + StreamingResponse三重协程调度机制精讲
第一章:FastAPI 2.0 异步 AI 流式响应 面试题汇总FastAPI 2.0 原生强化了对异步流式响应(StreamingResponse)的支持,尤其适用于大语言模型(LLM)推理、实时日志推送、AI 生成内容分块返回等场景。面试官常聚…...
从沙子到芯片:保姆级图解CMOS制造18步核心工艺(附高清流程图)
从沙子到芯片:图解CMOS制造18步核心工艺 想象一下,你手中智能手机的核心处理器,其内部晶体管数量已突破百亿级——这相当于将整个银河系的恒星数量压缩到指甲盖大小的硅片上。而这一切的起点,竟是海滩上最普通的沙子。本文将用18张…...
跨境电商多语种支持:SenseVoice-Small ONNX语音识别模型部署与本地化适配
跨境电商多语种支持:SenseVoice-Small ONNX语音识别模型部署与本地化适配 1. 环境准备与快速部署 SenseVoice-Small ONNX模型是一个经过量化处理的高效语音识别解决方案,特别适合跨境电商场景中的多语言语音处理需求。这个模型支持超过50种语言&#x…...
解码 DINO 核心:三大创新如何重塑端到端目标检测
1. 从DETR到DINO:目标检测的范式革命 记得我第一次用Faster R-CNN做目标检测时,光是调整锚框尺寸就花了整整三天。这种传统检测方法就像用老式打字机写代码——每个环节都需要手工微调。直到2020年DETR横空出世,才让我意识到目标检测还能这么…...
