multiheadattention类原理及源码理解

2023-11-01 19:28

本文主要是介绍multiheadattention类原理及源码理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

网络找的一段代码如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

python、pytorch、人工智能相关知识现阶段都是简单的了解,没有相关的实践。因此在学习的时候不要习惯性的扣代码细节。能把论文原理和代码逻辑对应即可、能总结代码块重点内容即可。

transformer中self-attention就是对一个输入序列计算每个位置的注意力,每个位置在论文原文中用d_model(512)维表示,多头就是每个位置用h(原文中8个)个头计算,这样每个头计算一个位置中的64维特征。

自注意力机制有什么好处呢?

自注意力机制的目的是让模型能够同时关注输入序列中的不同位置和信息,从而捕捉序列中的复杂模式和关系。通过计算每个位置的向量与其他位置的向量之间的相似度或相关性,模型可以学习到序列中每个元素对于输出结果的重要性,从而给予不同的权重。

为什么要使用多头呢?下面是我找到的解释:

多头计算可以让模型同时关注输入序列中的不同方面和细节,从而增强模型的表达能力和学习能力。每个注意力头可以捕捉输入序列中的不同模式和关系,而最终的线性变换可以将这些信息融合在一起。
多头计算可以降低模型的复杂度和计算成本。对于较大的 d_model 来说,如果只使用单头计算,那么 QK^T 的结果会非常大,导致 softmax 函数的梯度非常小,不利于网络的训练。而使用多头计算,可以将 d_model 分割成 h 个较小的子空间,从而减少计算量和内存消耗34。
多头计算还可以
提高模型的可解释性和泛化能力
。我们可以从模型中检查不同注意力头的分布,观察模型是如何关注不同位置和信息的。各个注意力头可以学会执行不同的任务,例如语法分析、实体识别等

MultiHeadedAttention类还做了什么事情?
1、通过4个线性层(通常是4)计算得到Q K V矩阵
在transformer中,Q、K、V是通过四个线性层得到的,分别是:
Q = XW^Q ,其中X是embedding输入矩阵,W^Q 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到Q空间。
K = XW^K ,其中X是embedding输入矩阵,W^K 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到K空间。
V = XW^V ,其中Xembedding是输入矩阵,W^V 是一个可训练的参数矩阵,大小为(d_model* d_model)用于将X映射到V空间。

这篇关于multiheadattention类原理及源码理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/325201

相关文章

Java Spring的依赖注入理解及@Autowired用法示例详解

《JavaSpring的依赖注入理解及@Autowired用法示例详解》文章介绍了Spring依赖注入(DI)的概念、三种实现方式(构造器、Setter、字段注入),区分了@Autowired(注入... 目录一、什么是依赖注入(DI)?1. 定义2. 举个例子二、依赖注入的几种方式1. 构造器注入(Con

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

MySQL中的表连接原理分析

《MySQL中的表连接原理分析》:本文主要介绍MySQL中的表连接原理分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、环境3、表连接原理【1】驱动表和被驱动表【2】内连接【3】外连接【4编程】嵌套循环连接【5】join buffer4、总结1、背景

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

Java Stream的distinct去重原理分析

《JavaStream的distinct去重原理分析》Javastream中的distinct方法用于去除流中的重复元素,它返回一个包含过滤后唯一元素的新流,该方法会根据元素的hashcode和eq... 目录一、distinct 的基础用法与核心特性二、distinct 的底层实现原理1. 顺序流中的去重

Spring @Scheduled注解及工作原理

《Spring@Scheduled注解及工作原理》Spring的@Scheduled注解用于标记定时任务,无需额外库,需配置@EnableScheduling,设置fixedRate、fixedDe... 目录1.@Scheduled注解定义2.配置 @Scheduled2.1 开启定时任务支持2.2 创建

Spring Boot 实现 IP 限流的原理、实践与利弊解析

《SpringBoot实现IP限流的原理、实践与利弊解析》在SpringBoot中实现IP限流是一种简单而有效的方式来保障系统的稳定性和可用性,本文给大家介绍SpringBoot实现IP限... 目录一、引言二、IP 限流原理2.1 令牌桶算法2.2 漏桶算法三、使用场景3.1 防止恶意攻击3.2 控制资源

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可