Causal Attention论文详解

2023-12-20 01:10
文章标签 详解 论文 attention causal

本文主要是介绍Causal Attention论文详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 背景介绍

Causal Attention论文是一篇因果推断(causal inference)和注意力(attention)结合的一篇文章,主要用在视觉和文本结合的领域,如VQA(Visual Question Answering)视觉问答。

VQA(Visual Question Answering)视觉问答的一个基本流程如下,对输入图进行self-attn编程得到K和V的向量,从文本得到Q的向量进行Attn计算,得到填空的结果(riding)。这个过程可以看成是一个因果推断的过程,对应的示意图如下X->Z->Y,X是输入,Z是模型过程,Y是输出,箭头表示相互依赖的关系。

在这里插入图片描述

实际中由于训练数据中存在bias偏差会导致结果不对,比如下图,看图回答问题(在屏幕上显示的是什么运动),由于预训练数据中Sport+ManSport+Screen出现次数多的话,在回答时self-attn关注点会更注意Sport+Man(即下图红框部分,导致回答错误结果为跳舞)。为此这篇论文中提出了Causal Attention的方法。

2. 详细说明

2.1 因果推断confounder

在因果推断(causal inference)中有一个概念是confounder(也有叫Confounding factor), 中文意思是干扰因子,在因果推断中表示影响推导的不可知因素,举个例子如下,药物Drug会帮助恢复Recovery,但隐藏的因素是一个人的性别Gender可能会同时影响使用什么样的药物恢复效果。这里的性别就是confounder

在这里插入图片描述

这里的推断流程从 X → Y X \rightarrow Y XY 变为了 X ← Z → Y X \leftarrow Z \rightarrow Y XZY,用 P ( y ∣ d o ( x ) ) P(y|do(x)) P(ydo(x)) 表示无偏估计的结果,也就是针对了相关的confounder因素进行了调整后的结果。公式表示如下,当且仅当没有confounder时, P ( y ∣ d o ( x ) ) = P ( y ∣ x ) P(y|do(x)) = P(y|x) P(ydo(x))=P(yx)

P ( y ∣ d o ( x ) ) = ∑ z P ( y ∣ x , z ) P ( z ) \begin{gather*} P(y | do(x)) = \sum_zP(y|x, z) P(z) \end{gather*} P(ydo(x))=zP(yx,z)P(z)

针对上面例子,对应的 P ( Y = r e c o v e r e d ∣ d o ( X = g i v e d r u g ) ) P(Y=recovered | do(X=give\ drug)) P(Y=recovereddo(X=give drug)) 等于如下:

P ( Y = r e c o v e r e d ∣ d o ( X = g i v e d r u g ) ) = P ( Y = r e c o v e r e d ∣ X = g i v e d r u g , Z = m a l e ) P ( Z = m a l e ) + P ( Y = r e c o v e r e d ∣ X = g i v e d r u g , Z = f e m a l e ) P ( Z = f e m a l e ) \begin{gather*} P(Y=recovered | do(X=give\ drug)) = P(Y=recovered | X=give\ drug, Z=male) P(Z=male) + P(Y=recovered | X=give\ drug, Z=female) P(Z=female) \end{gather*} P(Y=recovereddo(X=give drug))=P(Y=recoveredX=give drug,Z=male)P(Z=male)+P(Y=recoveredX=give drug,Z=female)P(Z=female)

在训练过程中数据bias就是由于cofounder(这里也被称为common sense的常识)引起的,如下图,C表示常识,常识存在多种,person can ride horse是常识中的一种, X表示通过person can ride horse产生的一个图片和对应的prompt(person can ride ___),M表示通过Faster-RCNN检测出来的物体object(personhorse), Y表示语言模型产生的推理结果person can ride horse。在训练中一个理想合法的推导是 X → M → Y X \rightarrow M \rightarrow Y XMY,但实际中常识C也会对最终的结果Y有影响,即 X ← C → M → Y X \leftarrow C \rightarrow M \rightarrow Y XCMY。训练中计算的是按 P ( Y ∣ X ) P(Y|X) P(YX),而实际中应该按 P ( Y ∣ d o ( X ) ) P(Y|do(X)) P(Ydo(X)) 来计算。

在这里插入图片描述

2.2 Causal Attention公式表示

之前的attention机制可以看成是一个前向的因果推理图(X->Z->Y)。基于这个图Causal Attention中把attention拆为两部分,一个是选择器(selector),用于从数据X中选择合适的知识Z;另一个是推理器(predictor),通过选择的Z去探索推理结果Y

以VQA为例,训练集是已知的,也就是计算的可观测的P(Y|X), Z表示训练中已有的知识,由于Z可以看成是从X中抽样出来一部分数据,所以计算的部分也叫为IS-Sampling。公式如下:

在这里插入图片描述

在训练过程中抽样的数据集存在潜在的偏差(bias),即Z <- X <-> Y, 需要进行修正,ZY之前的因果影响表示为 P ( Y ∣ d o ( Z ) ) P(Y|do(Z)) P(Ydo(Z)), X -> Z的这部分可以通过对X进行拆解为多个不同的 { x } \{x\} {x} 来表示,公式如下, x表示可能的输入,这里叫做CS-Samping
在这里插入图片描述

最终公式(2)代入公式(1)得到如下结果:
在这里插入图片描述

2.3 Causal Attention网络实现

P ( Y ∣ Z , X ) P(Y|Z,X) P(YZ,X) 使用一个softmax层进行计算;如公式(3)所示,为了计算 P ( Y ∣ d o ( X ) ) P(Y|do(X)) P(Ydo(X)) 要对XZ进行采样,但是前向代价过大,所以采用了Normalized Weighted Geometric Mean (NWGM) 的近似方法,近似后公式如下, f ( ⋅ ) 、 h ( ⋅ ) f(\cdot)、h(\cdot) f()h() 表示把输入X进行embedding后成为两个query set。
在这里插入图片描述

使用attention进行表示上述计算的话,In-Sample attention (IS-ATT)的结果 Z ^ \hat{Z} Z^如下, K I 和 V I K_I 和 V_I KIVI来自当前的输入样本,如RoI的特征; Q I Q_I QI自于 h ( X ) h(X) h(X),在top-down attention中 q I q_I qI为上下文的embedding,在self-attention中 q I q_I qI也是RoI的特征。
在这里插入图片描述

Cross-Sample attention (CS-ATT)的结果 X ^ \hat{X} X^如下, K C 和 V C K_C 和 V_C KCVC来自训练集中的其他样本, Q C Q_C QC自于 f ( X ) f(X) f(X)
在这里插入图片描述

对应的网络图如下:
在这里插入图片描述

2.4 Causal Attention在堆叠attention网络中的应用

2.4.1 Transformer+CATT

在transformer中encoder和decoder实现如下图, [ V I ] E [V_I]_E [VI]E [ V C ] E [V_C]_E [VC]E分别表示为IS-ATTCS-ATT的encoder输出, Z ^ \hat{Z} Z^ X ^ \hat{X} X^表示IS-ATTCS-ATT的decoder输出。
在这里插入图片描述

2.4.2 LXMERT+CATT

在这里插入图片描述

3. 参考

  • Causal Attention for Vision-Language Tasks
  • 论文笔记:Causal Attention for Vision-Language Tasks
  • LXMERT: Learning Cross-Modality Encoder Representations from Transformers
  • LXMERT
  • Confounding
  • Confounders: machine learning’s blindspot

这篇关于Causal Attention论文详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/514350

相关文章

JVisualVM之Java性能监控与调优利器详解

《JVisualVM之Java性能监控与调优利器详解》本文将详细介绍JVisualVM的使用方法,并结合实际案例展示如何利用它进行性能调优,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1. JVisualVM简介2. JVisualVM的安装与启动2.1 启动JVisualVM2

Redis中的Lettuce使用详解

《Redis中的Lettuce使用详解》Lettuce是一个高级的、线程安全的Redis客户端,用于与Redis数据库交互,Lettuce是一个功能强大、使用方便的Redis客户端,适用于各种规模的J... 目录简介特点连接池连接池特点连接池管理连接池优势连接池配置参数监控常用监控工具通过JMX监控通过Pr

MySQL 添加索引5种方式示例详解(实用sql代码)

《MySQL添加索引5种方式示例详解(实用sql代码)》在MySQL数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中,下面给大家分享MySQL添加索引5种方式示例详解(实用sql代码),... 在mysql数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中。索引可以在创建表时定义,也可

C++ RabbitMq消息队列组件详解

《C++RabbitMq消息队列组件详解》:本文主要介绍C++RabbitMq消息队列组件的相关知识,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. RabbitMq介绍2. 安装RabbitMQ3. 安装 RabbitMQ 的 C++客户端库4. A

MySQL 存储引擎 MyISAM详解(最新推荐)

《MySQL存储引擎MyISAM详解(最新推荐)》使用MyISAM存储引擎的表占用空间很小,但是由于使用表级锁定,所以限制了读/写操作的性能,通常用于中小型的Web应用和数据仓库配置中的只读或主要... 目录mysql 5.5 之前默认的存储引擎️‍一、MyISAM 存储引擎的特性️‍二、MyISAM 的主

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元

mybatis的mapper对应的xml写法及配置详解

《mybatis的mapper对应的xml写法及配置详解》这篇文章给大家介绍mybatis的mapper对应的xml写法及配置详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录前置mapper 对应 XML 基础配置mapper 对应 xml 复杂配置Mapper 中的相

MySQL 事务的概念及ACID属性和使用详解

《MySQL事务的概念及ACID属性和使用详解》MySQL通过多线程实现存储工作,因此在并发访问场景中,事务确保了数据操作的一致性和可靠性,下面通过本文给大家介绍MySQL事务的概念及ACID属性和... 目录一、什么是事务二、事务的属性及使用2.1 事务的 ACID 属性2.2 为什么存在事务2.3 事务

MySQL表空间结构详解表空间到段页操作

《MySQL表空间结构详解表空间到段页操作》在MySQL架构和存储引擎专题中介绍了使用不同存储引擎创建表时生成的表空间数据文件,在本章节主要介绍使用InnoDB存储引擎创建表时生成的表空间数据文件,对... 目录️‍一、什么是表空间结构1.1 表空间与表空间文件的关系是什么?️‍二、用户数据在表空间中是怎么

python3 pip终端出现错误解决的方法详解

《python3pip终端出现错误解决的方法详解》这篇文章主要为大家详细介绍了python3pip如果在终端出现错误该如何解决,文中的示例方法讲解详细,感兴趣的小伙伴可以跟随小编一起了解一下... 目录前言一、查看是否已安装pip二、查看是否添加至环境变量1.查看环境变量是http://www.cppcns