Bahdanau注意力机制

2024-08-30 04:28
文章标签 机制 注意力 bahdanau

本文主要是介绍Bahdanau注意力机制,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

在Bahadanu注意力机制中,本质上是序列到序列学习的注意力机制实现,在编码器-解码器结构中,解码器的每一步解码过程都依赖着整个上下文变量,通过Bahdanau注意力,使得解码器在每一步解码时,对于整个上下文变量的不同部分产生不同程度的对齐,如在文本翻译时,将“I am studying”的“studying”与“我正在学习”的“学习”进行对齐,即注意力在解码时将绝大多数注意力放在“studying”处。

原理和结构

原理

Bahdanau注意力机制本质上是将上下文变量进行转换即可,其中转换后的上下文变量计算方式如下式所示:

c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},h_t)h_t

在传统注意力机制中,一般使用的公式形如c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},k) V,在Bahdanau中,键与值是同一个变量,都是t时刻的编码器隐状态,s表示该时刻的查询,即上一时刻的解码器隐状态。

架构

下图为Bahdanau注意力机制的编码器-解码器架构示意图:

为便于理解,对上述示意结构进行说明:首先将X依次输入GRU,之后在循环过程中依次产生len个隐状态,最后一个隐状态h_{len}直接作为解码器的初始隐状态。在每个解码步骤 (t),注意力机制计算当前解码器隐藏状态 (s_t) 和编码器所有隐藏状态 (h_i) 的相似度,即应用注意力机制编写新的上下文变量,之后在解码器的循环解码过程中,都计算带注意力的上下文变量,通过此变量和上一解码隐状态计算当前时刻t的输出。

代码实现

引入库

注:本blog使用mxnet进行训练学习。

from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2lnpx.set_np()

 定义注意力解码器

这里实现一个接口,只需重新定义解码器即可。 为了更方便地显示学习的注意力权重, 以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

class AttentionDecoder(d2l.Decoder):def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)def attention_weights(self):raise NotImplementedError

接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。 首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

  3. 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。 因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

对接下来的代码实现略作补充说明:在编码器中,对一个batch每个输入(采用one-hot编码,长度为Vocab_size)依次进行嵌入层运算,得到固定embed_size个结果之后进行RNN运算,RNN使用层数为num_layers的深层循环神经网络,进行forward运算得到状态,部分过程进行闭包和解包,对于对维度大小出现疑惑的点,大多是闭包和解包造成的。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)self.dense = nn.Dense(vocab_size, flatten=False)def init_state(self, enc_outputs, enc_valid_lens, *args):outputs, hidden_state = enc_outputsreturn (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens)def forward(self, X, state):enc_outputs, hidden_state, enc_valid_lens = stateX = self.embedding(X).swapaxes(0, 1)outputs, self._attention_weights = [], []for x in X:query = np.expand_dims(hidden_state[0][-1], axis=1)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1)out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)outputs = self.dense(np.concatenate(outputs, axis=0))return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,enc_valid_lens]def attention_weights(self):return self._attention_weights

训练

我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

结果 

采用BLEU计算困惑度,代码具有较好的表现。

go . => entre ., bleu 0.000
i lost . => j'ai gagné ., bleu 0.000
he's calm . => j'ai gagné ., bleu 0.000
i'm home . => je suis chez moi <unk> !, bleu 0.719

这篇关于Bahdanau注意力机制的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1119792

相关文章

基于Redis自动过期的流处理暂停机制

《基于Redis自动过期的流处理暂停机制》基于Redis自动过期的流处理暂停机制是一种高效、可靠且易于实现的解决方案,防止延时过大的数据影响实时处理自动恢复处理,以避免积压的数据影响实时性,下面就来详... 目录核心思路代码实现1. 初始化Redis连接和键前缀2. 接收数据时检查暂停状态3. 检测到延时过

Redis中哨兵机制和集群的区别及说明

《Redis中哨兵机制和集群的区别及说明》Redis哨兵通过主从复制实现高可用,适用于中小规模数据;集群采用分布式分片,支持动态扩展,适合大规模数据,哨兵管理简单但扩展性弱,集群性能更强但架构复杂,根... 目录一、架构设计与节点角色1. 哨兵机制(Sentinel)2. 集群(Cluster)二、数据分片

深入理解go中interface机制

《深入理解go中interface机制》本文主要介绍了深入理解go中interface机制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录前言interface使用类型判断总结前言go的interface是一组method的集合,不

C# async await 异步编程实现机制详解

《C#asyncawait异步编程实现机制详解》async/await是C#5.0引入的语法糖,它基于**状态机(StateMachine)**模式实现,将异步方法转换为编译器生成的状态机类,本... 目录一、async/await 异步编程实现机制1.1 核心概念1.2 编译器转换过程1.3 关键组件解析

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连

Spring Security 单点登录与自动登录机制的实现原理

《SpringSecurity单点登录与自动登录机制的实现原理》本文探讨SpringSecurity实现单点登录(SSO)与自动登录机制,涵盖JWT跨系统认证、RememberMe持久化Token... 目录一、核心概念解析1.1 单点登录(SSO)1.2 自动登录(Remember Me)二、代码分析三、

Go语言并发之通知退出机制的实现

《Go语言并发之通知退出机制的实现》本文主要介绍了Go语言并发之通知退出机制的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1、通知退出机制1.1 进程/main函数退出1.2 通过channel退出1.3 通过cont

Spring Boot 中的默认异常处理机制及执行流程

《SpringBoot中的默认异常处理机制及执行流程》SpringBoot内置BasicErrorController,自动处理异常并生成HTML/JSON响应,支持自定义错误路径、配置及扩展,如... 目录Spring Boot 异常处理机制详解默认错误页面功能自动异常转换机制错误属性配置选项默认错误处理

Java中的xxl-job调度器线程池工作机制

《Java中的xxl-job调度器线程池工作机制》xxl-job通过快慢线程池分离短时与长时任务,动态降级超时任务至慢池,结合异步触发和资源隔离机制,提升高频调度的性能与稳定性,支撑高并发场景下的可靠... 目录⚙️ 一、调度器线程池的核心设计 二、线程池的工作流程 三、线程池配置参数与优化 四、总结:线程

Android ClassLoader加载机制详解

《AndroidClassLoader加载机制详解》Android的ClassLoader负责加载.dex文件,基于双亲委派模型,支持热修复和插件化,需注意类冲突、内存泄漏和兼容性问题,本文给大家介... 目录一、ClassLoader概述1.1 类加载的基本概念1.2 android与Java Class