当前位置: 首页 > news >正文

PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nnclass MyConvNet(nn.Module):def __init__(self):super(MyConvNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.conv1(x)return x# 实例化模型
model = MyConvNet()# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_conv_net(input_tensor, weight, bias=None):output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)return output_tensor# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.relu = nn.ReLU()def forward(self, x):x = self.relu(x)return x# 实例化模型
model = MyModel()# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = F.relu(input_tensor)return output_tensor# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):x = self.linear(x)return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = torch.matmul(input_tensor, weight.t()) + biasreturn output_tensor# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

相关文章:

PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn 和 torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别: 1. 继承方式与结构 torch.nn torch.nn 中的模块大多数是通过继承 torch.nn…...

React的应用级框架推荐——Next、Modern、Blitz等,快速搭建React项目

在 React 企业级应用开发中,Next.js、Modern.js 和 Blitz 是三个常见的框架,它们提供了不同的特性和功能,旨在简化开发流程并提高应用的性能和扩展性。以下是它们的详解与比较: Next、Modern、Blitz 1. Next.js Next.js 是由 Ve…...

基于GRU实现股价多变量时间序列预测(PyTorch版)

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记…...

Java创建对象有几种方式?

大家好,我是锋哥。今天分享关于【Java创建对象有几种方式?】面试题。希望对大家有帮助; Java创建对象有几种方式? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在Java中,创建对象主要有以下几种方式&…...

Vue3初学之Element Plus Dialog对话框,Message组件,MessageBox组件

Dialog的使用&#xff1a; 控制弹窗的显示和隐藏 <template><div><el-button click"dialogVisible true">打开弹窗</el-button><el-dialogv-model"dialogVisible"title"提示"width"30%":before-close&qu…...

基于Python机器学习的双色球数据分析与预测

python统计分析2003-2024年所有的中奖记录,通过人工智能机器学习预测双色球,个人意见,仅供参考. 声明&#xff1a;双色球具有随机性&#xff0c;任何工具无法预测。本文章仅作为技术交流&#xff0c;提供学习参考。本文所涉及的代码均为python之机器学习的代码。双色球为公益事…...

微软Win10 RP 19045.5435(KB5050081)预览版发布!

系统之家1月20日最新报道&#xff0c;微软面向Release Preview频道的Windows Insider项目成员&#xff0c;发布了适用于Windows10 22H2版本的KB5050081更新&#xff0c;更新后系统版本号将升至19045.5435。本次更新增加了对GB18030-2022标准的支持&#xff0c;同时新版日历将为…...

使用 Parcel 和 NPM 脚本进行打包

使用 Parcel 和 NPM 脚本进行打包 Parcel Parcel 是一个零配置的网页应用程序打包工具&#xff0c;主要用于快速构建现代 JavaScript 应用。 我们可以使用npm直接安装它 npm install --save-dev parcel //这将把 Parcel 添加到 devDependencies 中&#xff0c;表明它是一个…...

HTML<center>标签

HTML5不支持。 <center>标签在HTML4中用于使文本居中对齐。 用什么来代替呢&#xff1f; 例子 居中对齐文本&#xff08;使用 CSS&#xff09;&#xff1a; <html> <head> <style> h1 {text-align: center;} p {text-align: center;} div {text-a…...

LatentSync本地部署教程:基于音频精准生成唇形高度同步视频

LatentSync 是字节跳动联合北京交通大学推出的一个端到端的唇形同步框架&#xff0c;以下是对其的详细介绍&#xff1a; 一、技术基础 LatentSync 基于音频条件的潜在扩散模型&#xff0c;无需任何中间的 3D 表示或 2D 特征点。它利用了 Stable Diffusion 的强大生成能力&…...

ES使用笔记,聚合分组后再分页,探索性能优化问题

之前分享过一篇文档,也是关于聚合分组后再分页的具体实现,当时只想着怎么实现,没有去主要探索ES性能优化的问题, 这篇我会换一种方式,重新实现这个聚合分组后再分页的操作,并且指出能优化性能点,可能我们再使用的时候,并没有注意过的点,希望对你有帮助!大佬的话,请忽略! 上源码…...

VUE3 vite下的axios跨域

在使用 Vite 开发时&#xff0c;如果你的前端项目需要请求后端 API&#xff0c;且后端和前端不在同一个域上&#xff0c;可能会遇到跨域问题。跨域是指浏览器出于安全考虑&#xff0c;阻止了前端网页向不同源&#xff08;域名、协议、端口&#xff09;发送请求。 解决跨域问题…...

Mac下安装ADB环境的三种方式

参考网址&#xff1a; Mac下安装ADB环境的三种方式-百度开发者中心 ADB&#xff0c;即Android Debug Bridge&#xff0c;是Android开发过程中不可或缺的工具。通过ADB&#xff0c;开发者可以在计算机上管理设备或模拟器上的应用&#xff0c;提供了丰富的调试功能。然而&#…...

在Vue中,<img> 标签的 src 值

1. 直接指定 src 的值&#xff08;适用于网络图片&#xff09; 如果你使用的是网络图片&#xff08;即图片的URL是完整的HTTP或HTTPS链接&#xff09;&#xff0c;可以直接指定 src 的值&#xff1a; vue 复制 <template><div><img src"https://exampl…...

Kotlin基础知识学习(三)

函数使用 基本用法 函数声明变化 如果函数是公开的&#xff0c;则public关键字可以省略。用fun关键字表示函数的定义。如果函数没有返回值可以不用声明。如果函数表示重载&#xff0c;直接在fun同一行用override修饰。函数参数格式是变量名&#xff1a;变量类型。函数参数允…...

渗透测试之XEE[外部实体注入]漏洞 原理 攻击手法 xml语言结构 防御手法

目录 原理 XML语言解释 什么是xml语言&#xff1a; 以PHP举例xml外部实体注入 XML语言结构 面试题目 如何寻找xxe漏洞 XEE漏洞修复域防御 提高版本 代码修复 php java python 手动黑名单过滤(不推荐) 一篇文章带你深入理解漏洞之 XXE 漏洞 - 先知社区 原理 XXE&…...

店铺营业状态设置(day05)

Redis入门 Redis简介 Redis 是一个基于内存的 key-value 结构数据库。Redis 是互联网技术领域使用最为广泛的存储中间件。 Redis是一个基于内存的 key-value 结构数据库。 主要特点&#xff1a; 1、基于内存存储&#xff0c;读写性能高 2、适合存储热点数据&#xff08;热点…...

游戏引擎学习第84天

仓库:https://gitee.com/mrxiao_com/2d_game_2 我们正在试图弄清楚如何完成我们的世界构建 上周做了一些偏离计划的工作&#xff0c;开发了一个小型的背景位图合成工具&#xff0c;这个工具做得还不错&#xff0c;虽然是临时拼凑的&#xff0c;但验证了背景构建的思路。这个过…...

快手SDK接入错误处理经验总结(WebGL方案)

1、打包时提示Assets\WebGLTemplates\ks路径下未找到Index.html文件错误 处理方法&#xff1a;直接使用Unity默认模板下的Index.html文件即可 文件所在路径&#xff1a;Unity安装路径\Editor\Data\PlaybackEngines\WebGLSupport\BuildTools\WebGLTemplates\Default 参考图&a…...

C语言 for 循环:解谜数学,玩转生活!

放在最前面的 &#x1f388; &#x1f388; 我的CSDN主页:OTWOL的主页&#xff0c;欢迎&#xff01;&#xff01;&#xff01;&#x1f44b;&#x1f3fc;&#x1f44b;&#x1f3fc; &#x1f389;&#x1f389;我的C语言初阶合集&#xff1a;C语言初阶合集&#xff0c;希望能…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

大话软工笔记—需求分析概述

需求分析&#xff0c;就是要对需求调研收集到的资料信息逐个地进行拆分、研究&#xff0c;从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要&#xff0c;后续设计的依据主要来自于需求分析的成果&#xff0c;包括: 项目的目的…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...

stm32wle5 lpuart DMA数据不接收

配置波特率9600时&#xff0c;需要使用外部低速晶振...

从实验室到产业:IndexTTS 在六大核心场景的落地实践

一、内容创作&#xff1a;重构数字内容生产范式 在短视频创作领域&#xff0c;IndexTTS 的语音克隆技术彻底改变了配音流程。B 站 UP 主通过 5 秒参考音频即可克隆出郭老师音色&#xff0c;生成的 “各位吴彦祖们大家好” 语音相似度达 97%&#xff0c;单条视频播放量突破百万…...

起重机起升机构的安全装置有哪些?

起重机起升机构的安全装置是保障吊装作业安全的关键部件&#xff0c;主要用于防止超载、失控、断绳等危险情况。以下是常见的安全装置及其功能和原理&#xff1a; 一、超载保护装置&#xff08;核心安全装置&#xff09; 1. 起重量限制器 功能&#xff1a;实时监测起升载荷&a…...

vxe-table vue 表格复选框多选数据,实现快捷键 Shift 批量选择功能

vxe-table vue 表格复选框多选数据&#xff0c;实现快捷键 Shift 批量选择功能 查看官网&#xff1a;https://vxetable.cn 效果 代码 通过 checkbox-config.isShift 启用批量选中,启用后按住快捷键和鼠标批量选取 <template><div><vxe-grid v-bind"gri…...

LeetCode 0386.字典序排数:细心总结条件

【LetMeFly】386.字典序排数&#xff1a;细心总结条件 力扣题目链接&#xff1a;https://leetcode.cn/problems/lexicographical-numbers/ 给你一个整数 n &#xff0c;按字典序返回范围 [1, n] 内所有整数。 你必须设计一个时间复杂度为 O(n) 且使用 O(1) 额外空间的算法。…...