【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

本文主要是介绍【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:
请添加图片描述
  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:
在这里插入图片描述
  当输入序列 x x x经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列 x x x上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context[:, -3:] = extramask[:, -3:] = 1d2 = self.bert(context, attention_mask=mask)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

请添加图片描述
  图中 w 1 w_1 w1的梯度计算如上图,损失函数为 E t o t a l E_{total} Etotal,最终 w 1 w_1 w1的梯度里是需要用到原始输入 i 1 i_1 i1的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context_B = copy.deepcopy(context)mask_B = copy.deepcopy(mask)context_B[:, -3:] = extramask_B[:, -3:] = 1d2 = self.bert_A(context_B, attention_mask=mask_B)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

这篇关于【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/349914

相关文章

Java中流式并行操作parallelStream的原理和使用方法

《Java中流式并行操作parallelStream的原理和使用方法》本文详细介绍了Java中的并行流(parallelStream)的原理、正确使用方法以及在实际业务中的应用案例,并指出在使用并行流... 目录Java中流式并行操作parallelStream0. 问题的产生1. 什么是parallelS

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Redis中Set结构使用过程与原理说明

《Redis中Set结构使用过程与原理说明》本文解析了RedisSet数据结构,涵盖其基本操作(如添加、查找)、集合运算(交并差)、底层实现(intset与hashtable自动切换机制)、典型应用场... 目录开篇:从购物车到Redis Set一、Redis Set的基本操作1.1 编程常用命令1.2 集

Redis中的有序集合zset从使用到原理分析

《Redis中的有序集合zset从使用到原理分析》Redis有序集合(zset)是字符串与分值的有序映射,通过跳跃表和哈希表结合实现高效有序性管理,适用于排行榜、延迟队列等场景,其时间复杂度低,内存占... 目录开篇:排行榜背后的秘密一、zset的基本使用1.1 常用命令1.2 Java客户端示例二、zse

Redis中的AOF原理及分析

《Redis中的AOF原理及分析》Redis的AOF通过记录所有写操作命令实现持久化,支持always/everysec/no三种同步策略,重写机制优化文件体积,与RDB结合可平衡数据安全与恢复效率... 目录开篇:从日记本到AOF一、AOF的基本执行流程1. 命令执行与记录2. AOF重写机制二、AOF的

java程序远程debug原理与配置全过程

《java程序远程debug原理与配置全过程》文章介绍了Java远程调试的JPDA体系,包含JVMTI监控JVM、JDWP传输调试命令、JDI提供调试接口,通过-Xdebug、-Xrunjdwp参数配... 目录背景组成模块间联系IBM对三个模块的详细介绍编程使用总结背景日常工作中,每个程序员都会遇到bu

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

Python之变量命名规则详解

《Python之变量命名规则详解》Python变量命名需遵守语法规范(字母开头、不使用关键字),遵循三要(自解释、明确功能)和三不要(避免缩写、语法错误、滥用下划线)原则,确保代码易读易维护... 目录1. 硬性规则2. “三要” 原则2.1. 要体现变量的 “实际作用”,拒绝 “无意义命名”2.2. 要让