【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

本文主要是介绍【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:
请添加图片描述
  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:
在这里插入图片描述
  当输入序列 x x x经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列 x x x上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context[:, -3:] = extramask[:, -3:] = 1d2 = self.bert(context, attention_mask=mask)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

请添加图片描述
  图中 w 1 w_1 w1的梯度计算如上图,损失函数为 E t o t a l E_{total} Etotal,最终 w 1 w_1 w1的梯度里是需要用到原始输入 i 1 i_1 i1的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):context = x[0]  # 输入的句子mask = x[2]  d1 = self.bert(context, attention_mask=mask)logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]extra = label_A_emb[idx]context_B = copy.deepcopy(context)mask_B = copy.deepcopy(mask)context_B[:, -3:] = extramask_B[:, -3:] = 1d2 = self.bert_A(context_B, attention_mask=mask_B)logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]return logit1, logit2

这篇关于【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/349914

相关文章

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

SQL Server跟踪自动统计信息更新实战指南

《SQLServer跟踪自动统计信息更新实战指南》本文详解SQLServer自动统计信息更新的跟踪方法,推荐使用扩展事件实时捕获更新操作及详细信息,同时结合系统视图快速检查统计信息状态,重点强调修... 目录SQL Server 如何跟踪自动统计信息更新:深入解析与实战指南 核心跟踪方法1️⃣ 利用系统目录

Java.lang.InterruptedException被中止异常的原因及解决方案

《Java.lang.InterruptedException被中止异常的原因及解决方案》Java.lang.InterruptedException是线程被中断时抛出的异常,用于协作停止执行,常见于... 目录报错问题报错原因解决方法Java.lang.InterruptedException 是 Jav

kkFileView在线预览office的常见问题以及解决方案

《kkFileView在线预览office的常见问题以及解决方案》kkFileView在线预览Office常见问题包括base64编码配置、Office组件安装、乱码处理及水印添加,解决方案涉及版本适... 目录kkFileView在线预览office的常见问题1.base642.提示找不到OFFICE组件

SpringBoot监控API请求耗时的6中解决解决方案

《SpringBoot监控API请求耗时的6中解决解决方案》本文介绍SpringBoot中记录API请求耗时的6种方案,包括手动埋点、AOP切面、拦截器、Filter、事件监听、Micrometer+... 目录1. 简介2.实战案例2.1 手动记录2.2 自定义AOP记录2.3 拦截器技术2.4 使用Fi

Spring Security 单点登录与自动登录机制的实现原理

《SpringSecurity单点登录与自动登录机制的实现原理》本文探讨SpringSecurity实现单点登录(SSO)与自动登录机制,涵盖JWT跨系统认证、RememberMe持久化Token... 目录一、核心概念解析1.1 单点登录(SSO)1.2 自动登录(Remember Me)二、代码分析三、

Go语言并发之通知退出机制的实现

《Go语言并发之通知退出机制的实现》本文主要介绍了Go语言并发之通知退出机制的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1、通知退出机制1.1 进程/main函数退出1.2 通过channel退出1.3 通过cont

Spring Boot 中的默认异常处理机制及执行流程

《SpringBoot中的默认异常处理机制及执行流程》SpringBoot内置BasicErrorController,自动处理异常并生成HTML/JSON响应,支持自定义错误路径、配置及扩展,如... 目录Spring Boot 异常处理机制详解默认错误页面功能自动异常转换机制错误属性配置选项默认错误处理

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分