当前位置: 首页 > news >正文

PyTorch 中使用自动求导计算梯度

使用 PyTorch 进行自动求导和梯度计算

在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效地追踪和计算梯度。

1、梯度的定义

在数学中,梯度是一个向量,表示函数在某一点的变化率。在深度学习中,我们通常关心的是损失函数相对于模型参数的梯度。具体来说,假设我们有一个输出 out,我们计算的是损失函数对模型参数(如权重和偏置)的梯度,而不是直接对输出的梯度。

2、 简单例子

在我们接下来的例子中,我们将计算 out 相对于输入变量 x x x y y y的梯度,通常表示为 ( d out d x ) ( \frac{d \text{out}}{dx}) (dxdout) ( d out d y ) ( \frac{d \text{out}}{dy}) (dydout)

import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True)  # 输入变量 x
y = torch.tensor(3.0, requires_grad=True)  # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z**2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y  # 中间变量 zz_no_grad = z.detach()  # 创建不需要梯度的副本return f(z_no_grad) + y**3  # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y)  # 计算输出# 5. 反向传播以计算梯度
out.backward()  # 计算梯度# 6. 打印梯度
print(f"dz/dx: {x.grad}")  # 输出 x 的梯度
print(f"dz/dy: {y.grad}")  # 输出 y 的梯度
dout/dx: None
dout/dy: 27.0
import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True)  # 输入变量 x
y = torch.tensor(3.0, requires_grad=True)  # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z ** 2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y  # 中间变量 zreturn f(z) + y ** 3  # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y)  # 计算输出# 5. 反向传播以计算梯度
out.backward()  # 计算梯度# 6. 打印梯度
print(f"dout/dx: {x.grad}")  # 输出 x 的梯度
print(f"dout/dy: {y.grad}")  # 输出 y 的梯度
dout/dx: 10.0
dout/dy: 37.0

在这两个代码示例中,dout/dxdout/dy 的值存在显著差异,主要原因在于如何处理中间变量 ( z ) 以及其对最终输出 out 的影响。

结果分析

  1. 第一部分代码

    • g(x, y) 函数中,使用了 z . detach ( ) z.\text{detach}() z.detach() 创建了一个不需要梯度的副本 z no_grad z_{\text{no\_grad}} zno_grad。这意味着在计算 f ( z no_grad ) f(z_{\text{no\_grad}}) f(zno_grad) 时,PyTorch 不会将 z z z 的变化记录进计算图中。

    • 因此, z z z out \text{out} out 的影响被切断,导致
      d out d x = None \frac{d \text{out}}{d x} = \text{None} dxdout=None
      因为 x x x 的变化不会影响到 out \text{out} out 的计算。

    • 对于 y y y,计算得到的梯度为
      d out d y = 27.0 \frac{d \text{out}}{d y} = 27.0 dydout=27.0
      这是通过以下步骤得到的:

    • 输出为
      out = f ( z no_grad ) + y 3 \text{out} = f(z_{\text{no\_grad}}) + y^3 out=f(zno_grad)+y3

    • 使用链式法则:
      d out d y = 0 + 3 y 2 = 3 ( 3 2 ) = 27 \frac{d \text{out}}{d y} = 0 + 3y^2 = 3(3^2) = 27 dydout=0+3y2=3(32)=27

  2. 第二部分代码

    • g(x, y) 函数中,直接使用了 z z z 而没有使用 z . detach ( ) z.\text{detach}() z.detach()。这使得 z z z 的变化会被记录在计算图中。
    • 计算
      d out d x \frac{d \text{out}}{d x} dxdout
      时, z = x + y z = x + y z=x+y 的变化会影响到 out \text{out} out,因此计算得到的梯度为
      d out d x = 10.0 \frac{d \text{out}}{d x} = 10.0 dxdout=10.0
      这是因为:
    • f ( z ) = z 2 f(z) = z^2 f(z)=z2 的导数为
      d f ( z ) d z = 2 z \frac{d f(z)}{d z} = 2z dzdf(z)=2z
      z = 5 z = 5 z=5(当 x = 2 , y = 3 x=2, y=3 x=2,y=3 时),所以
      2 z = 10 2z = 10 2z=10
    • 对于 y y y,计算得到的梯度为
      d out d y = 37.0 \frac{d \text{out}}{d y} = 37.0 dydout=37.0
      这是因为
      d out d y = d ( f ( z ) + y 3 ) d y = 2 z ⋅ d z d y + 3 y 2 = 2 ( 5 ) ( 1 ) + 3 ( 3 2 ) = 10 + 27 = 37 \frac{d \text{out}}{d y} = \frac{d (f(z) + y^3)}{d y} = 2z \cdot \frac{d z}{d y} + 3y^2 = 2(5)(1) + 3(3^2) = 10 + 27 = 37 dydout=dyd(f(z)+y3)=2zdydz+3y2=2(5)(1)+3(32)=10+27=37

3、线性拟合及梯度计算

在深度学习中,线性回归是最基本的模型之一。通过线性回归,我们可以找到输入特征与输出之间的线性关系。在本文中,我们将使用 PyTorch 实现一个简单的线性拟合模型,定义模型为 y = a x + b x + c + d y = ax + bx + c + d y=ax+bx+c+d,并展示如何计算梯度,同时控制某些参数(如 b b b d d d)不更新梯度。
在这个模型中,我们将定义以下参数:

  • a a a:斜率,表示输入 x x x 对输出 y y y 的影响。
  • b b b:另一个斜率,表示输入 x x x 对输出 y y y 的影响,但在训练过程中不更新。
  • c c c:截距,表示当 x = 0 x=0 x=0 时的输出值。
  • d d d:一个常数项,在训练过程中不更新。

3.1、完整代码

下面是实现线性拟合的完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 创建数据
# 假设我们有一些样本数据
x_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = torch.tensor([3.0, 5.0, 7.0, 9.0, 11.0])  # 目标值# 2. 定义线性模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()self.a = nn.Parameter(torch.tensor(1.0))  # 需要更新的参数self.b = nn.Parameter(torch.tensor(0.5), requires_grad=False)  # 不需要更新的参数self.c = nn.Parameter(torch.tensor(0.0))  # 需要更新的参数self.d = nn.Parameter(torch.tensor(0.5), requires_grad=False)  # 不需要更新的参数def forward(self, x):return self.a * x + self.b * x + self.c + self.d# 3. 实例化模型
model = LinearModel()# 4. 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.005)  # 随机梯度下降优化器# 5. 训练模型
for epoch in range(5000):model.train()  # 设置模型为训练模式# 计算模型输出y_pred = model(x_data)# 计算损失loss = criterion(y_pred, y_data)# 反向传播optimizer.zero_grad()  # 清零梯度loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 每10个epoch打印一次loss和参数值if (epoch + 1) % 500 == 0:print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}, a: {model.a.item():.4f}, b: {model.b.item():.4f}, c: {model.c.item():.4f}, d: {model.d.item():.4f}')# 6. 打印最终参数
print(f'Final parameters: a = {model.a.item()}, b = {model.b.item()}, c = {model.c.item()}, d = {model.d.item()}')# 7. 绘制拟合结果
with torch.no_grad():# 生成用于绘图的 x 值x_fit = torch.linspace(0, 6, 100)  # 从 0 到 6 生成 100 个点y_fit = model(x_fit)  # 计算对应的 y 值# 绘制真实数据点
plt.scatter(x_data.numpy(), y_data.numpy(), color='red', label='True Data')
# 绘制拟合曲线
plt.plot(x_fit.numpy(), y_fit.numpy(), color='blue', label='Fitted Curve')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Linear Fit Result')
plt.legend()
plt.grid()
plt.show()

3.2、梯度计算过程

在这个例子中,我们使用了 PyTorch 的自动求导功能来计算梯度。以下是对每个参数的梯度计算过程的解释:

  1. 参数定义

    • a a a c c c 是需要更新的参数,因此它们的 requires_grad 属性默认为 True
    • b b b d d d 是不需要更新的参数,设置了 requires_grad=False,因此它们的梯度不会被计算。
  2. 损失计算

    • 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
      loss = 1 n ∑ i = 1 n ( y pred , i − y i ) 2 \text{loss} = \frac{1}{n} \sum_{i=1}^{n} (y_{\text{pred},i} - y_{i})^2 loss=n1i=1n(ypred,iyi)2
  3. 反向传播

    • 调用 loss.backward() 计算所有参数的梯度。由于 b b b d d drequires_grad 被设置为 False,因此它们的梯度不会被计算和更新。
  4. 参数更新

    • 使用优化器 optimizer.step() 更新参数。只有 a a a c c c 会被更新。
Epoch [500/100], Loss: 0.0038, a: 1.5399, b: 0.5000, c: 0.3559, d: 0.5000
Epoch [1000/100], Loss: 0.0007, a: 1.5171, b: 0.5000, c: 0.4382, d: 0.5000
Epoch [1500/100], Loss: 0.0001, a: 1.5073, b: 0.5000, c: 0.4735, d: 0.5000
Epoch [2000/100], Loss: 0.0000, a: 1.5032, b: 0.5000, c: 0.4886, d: 0.5000
Epoch [2500/100], Loss: 0.0000, a: 1.5014, b: 0.5000, c: 0.4951, d: 0.5000
Epoch [3000/100], Loss: 0.0000, a: 1.5006, b: 0.5000, c: 0.4979, d: 0.5000
Epoch [3500/100], Loss: 0.0000, a: 1.5002, b: 0.5000, c: 0.4991, d: 0.5000
Epoch [4000/100], Loss: 0.0000, a: 1.5001, b: 0.5000, c: 0.4996, d: 0.5000
Epoch [4500/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4998, d: 0.5000
Epoch [5000/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4999, d: 0.5000
Final parameters: a = 1.5000202655792236, b = 0.5, c = 0.4999275505542755, d = 0.5

在这里插入图片描述

相关文章:

PyTorch 中使用自动求导计算梯度

使用 PyTorch 进行自动求导和梯度计算 在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效…...

Oracle Instant Client 23.5安装配置完整教程

Oracle Instant Client 23.5安装配置完整教程 简介环境要求安装步骤1. 准备工作目录2. 下载Oracle Instant Client3. 解压Instant Client4. 安装依赖包5. 配置系统环境5.1 配置库文件路径5.2 配置环境变量 6. 配置Oracle钱包(可选) 验证安装常见问题解决…...

【jvm】方法区的理解

目录 1. 说明2. 方法区的演进3. 内部结构4. 作用5.内存管理 1. 说明 1.方法区用于存储已被虚拟机加载的类信息、常量、静态变量、即时编译器编译后的代码缓存等数据。它是各个线程共享的内存区域。2.尽管《Java虚拟机规范》中把方法区描述为堆的一个逻辑部分,但它却…...

ES-针对某个字段去重后-获取某个字段值的所有值

针对上面表的数据,现在想根据age分组,并获取每个分组后的name有哪些(去重后)。 select age, GROUP_CONCAT(DISTINCT(name)) from testtable group by age ; 结果: 如果想要增加排序: SELECT age, GROUP_CONCAT(DISTINCT name)…...

百度 2025届秋招提前批 文心一言大模型算法工程师

文章目录 个人情况一面/技术面 1h二面/技术面 1h三面/技术面 40min 个人情况 先说一下个人情况: 学校情况:211本中9硕,本硕学校都一般,本硕都是计算机科班,但研究方向并不是NLP,而是图表示学习论文情况&a…...

sglang 部署Qwen2VL7B,大模型部署,速度测试,深度学习

sglang 项目github仓库: https://github.com/sgl-project/sglang 项目说明书: https://sgl-project.github.io/start/install.html 资讯: https://github.com/sgl-project/sgl-learning-materials?tabreadme-ov-file#the-first-sglang…...

fastadmin操作数据库字段为json、查询遍历each、多级下拉、union、php密码设置、common常用函数的使用小技巧

数据库中遇到的操作 查询字段是json的某个值 //获取数据库中某个字段是json中得某个值,进行查询,goods是表中字段,brand_id是json中要查詢的字段。//数据类型一定要对应要不然查询不出来。$map[json_extract(goods, "$.brand_id")]…...

UniApp在Vue3的setup语法糖下自定义组件插槽详解

UniApp在 Vue3的 setup 语法糖下自定义组件插槽详解 UniApp 是一个基于 Vue.js 的跨平台开发框架&#xff0c;可以用来开发微信小程序、H5、App 等多种平台的应用。Vue 3 引入了 <script setup> 语法糖&#xff0c;使得组件的编写更加简洁和直观。本文将详细介绍如何在 …...

springboot上传下载文件

RequestMapping(“bigJson”) RestController Slf4j public class TestBigJsonController { Resource private BigjsonService bigjsonService;PostMapping("uploadJsonFile") public ResponseResult<Long> uploadJsonFile(RequestParam("file")Mul…...

Python学习从0到1 day29 Python 高阶技巧 ⑦ 正则表达式

目录 一、正则表达式 二、正则表达式的三个基础方法 1.match 从头匹配 2.search&#xff08;匹配规则&#xff0c;被匹配字符串&#xff09; 3.findall&#xff08;匹配规则&#xff0c;被匹配字符串&#xff09; 三、元字符匹配 单字符匹配&#xff1a; 注&#xff1a; 示例&a…...

机器学习-web scraping

Web Scraping&#xff0c;通常称为网络抓取或数据抓取&#xff0c;是一种通过自动化程序从网页中提取数据的技术。以下是对Web Scraping的详细解释&#xff1a; 一、定义与原理 Web Scraping是指采用技术手段从大量网页中提取结构化和非结构化信息&#xff0c;并按照一定的规…...

移远通信5G RedCap模组RG255C-CN通过中国电信5G Inside终端生态认证

近日&#xff0c;移远通信5G RedCap模组RG255C-CN荣获中国电信颁发的5G Inside终端生态认证证书。这表明&#xff0c;该产品在5G基本性能、网络兼容性、安全特性等方面已经过严格评测且表现优异&#xff0c;将进一步加速推动5G行业终端规模化应用。 中国电信5G Inside终端生态认…...

Javaweb梳理17——HTMLCSS简介

Javaweb梳理17——HTML&CSS简介 17 HTML&CSS简介17.1 HTML介绍17.2 快速入门17.3 基础标签17.3 .1 标题标签17.3.2 hr标签17.3.3 字体标签17.3.4 换行17.3.8 案例17.3.9 图片、音频、视频标签17.3.10 超链接标签17.3.11 列表标签17.3.12 表格标签17.3.11 布局标签17.3.…...

【Android、IOS、Flutter、鸿蒙、ReactNative 】自定义View

Android Java 自定义View 步骤 创建一个新的Java类&#xff0c;继承自View、ViewGroup或其他任何一个视图类。 如果需要&#xff0c;重写构造函数以支持不同的初始化方式。 重写onMeasure方法以提供正确的测量逻辑。 重写onDraw方法以实现绘制逻辑。 根据需要重写其他方法&…...

win11跳过联网激活步骤

win11跳过联网激活步骤 win11跳过联网激活步骤方法一&#xff1a;使用Shift F10快捷键&#xff08;推荐&#xff09;1. 启动Windows 112. 选择键盘布局或输入法3. 是否想要添加第二种键盘布局4. 让我们为你连接到网络5. 调出管理员模式CMD6. 耐心等待自动重启7. 启动Windows 1…...

利用c语言详细介绍下冒泡排序

软件开发过程中&#xff0c;排序算法是常规且使用众多的方法之一&#xff0c;而冒泡算法又是排序算法中最常规且基本的算法。今天我们利用c语言&#xff0c;图文详细介绍下冒泡算法。 一、图文介绍 我们输入一个数组&#xff0c;数组为【10&#xff0c;5&#xff0c;3&#xf…...

C# 面向对象

C# 面向对象编程 面向过程&#xff1a;一件事情分成多个步骤来完成。 把大象装进冰箱 (面向过程化设计思想)。走一步看一步。 1、打开冰箱门 2、把大象放进冰箱 3、关闭冰箱门 面向对象&#xff1a;以对象作为主体 把大象装进冰箱 1、抽取对象 大象 冰箱 门 &#xff0…...

android wifi扫描的capability

混合型加密android11 8155与普通linux设备扫描到的安全字段差别 android应用拿到关于wifi安全的字段&#xff1a; systembar-WifiBroadcastReceiver---- scanResult SSID: Redmi_697B, BSSID: a4:39:b3:70:8c:20, capabilities: [WPA-PSK-TKIPCCMP][WPA2-PSK-TKIPCCMP][RSN-PSK…...

datawhale 2411组队学习:模型压缩4 模型量化理论(数据类型、int8量化方法、PTQ和QWT)

文章目录 一、数据类型1.1 整型1.2 定点数1.3 浮点数1.3.1 正规浮点数&#xff08;fp32&#xff09;1.3.2 非正规浮点数&#xff08;fp32&#xff09;1.3.3 其它数据类型1.3.4 浮点数误差1.3.5 浮点数导致的模型训练问题 二、量化基本方法2.1 int8量化2.1.1 k-means 量化2.1.2 …...

数据分析-48-时间序列变点检测之在线实时数据的CPD

文章目录 1 时间序列结构1.1 变化点的定义1.2 结构变化的类型1.2.1 水平变化1.2.2 方差变化1.3 变点检测1.3.1 离线数据检测方法1.3.2 实时数据检测方法2 模拟数据2.1 模拟恒定方差数据2.2 模拟变化方差数据3 实时数据CPD3.1 SDAR学习算法3.2 Changefinder模块3.3 恒定方差CPD3…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0&#xff1a;开发环境同步测试 cookie 至 localhost&#xff0c;便于本地请求服务携带 cookie 参考地址&#xff1a;https://juejin.cn/post/7139354571712757767 里面有源码下载下来&#xff0c;加在到扩展即可使用FeHelp…...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

云原生玩法三问:构建自定义开发环境

云原生玩法三问&#xff1a;构建自定义开发环境 引言 临时运维一个古董项目&#xff0c;无文档&#xff0c;无环境&#xff0c;无交接人&#xff0c;俗称三无。 运行设备的环境老&#xff0c;本地环境版本高&#xff0c;ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...

【若依】框架项目部署笔记

参考【SpringBoot】【Vue】项目部署_no main manifest attribute, in springboot-0.0.1-sn-CSDN博客 多一个redis安装 准备工作&#xff1a; 压缩包下载&#xff1a;http://download.redis.io/releases 1. 上传压缩包&#xff0c;并进入压缩包所在目录&#xff0c;解压到目标…...

Java并发编程实战 Day 11:并发设计模式

【Java并发编程实战 Day 11】并发设计模式 开篇 这是"Java并发编程实战"系列的第11天&#xff0c;今天我们聚焦于并发设计模式。并发设计模式是解决多线程环境下常见问题的经典解决方案&#xff0c;它们不仅提供了优雅的设计思路&#xff0c;还能显著提升系统的性能…...