当前位置: 首页 > news >正文

PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nnclass MyConvNet(nn.Module):def __init__(self):super(MyConvNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.conv1(x)return x# 实例化模型
model = MyConvNet()# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_conv_net(input_tensor, weight, bias=None):output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)return output_tensor# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.relu = nn.ReLU()def forward(self, x):x = self.relu(x)return x# 实例化模型
model = MyModel()# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = F.relu(input_tensor)return output_tensor# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):x = self.linear(x)return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = torch.matmul(input_tensor, weight.t()) + biasreturn output_tensor# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

相关文章:

PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn 和 torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别: 1. 继承方式与结构 torch.nn torch.nn 中的模块大多数是通过继承 torch.nn…...

React的应用级框架推荐——Next、Modern、Blitz等,快速搭建React项目

在 React 企业级应用开发中,Next.js、Modern.js 和 Blitz 是三个常见的框架,它们提供了不同的特性和功能,旨在简化开发流程并提高应用的性能和扩展性。以下是它们的详解与比较: Next、Modern、Blitz 1. Next.js Next.js 是由 Ve…...

基于GRU实现股价多变量时间序列预测(PyTorch版)

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记…...

Java创建对象有几种方式?

大家好,我是锋哥。今天分享关于【Java创建对象有几种方式?】面试题。希望对大家有帮助; Java创建对象有几种方式? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在Java中,创建对象主要有以下几种方式&…...

Vue3初学之Element Plus Dialog对话框,Message组件,MessageBox组件

Dialog的使用&#xff1a; 控制弹窗的显示和隐藏 <template><div><el-button click"dialogVisible true">打开弹窗</el-button><el-dialogv-model"dialogVisible"title"提示"width"30%":before-close&qu…...

基于Python机器学习的双色球数据分析与预测

python统计分析2003-2024年所有的中奖记录,通过人工智能机器学习预测双色球,个人意见,仅供参考. 声明&#xff1a;双色球具有随机性&#xff0c;任何工具无法预测。本文章仅作为技术交流&#xff0c;提供学习参考。本文所涉及的代码均为python之机器学习的代码。双色球为公益事…...

微软Win10 RP 19045.5435(KB5050081)预览版发布!

系统之家1月20日最新报道&#xff0c;微软面向Release Preview频道的Windows Insider项目成员&#xff0c;发布了适用于Windows10 22H2版本的KB5050081更新&#xff0c;更新后系统版本号将升至19045.5435。本次更新增加了对GB18030-2022标准的支持&#xff0c;同时新版日历将为…...

使用 Parcel 和 NPM 脚本进行打包

使用 Parcel 和 NPM 脚本进行打包 Parcel Parcel 是一个零配置的网页应用程序打包工具&#xff0c;主要用于快速构建现代 JavaScript 应用。 我们可以使用npm直接安装它 npm install --save-dev parcel //这将把 Parcel 添加到 devDependencies 中&#xff0c;表明它是一个…...

HTML<center>标签

HTML5不支持。 <center>标签在HTML4中用于使文本居中对齐。 用什么来代替呢&#xff1f; 例子 居中对齐文本&#xff08;使用 CSS&#xff09;&#xff1a; <html> <head> <style> h1 {text-align: center;} p {text-align: center;} div {text-a…...

LatentSync本地部署教程:基于音频精准生成唇形高度同步视频

LatentSync 是字节跳动联合北京交通大学推出的一个端到端的唇形同步框架&#xff0c;以下是对其的详细介绍&#xff1a; 一、技术基础 LatentSync 基于音频条件的潜在扩散模型&#xff0c;无需任何中间的 3D 表示或 2D 特征点。它利用了 Stable Diffusion 的强大生成能力&…...

ES使用笔记,聚合分组后再分页,探索性能优化问题

之前分享过一篇文档,也是关于聚合分组后再分页的具体实现,当时只想着怎么实现,没有去主要探索ES性能优化的问题, 这篇我会换一种方式,重新实现这个聚合分组后再分页的操作,并且指出能优化性能点,可能我们再使用的时候,并没有注意过的点,希望对你有帮助!大佬的话,请忽略! 上源码…...

VUE3 vite下的axios跨域

在使用 Vite 开发时&#xff0c;如果你的前端项目需要请求后端 API&#xff0c;且后端和前端不在同一个域上&#xff0c;可能会遇到跨域问题。跨域是指浏览器出于安全考虑&#xff0c;阻止了前端网页向不同源&#xff08;域名、协议、端口&#xff09;发送请求。 解决跨域问题…...

Mac下安装ADB环境的三种方式

参考网址&#xff1a; Mac下安装ADB环境的三种方式-百度开发者中心 ADB&#xff0c;即Android Debug Bridge&#xff0c;是Android开发过程中不可或缺的工具。通过ADB&#xff0c;开发者可以在计算机上管理设备或模拟器上的应用&#xff0c;提供了丰富的调试功能。然而&#…...

在Vue中,<img> 标签的 src 值

1. 直接指定 src 的值&#xff08;适用于网络图片&#xff09; 如果你使用的是网络图片&#xff08;即图片的URL是完整的HTTP或HTTPS链接&#xff09;&#xff0c;可以直接指定 src 的值&#xff1a; vue 复制 <template><div><img src"https://exampl…...

Kotlin基础知识学习(三)

函数使用 基本用法 函数声明变化 如果函数是公开的&#xff0c;则public关键字可以省略。用fun关键字表示函数的定义。如果函数没有返回值可以不用声明。如果函数表示重载&#xff0c;直接在fun同一行用override修饰。函数参数格式是变量名&#xff1a;变量类型。函数参数允…...

渗透测试之XEE[外部实体注入]漏洞 原理 攻击手法 xml语言结构 防御手法

目录 原理 XML语言解释 什么是xml语言&#xff1a; 以PHP举例xml外部实体注入 XML语言结构 面试题目 如何寻找xxe漏洞 XEE漏洞修复域防御 提高版本 代码修复 php java python 手动黑名单过滤(不推荐) 一篇文章带你深入理解漏洞之 XXE 漏洞 - 先知社区 原理 XXE&…...

店铺营业状态设置(day05)

Redis入门 Redis简介 Redis 是一个基于内存的 key-value 结构数据库。Redis 是互联网技术领域使用最为广泛的存储中间件。 Redis是一个基于内存的 key-value 结构数据库。 主要特点&#xff1a; 1、基于内存存储&#xff0c;读写性能高 2、适合存储热点数据&#xff08;热点…...

游戏引擎学习第84天

仓库:https://gitee.com/mrxiao_com/2d_game_2 我们正在试图弄清楚如何完成我们的世界构建 上周做了一些偏离计划的工作&#xff0c;开发了一个小型的背景位图合成工具&#xff0c;这个工具做得还不错&#xff0c;虽然是临时拼凑的&#xff0c;但验证了背景构建的思路。这个过…...

快手SDK接入错误处理经验总结(WebGL方案)

1、打包时提示Assets\WebGLTemplates\ks路径下未找到Index.html文件错误 处理方法&#xff1a;直接使用Unity默认模板下的Index.html文件即可 文件所在路径&#xff1a;Unity安装路径\Editor\Data\PlaybackEngines\WebGLSupport\BuildTools\WebGLTemplates\Default 参考图&a…...

C语言 for 循环:解谜数学,玩转生活!

放在最前面的 &#x1f388; &#x1f388; 我的CSDN主页:OTWOL的主页&#xff0c;欢迎&#xff01;&#xff01;&#xff01;&#x1f44b;&#x1f3fc;&#x1f44b;&#x1f3fc; &#x1f389;&#x1f389;我的C语言初阶合集&#xff1a;C语言初阶合集&#xff0c;希望能…...

DeepSeek 赋能智能养老:情感陪伴机器人的温暖革新

目录 一、引言二、智能养老情感陪伴机器人的市场现状与需求2.1 市场现状2.2 老年人情感陪伴需求分析 三、DeepSeek 技术详解3.1 DeepSeek 的技术特点3.2 与其他类似技术的对比优势 四、DeepSeek 在智能养老情感陪伴机器人中的具体应用4.1 自然语言处理与对话交互4.2 情感识别与…...

Android学习总结-GetX库常见问题和解决方案

GetX库的常见问题 ​路由管理&#xff1a;Get.to() 后页面不跳转或卡顿&#xff1f;​​ ​问题&#xff1a;​​ 明明调用了 Get.to(NextPage())&#xff0c;但页面没反应&#xff0c;或者感觉有延迟卡顿。这可能发生在较复杂的页面树或低端设备上。​原因&#xff1a;​​ ​…...

Fullstack 面试复习笔记:Spring / Spring Boot / Spring Data / Security 整理

Fullstack 面试复习笔记&#xff1a;Spring / Spring Boot / Spring Data / Security 整理 之前的笔记&#xff1a; Fullstack 面试复习笔记&#xff1a;操作系统 / 网络 / HTTP / 设计模式梳理Fullstack 面试复习笔记&#xff1a;Java 基础语法 / 核心特性体系化总结Fullsta…...

Scade 语言概念 - 方程(equation)

在 Scade 6 程序中自定义算子(Operator)的定义、或数据流定义(data_def)的内容中&#xff0c;包含一种基本的语言结构&#xff1a;方程(equation)(注1)。在本篇中&#xff0c;将叙述 Scade 语言方程的文法形式&#xff0c;以及作用。 注1: 对 Scade 中的 equation, 或 equation…...

JavaScript 本地存储 (localStorage) 完全指南

文章目录 JavaScript 本地存储 (localStorage) 完全指南 &#x1f510;一、什么是 localStorage&#xff1f;&#x1f4a1;二、如何使用 localStorage&#xff1f;&#x1f527;1. 存储数据2. 读取数据3. 删除数据4. 清空所有数据 三、存储对象和数组的技巧 &#x1f3a8;1. 存…...

现代C++特性(一):基本数据类型扩展

文章目录 基础数据类型long long (C 11)numeric_limits()获取当前数据类型的最值warning C4309: “”: 截断常量值新字符类型char16_t和char32_tWindows编程常用字符类型wchar_tchar8_t (C 20) 基础数据类型 C中的基本类型是构建其他数据类型的基础&#xff0c;常见的基础类型…...

【读论文】U-Net: Convolutional Networks for Biomedical Image Segmentation 卷积神经网络

摘要1 Introduction2 Network Architecture3 Training3.1 Data Augmentation 4 Experiments5 Conclusion背景知识卷积激活函数池化上采样、上池化、反卷积softmax 归一化函数交叉熵损失 Olaf Ronneberger, Philipp Fischer, Thomas Brox Paper&#xff1a;https://arxiv.org/ab…...

【学习记录】在 Ubuntu 中将新硬盘挂载到 /home 目录的完整指南

文章目录 &#x1f4cb; 一、准备工作1. 备份重要数据2. 确认新硬盘设备信息 &#x1f6e0;️ 二、格式化新硬盘&#xff08;如未格式化&#xff09;1. 格式化为 ext4 文件系统&#xff08;推荐&#xff09; &#x1f501; 三、临时挂载并迁移数据1. 创建临时挂载点2. 挂载新硬…...

NVIDIA DRIVE AGX平台:引领智能驾驶安全新时代

随着科技的不断进步&#xff0c;汽车行业正迎来前所未有的变革&#xff0c;智能驾驶技术成为全球产业竞相布局的焦点之一。然而&#xff0c;这场技术革命的背后&#xff0c;最关键且被广泛关注的是安全性问题。近日&#xff0c;我认真研读了NVIDIA发布的《自动驾驶安全报告》白…...

【 java 集合知识 第一篇 】

目录 1.概念 1.1.集合与数组的区别 1.2.集合分类 1.3.Collection和Collections的区别 1.4.集合遍历的方法 2.List 2.1.List的实现 2.2.可以一边遍历一边修改List的方法 2.3.List快速删除元素的原理 2.4.ArrayList与LinkedList的区别 2.5.线程安全 2.6.ArrayList的扩…...