Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络

本文主要是介绍Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文是对Transfomer重要模块的源码解析,完整笔记链接点这里!

缩放点积自注意力 (Scaled Dot-Product Attention)

缩放点积自注意力是一种自注意力机制,它通过查询(Query)、键(Key)和值(Value)的关系来计算注意力权重。该机制的核心在于先计算查询和所有键的点积,然后进行缩放处理,应用softmax函数得到最终的注意力权重,最后用这些权重对值进行加权求和。

源码解析:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()self.temperature = temperature  # 温度参数,用于缩放点积self.dropout = nn.Dropout(attn_dropout)  # Dropout层def forward(self, q, k, v, mask=None):attn = torch.matmul(q / self.temperature, k.transpose(2, 3))  # 计算缩放后的点积if mask is not None:attn = attn.masked_fill(mask == 0, -1e9)  # 掩码操作,将需要忽略的位置设置为一个非常小的值attn = self.dropout(F.softmax(attn, dim=-1))  # 应用softmax函数并进行dropoutoutput = torch.matmul(attn, v)  # 使用注意力权重对值(v)进行加权求和return output, attn
  • __init__ 方法中的 temperature 参数用于缩放点积,通常设置为键(Key)维度的平方根。attn_dropout 是在应用softmax函数后进行dropout的比例。
  • forward 方法计算缩放点积自注意力。首先,它计算查询(q)和键(k)的点积,并通过除以 temperature 进行缩放。如果提供了 mask,则会使用 masked_fill 将掩码位置的注意力权重设为一个非常小的负数(这里是 -1e9),使得softmax后这些位置的权重接近于0。之后,应用dropout和softmax函数得到最终的注意力权重。最后,使用这些权重对值(v)进行加权求和得到输出。

多头注意力 (Multi-Head Attention)

多头注意力通过将输入分割成多个头,让每个头在不同的子空间表示上计算注意力,然后将这些头的输出合并。这样做可以让模型在多个子空间中捕获丰富的信息。

源码解析:
import torch.nn as nn
import torch.nn.functional as F
from transformer.Modules import ScaledDotProductAttentionclass MultiHeadAttention(nn.Module):''' Multi-Head Attention module '''def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):super().__init__()self.n_head = n_head  # 头的数量self.d_k = d_k  # 键/查询的维度self.d_v = d_v  # 值的维度self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)  # 查询的线性变换self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)  # 键的线性变换self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)  # 值的线性变换self.fc = nn.Linear(n_head * d_v, d_model, bias=False)  # 输出的线性变换self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)  # 缩放点积注意力模块self.dropout = nn.Dropout(dropout)  # Dropout层self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)  # 层归一化def forward(self, q, k, v, mask=None):# 保存输入以便后面进行残差连接residual = q# 线性变换并重塑以准备多头计算q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)# 转置以将头维度提前,便于并行计算q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)# 如果存在掩码,则扩展掩码以适应头维度if mask is not None:mask = mask.unsqueeze(1)   # 为头维度广播掩码# 调用缩放点积注意力模块q, attn = self.attention(q, k, v, mask=mask)# 转置并重塑以合并多头q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)# 应用线性变换和dropoutq = self.dropout(self.fc(q))# 添加残差连接并进行层归一化q += residualq = self.layer_norm(q)# 返回多头注意力的输出和注意力权重return q, attn
  • __init__ 方法初始化了多头注意力的参数,包括头的数量 n_head,查询/键/值的维度 d_kd_v,以及线性层 w_qsw_ksw_vsfc
  • forward 方法首先将输入 qkv 通过线性层映射到多头的维度,然后重塑并转置以便进行并行计算。如果存在掩码,它会被扩展以适应头维度。调用缩放点积注意力模块计算注意力,之后合并多头输出,并应用线性变换和dropout。最后,添加残差连接和层归一化。

前馈网络 (Positionwise FeedForward)

前馈网络(FFN)在自注意力层之后应用,用于进行非线性变换,增加模型的复杂度和表达能力。

源码解析:
import torch.nn as nn
import torch.nn.functional as Fclass PositionwiseFeedForward(nn.Module):''' A two-feed-forward-layer module '''def __init__(self, d_in, d_hid, dropout=0.1):super().__init__()self.w_1 = nn.Linear(d_in, d_hid)  # 第一个线性层self.w_2 = nn.Linear(d_hid, d_in)  # 第二个线性层self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)  # 层归一化self.dropout = nn.Dropout(dropout)  # Dropout层def forward(self, x):# 保存输入以便后面进行残差连接residual = x# 通过第一个线性层,然后应用ReLU激活函数x = self.w_1(x)x = F.relu(x)# 通过第二个线性层x = self.w_2(x)# 应用dropoutx = self.dropout(x)# 添加残差连接并进行层归一化x += residualx = self.layer_norm(x)# 返回输出return x
  • __init__ 方法初始化了两个线性层 w_1w_2,层归一化 layer_norm,以及dropout层。
  • forward 方法首先通过第一个线性层和ReLU激活函数,然后通过第二个线性层。应用dropout层后,添加残差连接并进行层归一化。

这篇关于Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/541623

相关文章

Spring三级缓存解决循环依赖的解析过程

《Spring三级缓存解决循环依赖的解析过程》:本文主要介绍Spring三级缓存解决循环依赖的解析过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、循环依赖场景二、三级缓存定义三、解决流程(以ServiceA和ServiceB为例)四、关键机制详解五、设计约

Redis实现分布式锁全解析之从原理到实践过程

《Redis实现分布式锁全解析之从原理到实践过程》:本文主要介绍Redis实现分布式锁全解析之从原理到实践过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、背景介绍二、解决方案(一)使用 SETNX 命令(二)设置锁的过期时间(三)解决锁的误删问题(四)Re

Android实现一键录屏功能(附源码)

《Android实现一键录屏功能(附源码)》在Android5.0及以上版本,系统提供了MediaProjectionAPI,允许应用在用户授权下录制屏幕内容并输出到视频文件,所以本文将基于此实现一个... 目录一、项目介绍二、相关技术与原理三、系统权限与用户授权四、项目架构与流程五、环境配置与依赖六、完整

Android实现定时任务的几种方式汇总(附源码)

《Android实现定时任务的几种方式汇总(附源码)》在Android应用中,定时任务(ScheduledTask)的需求几乎无处不在:从定时刷新数据、定时备份、定时推送通知,到夜间静默下载、循环执行... 目录一、项目介绍1. 背景与意义二、相关基础知识与系统约束三、方案一:Handler.postDel

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

Android使用ImageView.ScaleType实现图片的缩放与裁剪功能

《Android使用ImageView.ScaleType实现图片的缩放与裁剪功能》ImageView是最常用的控件之一,它用于展示各种类型的图片,为了能够根据需求调整图片的显示效果,Android提... 目录什么是 ImageView.ScaleType?FIT_XYFIT_STARTFIT_CENTE

Golang HashMap实现原理解析

《GolangHashMap实现原理解析》HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持高效的插入、查找和删除操作,:本文主要介绍GolangH... 目录HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Python利用ElementTree实现快速解析XML文件

《Python利用ElementTree实现快速解析XML文件》ElementTree是Python标准库的一部分,而且是Python标准库中用于解析和操作XML数据的模块,下面小编就来和大家详细讲讲... 目录一、XML文件解析到底有多重要二、ElementTree快速入门1. 加载XML的两种方式2.

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组