Bahdanau注意力机制

2024-08-30 04:28
文章标签 机制 注意力 bahdanau

本文主要是介绍Bahdanau注意力机制,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

在Bahadanu注意力机制中,本质上是序列到序列学习的注意力机制实现,在编码器-解码器结构中,解码器的每一步解码过程都依赖着整个上下文变量,通过Bahdanau注意力,使得解码器在每一步解码时,对于整个上下文变量的不同部分产生不同程度的对齐,如在文本翻译时,将“I am studying”的“studying”与“我正在学习”的“学习”进行对齐,即注意力在解码时将绝大多数注意力放在“studying”处。

原理和结构

原理

Bahdanau注意力机制本质上是将上下文变量进行转换即可,其中转换后的上下文变量计算方式如下式所示:

c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},h_t)h_t

在传统注意力机制中,一般使用的公式形如c_{t^{'}}=\Sigma_{t=1}^T \alpha (s_{t^{'}-1},k) V,在Bahdanau中,键与值是同一个变量,都是t时刻的编码器隐状态,s表示该时刻的查询,即上一时刻的解码器隐状态。

架构

下图为Bahdanau注意力机制的编码器-解码器架构示意图:

为便于理解,对上述示意结构进行说明:首先将X依次输入GRU,之后在循环过程中依次产生len个隐状态,最后一个隐状态h_{len}直接作为解码器的初始隐状态。在每个解码步骤 (t),注意力机制计算当前解码器隐藏状态 (s_t) 和编码器所有隐藏状态 (h_i) 的相似度,即应用注意力机制编写新的上下文变量,之后在解码器的循环解码过程中,都计算带注意力的上下文变量,通过此变量和上一解码隐状态计算当前时刻t的输出。

代码实现

引入库

注:本blog使用mxnet进行训练学习。

from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2lnpx.set_np()

 定义注意力解码器

这里实现一个接口,只需重新定义解码器即可。 为了更方便地显示学习的注意力权重, 以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

class AttentionDecoder(d2l.Decoder):def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)def attention_weights(self):raise NotImplementedError

接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。 首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

  3. 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。 因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

对接下来的代码实现略作补充说明:在编码器中,对一个batch每个输入(采用one-hot编码,长度为Vocab_size)依次进行嵌入层运算,得到固定embed_size个结果之后进行RNN运算,RNN使用层数为num_layers的深层循环神经网络,进行forward运算得到状态,部分过程进行闭包和解包,对于对维度大小出现疑惑的点,大多是闭包和解包造成的。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)self.dense = nn.Dense(vocab_size, flatten=False)def init_state(self, enc_outputs, enc_valid_lens, *args):outputs, hidden_state = enc_outputsreturn (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens)def forward(self, X, state):enc_outputs, hidden_state, enc_valid_lens = stateX = self.embedding(X).swapaxes(0, 1)outputs, self._attention_weights = [], []for x in X:query = np.expand_dims(hidden_state[0][-1], axis=1)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1)out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)outputs = self.dense(np.concatenate(outputs, axis=0))return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,enc_valid_lens]def attention_weights(self):return self._attention_weights

训练

我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

结果 

采用BLEU计算困惑度,代码具有较好的表现。

go . => entre ., bleu 0.000
i lost . => j'ai gagné ., bleu 0.000
he's calm . => j'ai gagné ., bleu 0.000
i'm home . => je suis chez moi <unk> !, bleu 0.719

这篇关于Bahdanau注意力机制的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1119792

相关文章

Spring事务传播机制最佳实践

《Spring事务传播机制最佳实践》Spring的事务传播机制为我们提供了优雅的解决方案,本文将带您深入理解这一机制,掌握不同场景下的最佳实践,感兴趣的朋友一起看看吧... 目录1. 什么是事务传播行为2. Spring支持的七种事务传播行为2.1 REQUIRED(默认)2.2 SUPPORTS2

MySQL中的锁机制详解之全局锁,表级锁,行级锁

《MySQL中的锁机制详解之全局锁,表级锁,行级锁》MySQL锁机制通过全局、表级、行级锁控制并发,保障数据一致性与隔离性,全局锁适用于全库备份,表级锁适合读多写少场景,行级锁(InnoDB)实现高并... 目录一、锁机制基础:从并发问题到锁分类1.1 并发访问的三大问题1.2 锁的核心作用1.3 锁粒度分

Redis的持久化之RDB和AOF机制详解

《Redis的持久化之RDB和AOF机制详解》:本文主要介绍Redis的持久化之RDB和AOF机制,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录概述RDB(Redis Database)核心原理触发方式手动触发自动触发AOF(Append-Only File)核

PostgreSQL中MVCC 机制的实现

《PostgreSQL中MVCC机制的实现》本文主要介绍了PostgreSQL中MVCC机制的实现,通过多版本数据存储、快照隔离和事务ID管理实现高并发读写,具有一定的参考价值,感兴趣的可以了解一下... 目录一 MVCC 基本原理python1.1 MVCC 核心概念1.2 与传统锁机制对比二 Postg

Maven 配置中的 <mirror>绕过 HTTP 阻断机制的方法

《Maven配置中的<mirror>绕过HTTP阻断机制的方法》:本文主要介绍Maven配置中的<mirror>绕过HTTP阻断机制的方法,本文给大家分享问题原因及解决方案,感兴趣的朋友一... 目录一、问题场景:升级 Maven 后构建失败二、解决方案:通过 <mirror> 配置覆盖默认行为1. 配置示

Redis过期删除机制与内存淘汰策略的解析指南

《Redis过期删除机制与内存淘汰策略的解析指南》在使用Redis构建缓存系统时,很多开发者只设置了EXPIRE但却忽略了背后Redis的过期删除机制与内存淘汰策略,下面小编就来和大家详细介绍一下... 目录1、简述2、Redis http://www.chinasem.cn的过期删除策略(Key Expir

Go语言中Recover机制的使用

《Go语言中Recover机制的使用》Go语言的recover机制通过defer函数捕获panic,实现异常恢复与程序稳定性,具有一定的参考价值,感兴趣的可以了解一下... 目录引言Recover 的基本概念基本代码示例简单的 Recover 示例嵌套函数中的 Recover项目场景中的应用Web 服务器中

Jvm sandbox mock机制的实践过程

《Jvmsandboxmock机制的实践过程》:本文主要介绍Jvmsandboxmock机制的实践过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、背景二、定义一个损坏的钟1、 Springboot工程中创建一个Clock类2、 添加一个Controller

Dubbo之SPI机制的实现原理和优势分析

《Dubbo之SPI机制的实现原理和优势分析》:本文主要介绍Dubbo之SPI机制的实现原理和优势,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Dubbo中SPI机制的实现原理和优势JDK 中的 SPI 机制解析Dubbo 中的 SPI 机制解析总结Dubbo中

Java 的 Condition 接口与等待通知机制详解

《Java的Condition接口与等待通知机制详解》在Java并发编程里,实现线程间的协作与同步是极为关键的任务,本文将深入探究Condition接口及其背后的等待通知机制,感兴趣的朋友一起看... 目录一、引言二、Condition 接口概述2.1 基本概念2.2 与 Object 类等待通知方法的区别