当前位置: 首页 > news >正文

深度学习02-pytorch-08-自动微分模块

​​​​​​​

其实自动微分模块,就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。

反向传播的目的: 更新参数, 所以会使用到自动微分模块。

神经网络传输的数据都是 float32 类型。 

案例1:

代码功能概述:

该代码展示了如何在 PyTorch 中使用 自动微分(Autograd) 计算损失函数相对于权重 w 和偏置 b 的梯度。这是机器学习模型训练中非常重要的步骤,因为这些梯度将用于更新模型的参数,从而最小化损失函数

import torch# 1. 当x为标量时,梯度的计算
def test01():x = torch.tensor(5)  # 输入变量x为标量5# 目标值y = torch.tensor(0.)  # 目标输出y设置为0# 设置要更新的权重 和 偏置的初始值w = torch.tensor(1., requires_grad=True, dtype=torch.float32)  # 权重w初始化为1,并启用梯度计算b = torch.tensor(3., requires_grad=True, dtype=torch.float32)  # 偏置b初始化为3,并启用梯度计算# 设置网络的输出值z = x * w + b  # 计算线性模型的输出 z = x*w + b (等同于线性回归的公式)# 设置损失函数,并进行损失的计算loss = torch.nn.MSELoss()  # 使用均方误差(MSE)作为损失函数loss1 = loss(z, y)  # 计算损失,z 是模型的预测值,y 是目标值# 自动微分,计算损失函数相对于w和b的梯度loss1.backward()  # 反向传播计算梯度# backward 函数计算的梯度值会存储在张量的grad 变量中print("w的梯度", w.grad)  # 打印出损失函数对 w 的梯度print("b的梯度", b.grad)  # 打印出损失函数对 b 的梯度test01() 

w的梯度 tensor(80.)
b的梯度 tensor(16.)

代码讲解:

    1.    输入与目标值:
    •    x = torch.tensor(5):输入为 x = 5,表示输入的特征值。
    •    y = torch.tensor(0.):目标输出 y 设置为 0,这是我们希望模型最终预测得到的值。
    2.    参数的初始化:
    •    w = torch.tensor(1., requires_grad=True):初始化权重 w 为 1,requires_grad=True 启用对 w 的梯度计算。
    •    b = torch.tensor(3., requires_grad=True):初始化偏置 b 为 3,同样启用对 b 的梯度计算。
requires_grad=True 的作用是让 PyTorch 知道我们想对这些参数进行梯度计算。
    3.    模型计算:
    •    z = x * w + b:计算模型的输出,类似于线性回归的公式。z 是模型的预测输出。
    4.    损失函数:
    •    loss = torch.nn.MSELoss():选择均方误差(MSE)作为损失函数,用于衡量预测值 z 与目标值 y 之间的误差。
    •    loss1 = loss(z, y):计算损失值,z 是模型预测输出,y 是目标值。

MSE 的公式为:

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (z_i - y_i)^2

在这个例子中,由于我们只使用了一个数据点,损失计算为:

\text{Loss} = (z - y)^2 = (x \cdot w + b - 0)^2

    5.    反向传播:
    •   loss1.backward():通过调用 backward(),PyTorch 会自动计算损失函数对 w 和 b 的梯度。这个过程称为反向传播(Backpropagation)。梯度的计算基于链式法则,PyTorch 会自动追踪所有的计算操作,计算各个参数对损失的导数。


    6.    梯度输出:
    •    w.grad:存储了损失函数对 w 的梯度。
    •    b.grad:存储了损失函数对 b 的梯度。

案例2:

import torchdef test02():# 输入张量 2x5,表示 2 个样本,每个样本有 5 个特征x = torch.ones(2, 5)  # 输入数据,全部初始化为 1# 目标输出张量 2x3,表示我们希望模型预测的输出有 3 个类别y = torch.zeros(2, 3)  # 目标输出,初始化为 0# 设置可更新的权重和偏置的初始值# 权重 w 的形状是 5x3,表示输入特征为 5,输出类别为 3w = torch.randn(5, 3, requires_grad=True)  # 随机初始化权重,启用梯度计算# 偏置 b 的形状是 3,表示每个输出类别有一个偏置b = torch.randn(3, requires_grad=True)  # 随机初始化偏置,启用梯度计算# 计算网络的输出,z = x * w + b# x 的形状是 2x5,w 的形状是 5x3,矩阵乘法后的结果 z 的形状是 2x3z = torch.matmul(x, w) + b  # 矩阵乘法和偏置加法# 设置损失函数,并计算损失# 这里使用均方误差(MSE),z 是预测值,y 是目标值loss_fn = torch.nn.MSELoss()  # 损失函数为均方误差loss = loss_fn(z, y)  # 计算损失,输出一个标量值# 自动微分,计算损失函数相对于 w 和 b 的梯度loss.backward()  # 反向传播,计算梯度# 打印权重和偏置的梯度,梯度值存储在 grad 属性中print("w 的梯度:\n", w.grad)  # 打印权重 w 的梯度print("b 的梯度:\n", b.grad)  # 打印偏置 b 的梯度# 调用函数进行计算
test02()

自动微分 (Autograd) 的工作原理:

    •    PyTorch 中的 Autograd 是自动微分引擎,它会记录所有张量的计算历史,并根据这些计算图自动执行反向传播,计算参数的梯度。
    •    在向前计算过程中,PyTorch 构建了一个动态计算图(计算图是有向无环图 DAG)。当你调用 .backward(),计算图会根据链式法则从损失开始计算每个变量的梯度。
    •    计算的梯度会存储在对应张量的 .grad 属性中,然后可以使用这些梯度来更新模型的参数。

总结:

    •    w.grad 和 b.grad 的值告诉我们,若我们改变 w 或 b,损失函数会如何变化。
    •    梯度的计算对于优化模型非常重要,因为我们会使用这些梯度来更新权重和偏置,使得损失函数最小化。

PyTorch 中的 自动微分模块 是通过 autograd 实现的,这是 PyTorch 中的核心功能之一,它可以帮助用户在神经网络的训练过程中自动计算梯度。autograd 模块使得实现反向传播和梯度计算变得非常简单和高效。

核心概念

  1. Tensor: PyTorch 的张量 (Tensor) 是自动微分系统的基本单位。如果将 Tensorrequires_grad 属性设置为 True,则 PyTorch 会开始跟踪所有与该张量相关的操作,并在反向传播时自动计算该张量的梯度。

  2. Computational Graph (计算图): PyTorch 会构建一个动态图,记录张量的所有操作。这个图是有向无环图(DAG),图中的每个节点代表一个变量,边代表该变量上发生的操作。当你调用 .backward() 时,PyTorch 会根据计算图自动计算每个张量的梯度。

  3. 梯度 (Gradient): 如果一个张量参与了计算并且 requires_grad=True,在反向传播时可以通过 .grad 属性获取其梯度值。

  4. 反向传播: 通过 tensor.backward() 来执行反向传播计算张量的梯度,默认情况下会对标量进行求导。

使用案例

  1. 创建一个张量并启用梯度跟踪:

    import torch
    ​
    # 创建一个张量,并启用梯度跟踪
    x = torch.tensor([[2.0, 3.0]], requires_grad=True)

  2. 执行一些操作:

    y = x * 3
    z = y.sum()
    print(z)

  3. 反向传播:

    z.backward()  # 对 z 求导
    print(x.grad)  # 查看 x 的梯度

    输出:

    tensor([[3., 3.]])

    在这个例子中,z = x * 3z.backward() 计算了 zx 的梯度,结果为 3

PyTorch 自动微分的几个重要点:

  1. requires_grad=True: 如果需要对某个张量求导,必须将其 requires_grad 属性设置为 True,否则在反向传播时 PyTorch 不会计算该张量的梯度。

  2. grad_fn: 每个跟踪计算的张量都有一个 grad_fn 属性,代表该张量的创建方式和跟踪的操作。例如,如果你对一个张量做了加法操作,它的 grad_fn 就会显示 AddBackward0

    print(y.grad_fn)  # <MulBackward0 object at 0x...>

  3. .backward(): backward() 方法会根据计算图反向传播,自动计算梯度。

  4. 梯度累加: 每次调用 backward() 时,梯度会被累加到 .grad 中,因此在多次反向传播之前,最好手动将 .grad 清零,使用 x.grad.zero_()

autograd 的典型使用场景

  • 神经网络训练:通过 autograd,我们可以在每次迭代时计算损失函数的梯度,然后使用这些梯度更新网络的参数。

  • 自定义梯度计算:可以通过创建复杂的操作来自动推导梯度。

Example: 简单的线性回归

import torch
​
# 生成数据
x = torch.randn(10, 1, requires_grad=True)
y = 3 * x + 2
​
# 定义损失函数
loss = (x - y).pow(2).mean()
​
# 反向传播
loss.backward()
​
# 查看 x 的梯度
print(x.grad)

在这个例子中,loss.backward() 会自动计算 xloss 的梯度。

总结

  • PyTorch 的自动微分机制通过 autograd 实现,用户只需要将张量的 requires_grad 设置为 True,在执行反向传播时,PyTorch 会自动计算张量的梯度。

  • 通过自动构建计算图,autograd 能够跟踪张量上的所有操作,动态计算梯度,极大地方便了深度学习模型的训练。

相关文章:

深度学习02-pytorch-08-自动微分模块

​​​​​​​ 其实自动微分模块&#xff0c;就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。 反向传播的目的&#xff1a; 更新参数&#xff0c; 所以会使用到自动微分模块。 神经网络传输的数据都是 float32 类型。 案例1: 代码功能概述&#xff1a; 该…...

使用Python实现深度学习模型:智能宠物监控与管理

在现代家庭中,宠物已经成为许多家庭的重要成员。为了更好地照顾宠物,智能宠物监控与管理系统应运而生。本文将详细介绍如何使用Python实现一个智能宠物监控与管理系统,并结合深度学习模型来提升其功能。 一、准备工作 在开始之前,我们需要准备以下工具和材料: Python环境…...

【HTTPS】对称加密和非对称加密

HTTPS 是什么 HTTPS 是在 HTTP 的基础上&#xff0c;引入了一个加密层&#xff08;SSL&#xff09;。HTTP 是明文传输的&#xff08;不安全&#xff09; 当下所见到的大部分网站都是 HTTPS 的&#xff0c;这都是拜“运营商劫持”所赐 运营商劫持 下载⼀个“天天动听“&…...

MySQL中的LIMIT与ORDER BY关键字详解

前言 众所周知&#xff0c;LIMIT和ORDER BY在数据库中&#xff0c;是两个非常关键并且经常一起使用的SQL语句部分&#xff0c;它们在数据处理和分页展示方面发挥着重要作用。 今天就结合工作中遇到的实际问题&#xff0c;回顾一下这块的知识点。同时希望这篇文章可以帮助到正…...

Java 编码系列:集合框架(List、Set、Map 及其常用实现类)

引言 在 Java 开发中&#xff0c;集合框架是不可或缺的一部分&#xff0c;它提供了存储和操作一组对象的工具。Java 集合框架主要包括 List、Set 和 Map 接口及其常用的实现类。正确理解和使用这些集合类不仅可以提高代码的可读性和性能&#xff0c;还能避免一些常见的错误。本…...

Go进阶概览 -【7.2 泛型的使用与实现分析】

7.2 泛型的使用与实现分析 泛型是Go 1.18引入的概念&#xff0c;在引入这个概念前经过了好几年的考量最终才将这这个特性加进去。 泛型在多种语言中都是存在的&#xff0c;比如C、Java等语言中都有泛型的概念。 本节我们将针对泛型的使用、实现原理进行整体的讲解。 本节代…...

罗德岛战记游戏源码(客户端+服务端+数据库+全套源码)游戏大小9.41G

罗德岛战记游戏源码&#xff08;客户端服务端数据库全套源码&#xff09;游戏大小9.41G 下载地址&#xff1a; 通过网盘分享的文件&#xff1a;【源码】罗德岛战记游戏源码&#xff08;客户端服务端数据库全套源码&#xff09;游戏大小9.41G 链接: https://pan.baidu.com/s/1y0…...

AI+教育|拥抱AI智能科技,让课堂更生动高效

AI在教育领域的应用正逐渐成为现实&#xff0c;提供互动性强的学习体验&#xff0c;正在改变传统教育模式。AI不仅改变了传统的教学模式&#xff0c;还为教育提供了更多的可能性和解决方案。从个性化学习体验到自动化管理任务&#xff0c;AI正在全方位提升教育质量和效率。随着…...

WebServer

一、服务器代码 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <unistd.h> #define PORT 80 #define BUFFER_SIZE 1024 void ha…...

java项目之基于spring boot的多维分类的知识管理系统的设计与实现源码

项目简介 基于spring boot的多维分类的知识管理系统的设计与实现实现了以下功能&#xff1a; 基于spring boot的多维分类的知识管理系统的设计与实现的主要使用者管理员可以管理用户信息&#xff0c;知识分类&#xff0c;知识信息等&#xff0c;用户可以查看和下载管理员发布…...

go的结构体、方法、接口

结构体&#xff1a; 结构体&#xff1a;不同类型数据集合 结构体成员是由一系列的成员变量构成&#xff0c;这些成员变量也被称为“字段” 先声明一下我们的结构体&#xff1a; type Person struct {name stringage intsex string } 定义结构体法1&#xff1a; var p1 P…...

力扣第一题——删除有序数组中的重复项

给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次&#xff0c;返回删除后数组的新长度。不要使用额外的数组空间&#xff0c;你必须在 原地 修改输入数组 并在使用 O(1)额外空间的条件下完成。 示例 1&#x…...

Tuxera NTFS for Mac 2023绿色版

​ 在数字化时代&#xff0c;数据的存储和传输变得至关重要。Mac用户经常需要在Windows NTFS格式的移动硬盘上进行读写操作&#xff0c;然而&#xff0c;由于MacOS系统默认不支持NTFS的写操作&#xff0c;这就需要我们寻找一款高效的读写软件。Tuxera NTFS for Mac 2023便是其中…...

LeetCode[中等] 155. 最小栈

设计一个支持 push &#xff0c;pop &#xff0c;top 操作&#xff0c;并能在常数时间内检索到最小元素的栈。 实现 MinStack 类: MinStack() 初始化堆栈对象。void push(int val) 将元素val推入堆栈。void pop() 删除堆栈顶部的元素。int top() 获取堆栈顶部的元素。int get…...

Python青少年简明教程目录

Python青少年简明教程目录 学习编程语言时&#xff0c;会遇到“开头难”和“深入难”的问题&#xff0c;这是许多编程学习者都会经历的普遍现象。 学习Python对于青少年来说是一个很好的编程起点&#xff0c;相对容易上手入门&#xff0c;但语言特性复杂&#xff0c;应用较广&…...

Revit学习记录-版本2018【持续补充】

将墙面拆分并使用不同材料 【修改】>【几何图形】>【拆分面】(SF)&#xff0c; 然后再使用【修改】>【几何图形】>【填色】&#xff08;PT&#xff09;进行填充 楼板漏在墙外 绘制楼板时最好选择墙体绘制&#xff0c;如果标高上不显示墙体&#xff0c;可以先选…...

深度学习01-概述

深度学习是机器学习的一个子集。机器学习是实现人工智能的一种途径&#xff0c;而深度学习则是通过多层神经网络模拟人类大脑的方式进行学习和知识提取。 深度学习的关键特点&#xff1a; 1. 自动提取特征&#xff1a;与传统的机器学习方法不同&#xff0c;深度学习不需要手动…...

leetcode232. 用栈实现队列

leetcode232. 用栈实现队列 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作&#xff08;push、pop、peek、empty&#xff09;&#xff1a; 实现 MyQueue 类&#xff1a; void push(int x) 将元素 x 推到队列的末尾 int pop() 从队列的开头移除并返回元…...

智慧火灾应急救援航拍检测数据集(无人机视角)

智慧火灾应急救援。 无人机&#xff0c;直升机等航拍视角下火灾应急救援检测数据集&#xff0c;数据分别标注了火&#xff0c;人&#xff0c;车辆这三个要素内容&#xff0c;29810张高清航拍影像&#xff0c;共31GB&#xff0c;适合森林防火&#xff0c;应急救援等方向的学术研…...

eureka.client.service-url.defaultZone的坑

错误的配置 eureka: client: service-url: default-zone: http://192.168.100.10:8080/eureka正确的配置 eureka: client: service-url: defaultZone: http://192.168.100.10:8080/eureka根据错误日志堆栈打断电调试 出现两个key&#xff0c;也就是defaultZone不支持snake-c…...

统信服务器操作系统【d版字符系统升级到dde图形化】配置方法

统信服务器操作系统d版本上由字符系统升级到 dde 桌面系统的过程 文章目录 一、准备环境二、功能描述安装步骤1. lightdm 安装2. dde 安装 一、准备环境 适用版本&#xff1a;■UOS服务器操作系统d版 适用架构&#xff1a;■ARM64、AMD64、MIPS64 网络&#xff1a;连接互联网…...

学习IEC 62055付费系统标准

1.IEC 62055 国际标准 IEC 62055 是目前关于付费系统的唯一国际标准&#xff0c;涵盖了付费系统、CIS 用户信息系统、售电系统、传输介质、数据传输标准、预付费电能表以及接口标准等内容。 IEC 62055-21 标准化架构IEC 62055-31 1 级和 2 级有功预付费电能表IEC 62055-41 STS…...

如何在Markdown写文章上传到wordpress保证图片不丢失

如何在Markdown写文章上传到wordpress保证图片不丢失 写文日期,2023-11-16 引文 众所周知markdown是一款nb的笔记软件&#xff0c;本篇文章讲解如何在markdown编写文件后上传至wordpress论坛。并且保证图片不丢失&#xff08;将图片上传至云端而非本地方法&#xff09; 一&…...

html,css基础知识点笔记(二)

9.18&#xff08;二&#xff09; 本文主要教列表的样式设计 1&#xff09;文本溢出 效果图 文字限制一行显示几个字&#xff0c;多余打点 line-height: 1.8em; white-space: nowrap; width: 40em; overflow: hidden; text-overflow: ellipsis;em表示一个文字的大小单位&…...

(k8s)kubernetes 部署Promehteus学习之路

整个Prometheus生态包含多个组件&#xff0c;除了Prometheus server组件其余都是可选的 Prometheus Server&#xff1a;主要的核心组件&#xff0c;用来收集和存储时间序列数据。 Client Library:&#xff1a;客户端库&#xff0c;为需要监控的服务生成相应的 metrics 并暴露给…...

初写MySQL四张表:(3/4)

我们已经完成了四张表的创建&#xff0c;学会了创建表和查看表字段信息的语句。 初写MySQL四张表:(1/4)-CSDN博客 初写MySQL四张表:(2/4)-CSDN博客 接下来&#xff0c;我们来学点对数据的操作&#xff1a;增 删 查&#xff08;一部分&#xff09;改 先来看这四张表以及相关…...

【Java】线程暂停比拼:wait() 和 sleep()的较量

欢迎浏览高耳机的博客 希望我们彼此都有更好的收获 感谢三连支持&#xff01; 在Java多线程编程中&#xff0c;合理地控制线程的执行是至关重要的。wait()和sleep()是两个常用的方法&#xff0c;它们都可以用来暂停线程的执行&#xff0c;但它们之间存在着显著的差异。本文将详…...

CQRS模型解析

简介 CQRS中文意思为命令于查询职责分离&#xff0c;我们可以将其了解成读写分离的思想。分为两个部分 业务侧和数据侧&#xff0c;业务侧主要执行的就是数据的写操作&#xff0c;而数据侧主要执行的就是数据的读操作。当然两侧的数据库可以是不同的。目前最为常用的CQRS思想方…...

qt-C++笔记之作用等同的宏和关键字

qt-C笔记之作用等同的宏和关键字 code review! Q_SLOT 和 slots&#xff1a; Q_SLOT是slots的替代宏&#xff0c;用于声明槽函数。 Q_SIGNAL 和 signals&#xff1a; Q_SIGNAL类似于signals&#xff0c;用于声明信号。 Q_EMIT 和 emit&#xff1a; Q_EMIT 是 Qt 中用于发射…...

java(3)数组的定义与使用

目录 1.前言 2.正文 2.1数组的概念 2.2数组的创建与初始化 2.2.1数组的创建 2.2.1数组的静态初始化 2.2.2数组的动态初始化 2.3数组是引用类型 2.3.1引用类型与基本类型区别 2.3.2认识NULL 2.4二维数组 2.5数组的基本运用 2.5.1数组的遍历 2.5.2数组转字符串 2.…...