当前位置: 首页 > news >正文

Pytorch 网络冻结的三种方法区别:detach、requires_grad、with_no_grad

1、requires_grad

requires_grad=True # 要求计算梯度;
requires_grad=False # 不要求计算梯度;

在pytorch中,tensor有一个 requires_grad参数,如果设置为True,那么它会追踪对于该张量的所有操作。在完成计算时可以通过调用backward()自动计算所有的梯度,并且,该张量的所有梯度会自动累加到张量的.grad属性;反之,如果设置为False,则不会记录这些操作过程,自然而然就不会进行计算梯度的工作 。 tensorrequires_grad的属性默认为False.

x = torch.tensor([1.0, 2.0])
x.requires_grad"""
结果:
False
"""

我们可以先看一下requires_grad参数设置分别为True和False时的情况。

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=False)
y1 = 2.0 * x + 2.0 * yprint(x, x.requires_grad)
print(y, y.requires_grad)
print(y1, y1.requires_grad)y1.backward(torch.tensor([1.0, 1.0]))  
print(x.grad)
print(y.grad)"""
结果:
tensor([1., 2.], requires_grad=True) True
tensor([3., 4.]) False
tensor([ 8., 12.], grad_fn=<AddBackward0>) True
tensor([2., 2.])
None
"""

在上面的实验中,发现在计算中如果存在tensor张量xrequires_gradTrue的情况,那么计算之后的结果y1requires_grad也为True,且计算梯度时仅会计算x的梯度,因为前面设置了y张量的requires_gradFalse,所以最后y张量的grad属性值为None

关于张量tensor的梯度计算可以参考另一篇博客:Tensor及其梯度

所以在深度学习训练时,要冻结部分权值参数不进行参数更新的话,可以在优化器初始化之前将参数组进行筛选,将不想进行训练的参数的requires_grad设置为False。代码示例参考如下:

cnn = CNN() #构建网络for n,p in cnn.named_parameters():print(n,p.requires_grad)if n=="conv1.0.weight":p.requires_grad = Falseoptimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,cnn.parameters()), lr=learning_rate)

也可以把requires_grad属性置为 False这个操作放在optimizer之后,参数都不会进行更新。但是区别在于,先进行requires_grad属性置为False的操作,再optimizer初始化,不会将该层的参数放进优化器中更新,而先进行optimizer初始化,再进行requires_grad属性置为False的操作,会将所有的参数放进优化器中,但不更新该指定层参数,只更新剩下的参数。对比看来,optimizer中的参数量会相比前者会更大一点。

所以一般最好是将requires_grad属性置为 False这个操作放在optimizer之前。

注意事项:

1、requires_grad属性置为 False或者默认时,不能在对该tensor计算梯度,否则会进行报错。因为并没有追踪到任何计算历史,所以就不存在梯度的计算了。

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=False)
y1 = 2.0 * x + 2.0 * y
# y1.backward(torch.tensor([1.0, 1.0])) 
# print(x.grad)
y.backward(torch.tensor([1.0, 1.0]))"""
结果:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
"""

2、整数型的tensor并没有requires_grad这个属性,只有浮点类型的tensor可以计算梯度

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
d:\test.ipynb Cell 9 in <cell line: 2>()1 a = torch.tensor([1,2])
----> 2 b = torch.tensor([3,4], requires_grad=True)3 c = a+b4 print(a.requires_grad)RuntimeError: Only Tensors of floating point and complex dtype can require gradients

2、detach()

detach方法就是返回了一个新的张量,该张量与当前计算图完全分离,且该张量的计算将不会记录到梯度当中。

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=True)
z = torch.tensor([3.0, 2.0], requires_grad=True)
x = x * 2
z1 = z.detach()
x1 = x.detach() 
y1 = 2.0 * x1 + 2.0 * y + 3 * z1
y1.backward(torch.tensor([1.0, 1.0])) 
print(x.requires_grad)
print(x.grad)
print(y.requires_grad)
print(y.grad)
print(z.requires_grad)
print(z.grad)
print(z1.requires_grad)
print(z1.grad)"""
结果:
True
None
True
tensor([2., 2.])
True
None
False
None
C:\Users\26973\AppData\Local\Temp\ipykernel_37516\3652236761.py:12: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.print(x.grad)
"""

从上面实验可以看到,使用detach()方法后,可以截断反向传播的梯度流,其作用有点类似于将requires_grad属性置为False的情况。

requires_grad_()requires_grad属性置为False不同的是 detach()函数会返回一个新的Tensor对象b , 并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。 ba共享数据的存储空间,二者指向同一块内存。

requires_grad_()函数会改变Tensorrequires_grad属性并返回Tensor修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=Truerequires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。

关于detach()返回的张量与原张量共享数据的存储空间,二者指向同一块内存可以由以下代码看出:

z1[0] = 5.0
print(z)
print(z1)"""
结果:
tensor([5., 2.], requires_grad=True)
tensor([5., 2.])
"""

当我们修改z1中的数据时,z中的数据也随之修改。

**总结:**当我们在计算到某一步时,不需要在记录某一个张量的梯度时,就可以使用detach()将其从追踪记录当中分离出来,这样一来该张量对应计算产生的梯度就不会被考虑了。比较常见的就是在GAN生成模型中,当训练一次生成器后,再训练判别器时,需要对生成器生成的fake进行损失计算,但是又不希望这部分损失对生成器进行权值的更新,这个时候需要冻结生成器那部分的权值,因此通常将生成器生成的fake张量使用fetch()进行阶段,再输入到判别器进行运算,这样最后使用loss.backward()时仅会对判别器部分的梯度进行计算

import torch
import numpy as npy = torch.tensor([3.0, 4.0], requires_grad=True)
z = torch.tensor([3.0, 2.0], requires_grad=True)
z1 = z.detach()
z2 = z1 + y
y1 = torch.sum(3 * z2)
y1.backward() 
print(z2, z2.requires_grad)
print(y.grad)
print(z2.grad)
print(z1.grad)"""
结果:
tensor([6., 6.], grad_fn=<AddBackward0>) True
tensor([3., 3.])
None
None
"""

这里假设z是生成器生成的图片,z1表示的是使用detch()截断后的张量,y表示的判别器内部的一些运算张量,z2表示经过判别器后的结果,y1假设是计算loss的损失函数,我们可以看到,使用y1.backward() 后,不会对生成器生成的z产生任何的梯度,在优化器优化时自然而然不会对其进行优化。而对于后面的步骤仍然会跟踪记录其所有的计算过程,比如对于z1在判别器中进行运算仍然会记录其过程,并仅会对判别器内部的参数y进行梯度计算,从而进行优化

(注:这里z2.grad之所以也为None,是因为z2节点不是叶子节点,它是由z1y进行累加而来的,所以在z2处不会有grad属性,这部分可以看其给出的警告)

UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

3、with_no_grad

torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2with torch.no_grad():  # 这里使用了no_grad()包裹不需要被追踪的计算过程y2 = y1 * 2y3 = x ** 5y4 = y1 + y2 + y3print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)
print(y4, y4.requires_grad)y4.backward(torch.ones(y4.shape))  # y1.backward() y2.backward()
print(x.grad)"""
结果:
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 1., 32.]) False
tensor([ 4., 44.], grad_fn=<AddBackward0>) True
tensor([2., 4.])
"""

可以看出,其实使用with torch.no_grad()这个后,被其包裹的所有运算都是不计算梯度的,其效果与detach()类似,所以使用下列代码的运行结果是一样的:

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2# with torch.no_grad():  # 这里使用了no_grad()包裹不需要被追踪的计算过程
y2 = y1 * 2
y3 = x ** 5
y2 = y2.detach()
y3 = y3.detach()y4 = y1 + y2 + y3print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)
print(y4, y4.requires_grad)y4.backward(torch.ones(y4.shape))  # y1.backward() y2.backward()
print(x.grad)
"""
结果:
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 1., 32.]) False
tensor([ 4., 44.], grad_fn=<AddBackward0>) True
tensor([2., 4.])
"""

detach()是考虑将单个张量从追踪记录当中脱离出来;
torch.no_grad()是一个warper,可以将多个计算步骤的张量计算脱离出去,本质上没啥区别。

4、总结:

  1. requires_grad:在最开始创建Tensor时候可以设置的属性,用于表明是否追踪当前Tensor的计算操作。后面也可以通过requires_grad_()方法设置该参数,但是只有叶子节点才可以设置该参数。
  2. detach()方法:则是用于将某一个Tensor从计算图中分离出来。返回的是一个内存共享的Tensor,一变都变。
  3. torch.no_grad():对所有包裹的计算操作进行分离。但是torch.no_grad()将会使用更少的内存,因为从包裹的开始,就表明不需要计算梯度了,因此就不需要保存中间结果

相关文章:

Pytorch 网络冻结的三种方法区别:detach、requires_grad、with_no_grad

1、requires_grad requires_gradTrue # 要求计算梯度&#xff1b; requires_gradFalse # 不要求计算梯度&#xff1b;在pytorch中&#xff0c;tensor有一个 requires_grad参数&#xff0c;如果设置为True&#xff0c;那么它会追踪对于该张量的所有操作。在完成计算时可以通过调…...

如何定位el-tree中的树节点当父元素滚动时如何定位子元素

使用到的方法 Element 接口的 scrollIntoView() 方法会滚动元素的父容器&#xff0c;使被调用 scrollIntoView() 的元素对用户可见。 参数 alignToTop可选 一个布尔值&#xff1a; 如果为 true&#xff0c;元素的顶端将和其所在滚动区的可视区域的顶端对齐。相应的 scrollIntoV…...

【WiFI问题自助】解决WiFi能连上但是没有网的问题

WiFi能连上但是没有网的问题 背景&#xff1a;wifi能连上&#xff0c;但是没有网 解决 遇事不决&#xff0c;先重启啊&#xff01;怎么重启&#xff1f;拔掉电源再插上&#xff01;拔掉网线再插上&#xff01; 直接ok了。 思考记录 今天WiFi又上不了网了&#xff0c;昨天报…...

论文阅读:JINA EMBEDDINGS: A Novel Set of High-Performance Sentence Embedding Models

Abstract JINA EMBEDINGS构成了一组高性能的句子嵌入模型&#xff0c;擅长将文本输入转换为数字表示&#xff0c;捕捉文本的语义。这些模型在密集检索和语义文本相似性等应用中表现出色。文章详细介绍了JINA EMBEDINGS的开发&#xff0c;从创建高质量的成对&#xff08;pairwi…...

计数排序.

一.定义&#xff1a; 计数排序&#xff08;Counting Sort&#xff09;是一种非比较性质的排序算法&#xff0c;其时间复杂度为O(nk)&#xff08;其中n为待排序的元素个数&#xff0c;k为不同值的个数&#xff09;。这意味着在数据值范围不大并且离散分布的情况下&#xff0c;规…...

flink中配置Rockdb的重要配置项

背景 由于我们在flink中使用了状态比较大&#xff0c;无法完全把状态数据存放到tm的堆内存中&#xff0c;所以我们选择了把状态存放到rockdb上&#xff0c;也就是使用rockdb作为状态后端存储,本文就是简单记录下使用rockdb状态后端存储的几个重要的配置项 使用rockdb状态后端…...

代码随想录二刷 | 数组 | 有序数组的平方

代码随想录二刷 &#xff5c; 数组 &#xff5c; 有序数组的平方 题目描述题目分析 & 代码实现暴力排序双指针法 题目描述 977.有序数组的平方 给你一个按 非递减顺序 排序的整数数组 nums&#xff0c;返回 每个数字的平方 组成的新数组&#xff0c;要求也按 非递减顺序 …...

基于单片机C51全自动洗衣机仿真设计

**单片机设计介绍&#xff0c; 基于单片机C51全自动洗衣机仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机C51的全自动洗衣机仿真设计是一个复杂的项目&#xff0c;它涉及到硬件和软件的设计和实现。以下是对这…...

「Verilog学习笔记」实现3-8译码器①

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 分析 ① 本题要求根据38译码器的功能表实现该电路&#xff0c;同时要求采用基础逻辑门实现&#xff0c;那么就需要将功能表转换为逻辑表达式。 timescale 1ns/1nsmodule d…...

Centos(Linux)服务器安装Dotnet8 及 常见问题解决

1. 下载dotnet8 sdk 下载 .NET 8.0 SDK (v8.0.100) - Linux x64 Binaries 拿到 dotnet-sdk-8.0.100-linux-x64.tar.gz 文件 2. 把文件上传到 /usr/local/software 目录 mkdir -p /usr/local/software/dotnet8 把文件拷贝过去 mv dotnet-sdk-8.0.100-linux-x64.tar.gz /usr/loc…...

最强人工智能ChatGPT引领AIGC发展

从公众号转载&#xff0c;关注微信公众号掌握更多技术动态 --------------------------------------------------------------- ——AI不会淘汰所有人&#xff0c;但会淘汰不懂AI的人 一、最强人工智能GPT-4 Turbo 在前不久的OpenAI开发者大会&#xff0c;正值Chatgpt3.5发布一…...

10.Oracle的同义词与序列

oracle11g的同义词与序列 一、Oracle同义词&#xff1a;1、同义词的基本使用2、同义词的相关权限3、同义词的作用范围 二、Oracle序列&#xff1a;1、序列的基本操作2、序列的相关权限 一、Oracle同义词&#xff1a; 同义词是一个数据库对象的别名&#xff0c;它允许用户通过不…...

【周报2023-11-10】

周报2023-11-10 本周的主要工作下周工作计划 本周的主要工作 本周的主要工作就有三个 第一个是进行对我们目前的高企项目的完善情况第二个是对于高企项目的接口对接情况以及细节的把控第三个为新的小程序项目做准备工作 首先第一个高企项目的完善情况得话主要是页面上 对于原…...

搜维尔科技:业内普遍选择Varjo头显作为医疗VR/AR/XR解决方案

Varjo 的人眼分辨率混合现实和虚拟现实头显将医疗专业人员的注意力和情感投入提升到更高水平。借助逼真的 XR/VR&#xff0c;医疗和保健人员可以为最具挑战性的现实场景做好准备&#xff01; 在虚拟、增强和混合现实中进行最高水平的训练和表现 以逼真的 3D 方式可视化医疗数据…...

数据结构02附录01:顺序表考研习题[C++]

图源&#xff1a;文心一言 考研笔记整理~&#x1f95d;&#x1f95d; 之前的博文链接在此&#xff1a;数据结构02&#xff1a;线性表[顺序表链表]_线性链表-CSDN博客~&#x1f95d;&#x1f95d; 本篇作为线性表的代码补充&#xff0c;每道题提供了优解和暴力解算法&#xf…...

ClientDateSet:Cannot perform this operation on a closed dataset

一、问题表现 Delphi 三层DataSnap&#xff0c;使用AlphaControls控件优化界面&#xff0c;一窗口编辑时&#xff0c;出现下列错误提示&#xff1a; 编译通过&#xff0c;该窗口中&#xff0c;重新显示数据&#xff0c;下图&#xff1a; 相关代码&#xff1a; procedure…...

python中列表的基础解释

列表&#xff1a; 一种可以存放多种类型数据的数据结构 列表的创建&#xff1a; 1.用【】创建列表 #创建一个空列表 list1[] #创建一个非空列表 list2 [zhang,li,ying,1,2,3] #输出内容及类型 print(list1,type(list1)) print(list2,type(list2))结果&#xff1a; 2.使用list…...

『力扣刷题本』:链表分割

一、题目 现有一链表的头指针 ListNode* pHead&#xff0c;给一定值x&#xff0c;编写一段代码将所有小于x的结点排在其余结点之前&#xff0c;且不能改变原来的数据顺序&#xff0c;返回重新排列后的链表的头指针。 二、思路解析 首先&#xff0c;让我们列出我们需要做的事情&…...

FISCOBCOS入门(十)Truffle测试helloworld智能合约

本文带你从零开始搭建truffle以及编写迁移脚本和测试文件,并对测试文件的代码进行解释,让你更深入的理解truffle测试智能合约的原理,制作不易,望一键三连 在windos终端内安装truffle npm install -g truffle 安装truffle时可能出现网络报错,多试几次即可 truffle --vers…...

Unity Text文本首行缩进两个字符的方法

Text文本首行缩进两个字符的方法比较简单。通过代码把"\u3000\u3000"加到文本字符串前面即可。 参考如下代码&#xff1a; TMPtext1.text "\u3000\u3000" "这是一段有首行缩进的文本内容。\n这是第二行"; 运行效果如下图所示&#xff1a; 虽…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

Java多线程实现之Callable接口深度解析

Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

unix/linux,sudo,其发展历程详细时间线、由来、历史背景

sudo 的诞生和演化,本身就是一部 Unix/Linux 系统管理哲学变迁的微缩史。来,让我们拨开时间的迷雾,一同探寻 sudo 那波澜壮阔(也颇为实用主义)的发展历程。 历史背景:su的时代与困境 ( 20 世纪 70 年代 - 80 年代初) 在 sudo 出现之前,Unix 系统管理员和需要特权操作的…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作&#xff1a;验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化&#xff1a;测试aof和aof持久化机制&#xff0c;确保数据在开启后正确恢复。 事务&#xff1a;检查事务的原子性和回滚机制。 发布订阅&#xff1a;确保消息正确传递。 2、性…...

es6+和css3新增的特性有哪些

一&#xff1a;ECMAScript 新特性&#xff08;ES6&#xff09; ES6 (2015) - 革命性更新 1&#xff0c;记住的方法&#xff0c;从一个方法里面用到了哪些技术 1&#xff0c;let /const块级作用域声明2&#xff0c;**默认参数**&#xff1a;函数参数可以设置默认值。3&#x…...

HTTPS证书一年多少钱?

HTTPS证书作为保障网站数据传输安全的重要工具&#xff0c;成为众多网站运营者的必备选择。然而&#xff0c;面对市场上种类繁多的HTTPS证书&#xff0c;其一年费用究竟是多少&#xff0c;又受哪些因素影响呢&#xff1f; 首先&#xff0c;HTTPS证书通常在PinTrust这样的专业平…...

嵌入式面试常问问题

以下内容面向嵌入式/系统方向的初学者与面试备考者,全面梳理了以下几大板块,并在每个板块末尾列出常见的面试问答思路,帮助你既能夯实基础,又能应对面试挑战。 一、TCP/IP 协议 1.1 TCP/IP 五层模型概述 链路层(Link Layer) 包括网卡驱动、以太网、Wi‑Fi、PPP 等。负责…...

零基础在实践中学习网络安全-皮卡丘靶场(第十一期-目录遍历模块)

经过前面几期的内容我们学习了很多网络安全的知识&#xff0c;而这期内容就涉及到了前面的第六期-RCE模块&#xff0c;第七期-File inclusion模块&#xff0c;第八期-Unsafe Filedownload模块。 什么是"遍历"呢&#xff1a;对学过一些开发语言的朋友来说应该知道&…...