LLM - Transformer Multi-Head Attention 维度变化与源码详解

2024-02-21 09:28

本文主要是介绍LLM - Transformer Multi-Head Attention 维度变化与源码详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一.引言

前面我们基于 LLM 大模型源码介绍了 Causal Mask 以及 ROPE 旋转位置编码的实现,本文介绍源码中 Transformer 的实现流程,我们基于代码逐行分析维度变化与代码含义,希望能够清晰的了解 LLM 中 Transformer 运行的流程。

二.Transformer 分层维度

上面这个 Transformer 的基础结构我们在之前已经提到过很多次,这里结合维度变化再啰嗦一次,更详细的介绍可以参考: LLM - Transformer && LLaMA2 结构分析与 LoRA 详解。

1.单条样本

- Embedding Layer

对于一个典型的 LLM 大模型,输入 Embedding 层的维度 d_model 通常指的是将输入的标记 token 通过一个 embedding 层映射转换为连续向量的维度。例如,在 BERT-base 模型中,d_model 是 768,而在当下大模型中 d_model 为 8192。

- Transformer Layer

Transformer 层的输出维度通常和输入 Embedding 层的维度一致,即 d_model。如果我们持续使用 BERT-base 的例子,那么每个 Transformer 层 [ BERT中称为encoder层,LLM 中多为 decoder 层 ] 的输出也将是维度为 768 / 8192 的向量。

- lm_head Layer 

最后的 lm_head(语言模型头)的维度通常等于词汇表的大小 vocab_size,因为 lm_head 的作用是将 Transformer 层的输出转换成每个词汇的概率分布。举例来说,如果模型处理的语言的词汇表大小为 30000 个单词,那么 lm_head 的输出维度就是 30000。

- hidden_states

hidden_states 是 Transformer 模型处理过程中的一个术语,常见于模型的中间输出和内部分析。其记录了隐层的激活值,对于每个输入标记 token,Transformer 的每个层都会有一个输出向量,它表示的是在该层的特定深度上输入的表示。对于一个 N 层堆叠的 Transformer 模型,对于一个给定的输入序列,模型将会有 N 个这样的隐藏状态集。其中每个隐藏状态也会包含注意力分布,这是 Transformer 的自注意力机制的一个关键组成部分,它允许模型在处理输入时衡量不同部分之间的相互依赖性。

Tips:

假设我们有一个 BERT-base 模型,它使用 12 层 Transformer,每层的输出维度为 768,若输入一个有 5 个 tokens 的序列,每个 token 会首先被转换成一个 768 维的 embedding 向量。因此,hidden_states 在模型刚开始时会是一个形状为 (5, 768) 的张量。经过 12 层 Transformer 层处理后最后输出的 hidden_states 将会是一个形状为 (12, 5, 768) 的 3 维张量,其中包含了序列中每个token 在各个层上的表征。

2.批次样本

上面给出了单条样本的转换流程,下面我们分析下 batch_size 情况下维度的变换流程。假设我们有一个 BERT-base 模型:

词汇表大小 vocab_size = 30000

嵌入层维度 d_model = 768

堆叠层数量 N = 12

最大序列长度 max_seq_length = 128

批次大小 batch_size = 32

以下是数据通过模型时维度的具体变化过程:

- Input Layer

输入层维度为 (batch_size, max_seq_length) 即 (32, 128),每一个 128 的张量表示批次中每个序列的 token_id,即 text 通过 tokenizer 处理后的结果。 

- Embedding Layer

(bsz, max_seq_length) 的整数张量会被送入 Embedding 层,以 Bert 为例,其会被映射到 (bsz, max_seq_length, d_model) 的维度,即 (32, 128, 768)。这表明我们现在有 32 条样本,每个序列有 128 个 768 维的词嵌入向量。

- Transformer Layer

每个 Transformer 层接受一个 (bsz, max_seq_length, d_model) 的张量,经过 multi_head_attention 后输出一个相同形状的张量,这是因为 transformer 层通常会保持输入输出的维度相同,因此经过本层映射后,维度依然为 (bsz, max_seq_length, d_model) 即 (32, 128, 768)。

- lm_head Layer

lm_head 线性层将 Transformer 层的输出 (bsz, max_seq_length, d_model) 转换为 (bsz, max_seq_length, vocab_size) 的张量,即 (32, 128, 30000)。这一层一般是通过 Linear 实现的,对于复杂的 LLM,还会有 MLP 层,但最终 lm_head 的目的都是将 d_model 映射到 vocab_size,即生成一个与词汇表大小匹配的权重矩阵,代表每个 token 可能性的分布。

Tips:

如果考虑中间的 hidden_states,那么对于序列中的每个 token,在每个 Transformer 层中,我们都会得到一个 768 维的向量。因此,对于整个 batch 来说,每一层的 hidden_states 的形状为(batch_size, max_seq_length, d_model),即 (32, 128, 768)。如果我们保存所有层的hidden_states,那么我们就得到了一个形状为 (num_layers, batch_size, max_seq_length, d_model) 的 4 维张量,即 (12, 32, 128, 768),这里 num_layers 就是前面提到的 N,即 LLM 中 transformer 层堆叠的数量,这样,你就可以看到不同维度如何随着数据流通过模型而变化。这里需要注意的是真实情况下由于序列化长度可能不同,还会涉及到填充 padding 和掩码 masking 来确保批量处理是有效的,然而这并不影响上述维度变化的基本流程。

三.Transformer 维度变换

为了大家可以在本机 debug 快速测试,下面的示例我们以 Bert 及其 tokenizer 作为基模型构建 token_id 以及 Embedding,后续的 Multi-Head Attention 我们基于 Qwen 的逻辑进行了迁移,保持主体实现风格不变,更完整的代码可以参考 HF 上 modeling.py。

1.Input Layer 

输入层以及嵌入层我们通过 Bert 模型的 tokenizer 获取:

#!/usr/bin/python
# -*- coding: UTF-8 -*-import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerif __name__ == '__main__':tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')pretrained_bert = BertModel.from_pretrained('bert-base-uncased')input_texts = ["This is a test sentence.", "Here is another test sentence."]input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,return_tensors='pt') for text in input_texts]input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimension

为了方便我们 input_texts 构造两条样本,所以 bsz = 2、max_length = 10、d_model = 768,input_ids 维度为 (10, ):

通过 concat 得到 (bsz, max_length) = (2, 10) 的初始维度:

tensor([[ 101, 2023, 2003, 1037, 3231, 6251, 1012,  102,    0,    0],[ 101, 2182, 2003, 2178, 3231, 6251, 1012,  102,    0,    0]])

2.Embedding Layer

    with torch.no_grad():embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT modelprint(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)

这里通过 bert 的 Embedding 层获取 input_id 对应的 Embedding,由于 d_model = 768,所以前面 token_id 的 (bsz, max_length) 转换为 (bsz, max_length, d_model) 即 (2, 10, 768):

tensor([[[-3.7545e-02,  5.3234e-04, -1.3553e-02,  ..., -1.9545e-01,2.3569e-01,  4.7479e-01],[-7.1746e-01, -2.8763e-01,  1.4100e-01,  ..., -5.5593e-01,6.1830e-01,  3.9255e-01],[-1.9318e-01, -4.0202e-01,  3.2924e-01,  ..., -1.5206e-01,3.4014e-01,  1.0233e+00],...,[ 1.5273e-01,  1.1651e-01,  1.5754e-01,  ...,  6.9833e-02,-8.5732e-01, -4.3875e-02],[ 7.0679e-02, -2.3521e-01,  6.1713e-01,  ..., -7.3852e-02,2.5070e-01, -6.3240e-02],[-1.3249e-01, -3.6026e-01,  3.5025e-01,  ..., -5.5981e-02,1.0420e-01, -4.3954e-01]],[[-2.9592e-02, -1.4164e-01, -2.2295e-03,  ..., -1.3087e-01,2.9421e-01,  5.5132e-01],[-1.0146e+00, -6.8757e-01,  1.9959e-01,  ..., -4.2000e-01,1.7332e-01,  9.2754e-02],[-1.3425e-01, -8.1044e-01,  2.6674e-01,  ...,  4.6978e-02,-1.0026e-01,  4.5293e-01],...,[ 4.5527e-01,  2.2234e-02, -3.6816e-01,  ...,  4.3154e-01,-8.6396e-01, -2.8542e-01],[ 1.4188e-01, -2.4001e-01,  6.5681e-01,  ..., -5.7224e-02,3.1025e-01, -9.0286e-02],[-3.9205e-02, -3.2815e-01,  4.7910e-01,  ..., -4.7641e-02,2.9916e-02, -4.5328e-01]]])

3.Multi-Head Attention

        embed_dim = embedded_output.size(-1)num_heads = 4model = BITDDDAttention(embed_dim, num_heads)output = model(embedded_output)print(output.size())  # Output shape should be (2, 10, embed_dim)

本层我们从 LLM modeling.py 中将 Atention 的核心部分迁移到 BITDDDAtention Class 中:

class BITDDDAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(BITDDDAttention, self).__init__()self.embed_dim = embed_dim  # embedding 维度self.num_heads = num_heads  # head 数量assert embed_dim % num_heads == 0self.head_dim = embed_dim // num_heads# 构建 Q/K/V 向量以及最后的全连接 MLPself.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, _ = x.size()# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# Compute the attention scoresattention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5attention_probs = torch.softmax(attention_scores, dim=-1)# Apply the attention weights to the valueattention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)# Apply a linear layer to the outputx = self.fc_out(attention_output)return x

下面我们逐行看下 Mutil-Head Attention 的执行流程与维度变化:

- Size

batch_size, seq_len, _ = x.size()

这一步解析 Attention 层输入的 batch 样本的 bsz、seq_len,由于 init 方法中已经给出了 emd_dim,所以这里使用 '_' 忽略。

- Q/K/V 获取

# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

self.query、key、value 都是 nn.Linear(embed_dim, embed_dim) 的线性转换层,Q/K/V 的处理逻辑是相同的,这里通过 view 即 resize 方法将线性转换后的向量 (bsz, seq_len, embed_dim) 转换为 (bsz, seq_len, num_heads, head_dim),最后通过 permute 交换位置得到 (bsz, num_heads, seq_len, head_dim) 的输出向量,用于后续 multi-head 的计算。这里通过 assert 判断是否整除:

self.head_dim = embed_dim // num_heads

根据上面 init 给出的 heads 以及 embed_dim,可以得到最终维度为: (2, 4, 10, 192):

tensor([[[[ 0.0185, -0.1872,  0.1827,  ..., -0.7914, -0.0074, -0.6228],[-0.1390,  0.4675,  0.0325,  ...,  0.0187,  0.0912, -0.2692],[-0.1342,  0.2904, -0.2637,  ...,  0.1130, -0.0226, -0.3510],...,[-0.1151, -0.2627, -0.6453,  ...,  0.4885, -0.1982, -0.1538],[-0.1281,  0.2321, -0.0815,  ..., -0.1740,  0.4909, -0.1373],[-0.1364,  0.2844, -0.0728,  ..., -0.0620,  0.3605, -0.2292]],[[ 0.3641,  0.1707, -0.0567,  ...,  0.0267,  0.3272,  0.1560],[-0.1206,  0.6853,  0.0990,  ..., -0.0875,  0.2414,  0.5490],[-0.4080,  0.0679,  0.3174,  ...,  0.0970, -0.0127,  0.1664],...,[-0.2878,  0.2856,  0.0777,  ..., -0.0791,  0.0847,  0.0545],[ 0.2381, -0.1032,  0.2887,  ...,  0.2219,  0.2837,  0.0345],[ 0.1421, -0.0956,  0.1983,  ...,  0.1784,  0.1827,  0.0776]],[[-0.2031, -0.2496, -0.0072,  ..., -0.1553, -0.0441,  0.0200],[-0.2028, -0.4097,  0.1779,  ...,  0.0333, -0.4005, -0.3453],[ 0.0926, -0.1818,  0.0492,  ...,  0.3059, -0.6175, -0.2858],...,[ 0.3494, -0.4813,  0.7086,  ...,  0.6181,  0.1515, -0.1279],[-0.0542,  0.3148,  0.0172,  ...,  0.0037, -0.2878, -0.1582],[-0.1381,  0.2450,  0.0490,  ..., -0.0824, -0.2504, -0.2464]],[[ 0.6905, -0.1202,  0.6489,  ...,  0.6069,  0.2634, -0.0595],[ 0.3937, -0.2795,  0.7692,  ...,  0.1321, -0.0240, -0.1484],[ 0.2260, -0.4332,  0.4651,  ..., -0.1797, -0.1127, -0.3294],...,[ 0.0168, -0.2892,  0.4032,  ..., -0.4515,  0.3833, -0.7699],[ 0.1970, -0.3264,  0.4196,  ...,  0.3044, -0.0819, -0.2083],[ 0.2492, -0.3419,  0.5813,  ...,  0.1855, -0.2431, -0.1149]]],[[[ 0.0225, -0.2359,  0.0754,  ..., -0.7577,  0.0936, -0.6233],[ 0.0479,  0.5459, -0.3047,  ..., -0.3134,  0.0416,  0.0397],[ 0.1172,  0.2506, -0.5461,  ...,  0.1287, -0.0441, -0.2074],...,[-0.3586, -0.3827, -0.6436,  ...,  0.3915, -0.2485, -0.1576],[-0.1502,  0.1852, -0.1007,  ..., -0.1310,  0.5079, -0.1868],[-0.1622,  0.2055, -0.1428,  ..., -0.0887,  0.3516, -0.2383]],[[ 0.4338,  0.2326, -0.0661,  ...,  0.0309,  0.3088,  0.1711],[-0.4011,  0.9250,  0.2983,  ..., -0.4108,  0.4223,  0.6880],[-0.2721,  0.4383,  0.6376,  ..., -0.0888, -0.0647, -0.0073],...,[ 0.1742,  0.2020, -0.1020,  ..., -0.1444,  0.2459,  0.1079],[ 0.2608, -0.0978,  0.2557,  ...,  0.2132,  0.2125,  0.0010],[ 0.1041, -0.1335,  0.1523,  ...,  0.1797,  0.1323,  0.0036]],[[-0.1826, -0.2200, -0.0026,  ..., -0.1664, -0.0773,  0.0607],[-0.1257, -0.2642,  0.6933,  ...,  0.4202, -0.1153, -0.3960],[-0.1353, -0.4837,  0.3527,  ...,  0.3592, -0.5616, -0.3685],...,[ 0.6056, -0.3298,  0.7872,  ...,  0.3984,  0.4775,  0.2213],[-0.1211,  0.3394, -0.0247,  ...,  0.0251, -0.3108, -0.1656],[-0.1572,  0.3040,  0.0164,  ..., -0.1026, -0.2737, -0.2175]],[[ 0.7236, -0.1187,  0.6491,  ...,  0.6230,  0.2401, -0.0061],[-0.0402, -0.0318,  0.7717,  ..., -0.0389,  0.1465, -0.3047],[ 0.2734, -0.4473,  0.6278,  ..., -0.3827, -0.0412, -0.7133],...,[-0.0418,  0.0670,  0.1462,  ..., -0.6109,  0.4838, -0.4277],[ 0.2340, -0.3250,  0.4256,  ...,  0.3217, -0.0688, -0.1837],[ 0.1940, -0.2878,  0.5281,  ...,  0.2155, -0.1810, -0.0649]]]])

- Attention Score 计算

# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)

Attention Score 的计算依赖于 Q/K,这里把 key 的维度通过 permute 做了转换,由  (bsz, num_heads, seq_len, head_dim) 变换为  (bsz, num_heads, head_dim, seq_len),matmul 相乘后得到 attention_scores 的维度为 (bsz, num_heads, seq_len, seq_len) 即 (2, 4, 10, 10),除以 sqrt(head_dim) 是在应用 scale_dot 防止 matmul 的乘积过大,而最后 softmax(dim=-1) 则将 Attention Score 的最后一维的 10 个数字进行了归一化:

tensor([[[[0.1188, 0.1000, 0.0912, 0.0963, 0.0984, 0.0862, 0.1004, 0.0955,0.1040, 0.1093],[0.1077, 0.0949, 0.0963, 0.1035, 0.0961, 0.0940, 0.0946, 0.0890,0.1090, 0.1149],[0.0965, 0.1010, 0.0930, 0.0972, 0.1032, 0.1031, 0.0989, 0.0982,0.1062, 0.1026],[0.0932, 0.1033, 0.0970, 0.0977, 0.0961, 0.1050, 0.0990, 0.1082,0.1006, 0.0999],[0.0947, 0.1033, 0.0949, 0.0945, 0.0957, 0.1036, 0.0964, 0.0985,0.1083, 0.1102],[0.0941, 0.1026, 0.0939, 0.0942, 0.0953, 0.1008, 0.1001, 0.1089,0.1038, 0.1063],[0.1017, 0.1019, 0.1008, 0.0937, 0.1095, 0.1007, 0.0913, 0.0832,0.1092, 0.1079],[0.0926, 0.1121, 0.1009, 0.0991, 0.0955, 0.1017, 0.0964, 0.1030,0.1004, 0.0982],[0.1010, 0.1021, 0.0948, 0.0954, 0.0976, 0.1024, 0.0916, 0.1032,0.1049, 0.1071],[0.1046, 0.1077, 0.0932, 0.0948, 0.1006, 0.1002, 0.0934, 0.0983,0.1042, 0.1032]],......         [[0.1047, 0.0999, 0.1045, 0.1054, 0.0979, 0.1071, 0.0863, 0.0884,0.1012, 0.1046],[0.1019, 0.1113, 0.1010, 0.0990, 0.0981, 0.1060, 0.0872, 0.0915,0.1017, 0.1022],[0.0977, 0.0996, 0.0993, 0.1027, 0.0970, 0.0985, 0.0977, 0.0990,0.1069, 0.1018],[0.1042, 0.1125, 0.1049, 0.1022, 0.0981, 0.0950, 0.0864, 0.0876,0.1045, 0.1046],[0.0987, 0.1151, 0.1018, 0.0956, 0.0923, 0.0955, 0.0938, 0.0910,0.1073, 0.1087],[0.0985, 0.1143, 0.0936, 0.1029, 0.0954, 0.1028, 0.0857, 0.0901,0.1076, 0.1092],[0.0961, 0.0908, 0.1013, 0.1055, 0.0992, 0.1035, 0.0919, 0.0971,0.1090, 0.1056],[0.0874, 0.0935, 0.1012, 0.1057, 0.1044, 0.0968, 0.0936, 0.0950,0.1132, 0.1091],[0.1041, 0.1118, 0.0958, 0.0968, 0.0971, 0.1044, 0.0925, 0.0906,0.1028, 0.1041],[0.1028, 0.1123, 0.0971, 0.1005, 0.1013, 0.1020, 0.0896, 0.0894,0.1026, 0.1023]]]])

- Attention Output 

# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

Attention Probs 的维度为  ​​(2, 4, 10, 10) ,value 的维度为 (2, 4, 10, 192),相乘后得到 (2, 4, 10, 192) 即 (bsz, num_heads, seq_len, head_dim),通过 permute 转换为 (bsz, seq_len, num_heads, head_dim),再通过 view 将后两维 num_heads x head_dim 合并为 d_model,从而最终 attention_output 的维度为 (bsz, seq_len, d_model) 与原始 token_ids 通过 Embedding 层映射后的向量维度保持一致。

tensor([[[-0.0421,  0.0127, -0.3383,  ...,  0.1617, -0.2079, -0.3181],[-0.0401,  0.0172, -0.3488,  ...,  0.1581, -0.2098, -0.3260],[-0.0397,  0.0130, -0.3420,  ...,  0.1581, -0.2081, -0.3266],...,[-0.0428,  0.0137, -0.3372,  ...,  0.1589, -0.2145, -0.3305],[-0.0424,  0.0120, -0.3419,  ...,  0.1574, -0.2198, -0.3344],[-0.0426,  0.0140, -0.3463,  ...,  0.1562, -0.2191, -0.3338]],[[ 0.0502,  0.0926, -0.3300,  ...,  0.1376, -0.1264, -0.3689],[ 0.0369,  0.0917, -0.3339,  ...,  0.1439, -0.1089, -0.3571],[ 0.0419,  0.0915, -0.3328,  ...,  0.1480, -0.1168, -0.3654],...,[ 0.0438,  0.0946, -0.3290,  ...,  0.1435, -0.1302, -0.3702],[ 0.0417,  0.0898, -0.3281,  ...,  0.1358, -0.1374, -0.3759],[ 0.0428,  0.0906, -0.3306,  ...,  0.1345, -0.1330, -0.3752]]])

- Linear 浅层 MLP

# Apply a linear layer to the output
x = self.fc_out(attention_output)

fc_out 的维度是 nn.Linear(embed_dim, embed_dim),所有 attention_output 经过处理后 (bsz, seq_len, d_model) x (d_model, d_model) = (bsz, seq_len, d_model)。

tensor([[[ 1.0228e-01,  1.6250e-01, -1.4914e-01,  ..., -1.7511e-01,-2.1751e-03, -2.0877e-02],[ 9.9930e-02,  1.6427e-01, -1.4394e-01,  ..., -1.7894e-01,1.9605e-03, -2.4290e-02],[ 1.0188e-01,  1.6577e-01, -1.4313e-01,  ..., -1.7274e-01,5.3616e-03, -1.8874e-02],...,[ 1.0584e-01,  1.6541e-01, -1.4315e-01,  ..., -1.7077e-01,-4.8522e-04, -2.2207e-02],[ 1.0028e-01,  1.6638e-01, -1.3908e-01,  ..., -1.7138e-01,-4.0303e-05, -2.2604e-02],[ 1.0054e-01,  1.6448e-01, -1.4135e-01,  ..., -1.7086e-01,2.8514e-03, -1.9951e-02]],[[ 4.9912e-02,  1.3306e-01, -1.2705e-01,  ..., -1.2117e-01,3.5498e-02,  3.8191e-03],[ 4.8556e-02,  1.3361e-01, -1.2207e-01,  ..., -1.2270e-01,3.7410e-02, -3.3710e-03],[ 4.9592e-02,  1.3507e-01, -1.2446e-01,  ..., -1.2247e-01,4.3996e-02,  2.0591e-03],...,[ 5.2688e-02,  1.3105e-01, -1.2519e-01,  ..., -1.1373e-01,3.7038e-02,  2.5118e-03],[ 4.8786e-02,  1.3443e-01, -1.1793e-01,  ..., -1.1811e-01,3.4455e-02,  3.0611e-04],[ 4.7252e-02,  1.3401e-01, -1.1889e-01,  ..., -1.1601e-01,3.6708e-02,  2.7476e-03]]])

4.完整代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerclass BITDDDAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(BITDDDAttention, self).__init__()self.embed_dim = embed_dim  # embedding 维度self.num_heads = num_heads  # head 数量assert embed_dim % num_heads == 0self.head_dim = embed_dim // num_heads# 构建 Q/K/V 向量以及最后的全连接 MLPself.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, _ = x.size()# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# Compute the attention scoresattention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5attention_probs = torch.softmax(attention_scores, dim=-1)# Apply the attention weights to the valueattention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)# Apply a linear layer to the outputx = self.fc_out(attention_output)return xif __name__ == '__main__':tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')pretrained_bert = BertModel.from_pretrained('bert-base-uncased')input_texts = ["This is a test sentence.", "Here is another test sentence."]input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,return_tensors='pt') for text in input_texts]input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimensionwith torch.no_grad():embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT modelprint(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)embed_dim = embedded_output.size(-1)num_heads = 4model = BITDDDAttention(embed_dim, num_heads)output = model(embedded_output)print(output.size())  # Output shape should be (2, 10, embed_dim)

四.总结

上述代码可以在本地 CPU/GPU 环境跑起来,大家可以自己打断点熟悉整个过程维度的变化,计算的流程,Multi-Head Attention 分多个 head 计算不同 token 的注意力权重并加权求和,对于 Decoder-Only 的架构,其还会添加 Causal Mask 保证前面的文字看不到后面的文字。本文先介绍到 Transformer 的输出,后续我们介绍如何通过 Transformer 最后一层 lm_head 的输出计算 next_token 的概率并计算交叉熵 loss。

这篇关于LLM - Transformer Multi-Head Attention 维度变化与源码详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/731293

相关文章

Java内存分配与JVM参数详解(推荐)

《Java内存分配与JVM参数详解(推荐)》本文详解JVM内存结构与参数调整,涵盖堆分代、元空间、GC选择及优化策略,帮助开发者提升性能、避免内存泄漏,本文给大家介绍Java内存分配与JVM参数详解,... 目录引言JVM内存结构JVM参数概述堆内存分配年轻代与老年代调整堆内存大小调整年轻代与老年代比例元空

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

mysql表操作与查询功能详解

《mysql表操作与查询功能详解》本文系统讲解MySQL表操作与查询,涵盖创建、修改、复制表语法,基本查询结构及WHERE、GROUPBY等子句,本文结合实例代码给大家介绍的非常详细,感兴趣的朋友跟随... 目录01.表的操作1.1表操作概览1.2创建表1.3修改表1.4复制表02.基本查询操作2.1 SE

MySQL中的锁机制详解之全局锁,表级锁,行级锁

《MySQL中的锁机制详解之全局锁,表级锁,行级锁》MySQL锁机制通过全局、表级、行级锁控制并发,保障数据一致性与隔离性,全局锁适用于全库备份,表级锁适合读多写少场景,行级锁(InnoDB)实现高并... 目录一、锁机制基础:从并发问题到锁分类1.1 并发访问的三大问题1.2 锁的核心作用1.3 锁粒度分

MySQL数据库中ENUM的用法是什么详解

《MySQL数据库中ENUM的用法是什么详解》ENUM是一个字符串对象,用于指定一组预定义的值,并可在创建表时使用,下面:本文主要介绍MySQL数据库中ENUM的用法是什么的相关资料,文中通过代码... 目录mysql 中 ENUM 的用法一、ENUM 的定义与语法二、ENUM 的特点三、ENUM 的用法1

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

一文详解Git中分支本地和远程删除的方法

《一文详解Git中分支本地和远程删除的方法》在使用Git进行版本控制的过程中,我们会创建多个分支来进行不同功能的开发,这就容易涉及到如何正确地删除本地分支和远程分支,下面我们就来看看相关的实现方法吧... 目录技术背景实现步骤删除本地分支删除远程www.chinasem.cn分支同步删除信息到其他机器示例步骤

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

mysql中的服务器架构详解

《mysql中的服务器架构详解》:本文主要介绍mysql中的服务器架构,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、mysql服务器架构解释3、总结1、背景简单理解一下mysqphpl的服务器架构。2、mysjsql服务器架构解释mysql的架

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹