当前位置: 首页 > news >正文

Pytorch 反向传播 计算图被修改的报错

先看看报错的内容

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

报错中说,一个需要梯度计算的变量已经被原地修改了,这引发了报错。

torch.set_grad_enabled(True)

然后我使用上述语句开启了梯度跟踪,发现问题出在我的标签计算函数:

def get_label(net, X):return net(X).reshape((-1, 1))

为什么会出错呢?在这种情况下,由于 label 是从网络输出直接计算得到的,它与网络的计算图相连接。如果在 label 上进行了原地操作(上述的修改形状操作),就可能破坏计算图,使其不可导或其他,总之是导致反向传播时无法正确计算梯度,从而引发报错。

那怎么解决这个问题?将该结果与计算图进行分离就行了,此刻如果再进行反向传播,梯度就不会传播到此处。修改后,代码如下;

def get_label(net, X):return net(X).detach().reshape((-1, 1))

detach()函数的作用是将数据和计算图分离开来,得到数据部分,与计算图再无瓜葛。

举一个更形象的例子,看下面的代码:

label = net(X)  # 计算标签
# 对 label 或 label 的某个部分进行了原地操作,比如:
# label[0, 0] = label[0, 0] * 2
# 或
# label += 1
loss = Loss(label, y)  # 计算损失

在这个例子中,label由第一条语句前向传播得到,是直接与网络的输出连在一起,后面我却对label的值进行了手动修改。

这些操作可能导致计算图的结构不完整或不可导,从而影响反向传播的计算。为了避免这样的问题,一般建议避免在计算标签或损失时对张量进行原地操作。如果需要修改张量的值,最好创建一个新的张量,而不是直接在原有张量上进行修改。

下面是我的整个程序,大家也可以调试代码来理解其中的含义:

import torch.nn as nn
import matplotlib.pyplot as plt
import torch
from torch.utils import data
def get_label(net, X):#计算标签,计算完后必须要使用detach()分离计算图,否则代码将报计算图被修改的错误return net(X).detach().reshape((-1, 1))def train(net, trainer, Loss, train_data, train_label, epochs, batch_size):#将训练数据和标签捆在一起,便于后面一起便利data_iter = data.DataLoader(list(zip(train_data, train_label)), batch_size=batch_size)#用来存储数据的变化值,前者为训练轮次,后者为每一轮训练平均损失draw_x, draw_y = [], []for epoch in range(epochs):#每次处理一个批次的数据for X, y in data_iter:trainer.zero_grad()  # 清除梯度pre_y = net(X)  # 前向传播loss = Loss(pre_y, y)  # 计算损失loss.backward()  # 反向传播,计算梯度trainer.step()  # 更新权重,进行优化#添加绘图需要的数据draw_x.append(epoch)draw_y.append(torch.mean(Loss(net(train_data),train_label)).data)#设置绘图参数plt.figure(figsize=(5, 4), dpi=150)#设置图像大小和分辨率plt.plot(draw_x, draw_y, label='train_loss')#设置要绘制的数据,被给出图例plt.xlabel('epoch')#设置X轴标题plt.ylabel('loss')#设置y轴标题plt.legend()#显示图例#显示最终图像plt.show()def test(net, Loss, test_data, test_label):loss_sum = torch.zeros_like(test_label)data_iter = data.DataLoader(list(zip(test_data, test_label)), batch_size=batch_size, shuffle=False)for X, y in data_iter:pre_y = net(X)  # 前向传播loss = Loss(pre_y, y)  # 计算损失loss_sum += loss  # 累加损失return torch.sum(loss_sum) / len(loss_sum)  # 返回平均损失def init_weight(m):if type(m) == nn.Linear:#权重使用何凯明正态初始化方法进行初始化nn.init.kaiming_normal_(m.weight)#偏置使用0偏置nn.init.zeros_(m.bias)lr = 0.01  # 学习率
epochs = 100  # 训练轮数
batch_size = 5  # 批大小
shared = nn.Linear(5, 5)  # 共享层
net = nn.Sequential(nn.Linear(10, 5), nn.ReLU(),  # 输入层到隐藏层1的线性层,ReLU激活函数shared, nn.ReLU(),  # 共享层,ReLU激活函数shared, nn.ReLU(),  # 共享层,ReLU激活函数nn.Linear(5, 1))  # 从隐藏层到输出层的线性层,无激活函数(线性回归)#显示真实参数(我们的标签就是用这个参数跑出来的),这也是我们最终需要拟合的参数
for name, param in net.named_parameters():print(name, param)#获取随机数作为样本
X = torch.randn((200, 10))
# 通过网络得到真实标签
True_label = get_label(net, X)
#一开始自动随机生成了参数已经被我当作真实参数了,此刻我需要另重新初始化参数
net.apply(init_weight)
#获取训练器
trainer = torch.optim.SGD(net.parameters(), lr=lr)
#获取损失函数
Loss = nn.MSELoss()  # 定义损失函数,使用均方误差。#开始训练模型发
train(net, trainer, Loss, X[:50], True_label[:50], epochs, batch_size=batch_size)
#打印测试损失
print(f'测试损失{test(net, Loss, X[50:], True_label[50:])}')

相关文章:

Pytorch 反向传播 计算图被修改的报错

先看看报错的内容 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable an…...

android studio设置gradle和gradle JDK版本

文章目录 1.gradle JDK版本2.gradle版本 1.gradle JDK版本 file -> project structure -> SDK Location -> Gradle Settings -> Gradle JDK -> Download JDK 2.gradle版本 file -> project structure -> Project...

Android 15即将到来,或将推出5大新功能特性

Android15 OneUI电池优化 三星最近完成了对其所有设备的稳定版 One UI 6.0 更新的推出,引起了用户的极大兴奋。据新出现的互联网统计数据显示,即将发布的基于 Android 15 的 One UI 7 将通过优化电池和功耗来重新定义用户体验,这是一项具有突…...

sqlalchemy 事务自动控制(类java aop)

最近使用它交互数据库,想实现类似java aop那种自动事务控制,不用手动commit或者rollback。我是用的是flaskdenpendency-injecter 这是我的db的配置类,里面会初始化一些session配置,里面比较重要的是把autocommit和autoflush关闭了…...

vue2-手写轮播图

轮播图5长展示&#xff0c;点击指示器向右移动一个图片&#xff0c;每隔2秒移动一张照片&#xff01; <template><div class"top-app"><div class"carousel-container"><div class"carousel" ref"carousel">&…...

Google I/O大会:Android 13

3个体验升级的方向 以智能手机为场景核心、 扩大智能终端的应用边界以及实现多设备间更好地协同。具体到系统体验层&#xff0c;安卓13将支持图标颜色随主题更换、为不同应用设定使用的语言、新的媒体中心界面等等&#xff0c;同时谷歌也推出了自家的钱包应用&#xff08;Goog…...

VUE指令(一)

vue会根据不同的指令&#xff0c;针对不同的标签实现不同的功能。指令是带有 v- 前缀的特殊标签属性。指令的职责是&#xff0c;当表达式的值改变时&#xff0c;将其产生的连带影响&#xff0c;响应式地作用于 DOM。 1、v-text&#xff1a;设置元素的文本内容&#xff0c;不会解…...

微信小程序开发学习笔记《7》全局配置以及小程序窗口

微信小程序开发学习笔记《7》全局配置以及小程序窗口 博主正在学习微信小程序开发&#xff0c;希望记录自己学习过程同时与广大网友共同学习讨论。全局配置官方文档 一、全局配置文件及常用的配置项 小程序根目录下的app.json 文件是小程序的全局配置文件。 常用的配置项如…...

Vue、uniApp、微信小程序、Html5等实现数缓存

此文章带你实现前端缓存&#xff0c;利用时间戳封装一个类似于Redis可以添加过期时间的缓存工具 不仅可以实现对缓存数据设置过期时间&#xff0c;还可以自定义是否需要对缓存数据进行加密处理 工具介绍说明 对缓存数据进行非对称加密处理 对必要数据进行缓存&#xff0c;并…...

如何将ArcGIS工程文件迁移到ArcGIS Pro内

当你刚接触ArcGIS Pro的时候&#xff0c;尝试新建一个工程文件会发现工程文件的后缀已经改变&#xff0c;那么以前在ArcGIS内辛苦制作的工程文件是否就不能在ArcGIS Pro内使用了&#xff0c;答案是否定的&#xff0c;对此Esri也给出了解决方案&#xff0c;这里为大家介绍一下迁…...

Jenkins基础篇--添加用户和用户权限设置

添加用户 点击系统管理&#xff0c;点击管理用户&#xff0c;然后点击创建用户&#xff08;Create User&#xff09; 用户权限管理 点击系统管理&#xff0c;点击全局安全配置&#xff0c;找到授权策略&#xff0c;选择安全矩阵&#xff0c;配置好用户权限后&#xff0c;点击…...

C语言基础内容(七)——第08章_C语言常用函数

文章目录 第08章_C语言常用函数本章专题脉络1、字符串相关函数1.1 字符串的表示方式1.2 两种方式的区别1.2 字符串常用函数strlen()strcpy()strncpy()strcat()strncat()strcmp()strlwr()/strupr()1.3 基本数据类型和字符串的转换基本数据类型 -> 字符串字符串 -> 基本数据…...

CRM系统针对销售管理有哪些功能?如何帮助销售效率增长?

从长远来看&#xff0c;有效的CRM管理系统可以帮助您的企业达到甚至超过收入目标。现代大多数企业都依靠CRM系统来管理其销售周期并增加收入。但是&#xff0c;当大多数人提到CRM时&#xff0c;他们指的是使能够改善业务关系并轻松管理不断团队的软件或工具。合格的CRM系统能够…...

基于Pixhawk和ROS搭建自主无人车(一):底盘控制篇

参考 ArduPilot Development超维空间科技乐迪MiniPix车船使用说明书 1. 硬件篇 1.1 底盘构成一览 1.2 底盘接线示意 2. 软件篇 2.1 APM 固件下载 pixhawk 是硬件平台&#xff0c;PX4 是 pixhawk 的原生固件&#xff0c;APM&#xff08;Ardupilot Mega&#xff09;是硬件平台…...

部署 Spring Boot 应用中文文档

本文为官方文档直译版本。原文链接 部署 Spring Boot 应用中文文档 引言部署到云Cloud Foundry与服务绑定 KubernetesKubernetes 容器生命周期 HerokuOpenShift亚马逊网络服务&#xff08;AWS&#xff09;AWS Elastic Beanstalk使用 Tomcat 平台使用 Java SE 平台 总结 CloudCa…...

【数据库原理】(23)实际应用中的查询优化方法

一.基于索引的优化 索引是数据库查询优化的关键工具之一。合理地使用索引可以显著提高查询速度&#xff0c;降低全表扫描的成本。以下是建立和使用索引的一些基本原则和最佳实践。 索引的建立与使用原则 数据量规模与查询频率: 值得建立索引的表通常具有较多的记录&#xff0…...

MySQL中datetime和timestamp的区别

datetime和timestamp的区别 相同点: 存储格式相同 datetime和timestamp两者的时间格式都是YYYY-MM-DD HH:MM:SS 不同点: 存储范围不同. datetime的范围是1000-01-01到9999-12-31. 而timestamp是从1970-01-01到2038-01-19, 即后者的时间范围很小. 与时区关系. datetime是存储…...

2024年如何使用WordPress构建克隆Udemy市场

您想创建像 Udemy 这样的学习管理 (LMS) 网站吗&#xff1f;最好的学习管理系统工具LifterLMS将帮助您制作像Udemy市场这样的 LMS 网站。 目录 Udemy市场是什么&#xff1f; 创建 Udemy 克隆所需的几项强制性技术&#xff1a; 步骤 1) 注册您的域名 步骤 2) 获取虚拟主…...

(leetcode)Z字形变换 -- 模拟算法

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 题目链接 . - 力扣&#xff08;LeetCode&#xff09; 输入描述 string convert(string s, int numRows)&#xff0c;输入一个字符串s&#xff0c;以及一个行数numRows&#xff0c;将字符串按照这个行数进行Z字形排列&…...

STM32--基于STM32F103的MAX30102心率血氧测量

本文介绍基于STM32F103ZET6MAX30102心率血氧测量0.96寸OLED&#xff08;7针&#xff09;显示&#xff08;完整程序代码见文末链接&#xff09; 一、简介 MAX30102是一个集成的脉搏血氧仪和心率监测仪生物传感器的模块。它集成了一个红光LED和一个红外光LED、光电检测器、光器…...

【读书笔记】《在远远的背后带领》

《在远远的背后带领》书话整理书名由来 "在远远的背后带领"这个书名&#xff0c;源于作者对十余年养育实践的回顾与思考。她发现&#xff0c;父母养育孩子容易走两个极端&#xff1a; 过度控制&#xff1a;强迫孩子按照自己的想法行事&#xff0c;结果双方俱疲&#…...

springboot+vue基于web的美食外卖点餐平台的设外卖员商家

目录同行可拿货,招校园代理 ,本人源头供货商外卖员功能分析商家功能分析技术实现要点项目技术支持源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作同行可拿货,招校园代理 ,本人源头供货商 外卖员功能分析 外卖员在美食外卖点餐平台中的核心…...

认知内耗:在亚马逊,为何品牌名内部的“关键词”正在相互厮杀

在亚马逊的品牌丛林中&#xff0c;最隐蔽的悲剧莫过于&#xff1a;你精心构思的品牌名&#xff0c;其内部的各个组成部分&#xff08;如“欧文斯”、“康宁”、“玻璃纤维”&#xff09;&#xff0c;并未协同指向你&#xff0c;反而各自激活了消费者心智中其他更强大品牌的“认…...

千问3.5-2B效果对比评测:与Qwen-VL-Chat基础版在OCR精度和响应速度上的实测差异

千问3.5-2B效果对比评测&#xff1a;与Qwen-VL-Chat基础版在OCR精度和响应速度上的实测差异 1. 评测背景与模型介绍 视觉语言模型正在改变我们与图像交互的方式。作为Qwen系列的最新成员&#xff0c;千问3.5-2B以其轻量级架构和高效性能引起了广泛关注。本次评测将聚焦于两个…...

致翔智慧校园招生迎新系统正式上线!一键解锁「零跑腿」入学新体验!

告别排长队、告别填不完的纸质表、告别来回跑、告别信息反复核对&#xff01;致翔智慧校园招生迎新管理系统重磅上线啦&#xff01;从招生报名到迎新报到&#xff0c;全流程数字化、一站式智能化&#xff0c;轻松搞定所有环节&#xff01;✨ 告别繁琐&#xff0c;新生入学超丝滑…...

CYBER-VISION零号协议互联网舆情智能监测与分析系统

CYBER-VISION零号协议&#xff1a;构建你的互联网舆情智能监测雷达 最近和几个做市场、公关的朋友聊天&#xff0c;他们都在抱怨同一个问题&#xff1a;每天花大量时间刷新闻、看社交媒体&#xff0c;就为了捕捉行业动态和用户反馈&#xff0c;生怕错过什么重要信息。人工监测…...

python异常模拟工具类(异常生成工具类)

文章目录创建代码类使用主要是做测试的时候方便&#xff0c;创建代码类 1、新建python文件exception_mock_utils.py&#xff0c;代码为&#xff1a; import random import time from typing import Any, Optionalclass ExceptionMockUtils:"""异常模拟工具类用…...

企业开始用 AI 后,最容易被忽略的其实是这件事!

这两年&#xff0c;越来越多企业开始尝试把 AI 用到日常办公中。从写邮件、整理纪要&#xff0c;到查询知识库、生成文档&#xff0c;AI 正在从个人工具变成企业工作的一部分。但很多企业在推进 AI 时&#xff0c;首先关注的往往是功能和效率&#xff0c;比如“能不能写”“能不…...

【Matlab】MATLAB教程:图形属性修改(案例:set(h,‘Color‘,‘red‘),应用:自定义图形样式)

MATLAB教程:图形属性修改(案例:set(h,Color,red),应用:自定义图形样式) 在MATLAB数据可视化、实验报告绘图、工程结果展示等场景中,默认绘制的图形往往难以满足个性化需求和规范要求。无论是调整线条颜色、粗细,还是优化坐标轴、图例样式,核心目标都是通过图形属性修…...

设计标注工具:解决团队协作痛点的高效解决方案

设计标注工具&#xff1a;解决团队协作痛点的高效解决方案 【免费下载链接】sketch-measure Make it a fun to create spec for developers and teammates 项目地址: https://gitcode.com/gh_mirrors/sk/sketch-measure 设计标注是连接设计与开发的重要环节&#xff0c;…...