Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络

本文主要是介绍Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文是对Transfomer重要模块的源码解析,完整笔记链接点这里!

缩放点积自注意力 (Scaled Dot-Product Attention)

缩放点积自注意力是一种自注意力机制,它通过查询(Query)、键(Key)和值(Value)的关系来计算注意力权重。该机制的核心在于先计算查询和所有键的点积,然后进行缩放处理,应用softmax函数得到最终的注意力权重,最后用这些权重对值进行加权求和。

源码解析:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()self.temperature = temperature  # 温度参数,用于缩放点积self.dropout = nn.Dropout(attn_dropout)  # Dropout层def forward(self, q, k, v, mask=None):attn = torch.matmul(q / self.temperature, k.transpose(2, 3))  # 计算缩放后的点积if mask is not None:attn = attn.masked_fill(mask == 0, -1e9)  # 掩码操作,将需要忽略的位置设置为一个非常小的值attn = self.dropout(F.softmax(attn, dim=-1))  # 应用softmax函数并进行dropoutoutput = torch.matmul(attn, v)  # 使用注意力权重对值(v)进行加权求和return output, attn
  • __init__ 方法中的 temperature 参数用于缩放点积,通常设置为键(Key)维度的平方根。attn_dropout 是在应用softmax函数后进行dropout的比例。
  • forward 方法计算缩放点积自注意力。首先,它计算查询(q)和键(k)的点积,并通过除以 temperature 进行缩放。如果提供了 mask,则会使用 masked_fill 将掩码位置的注意力权重设为一个非常小的负数(这里是 -1e9),使得softmax后这些位置的权重接近于0。之后,应用dropout和softmax函数得到最终的注意力权重。最后,使用这些权重对值(v)进行加权求和得到输出。

多头注意力 (Multi-Head Attention)

多头注意力通过将输入分割成多个头,让每个头在不同的子空间表示上计算注意力,然后将这些头的输出合并。这样做可以让模型在多个子空间中捕获丰富的信息。

源码解析:
import torch.nn as nn
import torch.nn.functional as F
from transformer.Modules import ScaledDotProductAttentionclass MultiHeadAttention(nn.Module):''' Multi-Head Attention module '''def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):super().__init__()self.n_head = n_head  # 头的数量self.d_k = d_k  # 键/查询的维度self.d_v = d_v  # 值的维度self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)  # 查询的线性变换self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)  # 键的线性变换self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)  # 值的线性变换self.fc = nn.Linear(n_head * d_v, d_model, bias=False)  # 输出的线性变换self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)  # 缩放点积注意力模块self.dropout = nn.Dropout(dropout)  # Dropout层self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)  # 层归一化def forward(self, q, k, v, mask=None):# 保存输入以便后面进行残差连接residual = q# 线性变换并重塑以准备多头计算q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)# 转置以将头维度提前,便于并行计算q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)# 如果存在掩码,则扩展掩码以适应头维度if mask is not None:mask = mask.unsqueeze(1)   # 为头维度广播掩码# 调用缩放点积注意力模块q, attn = self.attention(q, k, v, mask=mask)# 转置并重塑以合并多头q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)# 应用线性变换和dropoutq = self.dropout(self.fc(q))# 添加残差连接并进行层归一化q += residualq = self.layer_norm(q)# 返回多头注意力的输出和注意力权重return q, attn
  • __init__ 方法初始化了多头注意力的参数,包括头的数量 n_head,查询/键/值的维度 d_kd_v,以及线性层 w_qsw_ksw_vsfc
  • forward 方法首先将输入 qkv 通过线性层映射到多头的维度,然后重塑并转置以便进行并行计算。如果存在掩码,它会被扩展以适应头维度。调用缩放点积注意力模块计算注意力,之后合并多头输出,并应用线性变换和dropout。最后,添加残差连接和层归一化。

前馈网络 (Positionwise FeedForward)

前馈网络(FFN)在自注意力层之后应用,用于进行非线性变换,增加模型的复杂度和表达能力。

源码解析:
import torch.nn as nn
import torch.nn.functional as Fclass PositionwiseFeedForward(nn.Module):''' A two-feed-forward-layer module '''def __init__(self, d_in, d_hid, dropout=0.1):super().__init__()self.w_1 = nn.Linear(d_in, d_hid)  # 第一个线性层self.w_2 = nn.Linear(d_hid, d_in)  # 第二个线性层self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)  # 层归一化self.dropout = nn.Dropout(dropout)  # Dropout层def forward(self, x):# 保存输入以便后面进行残差连接residual = x# 通过第一个线性层,然后应用ReLU激活函数x = self.w_1(x)x = F.relu(x)# 通过第二个线性层x = self.w_2(x)# 应用dropoutx = self.dropout(x)# 添加残差连接并进行层归一化x += residualx = self.layer_norm(x)# 返回输出return x
  • __init__ 方法初始化了两个线性层 w_1w_2,层归一化 layer_norm,以及dropout层。
  • forward 方法首先通过第一个线性层和ReLU激活函数,然后通过第二个线性层。应用dropout层后,添加残差连接并进行层归一化。

这篇关于Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/541623

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

一文解析C#中的StringSplitOptions枚举

《一文解析C#中的StringSplitOptions枚举》StringSplitOptions是C#中的一个枚举类型,用于控制string.Split()方法分割字符串时的行为,核心作用是处理分割后... 目录C#的StringSplitOptions枚举1.StringSplitOptions枚举的常用

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

MyBatis延迟加载与多级缓存全解析

《MyBatis延迟加载与多级缓存全解析》文章介绍MyBatis的延迟加载与多级缓存机制,延迟加载按需加载关联数据提升性能,一级缓存会话级默认开启,二级缓存工厂级支持跨会话共享,增删改操作会清空对应缓... 目录MyBATis延迟加载策略一对多示例一对多示例MyBatis框架的缓存一级缓存二级缓存MyBat

前端缓存策略的自解方案全解析

《前端缓存策略的自解方案全解析》缓存从来都是前端的一个痛点,很多前端搞不清楚缓存到底是何物,:本文主要介绍前端缓存的自解方案,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、为什么“清缓存”成了技术圈的梗二、先给缓存“把个脉”:浏览器到底缓存了谁?三、设计思路:把“发版”做成“自愈”四、代码

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工

Java JDK Validation 注解解析与使用方法验证

《JavaJDKValidation注解解析与使用方法验证》JakartaValidation提供了一种声明式、标准化的方式来验证Java对象,与框架无关,可以方便地集成到各种Java应用中,... 目录核心概念1. 主要注解基本约束注解其他常用注解2. 核心接口使用方法1. 基本使用添加依赖 (Maven

Java中的分布式系统开发基于 Zookeeper 与 Dubbo 的应用案例解析

《Java中的分布式系统开发基于Zookeeper与Dubbo的应用案例解析》本文将通过实际案例,带你走进基于Zookeeper与Dubbo的分布式系统开发,本文通过实例代码给大家介绍的非常详... 目录Java 中的分布式系统开发基于 Zookeeper 与 Dubbo 的应用案例一、分布式系统中的挑战二