当前位置: 首页 > news >正文

pytorch求导

pytorch求导的初步认识

requires_grad

tensor(data, dtype=None, device=None, requires_grad=False)

requires_grad是torch.tensor类的一个属性。如果设置为True,它会告诉PyTorch跟踪对该张量的操作,允许在反向传播期间计算梯度。

x.requires_grad    判断一个tensor是否可以求导,返回布尔值

叶子变量-leaf variable

  • 对于requires_grad=False 的张量,我们约定俗成地把它们归为叶子张量。
  • 对于requires_grad为True的张量,如果他们是由用户创建的,则它们是叶张量。

 如果某一个叶子变量,开始时不可导的,后面想设置它可导,该怎么办?

x.requires_grad_(True/False)   设置tensor的可导与不可导

注意:这种方法只适用于设置叶子变量,否则会出现如下错误

x = torch.tensor(2.0, requires_grad=True)
y = torch.pow(x, 2)
z = torch.add(y, 3)
z.backward()
print(x.grad)
print(y.grad)
tensor(4.)
None
  1. 创建一个浮点型张量x,其值为2.0,并设置requires_grad=True,使PyTorch可以跟踪x的计算历史并允许计算它的梯度。

  2. 创建一个新张量y,y是x的平方。

  3. 创建一个新张量z,z是y和3的和。

  4. 调用z.backward()进行反向传播,计算z关于x的梯度。

  5. 打印x的梯度,应该是2*x=4.0。

  6. 试图打印y的梯度。但是,PyTorch默认只计算并保留叶子节点的梯度非叶子节点的梯度在计算过程中会被释放掉,因此y的梯度应该为None。

保留中间变量的梯度

tensor.retain_grad()

 retain_grad()retain_graph是用来处理两个不同的情况

  1. retain_grad(): 用于保留非叶子节点的梯度。如果你想在反向传播结束后查看或使用非叶子节点的梯度,你应该在非叶子节点上调用.retain_grad()

  2. retain_graph: 当你调用.backward()时,PyTorch会自动清除计算图以释放内存。这意味着你不能在同一个计算图上多次调用.backward()。但是,如果你需要多次调用.backward()(例如在某些特定的优化算法中),你可以在调用.backward()时设置retain_graph=True保留计算图

.grad

通过tensor的grad属性查看所求得的梯度值。

.grad_fn

在PyTorch中,.grad_fn属性是一个引用到创建该Tensor的Function对象。也就是说,这个属性可以告诉你这个张量是如何生成的。对于由用户直接创建的张量,它的.grad_fnNone。对于由某个操作创建的张量,.grad_fn将引用到一个与这个操作相关的对象

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
z = y.mean()print(x.grad_fn)
print(y.grad_fn)
print(z.grad_fn)

这里,x是由用户直接创建的,所以x.grad_fnNoney是通过乘法操作创建的,所以y.grad_fn是一个MulBackward0对象,这表明y是通过乘法操作创建的。z是通过求平均数操作创建的,所以z.grad_fn是一个MeanBackward0对象。

 pytorch自动求导实现神经网络

numpy手动实现

import numpy as np
import matplotlib.pyplot as pltN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)W1 = np.random.randn(D_in, H)  # 1000维转成100维
W2 = np.random.randn(H, D_out)  # 100维转成10维learning_rate = 1e-6all_loss = []epoch = 500for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''h = X.dot(W1)  # N * Hh_relu = np.maximum(h, 0)  # 激活函数,N * Hy_hat = h_relu.dot(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = np.square(y_hat - y).sum()  # 均方误差,忽略了÷Nprint("Epoch:{}   Loss:{}".format(t, loss))  # 打印每个迭代的损失all_loss.append(loss)'''后向传播(backward pass)'''# 计算梯度(此处没用torch,用最普通的链式求导,最终要得到 d{loss}/dX)grad_y_hat = 2.0 * (y_hat - y)  # d{loss}/d{y_hat},N * D_outgrad_W2 = h_relu.T.dot(grad_y_hat)  # 看前向传播中的第三个式子,d{loss}/d{W2},H * D_outgrad_h_relu = grad_y_hat.dot(W2.T)  # 看前向传播中的第三个式子,d{loss}/d{h_relu},N * Hgrad_h = grad_h_relu.copy()  # 这是h>0时的情况,d{h_relu}/d{h}=1grad_h[h < 0] = 0  # d{loss}/d{h}grad_W1 = X.T.dot(grad_h)  # 看前向传播中的第一个式子,d{loss}/d{W1}'''参数更新(update weights of W1 and W2)'''W1 -= learning_rate * grad_W1W2 -= learning_rate * grad_W2plt.plot(all_loss)
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()

pytorch自动实现

import torchN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)W1 = torch.randn(D_in, H, requires_grad=True)  # 1000维转成100维
W2 = torch.randn(H, D_out, requires_grad=True)  # 100维转成10维learning_rate = 1e-6for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = X.mm(W1).clamp(min=0).mm(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = (y_hat - y).pow(2).sum()  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''with torch.no_grad():W1 -= learning_rate * W1.gradW2 -= learning_rate * W2.gradW1.grad.zero_()W2.grad.zero_()

pytorch手动实现

import torch
import matplotlib.pyplot as pltN, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)W1 = torch.randn(D_in, H)  # 1000维转成100维
W2 = torch.randn(H, D_out)  # 100维转成10维learning_rate = 1e-6all_loss = []for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''h = X.mm(W1)  # N * Hh_relu = h.clamp(min=0)  # 激活函数,N * Hy_hat = h_relu.mm(W2)  # N * D_out'''计算损失函数(compute loss)'''loss = (y_hat - y).pow(2).sum().item()  # 均方误差,忽略了÷Nprint("Epoch:{}   Loss:{}".format(t, loss))  # 打印每个迭代的损失all_loss.append(loss)'''后向传播(backward pass)'''# 计算梯度(此处没用torch,用最普通的链式求导,最终要得到 d{loss}/dX)grad_y_hat = 2.0 * (y_hat - y)  # d{loss}/d{y_hat},N * D_outgrad_W2 = h_relu.t().mm(grad_y_hat)  # 看前向传播中的第三个式子,d{loss}/d{W2},H * D_outgrad_h_relu = grad_y_hat.mm(W2.t())  # 看前向传播中的第三个式子,d{loss}/d{h_relu},N * Hgrad_h = grad_h_relu.clone()  # 这是h>0时的情况,d{h_relu}/d{h}=1grad_h[h < 0] = 0  # d{loss}/d{h}grad_W1 = X.t().mm(grad_h)  # 看前向传播中的第一个式子,d{loss}/d{W1}'''参数更新(update weights of W1 and W2)'''W1 -= learning_rate * grad_W1W2 -= learning_rate * grad_W2plt.plot(all_loss)
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()

torch.nn实现

import torch
import torch.nn as nn  # 各种定义 neural network 的方法N, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)model = torch.nn.Sequential(torch.nn.Linear(D_in, H, bias=True),  # W1 * X + b,默认Truetorch.nn.ReLU(),torch.nn.Linear(H, D_out)
)# model = model.cuda()  #这是使用GPU的情况loss_fn = nn.MSELoss(reduction='sum')learning_rate = 1e-4for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = model(X)  # model(X) = model.forward(X), N * D_out'''计算损失函数(compute loss)'''loss = loss_fn(y_hat, y)  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''with torch.no_grad():for param in model.parameters():param -= learning_rate * param.grad  # 模型中所有的参数更新model.zero_grad()

torch.nn的继承类

import torch
import torch.nn as nn  # 各种定义 neural network 的方法
from torchsummary import summary
# pip install torchsummary
N, D_in, H, D_out = 64, 1000, 100, 10  # 64个训练数据(只是一个batch),输入是1000维,hidden是100维,输出是10维'''随机创建一些训练数据'''
X = torch.randn(N, D_in)
y = torch.randn(N, D_out)'''定义两层网络'''class TwoLayerNet(torch.nn.Module):def __init__(self, D_in, H, D_out):super(TwoLayerNet, self).__init__()# 定义模型结构self.linear1 = torch.nn.Linear(D_in, H, bias=False)self.linear2 = torch.nn.Linear(H, D_out, bias=False)def forward(self, x):y_hat = self.linear2(self.linear1(X).clamp(min=0))return y_hatmodel = TwoLayerNet(D_in, H, D_out)loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)for t in range(500):  # 做500次迭代'''前向传播(forward pass)'''y_hat = model(X)  # model.forward(), N * D_out'''计算损失函数(compute loss)'''loss = loss_fn(y_hat, y)  # 均方误差,忽略了÷N,loss就是一个计算图(computation graph)print("Epoch:{}   Loss:{}".format(t, loss.item()))  # 打印每个迭代的损失optimizer.zero_grad()  # 求导之前把 gradient 清空'''后向传播(backward pass)'''loss.backward()'''参数更新(update weights of W1 and W2)'''optimizer.step()  # 一步把所有参数全更新print(summary(model, (64, 1000)))

相关文章:

pytorch求导

pytorch求导的初步认识 requires_grad tensor(data, dtypeNone, deviceNone, requires_gradFalse)requires_grad是torch.tensor类的一个属性。如果设置为True&#xff0c;它会告诉PyTorch跟踪对该张量的操作&#xff0c;允许在反向传播期间计算梯度。 x.requires_grad 判…...

Java基础异常详解

Java基础异常详解 文章目录 Java基础异常详解编译时异常&#xff08;Checked Exception&#xff09;&#xff1a;运行时异常&#xff08;Unchecked Exception&#xff09;: Java中的异常是用于处理程序运行时出现的错误或异常情况的一种机制。 异常本身也是一个类。 异常分为…...

vue3+vue-i18n 监听语言的切换

最近在用 vue3 做一个后台管理系统&#xff0c;之前是只考虑中文&#xff0c;现在加了个需求是多语言。 本来也不是太难的需求&#xff0c;但是我用的并不熟悉&#xff0c;并且除了页面展示不同的语言&#xff0c;需求是在切换语言的时候在几个页面中需要做出一些自定义的行为&…...

【考研复习】24王道数据结构课后习题代码|2.3线性表的链式表示

文章目录 总结01 递归删除结点02 删除结点03 反向输出04 删除最小值05 逆置06 链表递增排序07 删除区间值08 找公共结点09 增序输出链表10 拆分链表--尾插11 拆分链表--头插12 删除相同元素13 合并链表14 生成含有公共元素的链表C15 求并集16 判断子序列17 判断循环链表是否对称…...

娇滴滴的一朵花(Python实现)

目录 1 娇滴滴的她 2 Python代码实现 1 娇滴滴的她 娇滴滴。双眉敛破春山色。春山色。 为君含笑,为君愁蹙。多情别後无消息。 此时更有谁知得。谁知得。夜深无寐&#xff0c;度江横笛。 2 Python代码实现 import turtle from turtle import * turtle.title(春天送她一朵小花)…...

Android AccessibilityService研究

AccessibilityService流程分析 AccessibilityService开启方式AccessibilityService 开启原理 AccessibilityService开启方式 . 在Framework里直接添加对应用app 服务component。 loadSetting(stmt, Settings.Secure.ACCESSIBILITY_ENABLED,1); loadSetting(stmt, Settings.Se…...

华为OD机试(含B卷)真题2023 算法分类版,58道20个算法分类,如果距离机考时间不多了,就看这个吧,稳稳的

目录 一、数据结构1、线性表2、优先队列3、滑动窗口4、二叉树5、并查集6、栈 二、算法1、基础算法2、字符串3、图4、动态规划5、数学 三、漫画算法2&#xff1a;小灰的算法进阶参与方式 很多小伙伴问我&#xff0c;华为OD机试算法题太多了&#xff0c;知识点繁杂&#xff0c;如…...

JMeter命令行执行+生成HTML报告

1、为什么用命令行模式 使用GUI方式启动jmeter&#xff0c;运行线程较多的测试时&#xff0c;会造成内存和CPU的大量消耗&#xff0c;导致客户机卡死&#xff1b; 所以一般采用的方式是在GUI模式下调整测试脚本&#xff0c;再用命令行模式执行&#xff1b; 命令行方式支持在…...

学习Boost二:从附录3来看编码习惯

附录C 关键字浅谈 在C11标准中&#xff08;C11.2.12&#xff09;总共定义了73个关键字&#xff08;keyword&#xff09;、2个“准”关键字&#xff08;identifiers with special meaning&#xff09;和11个操作符替代字&#xff08;alternative representation&#xff09;[1]。…...

STM32基础入门学习笔记:核心板 电路原理与驱动编程

文章目录&#xff1a; 一&#xff1a;LED灯操作 1.LED灯的点亮和熄灭 延迟闪烁 main.c led.c led.h BitAction枚举 2.LED呼吸灯&#xff08;灯的强弱交替变化&#xff09; main.c delay.c 3.按键控制LED灯 key.h key.c main.c 二&#xff1a;FLASH读写程序(有…...

最后一次模拟考试题解

哦我想这不用看都知道是为了水任务 T1 黑白染色 其实这题有原 什么手写体 md (指 markdown) 分析 首先这题如果你题目没看错的话 ,会发现其实他是 n m n \times m nm 让你求 n n n \times n nn 的区域内的点&#xff08;不会只有我一个人题目看错了罢 然后我们会发现…...

Mac 创建和删除 Automator 工作流程,设置 Terminal 快捷键

1. 创建 Automator 流程 本文以创建一个快捷键启动 Terminal 的自动操作为示例。 点击打开 自动操作&#xff1b; 点击 新建文稿 点击 快速操作 选择 运行 AppleScript 填入以下内容 保存名为 “Open Terminal” 打开 设置 > 键盘&#xff0c;选择 键盘快捷键 以此选择 服…...

2023华为OD机试真题B卷 Java 实现【最长的元音串】

前言 本题使用Java解答,如果需要Python代码,链接 题目 给定一个只由英文字母(a-z, A-Z)组成的字符串,找出其中最长的只包含元音字母(a, e, i, o, u, A, E, I, O, U)的子串,并返回其长度。如果不存在元音子串,则返回0。 输入: 一个由英文字母组成的字符串,长度大…...

网络防御之传输安全

1.什么是数据认证&#xff0c;有什么作用&#xff0c;有哪些实现的技术手段? 数据认证是一种权威的电子文档 作用&#xff1a;它能保证数据的完整性、可靠性、真实性 技术手段有数字签名、加密算法、哈希函数等 2.什么是身份认证&#xff0c;有什么作用&#xff0c;有哪些…...

【css】组合器

组合器是解释选择器之间关系的某种机制。在简单选择器器之间&#xff0c;可以包含一个组合器&#xff0c;从而实现简单选择器难以达到的效果。 CSS 中有四种组合器&#xff1a; 后代选择器 (空格)&#xff1a;匹配属于指定元素后代的所有元素&#xff0c;示例&#xff1a;div …...

HTTPS、TLS加密传输

HTTPS、TLS加密传输 HTTPS、TLS加密传输1、HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff09;2、TLS HTTPS、TLS加密传输 1、HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff09; HTTPS&#xff08;HyperText Transfer Protocol Secure&#x…...

docker frp 搭建 http + stcp 代理

所需服务器 2台 一台具有国外公网ip 一台具有国内 ip 内网外网都可以 外公网ip服务器配置如下 cat docker-compose.yamlversion: "2" services:frps:image: alpine:latesthostname: frpsrestart: alwayscontainer_name: frpsprivileged: trueuser: rootcommand: […...

项目出bug,找不到bug,如何拉回之前的版本

1.用gitee如何拉取代码 本文为转载于「闪耀太阳a」的原创文章原文链接&#xff1a;https://blog.csdn.net/Gufang617/article/details/119929145 怎么从gitee上拉取代码 1.首先找到gitee上想要拉取得代码URL地址 点击复制这里的https地址 1 ps:&#xff08;另外一种方法&…...

vue-cli

vue-cli脚手架 案例一&#xff1a; 案例二&#xff1a; 案例三&#xff1a; ​ 一、脚手架简介 Vue脚手架是Vue官方提供的标准化开发工具&#xff08;开发平台&#xff09;&#xff0c;它提供命令行和UI界面&#xff0c;方便创建vue工程、配置第三方依赖、编译vue工程 1. …...

android获取屏幕分辨率的正确方法;获取到分辨率(垂直方向像素)的不正确

我通过下面的方法去获取屏幕分辨率的&#xff0c;但获取到的分辨率有时会不准确。原因是此方法有时候会忽略一些布局或控件的高度&#xff0c;从而得不到正确的高度。 public static String getDeviceResolution(Context context){//从系统服务中获取窗口管理器WindowManager w…...

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

基于当前项目通过npm包形式暴露公共组件

1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹&#xff0c;并新增内容 3.创建package文件夹...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

12.找到字符串中所有字母异位词

&#x1f9e0; 题目解析 题目描述&#xff1a; 给定两个字符串 s 和 p&#xff0c;找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义&#xff1a; 若两个字符串包含的字符种类和出现次数完全相同&#xff0c;顺序无所谓&#xff0c;则互为…...

NFT模式:数字资产确权与链游经济系统构建

NFT模式&#xff1a;数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新&#xff1a;构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议&#xff1a;基于LayerZero协议实现以太坊、Solana等公链资产互通&#xff0c;通过零知…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具&#xff0c;可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板&#xff0c;允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板&#xff0c;并通…...

音视频——I2S 协议详解

I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议&#xff0c;专门用于在数字音频设备之间传输数字音频数据。它由飞利浦&#xff08;Philips&#xff09;公司开发&#xff0c;以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...