PyTorch 中使用自动求导计算梯度
使用 PyTorch 进行自动求导和梯度计算
在 PyTorch 中,张量的 requires_grad
属性决定了是否需要计算该张量的梯度。设置为 True
的张量会在计算过程中记录操作,以便在调用 .backward()
方法时自动计算梯度。通过构建计算图,PyTorch 能够有效地追踪和计算梯度。
1、梯度的定义
在数学中,梯度是一个向量,表示函数在某一点的变化率。在深度学习中,我们通常关心的是损失函数相对于模型参数的梯度。具体来说,假设我们有一个输出 out
,我们计算的是损失函数对模型参数(如权重和偏置)的梯度,而不是直接对输出的梯度。
2、 简单例子
在我们接下来的例子中,我们将计算 out
相对于输入变量 x x x 和 y y y的梯度,通常表示为 ( d out d x ) ( \frac{d \text{out}}{dx}) (dxdout)和 ( d out d y ) ( \frac{d \text{out}}{dy}) (dydout)
import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True) # 输入变量 x
y = torch.tensor(3.0, requires_grad=True) # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z**2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y # 中间变量 zz_no_grad = z.detach() # 创建不需要梯度的副本return f(z_no_grad) + y**3 # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y) # 计算输出# 5. 反向传播以计算梯度
out.backward() # 计算梯度# 6. 打印梯度
print(f"dz/dx: {x.grad}") # 输出 x 的梯度
print(f"dz/dy: {y.grad}") # 输出 y 的梯度
dout/dx: None
dout/dy: 27.0
import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True) # 输入变量 x
y = torch.tensor(3.0, requires_grad=True) # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z ** 2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y # 中间变量 zreturn f(z) + y ** 3 # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y) # 计算输出# 5. 反向传播以计算梯度
out.backward() # 计算梯度# 6. 打印梯度
print(f"dout/dx: {x.grad}") # 输出 x 的梯度
print(f"dout/dy: {y.grad}") # 输出 y 的梯度
dout/dx: 10.0
dout/dy: 37.0
在这两个代码示例中,dout/dx
和 dout/dy
的值存在显著差异,主要原因在于如何处理中间变量 ( z ) 以及其对最终输出 out
的影响。
结果分析
-
第一部分代码:
-
在
g(x, y)
函数中,使用了 z . detach ( ) z.\text{detach}() z.detach() 创建了一个不需要梯度的副本 z no_grad z_{\text{no\_grad}} zno_grad。这意味着在计算 f ( z no_grad ) f(z_{\text{no\_grad}}) f(zno_grad) 时,PyTorch 不会将 z z z 的变化记录进计算图中。 -
因此, z z z 对 out \text{out} out 的影响被切断,导致
d out d x = None \frac{d \text{out}}{d x} = \text{None} dxdout=None
因为 x x x 的变化不会影响到 out \text{out} out 的计算。 -
对于 y y y,计算得到的梯度为
d out d y = 27.0 \frac{d \text{out}}{d y} = 27.0 dydout=27.0
这是通过以下步骤得到的: -
输出为
out = f ( z no_grad ) + y 3 \text{out} = f(z_{\text{no\_grad}}) + y^3 out=f(zno_grad)+y3 -
使用链式法则:
d out d y = 0 + 3 y 2 = 3 ( 3 2 ) = 27 \frac{d \text{out}}{d y} = 0 + 3y^2 = 3(3^2) = 27 dydout=0+3y2=3(32)=27
-
-
第二部分代码:
- 在
g(x, y)
函数中,直接使用了 z z z 而没有使用 z . detach ( ) z.\text{detach}() z.detach()。这使得 z z z 的变化会被记录在计算图中。 - 计算
d out d x \frac{d \text{out}}{d x} dxdout
时, z = x + y z = x + y z=x+y 的变化会影响到 out \text{out} out,因此计算得到的梯度为
d out d x = 10.0 \frac{d \text{out}}{d x} = 10.0 dxdout=10.0
这是因为: - f ( z ) = z 2 f(z) = z^2 f(z)=z2 的导数为
d f ( z ) d z = 2 z \frac{d f(z)}{d z} = 2z dzdf(z)=2z
当 z = 5 z = 5 z=5(当 x = 2 , y = 3 x=2, y=3 x=2,y=3 时),所以
2 z = 10 2z = 10 2z=10 - 对于 y y y,计算得到的梯度为
d out d y = 37.0 \frac{d \text{out}}{d y} = 37.0 dydout=37.0
这是因为
d out d y = d ( f ( z ) + y 3 ) d y = 2 z ⋅ d z d y + 3 y 2 = 2 ( 5 ) ( 1 ) + 3 ( 3 2 ) = 10 + 27 = 37 \frac{d \text{out}}{d y} = \frac{d (f(z) + y^3)}{d y} = 2z \cdot \frac{d z}{d y} + 3y^2 = 2(5)(1) + 3(3^2) = 10 + 27 = 37 dydout=dyd(f(z)+y3)=2z⋅dydz+3y2=2(5)(1)+3(32)=10+27=37
- 在
3、线性拟合及梯度计算
在深度学习中,线性回归是最基本的模型之一。通过线性回归,我们可以找到输入特征与输出之间的线性关系。在本文中,我们将使用 PyTorch 实现一个简单的线性拟合模型,定义模型为 y = a x + b x + c + d y = ax + bx + c + d y=ax+bx+c+d,并展示如何计算梯度,同时控制某些参数(如 b b b 和 d d d)不更新梯度。
在这个模型中,我们将定义以下参数:
- a a a:斜率,表示输入 x x x 对输出 y y y 的影响。
- b b b:另一个斜率,表示输入 x x x 对输出 y y y 的影响,但在训练过程中不更新。
- c c c:截距,表示当 x = 0 x=0 x=0 时的输出值。
- d d d:一个常数项,在训练过程中不更新。
3.1、完整代码
下面是实现线性拟合的完整代码:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 创建数据
# 假设我们有一些样本数据
x_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = torch.tensor([3.0, 5.0, 7.0, 9.0, 11.0]) # 目标值# 2. 定义线性模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()self.a = nn.Parameter(torch.tensor(1.0)) # 需要更新的参数self.b = nn.Parameter(torch.tensor(0.5), requires_grad=False) # 不需要更新的参数self.c = nn.Parameter(torch.tensor(0.0)) # 需要更新的参数self.d = nn.Parameter(torch.tensor(0.5), requires_grad=False) # 不需要更新的参数def forward(self, x):return self.a * x + self.b * x + self.c + self.d# 3. 实例化模型
model = LinearModel()# 4. 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.005) # 随机梯度下降优化器# 5. 训练模型
for epoch in range(5000):model.train() # 设置模型为训练模式# 计算模型输出y_pred = model(x_data)# 计算损失loss = criterion(y_pred, y_data)# 反向传播optimizer.zero_grad() # 清零梯度loss.backward() # 计算梯度optimizer.step() # 更新参数# 每10个epoch打印一次loss和参数值if (epoch + 1) % 500 == 0:print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}, a: {model.a.item():.4f}, b: {model.b.item():.4f}, c: {model.c.item():.4f}, d: {model.d.item():.4f}')# 6. 打印最终参数
print(f'Final parameters: a = {model.a.item()}, b = {model.b.item()}, c = {model.c.item()}, d = {model.d.item()}')# 7. 绘制拟合结果
with torch.no_grad():# 生成用于绘图的 x 值x_fit = torch.linspace(0, 6, 100) # 从 0 到 6 生成 100 个点y_fit = model(x_fit) # 计算对应的 y 值# 绘制真实数据点
plt.scatter(x_data.numpy(), y_data.numpy(), color='red', label='True Data')
# 绘制拟合曲线
plt.plot(x_fit.numpy(), y_fit.numpy(), color='blue', label='Fitted Curve')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Linear Fit Result')
plt.legend()
plt.grid()
plt.show()
3.2、梯度计算过程
在这个例子中,我们使用了 PyTorch 的自动求导功能来计算梯度。以下是对每个参数的梯度计算过程的解释:
-
参数定义:
- a a a 和 c c c 是需要更新的参数,因此它们的
requires_grad
属性默认为True
。 - b b b 和 d d d 是不需要更新的参数,设置了
requires_grad=False
,因此它们的梯度不会被计算。
- a a a 和 c c c 是需要更新的参数,因此它们的
-
损失计算:
- 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
loss = 1 n ∑ i = 1 n ( y pred , i − y i ) 2 \text{loss} = \frac{1}{n} \sum_{i=1}^{n} (y_{\text{pred},i} - y_{i})^2 loss=n1i=1∑n(ypred,i−yi)2
- 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
-
反向传播:
- 调用
loss.backward()
计算所有参数的梯度。由于 b b b 和 d d d 的requires_grad
被设置为False
,因此它们的梯度不会被计算和更新。
- 调用
-
参数更新:
- 使用优化器
optimizer.step()
更新参数。只有 a a a 和 c c c 会被更新。
- 使用优化器
Epoch [500/100], Loss: 0.0038, a: 1.5399, b: 0.5000, c: 0.3559, d: 0.5000
Epoch [1000/100], Loss: 0.0007, a: 1.5171, b: 0.5000, c: 0.4382, d: 0.5000
Epoch [1500/100], Loss: 0.0001, a: 1.5073, b: 0.5000, c: 0.4735, d: 0.5000
Epoch [2000/100], Loss: 0.0000, a: 1.5032, b: 0.5000, c: 0.4886, d: 0.5000
Epoch [2500/100], Loss: 0.0000, a: 1.5014, b: 0.5000, c: 0.4951, d: 0.5000
Epoch [3000/100], Loss: 0.0000, a: 1.5006, b: 0.5000, c: 0.4979, d: 0.5000
Epoch [3500/100], Loss: 0.0000, a: 1.5002, b: 0.5000, c: 0.4991, d: 0.5000
Epoch [4000/100], Loss: 0.0000, a: 1.5001, b: 0.5000, c: 0.4996, d: 0.5000
Epoch [4500/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4998, d: 0.5000
Epoch [5000/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4999, d: 0.5000
Final parameters: a = 1.5000202655792236, b = 0.5, c = 0.4999275505542755, d = 0.5
相关文章:

PyTorch 中使用自动求导计算梯度
使用 PyTorch 进行自动求导和梯度计算 在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效…...
Oracle Instant Client 23.5安装配置完整教程
Oracle Instant Client 23.5安装配置完整教程 简介环境要求安装步骤1. 准备工作目录2. 下载Oracle Instant Client3. 解压Instant Client4. 安装依赖包5. 配置系统环境5.1 配置库文件路径5.2 配置环境变量 6. 配置Oracle钱包(可选) 验证安装常见问题解决…...
【jvm】方法区的理解
目录 1. 说明2. 方法区的演进3. 内部结构4. 作用5.内存管理 1. 说明 1.方法区用于存储已被虚拟机加载的类信息、常量、静态变量、即时编译器编译后的代码缓存等数据。它是各个线程共享的内存区域。2.尽管《Java虚拟机规范》中把方法区描述为堆的一个逻辑部分,但它却…...

ES-针对某个字段去重后-获取某个字段值的所有值
针对上面表的数据,现在想根据age分组,并获取每个分组后的name有哪些(去重后)。 select age, GROUP_CONCAT(DISTINCT(name)) from testtable group by age ; 结果: 如果想要增加排序: SELECT age, GROUP_CONCAT(DISTINCT name)…...
百度 2025届秋招提前批 文心一言大模型算法工程师
文章目录 个人情况一面/技术面 1h二面/技术面 1h三面/技术面 40min 个人情况 先说一下个人情况: 学校情况:211本中9硕,本硕学校都一般,本硕都是计算机科班,但研究方向并不是NLP,而是图表示学习论文情况&a…...

sglang 部署Qwen2VL7B,大模型部署,速度测试,深度学习
sglang 项目github仓库: https://github.com/sgl-project/sglang 项目说明书: https://sgl-project.github.io/start/install.html 资讯: https://github.com/sgl-project/sgl-learning-materials?tabreadme-ov-file#the-first-sglang…...
fastadmin操作数据库字段为json、查询遍历each、多级下拉、union、php密码设置、common常用函数的使用小技巧
数据库中遇到的操作 查询字段是json的某个值 //获取数据库中某个字段是json中得某个值,进行查询,goods是表中字段,brand_id是json中要查詢的字段。//数据类型一定要对应要不然查询不出来。$map[json_extract(goods, "$.brand_id")]…...
UniApp在Vue3的setup语法糖下自定义组件插槽详解
UniApp在 Vue3的 setup 语法糖下自定义组件插槽详解 UniApp 是一个基于 Vue.js 的跨平台开发框架,可以用来开发微信小程序、H5、App 等多种平台的应用。Vue 3 引入了 <script setup> 语法糖,使得组件的编写更加简洁和直观。本文将详细介绍如何在 …...
springboot上传下载文件
RequestMapping(“bigJson”) RestController Slf4j public class TestBigJsonController { Resource private BigjsonService bigjsonService;PostMapping("uploadJsonFile") public ResponseResult<Long> uploadJsonFile(RequestParam("file")Mul…...

Python学习从0到1 day29 Python 高阶技巧 ⑦ 正则表达式
目录 一、正则表达式 二、正则表达式的三个基础方法 1.match 从头匹配 2.search(匹配规则,被匹配字符串) 3.findall(匹配规则,被匹配字符串) 三、元字符匹配 单字符匹配: 注: 示例&a…...
机器学习-web scraping
Web Scraping,通常称为网络抓取或数据抓取,是一种通过自动化程序从网页中提取数据的技术。以下是对Web Scraping的详细解释: 一、定义与原理 Web Scraping是指采用技术手段从大量网页中提取结构化和非结构化信息,并按照一定的规…...

移远通信5G RedCap模组RG255C-CN通过中国电信5G Inside终端生态认证
近日,移远通信5G RedCap模组RG255C-CN荣获中国电信颁发的5G Inside终端生态认证证书。这表明,该产品在5G基本性能、网络兼容性、安全特性等方面已经过严格评测且表现优异,将进一步加速推动5G行业终端规模化应用。 中国电信5G Inside终端生态认…...

Javaweb梳理17——HTMLCSS简介
Javaweb梳理17——HTML&CSS简介 17 HTML&CSS简介17.1 HTML介绍17.2 快速入门17.3 基础标签17.3 .1 标题标签17.3.2 hr标签17.3.3 字体标签17.3.4 换行17.3.8 案例17.3.9 图片、音频、视频标签17.3.10 超链接标签17.3.11 列表标签17.3.12 表格标签17.3.11 布局标签17.3.…...
【Android、IOS、Flutter、鸿蒙、ReactNative 】自定义View
Android Java 自定义View 步骤 创建一个新的Java类,继承自View、ViewGroup或其他任何一个视图类。 如果需要,重写构造函数以支持不同的初始化方式。 重写onMeasure方法以提供正确的测量逻辑。 重写onDraw方法以实现绘制逻辑。 根据需要重写其他方法&…...

win11跳过联网激活步骤
win11跳过联网激活步骤 win11跳过联网激活步骤方法一:使用Shift F10快捷键(推荐)1. 启动Windows 112. 选择键盘布局或输入法3. 是否想要添加第二种键盘布局4. 让我们为你连接到网络5. 调出管理员模式CMD6. 耐心等待自动重启7. 启动Windows 1…...

利用c语言详细介绍下冒泡排序
软件开发过程中,排序算法是常规且使用众多的方法之一,而冒泡算法又是排序算法中最常规且基本的算法。今天我们利用c语言,图文详细介绍下冒泡算法。 一、图文介绍 我们输入一个数组,数组为【10,5,3…...

C# 面向对象
C# 面向对象编程 面向过程:一件事情分成多个步骤来完成。 把大象装进冰箱 (面向过程化设计思想)。走一步看一步。 1、打开冰箱门 2、把大象放进冰箱 3、关闭冰箱门 面向对象:以对象作为主体 把大象装进冰箱 1、抽取对象 大象 冰箱 门 ࿰…...
android wifi扫描的capability
混合型加密android11 8155与普通linux设备扫描到的安全字段差别 android应用拿到关于wifi安全的字段: systembar-WifiBroadcastReceiver---- scanResult SSID: Redmi_697B, BSSID: a4:39:b3:70:8c:20, capabilities: [WPA-PSK-TKIPCCMP][WPA2-PSK-TKIPCCMP][RSN-PSK…...

datawhale 2411组队学习:模型压缩4 模型量化理论(数据类型、int8量化方法、PTQ和QWT)
文章目录 一、数据类型1.1 整型1.2 定点数1.3 浮点数1.3.1 正规浮点数(fp32)1.3.2 非正规浮点数(fp32)1.3.3 其它数据类型1.3.4 浮点数误差1.3.5 浮点数导致的模型训练问题 二、量化基本方法2.1 int8量化2.1.1 k-means 量化2.1.2 …...

数据分析-48-时间序列变点检测之在线实时数据的CPD
文章目录 1 时间序列结构1.1 变化点的定义1.2 结构变化的类型1.2.1 水平变化1.2.2 方差变化1.3 变点检测1.3.1 离线数据检测方法1.3.2 实时数据检测方法2 模拟数据2.1 模拟恒定方差数据2.2 模拟变化方差数据3 实时数据CPD3.1 SDAR学习算法3.2 Changefinder模块3.3 恒定方差CPD3…...

基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
Leetcode 3577. Count the Number of Computer Unlocking Permutations
Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词
Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...

ArcGIS Pro制作水平横向图例+多级标注
今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
在机器学习的回归分析中,损失函数的选择对模型性能具有决定性影响。均方误差(MSE)作为经典的损失函数,在处理干净数据时表现优异,但在面对包含异常值的噪声数据时,其对大误差的二次惩罚机制往往导致模型参数…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
【Go语言基础【13】】函数、闭包、方法
文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...

安全突围:重塑内生安全体系:齐向东在2025年BCS大会的演讲
文章目录 前言第一部分:体系力量是突围之钥第一重困境是体系思想落地不畅。第二重困境是大小体系融合瓶颈。第三重困境是“小体系”运营梗阻。 第二部分:体系矛盾是突围之障一是数据孤岛的障碍。二是投入不足的障碍。三是新旧兼容难的障碍。 第三部分&am…...