【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

本文主要是介绍【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:
请添加图片描述
  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:
在这里插入图片描述
  当输入序列 x x x经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列 x x x上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context[:, -3:] = extramask[:, -3:] = 1d2 = self.bert(context, attention_mask=mask)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

请添加图片描述
  图中 w 1 w_1 w1的梯度计算如上图,损失函数为 E t o t a l E_{total} Etotal,最终 w 1 w_1 w1的梯度里是需要用到原始输入 i 1 i_1 i1的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context_B = copy.deepcopy(context)mask_B = copy.deepcopy(mask)context_B[:, -3:] = extramask_B[:, -3:] = 1d2 = self.bert_A(context_B, attention_mask=mask_B)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

这篇关于【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/349914

相关文章

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

Golang HashMap实现原理解析

《GolangHashMap实现原理解析》HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持高效的插入、查找和删除操作,:本文主要介绍GolangH... 目录HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

usb接口驱动异常问题常用解决方案

《usb接口驱动异常问题常用解决方案》当遇到USB接口驱动异常时,可以通过多种方法来解决,其中主要就包括重装USB控制器、禁用USB选择性暂停设置、更新或安装新的主板驱动等... usb接口驱动异常怎么办,USB接口驱动异常是常见问题,通常由驱动损坏、系统更新冲突、硬件故障或电源管理设置导致。以下是常用解决

Windows Docker端口占用错误及解决方案总结

《WindowsDocker端口占用错误及解决方案总结》在Windows环境下使用Docker容器时,端口占用错误是开发和运维中常见且棘手的问题,本文将深入剖析该问题的成因,介绍如何通过查看端口分配... 目录引言Windows docker 端口占用错误及解决方案汇总端口冲突形成原因解析诊断当前端口情况解

Vue3组件中getCurrentInstance()获取App实例,但是返回null的解决方案

《Vue3组件中getCurrentInstance()获取App实例,但是返回null的解决方案》:本文主要介绍Vue3组件中getCurrentInstance()获取App实例,但是返回nu... 目录vue3组件中getCurrentInstajavascriptnce()获取App实例,但是返回n

Spring Boot循环依赖原理、解决方案与最佳实践(全解析)

《SpringBoot循环依赖原理、解决方案与最佳实践(全解析)》循环依赖指两个或多个Bean相互直接或间接引用,形成闭环依赖关系,:本文主要介绍SpringBoot循环依赖原理、解决方案与最... 目录一、循环依赖的本质与危害1.1 什么是循环依赖?1.2 核心危害二、Spring的三级缓存机制2.1 三

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你