PyTorch nn.MultiHead() 参数理解

2024-01-27 01:08

本文主要是介绍PyTorch nn.MultiHead() 参数理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

之前一直是自己实现MultiHead Self-Attention程序,代码段又臭又长。后来发现Pytorch 早已经有API nn.MultiHead()函数,但是使用时我却遇到了很大的麻烦。

首先放上官网说明:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) MultiHead(Q,K,V)=Concat(head_1,…,head_h)W_O\quad where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) MultiHead(Q,K,V)=Concat(head1,,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)

# 模型初始化
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
'''
embed_dim – 嵌入向量总长度.num_heads – 并行的head数目,即同时做多少次不同语义的attention.dropout – dropout的概率.bias – 是否添加偏置.默认: True.
'''# 模型运算
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
'''
Inputs:query: (L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.key: (S, N, E) , where S is the source sequence length, N is the batch size, E is the embedding dimension.value: (S, N, E) where S is the source sequence length, N is the batch size, E is the embedding dimension.key_padding_mask: (N, S)(N,S) , ByteTensor, where N is the batch size, S is the source sequence length.attn_mask: 2D mask (L, S)(L,S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S)(N∗num_heads,L,S) where N is the batch size, L is the target sequence length, S is the source sequence length.Outputs:attn_output: (L, N, E)(L,N,E) where L is the target sequence length, N is the batch size, E is the embedding dimension.attn_output_weights: (N, L, S)(N,L,S) where N is the batch size, L is the target sequence length, S is the source sequence length.
'''

值得注意的一点是,query,key,value的输入形状一定是 [sequence_size, batch_size, emb_size]

官网例子:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

embed_dim, num_heads参数

但是我执行程序却报错了:

A = torch.arange(1,25).view(4,3,2)
A = A.float()self_attn = torch.nn.MultiheadAttention(embed_dim=2, num_heads=4, dropout=0.0)
res,weight = self_attn(A,A,A)

报错信息:

    self_attn = torch.nn.MultiheadAttention(embed_dim=2, num_heads=4, dropout=0.0)File "E:\Anaconda3\lib\site-packages\torch\nn\modules\activation.py", line 740, in __init__assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
AssertionError: embed_dim must be divisible by num_heads

切到源码看不懂,而且我用pycharm 一直ctrl+鼠标进入不了最底层的代码,只有前几层的代码。(求相关领域大佬教教我)

经过自己的尝试,nn.MultiheadAttention(embed_dim, num_heads)中的要满足两点约束:

  • embed_dim == input_dim ,即query,key,value的embedding_size必须等于embed_dim
  • embed_dim%num_heads==0

上面的约束,也就说明了在使用nn.MultiheadAttention(embed_dim, num_heads)时, num_heads不是我们想设多少就设定多少。

我的看法:

nn.MultiheadAttention(embed_dim, num_heads) 中的embed_dim 是输入的embeddingsize,即query输入形状(L, N, E)的E数值,nn.MultiheadAttention 想要实现的是无论head 数目设置成多少,输出的向量大小都是不变的。写完这句话,发现自己表达能力真是弱,自己都看不懂。可以结合下面公式理解。
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) Q ∈ R ( L , N , E ) , K ∈ R ( S , N , E ) , V ∈ R ( S , N , E ) , W i ∈ R ( E , E / h ) , W O ∈ R ( E , E ) MultiHead(Q,K,V)=Concat(head_1,…,head_h)W_O\quad where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)\\ Q\in R^{(L, N, E)},\ K\in R^{(S, N, E)},\ V\in R^{(S, N, E)},\ W_i\in R^{(E,E/h)},W_O\in R^{(E,E)} MultiHead(Q,K,V)=Concat(head1,,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)QR(L,N,E), KR(S,N,E), VR(S,N,E), WiR(E,E/h),WOR(E,E)

attn_mask参数

self-attention公式: SA ⁡ ( Q , K , V ) = s o f t m a x ( Q K T d k ) V \operatorname{SA}(Q, K, V)=softmax\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V SA(Q,K,V)=softmax(dk QKT)V

Q K T Q K^{T} QKT生成权重分布,但是在应用中有些位置的权重是不可见的,比如在时间序列中,第t天时,我们并不知道t+1天之后的信息。这时就需要传入attn_mask参数,屏蔽这些不合理的权重。attn_mask要求是booltensor,某个位置true表示掩盖该位置。

import torch
import torch.nn as nnA = torch.Tensor(5,2,4)
nn.init.xavier_normal_(A)
print(A)# tensor([[[ 0.3688,  0.0391,  0.2048, -0.0906],
#          [-0.0654,  0.1193, -0.1792,  0.0470]],
#
#         [[ 0.0812, -0.4180, -0.1353, -0.2670],
#          [ 0.0433,  0.1442,  0.1733,  0.0535]],
#
#         [[ 0.2352, -0.3314, -0.0238,  0.4116],
#          [ 0.1062,  0.5122,  0.1572, -0.2991]],
#
#         [[ 0.3381,  0.4004, -0.1936, -0.1553],
#          [-0.0168,  0.5914,  0.7389, -0.1740]],
#
#         [[ 0.0446, -0.1739, -0.2020,  0.2580],
#          [-0.0109,  0.0854,  0.2634, -0.4735]]])M = nn.MultiheadAttention(embed_dim=4, num_heads=2)
attention_mask = ~torch.tril(torch.ones([A.shape[0],A.shape[0]])).bool()
print(attention_mask)# tensor([[False,  True,  True,  True,  True],
#         [False, False,  True,  True,  True],
#         [False, False, False,  True,  True],
#         [False, False, False, False,  True],
#         [False, False, False, False, False]])attn_output, attn_output_weights=M(A,A,A, attn_mask=attention_mask)
print(attention_mask)# tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#          [0.5067, 0.4933, 0.0000, 0.0000, 0.0000],
#          [0.3350, 0.3276, 0.3374, 0.0000, 0.0000],
#          [0.2523, 0.2511, 0.2549, 0.2417, 0.0000],
#          [0.2004, 0.1962, 0.2039, 0.1981, 0.2013]],
# 
#         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#          [0.5025, 0.4975, 0.0000, 0.0000, 0.0000],
#          [0.3325, 0.3312, 0.3363, 0.0000, 0.0000],
#          [0.2535, 0.2429, 0.2633, 0.2404, 0.0000],
#          [0.2002, 0.1986, 0.2008, 0.1976, 0.2028]]], grad_fn=<DivBackward0>)

这篇关于PyTorch nn.MultiHead() 参数理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/648621

相关文章

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

一文详解PostgreSQL复制参数

《一文详解PostgreSQL复制参数》PostgreSQL作为一款功能强大的开源关系型数据库,其复制功能对于构建高可用性系统至关重要,本文给大家详细介绍了PostgreSQL的复制参数,需要的朋友可... 目录一、复制参数基础概念二、核心复制参数深度解析1. max_wal_seChina编程nders:WAL

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

Linux高并发场景下的网络参数调优实战指南

《Linux高并发场景下的网络参数调优实战指南》在高并发网络服务场景中,Linux内核的默认网络参数往往无法满足需求,导致性能瓶颈、连接超时甚至服务崩溃,本文基于真实案例分析,从参数解读、问题诊断到优... 目录一、问题背景:当并发连接遇上性能瓶颈1.1 案例环境1.2 初始参数分析二、深度诊断:连接状态与

spring IOC的理解之原理和实现过程

《springIOC的理解之原理和实现过程》:本文主要介绍springIOC的理解之原理和实现过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、IoC 核心概念二、核心原理1. 容器架构2. 核心组件3. 工作流程三、关键实现机制1. Bean生命周期2.

史上最全nginx详细参数配置

《史上最全nginx详细参数配置》Nginx是一个轻量级高性能的HTTP和反向代理服务器,同时也是一个通用代理服务器(TCP/UDP/IMAP/POP3/SMTP),最初由俄罗斯人IgorSyso... 目录基本命令默认配置搭建站点根据文件类型设置过期时间禁止文件缓存防盗链静态文件压缩指定定错误页面跨域问题

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

SpringBoot请求参数接收控制指南分享

《SpringBoot请求参数接收控制指南分享》:本文主要介绍SpringBoot请求参数接收控制指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring Boot 请求参数接收控制指南1. 概述2. 有注解时参数接收方式对比3. 无注解时接收参数默认位置