深度学习代码|Multi-Headed Attention (MHA)多头注意力机制的代码实现

本文主要是介绍深度学习代码|Multi-Headed Attention (MHA)多头注意力机制的代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

相关文章

李沐《动手学深度学习》注意力机制

文章目录

  • 相关文章
  • 一、导入相关库
  • 二、准备工作
    • (一)理论基础
    • (二)定义PrepareForMultiHeadAttention模块
  • 三、多头注意模块
    • (一)理论基础
    • (二)创建MultiHeadAttention模块


一、导入相关库

import math
from typing import Optional, List# 从 typing 模块中导入 Optional 和 List 类型,用于类型提示import torch
from torch import nnfrom labml import tracker #跟踪实验中的指标和损失值等信息

二、准备工作

(一)理论基础

在多头自注意力机制中,每个注意力头都需要一个查询(Query)、一个键(Key)和一个值(Value)向量。这些向量通过线性变换从输入特征中提取而来,然后用于计算注意力权重和加权求和。具体来说,给定输入特征张量 X X X,我们首先通过三个线性变换来计算查询 Q Q Q、键 K K K和值向量 V V V
Q = X ⋅ W Q K = X ⋅ W K V = X ⋅ W V Q = X \cdot W_Q\\ K = X \cdot W_K\\ V = X \cdot W_V Q=XWQK=XWKV=XWV

其中, W Q W_Q WQ W K W_K WK W V W_V WV 是学习到的权重矩阵,用于将输入特征 X X X 映射到查询、键和值向量的空间中。这些线性变换可以通过 PyTorch 中的 nn.Linear 层来实现。

(二)定义PrepareForMultiHeadAttention模块

该模块用于准备多头自注意力机制中的查询、键和值向量:

  • 定义多头自注意力机制中的线性变换操作(在自注意力机制中,需要将输入的特征向量通过线性变换映射到不同的空间中,以便进行多头注意力的计算。)
  • 将向量拆分为给定数量的头部,以获得多头注意。
class PrepareForMultiHeadAttention(nn.module):'''d_model:模型输入的特征维度;heads:注意力机制中的头数;d_k:每个头部中以向量表示的维度数;bias:是否使用偏置项'''def __init__(self,d_model:int,heads:int,d_k:int,bias:bool):super().__init__()#线性变换的线性层,输入为d_model,输出为heads*d_kself.linear=nn.Linear(d_model,heads*d_k,bias=bias) self.heads=heads self.d_k=d_k '''x的形状为:seq_len,batch_size,d_model或batch_size,d_model输出形状:seq_len,batch_size,heads,d_k或batch_size,heads,d_k'''def forward(self,x:torch.Tensor):#获取输入张量 x 的形状,去掉最后一个维度,得到一个形状为 (seq_len, batch_size) 或 (batch_size,) 的元组head_shape=x.shape[:-1]x=self.linear(x)#将线性变换后的张量进行重塑操作,将最后一个维度拆分为heads个头部,每个头部的维度为d_kx=x.view(*head_shape,self.heads,self.d_k)return x

pytorch中的view方法:用于对张量进行重塑(reshape)。其作用是将张量的形状变换为指定的形状,但是要求变换后的形状与原始形状的元素数量保持一致。

* 表示解包(unpacking)操作符。在函数调用或函数定义中,*args 表示将传入的参数打包成一个元组,而在函数调用时,*args 则表示将元组解包为独立的参数。

三、多头注意模块

(一)理论基础

通过计算查询 Q Q Q和键 K K K之间的点积,并应用缩放因子 1 d k \frac{1}{\sqrt{d_k}} dk 1,得到注意力权重:
α = s o f t m a x ( Q K T d k ) \alpha=softmax(\frac{QK^T}{\sqrt{d_k}}) α=softmax(dk QKT)

利用注意力权重对值向量 V V V进行加权求和,得到最终的输出向量
A t t e n t i o n ( Q , K , V ) = α V = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=\alpha V=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=αV=softmax(dk QKT)V
这样,对于每个注意力头,都可以得到一个输出向量,在多头自注意力机制中,会并行地进行多个这样的注意力头的计算,最后将他们的输出向量连接起来,形成最终的输出。

(二)创建MultiHeadAttention模块

  • __init__:初始化对象的属性
  • get_score:计算查询和键之间的分数 S i j b h = ∑ d Q i b h d K j b h d S_{ijbh}=\sum_dQ_{ibhd}K_{jbhd} Sijbh=dQibhdKjbhd
  • prepare_mask:对注意力遮罩进行处理,使其与后续计算注意力权重的张量形状相匹配
  • forward:前向计算过程

注意力遮罩(Attention Mask)用于指示哪些位置的信息是有效的,哪些位置是无效的。在自注意力机制中,有时需要对注意力权重进行调整,以便在计算注意力时忽略某些位置的信息,或者对某些位置的信息赋予特定的权重。
注意力遮罩主要应用在查询和键之间的相似度计算过程中,用于调整或者限制查询和键之间的关系。

class MultiHeadAttention(nn.Module):'''heads:头的数量。d_model:query 、key 和value 向量中的要素数d_k:每头特征数'''def __init__(self,heads:int,d_model:int,dropout_prob:float=0.1,bias=True):super().__init__()self.d_k=d_model // heads #计算每个头部的查询、键和值的维度self.heads=headsself.query=PrepareForMultiHeadAttention(d_model,heads,self.d_k,bias=bias)self.key=PrepareForeMultiHeadAttention(d_model,heads,self.d_k,bias=bias)self.value=PrepareForMultiHeadAttention(d_model,heads,self.d_k,bias=True)#创建一个 Softmax 层,用于计算注意力权重,dim=1 表示在时间维度上进行 Softmax 计算。self.softmax=nn.Softmax(dim=1)#创建一个 Dropout 层,用于在训练过程中进行随机失活self.dropout=nn.Dropout(dropout_prob)#计算缩放因子,这里使用了倒数的平方根进行缩放self.scale=1/math.sqrt(self.d_k)#初始化一个属性 attn,用于存储注意力权重self.attn=Nonedef get_scores(self,query:torch.Tensor,key:torch.Tensor):return torch.einsum('ibhd,jbhd -> ijbh',query,key)'''mask: 输入的注意力遮罩,形状为(seq_len_q, seq_len_k, batch_size)。query_shape: 查询张量的形状,其中包含序列长度和批次大小。key_shape: 键张量的形状,其中包含序列长度和批次大小。'''def prepare_mask(self,mask:torch.Tensor,query_shape:List[int],key_shape:List[int]):#确保输入的注意力遮罩的第一个维度的大小与查询张量的序列长度维度大小匹配或者为1assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]#确保输入的注意力遮罩的第二个维度大小与键张量的序列长度维度大小匹配assert mask.shape[1] == key_shape[0]#确保输入的注意力遮罩的第三个维度的大小与查询张量的批次大小维度大小匹配或者为1assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]#将注意力遮罩的最后一个维度扩展一个维度#形状从(seq_len_q, seq_len_k, batch_size)扩展为(seq_len_q, seq_len_k, batch_size, heads)mask = mask.unsqueeze(-1)return mask'''query、key、value的形状:[seq_len, batch_size, d_model]mask的形状:[seq_len, seq_len, batch_size] '''def forward(self,*,query:torch.Tensor,key:torch.Tensor,value:torch.Tensor,mask:Optional[torch.Tensor] = None):seq_len,batch_size,_ = query.shapeif mask is not None:mask=self.prepare_mask(mask,query.shape,key.shape)#query、key、value的形状经过处理后变为:[seq_len,batch_size,heads,d_k]query=self.query(query)key=self.key(key)value=self.value(value)#计算注意力分数scores=self.get_scores(query,key)#应用缩放因子scores *= self.scale#应用maskif mask is not None:#将遮罩中值为 0 的位置替换为负无穷,这样在计算 Softmax 时对应位置的注意力权重将为 0scores=scores.masked_fill(mask==0,float('-inf'))#计算注意力权重attn=self.softmax(scores)#在调试过程中输出注意力权重 attn 的信息tracker.debug('attn',attn)#应用dropoutattn=self.dropout(attn)#加权求和x=torch.einsum("ijbh,jbhd->ibhd",attn,value)#将注意力权重从计算图中分离出来,以避免在反向传播过程中对注意力权重的梯度进行更新self.attn=attn.detach()#连接多个头:将经过加权求和得到的向量按照特定形状重新排列x=x.reshape(seq_len,batch_size,-1)return self.output(x)

参考:https://github.com/labmlai/annotated_deep_learning_paper_implementations?tab=readme-ov-file

这篇关于深度学习代码|Multi-Headed Attention (MHA)多头注意力机制的代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/693316

相关文章

基于 HTML5 Canvas 实现图片旋转与下载功能(完整代码展示)

《基于HTML5Canvas实现图片旋转与下载功能(完整代码展示)》本文将深入剖析一段基于HTML5Canvas的代码,该代码实现了图片的旋转(90度和180度)以及旋转后图片的下载... 目录一、引言二、html 结构分析三、css 样式分析四、JavaScript 功能实现一、引言在 Web 开发中,

SpringBoot中使用Flux实现流式返回的方法小结

《SpringBoot中使用Flux实现流式返回的方法小结》文章介绍流式返回(StreamingResponse)在SpringBoot中通过Flux实现,优势包括提升用户体验、降低内存消耗、支持长连... 目录背景流式返回的核心概念与优势1. 提升用户体验2. 降低内存消耗3. 支持长连接与实时通信在Sp

Conda虚拟环境的复制和迁移的四种方法实现

《Conda虚拟环境的复制和迁移的四种方法实现》本文主要介绍了Conda虚拟环境的复制和迁移的四种方法实现,包括requirements.txt,environment.yml,conda-pack,... 目录在本机复制Conda虚拟环境相同操作系统之间复制环境方法一:requirements.txt方法

Spring Boot 实现 IP 限流的原理、实践与利弊解析

《SpringBoot实现IP限流的原理、实践与利弊解析》在SpringBoot中实现IP限流是一种简单而有效的方式来保障系统的稳定性和可用性,本文给大家介绍SpringBoot实现IP限... 目录一、引言二、IP 限流原理2.1 令牌桶算法2.2 漏桶算法三、使用场景3.1 防止恶意攻击3.2 控制资源

Python如何去除图片干扰代码示例

《Python如何去除图片干扰代码示例》图片降噪是一个广泛应用于图像处理的技术,可以提高图像质量和相关应用的效果,:本文主要介绍Python如何去除图片干扰的相关资料,文中通过代码介绍的非常详细,... 目录一、噪声去除1. 高斯噪声(像素值正态分布扰动)2. 椒盐噪声(随机黑白像素点)3. 复杂噪声(如伪

springboot下载接口限速功能实现

《springboot下载接口限速功能实现》通过Redis统计并发数动态调整每个用户带宽,核心逻辑为每秒读取并发送限定数据量,防止单用户占用过多资源,确保整体下载均衡且高效,本文给大家介绍spring... 目录 一、整体目标 二、涉及的主要类/方法✅ 三、核心流程图解(简化) 四、关键代码详解1️⃣ 设置

Java Spring ApplicationEvent 代码示例解析

《JavaSpringApplicationEvent代码示例解析》本文解析了Spring事件机制,涵盖核心概念(发布-订阅/观察者模式)、代码实现(事件定义、发布、监听)及高级应用(异步处理、... 目录一、Spring 事件机制核心概念1. 事件驱动架构模型2. 核心组件二、代码示例解析1. 事件定义

Nginx 配置跨域的实现及常见问题解决

《Nginx配置跨域的实现及常见问题解决》本文主要介绍了Nginx配置跨域的实现及常见问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来... 目录1. 跨域1.1 同源策略1.2 跨域资源共享(CORS)2. Nginx 配置跨域的场景2.1

Python中提取文件名扩展名的多种方法实现

《Python中提取文件名扩展名的多种方法实现》在Python编程中,经常会遇到需要从文件名中提取扩展名的场景,Python提供了多种方法来实现这一功能,不同方法适用于不同的场景和需求,包括os.pa... 目录技术背景实现步骤方法一:使用os.path.splitext方法二:使用pathlib模块方法三

CSS实现元素撑满剩余空间的五种方法

《CSS实现元素撑满剩余空间的五种方法》在日常开发中,我们经常需要让某个元素占据容器的剩余空间,本文将介绍5种不同的方法来实现这个需求,并分析各种方法的优缺点,感兴趣的朋友一起看看吧... css实现元素撑满剩余空间的5种方法 在日常开发中,我们经常需要让某个元素占据容器的剩余空间。这是一个常见的布局需求