当前位置: 首页 > news >正文

深度学习02-pytorch-08-自动微分模块

​​​​​​​

其实自动微分模块,就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。

反向传播的目的: 更新参数, 所以会使用到自动微分模块。

神经网络传输的数据都是 float32 类型。 

案例1:

代码功能概述:

该代码展示了如何在 PyTorch 中使用 自动微分(Autograd) 计算损失函数相对于权重 w 和偏置 b 的梯度。这是机器学习模型训练中非常重要的步骤,因为这些梯度将用于更新模型的参数,从而最小化损失函数

import torch# 1. 当x为标量时,梯度的计算
def test01():x = torch.tensor(5)  # 输入变量x为标量5# 目标值y = torch.tensor(0.)  # 目标输出y设置为0# 设置要更新的权重 和 偏置的初始值w = torch.tensor(1., requires_grad=True, dtype=torch.float32)  # 权重w初始化为1,并启用梯度计算b = torch.tensor(3., requires_grad=True, dtype=torch.float32)  # 偏置b初始化为3,并启用梯度计算# 设置网络的输出值z = x * w + b  # 计算线性模型的输出 z = x*w + b (等同于线性回归的公式)# 设置损失函数,并进行损失的计算loss = torch.nn.MSELoss()  # 使用均方误差(MSE)作为损失函数loss1 = loss(z, y)  # 计算损失,z 是模型的预测值,y 是目标值# 自动微分,计算损失函数相对于w和b的梯度loss1.backward()  # 反向传播计算梯度# backward 函数计算的梯度值会存储在张量的grad 变量中print("w的梯度", w.grad)  # 打印出损失函数对 w 的梯度print("b的梯度", b.grad)  # 打印出损失函数对 b 的梯度test01() 

w的梯度 tensor(80.)
b的梯度 tensor(16.)

代码讲解:

    1.    输入与目标值:
    •    x = torch.tensor(5):输入为 x = 5,表示输入的特征值。
    •    y = torch.tensor(0.):目标输出 y 设置为 0,这是我们希望模型最终预测得到的值。
    2.    参数的初始化:
    •    w = torch.tensor(1., requires_grad=True):初始化权重 w 为 1,requires_grad=True 启用对 w 的梯度计算。
    •    b = torch.tensor(3., requires_grad=True):初始化偏置 b 为 3,同样启用对 b 的梯度计算。
requires_grad=True 的作用是让 PyTorch 知道我们想对这些参数进行梯度计算。
    3.    模型计算:
    •    z = x * w + b:计算模型的输出,类似于线性回归的公式。z 是模型的预测输出。
    4.    损失函数:
    •    loss = torch.nn.MSELoss():选择均方误差(MSE)作为损失函数,用于衡量预测值 z 与目标值 y 之间的误差。
    •    loss1 = loss(z, y):计算损失值,z 是模型预测输出,y 是目标值。

MSE 的公式为:

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (z_i - y_i)^2

在这个例子中,由于我们只使用了一个数据点,损失计算为:

\text{Loss} = (z - y)^2 = (x \cdot w + b - 0)^2

    5.    反向传播:
    •   loss1.backward():通过调用 backward(),PyTorch 会自动计算损失函数对 w 和 b 的梯度。这个过程称为反向传播(Backpropagation)。梯度的计算基于链式法则,PyTorch 会自动追踪所有的计算操作,计算各个参数对损失的导数。


    6.    梯度输出:
    •    w.grad:存储了损失函数对 w 的梯度。
    •    b.grad:存储了损失函数对 b 的梯度。

案例2:

import torchdef test02():# 输入张量 2x5,表示 2 个样本,每个样本有 5 个特征x = torch.ones(2, 5)  # 输入数据,全部初始化为 1# 目标输出张量 2x3,表示我们希望模型预测的输出有 3 个类别y = torch.zeros(2, 3)  # 目标输出,初始化为 0# 设置可更新的权重和偏置的初始值# 权重 w 的形状是 5x3,表示输入特征为 5,输出类别为 3w = torch.randn(5, 3, requires_grad=True)  # 随机初始化权重,启用梯度计算# 偏置 b 的形状是 3,表示每个输出类别有一个偏置b = torch.randn(3, requires_grad=True)  # 随机初始化偏置,启用梯度计算# 计算网络的输出,z = x * w + b# x 的形状是 2x5,w 的形状是 5x3,矩阵乘法后的结果 z 的形状是 2x3z = torch.matmul(x, w) + b  # 矩阵乘法和偏置加法# 设置损失函数,并计算损失# 这里使用均方误差(MSE),z 是预测值,y 是目标值loss_fn = torch.nn.MSELoss()  # 损失函数为均方误差loss = loss_fn(z, y)  # 计算损失,输出一个标量值# 自动微分,计算损失函数相对于 w 和 b 的梯度loss.backward()  # 反向传播,计算梯度# 打印权重和偏置的梯度,梯度值存储在 grad 属性中print("w 的梯度:\n", w.grad)  # 打印权重 w 的梯度print("b 的梯度:\n", b.grad)  # 打印偏置 b 的梯度# 调用函数进行计算
test02()

自动微分 (Autograd) 的工作原理:

    •    PyTorch 中的 Autograd 是自动微分引擎,它会记录所有张量的计算历史,并根据这些计算图自动执行反向传播,计算参数的梯度。
    •    在向前计算过程中,PyTorch 构建了一个动态计算图(计算图是有向无环图 DAG)。当你调用 .backward(),计算图会根据链式法则从损失开始计算每个变量的梯度。
    •    计算的梯度会存储在对应张量的 .grad 属性中,然后可以使用这些梯度来更新模型的参数。

总结:

    •    w.grad 和 b.grad 的值告诉我们,若我们改变 w 或 b,损失函数会如何变化。
    •    梯度的计算对于优化模型非常重要,因为我们会使用这些梯度来更新权重和偏置,使得损失函数最小化。

PyTorch 中的 自动微分模块 是通过 autograd 实现的,这是 PyTorch 中的核心功能之一,它可以帮助用户在神经网络的训练过程中自动计算梯度。autograd 模块使得实现反向传播和梯度计算变得非常简单和高效。

核心概念

  1. Tensor: PyTorch 的张量 (Tensor) 是自动微分系统的基本单位。如果将 Tensorrequires_grad 属性设置为 True,则 PyTorch 会开始跟踪所有与该张量相关的操作,并在反向传播时自动计算该张量的梯度。

  2. Computational Graph (计算图): PyTorch 会构建一个动态图,记录张量的所有操作。这个图是有向无环图(DAG),图中的每个节点代表一个变量,边代表该变量上发生的操作。当你调用 .backward() 时,PyTorch 会根据计算图自动计算每个张量的梯度。

  3. 梯度 (Gradient): 如果一个张量参与了计算并且 requires_grad=True,在反向传播时可以通过 .grad 属性获取其梯度值。

  4. 反向传播: 通过 tensor.backward() 来执行反向传播计算张量的梯度,默认情况下会对标量进行求导。

使用案例

  1. 创建一个张量并启用梯度跟踪:

    import torch
    ​
    # 创建一个张量,并启用梯度跟踪
    x = torch.tensor([[2.0, 3.0]], requires_grad=True)

  2. 执行一些操作:

    y = x * 3
    z = y.sum()
    print(z)

  3. 反向传播:

    z.backward()  # 对 z 求导
    print(x.grad)  # 查看 x 的梯度

    输出:

    tensor([[3., 3.]])

    在这个例子中,z = x * 3z.backward() 计算了 zx 的梯度,结果为 3

PyTorch 自动微分的几个重要点:

  1. requires_grad=True: 如果需要对某个张量求导,必须将其 requires_grad 属性设置为 True,否则在反向传播时 PyTorch 不会计算该张量的梯度。

  2. grad_fn: 每个跟踪计算的张量都有一个 grad_fn 属性,代表该张量的创建方式和跟踪的操作。例如,如果你对一个张量做了加法操作,它的 grad_fn 就会显示 AddBackward0

    print(y.grad_fn)  # <MulBackward0 object at 0x...>

  3. .backward(): backward() 方法会根据计算图反向传播,自动计算梯度。

  4. 梯度累加: 每次调用 backward() 时,梯度会被累加到 .grad 中,因此在多次反向传播之前,最好手动将 .grad 清零,使用 x.grad.zero_()

autograd 的典型使用场景

  • 神经网络训练:通过 autograd,我们可以在每次迭代时计算损失函数的梯度,然后使用这些梯度更新网络的参数。

  • 自定义梯度计算:可以通过创建复杂的操作来自动推导梯度。

Example: 简单的线性回归

import torch
​
# 生成数据
x = torch.randn(10, 1, requires_grad=True)
y = 3 * x + 2
​
# 定义损失函数
loss = (x - y).pow(2).mean()
​
# 反向传播
loss.backward()
​
# 查看 x 的梯度
print(x.grad)

在这个例子中,loss.backward() 会自动计算 xloss 的梯度。

总结

  • PyTorch 的自动微分机制通过 autograd 实现,用户只需要将张量的 requires_grad 设置为 True,在执行反向传播时,PyTorch 会自动计算张量的梯度。

  • 通过自动构建计算图,autograd 能够跟踪张量上的所有操作,动态计算梯度,极大地方便了深度学习模型的训练。

相关文章:

深度学习02-pytorch-08-自动微分模块

​​​​​​​ 其实自动微分模块&#xff0c;就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。 反向传播的目的&#xff1a; 更新参数&#xff0c; 所以会使用到自动微分模块。 神经网络传输的数据都是 float32 类型。 案例1: 代码功能概述&#xff1a; 该…...

使用Python实现深度学习模型:智能宠物监控与管理

在现代家庭中,宠物已经成为许多家庭的重要成员。为了更好地照顾宠物,智能宠物监控与管理系统应运而生。本文将详细介绍如何使用Python实现一个智能宠物监控与管理系统,并结合深度学习模型来提升其功能。 一、准备工作 在开始之前,我们需要准备以下工具和材料: Python环境…...

【HTTPS】对称加密和非对称加密

HTTPS 是什么 HTTPS 是在 HTTP 的基础上&#xff0c;引入了一个加密层&#xff08;SSL&#xff09;。HTTP 是明文传输的&#xff08;不安全&#xff09; 当下所见到的大部分网站都是 HTTPS 的&#xff0c;这都是拜“运营商劫持”所赐 运营商劫持 下载⼀个“天天动听“&…...

MySQL中的LIMIT与ORDER BY关键字详解

前言 众所周知&#xff0c;LIMIT和ORDER BY在数据库中&#xff0c;是两个非常关键并且经常一起使用的SQL语句部分&#xff0c;它们在数据处理和分页展示方面发挥着重要作用。 今天就结合工作中遇到的实际问题&#xff0c;回顾一下这块的知识点。同时希望这篇文章可以帮助到正…...

Java 编码系列:集合框架(List、Set、Map 及其常用实现类)

引言 在 Java 开发中&#xff0c;集合框架是不可或缺的一部分&#xff0c;它提供了存储和操作一组对象的工具。Java 集合框架主要包括 List、Set 和 Map 接口及其常用的实现类。正确理解和使用这些集合类不仅可以提高代码的可读性和性能&#xff0c;还能避免一些常见的错误。本…...

Go进阶概览 -【7.2 泛型的使用与实现分析】

7.2 泛型的使用与实现分析 泛型是Go 1.18引入的概念&#xff0c;在引入这个概念前经过了好几年的考量最终才将这这个特性加进去。 泛型在多种语言中都是存在的&#xff0c;比如C、Java等语言中都有泛型的概念。 本节我们将针对泛型的使用、实现原理进行整体的讲解。 本节代…...

罗德岛战记游戏源码(客户端+服务端+数据库+全套源码)游戏大小9.41G

罗德岛战记游戏源码&#xff08;客户端服务端数据库全套源码&#xff09;游戏大小9.41G 下载地址&#xff1a; 通过网盘分享的文件&#xff1a;【源码】罗德岛战记游戏源码&#xff08;客户端服务端数据库全套源码&#xff09;游戏大小9.41G 链接: https://pan.baidu.com/s/1y0…...

AI+教育|拥抱AI智能科技,让课堂更生动高效

AI在教育领域的应用正逐渐成为现实&#xff0c;提供互动性强的学习体验&#xff0c;正在改变传统教育模式。AI不仅改变了传统的教学模式&#xff0c;还为教育提供了更多的可能性和解决方案。从个性化学习体验到自动化管理任务&#xff0c;AI正在全方位提升教育质量和效率。随着…...

WebServer

一、服务器代码 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <unistd.h> #define PORT 80 #define BUFFER_SIZE 1024 void ha…...

java项目之基于spring boot的多维分类的知识管理系统的设计与实现源码

项目简介 基于spring boot的多维分类的知识管理系统的设计与实现实现了以下功能&#xff1a; 基于spring boot的多维分类的知识管理系统的设计与实现的主要使用者管理员可以管理用户信息&#xff0c;知识分类&#xff0c;知识信息等&#xff0c;用户可以查看和下载管理员发布…...

go的结构体、方法、接口

结构体&#xff1a; 结构体&#xff1a;不同类型数据集合 结构体成员是由一系列的成员变量构成&#xff0c;这些成员变量也被称为“字段” 先声明一下我们的结构体&#xff1a; type Person struct {name stringage intsex string } 定义结构体法1&#xff1a; var p1 P…...

力扣第一题——删除有序数组中的重复项

给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次&#xff0c;返回删除后数组的新长度。不要使用额外的数组空间&#xff0c;你必须在 原地 修改输入数组 并在使用 O(1)额外空间的条件下完成。 示例 1&#x…...

Tuxera NTFS for Mac 2023绿色版

​ 在数字化时代&#xff0c;数据的存储和传输变得至关重要。Mac用户经常需要在Windows NTFS格式的移动硬盘上进行读写操作&#xff0c;然而&#xff0c;由于MacOS系统默认不支持NTFS的写操作&#xff0c;这就需要我们寻找一款高效的读写软件。Tuxera NTFS for Mac 2023便是其中…...

LeetCode[中等] 155. 最小栈

设计一个支持 push &#xff0c;pop &#xff0c;top 操作&#xff0c;并能在常数时间内检索到最小元素的栈。 实现 MinStack 类: MinStack() 初始化堆栈对象。void push(int val) 将元素val推入堆栈。void pop() 删除堆栈顶部的元素。int top() 获取堆栈顶部的元素。int get…...

Python青少年简明教程目录

Python青少年简明教程目录 学习编程语言时&#xff0c;会遇到“开头难”和“深入难”的问题&#xff0c;这是许多编程学习者都会经历的普遍现象。 学习Python对于青少年来说是一个很好的编程起点&#xff0c;相对容易上手入门&#xff0c;但语言特性复杂&#xff0c;应用较广&…...

Revit学习记录-版本2018【持续补充】

将墙面拆分并使用不同材料 【修改】>【几何图形】>【拆分面】(SF)&#xff0c; 然后再使用【修改】>【几何图形】>【填色】&#xff08;PT&#xff09;进行填充 楼板漏在墙外 绘制楼板时最好选择墙体绘制&#xff0c;如果标高上不显示墙体&#xff0c;可以先选…...

深度学习01-概述

深度学习是机器学习的一个子集。机器学习是实现人工智能的一种途径&#xff0c;而深度学习则是通过多层神经网络模拟人类大脑的方式进行学习和知识提取。 深度学习的关键特点&#xff1a; 1. 自动提取特征&#xff1a;与传统的机器学习方法不同&#xff0c;深度学习不需要手动…...

leetcode232. 用栈实现队列

leetcode232. 用栈实现队列 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作&#xff08;push、pop、peek、empty&#xff09;&#xff1a; 实现 MyQueue 类&#xff1a; void push(int x) 将元素 x 推到队列的末尾 int pop() 从队列的开头移除并返回元…...

智慧火灾应急救援航拍检测数据集(无人机视角)

智慧火灾应急救援。 无人机&#xff0c;直升机等航拍视角下火灾应急救援检测数据集&#xff0c;数据分别标注了火&#xff0c;人&#xff0c;车辆这三个要素内容&#xff0c;29810张高清航拍影像&#xff0c;共31GB&#xff0c;适合森林防火&#xff0c;应急救援等方向的学术研…...

eureka.client.service-url.defaultZone的坑

错误的配置 eureka: client: service-url: default-zone: http://192.168.100.10:8080/eureka正确的配置 eureka: client: service-url: defaultZone: http://192.168.100.10:8080/eureka根据错误日志堆栈打断电调试 出现两个key&#xff0c;也就是defaultZone不支持snake-c…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

【Go】3、Go语言进阶与依赖管理

前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课&#xff0c;做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程&#xff0c;它的核心机制是 Goroutine 协程、Channel 通道&#xff0c;并基于CSP&#xff08;Communicating Sequential Processes&#xff0…...

如何在网页里填写 PDF 表格?

有时候&#xff0c;你可能希望用户能在你的网站上填写 PDF 表单。然而&#xff0c;这件事并不简单&#xff0c;因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件&#xff0c;但原生并不支持编辑或填写它们。更糟的是&#xff0c;如果你想收集表单数据&#xff…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...