torch.einsum详解

2024-08-20 23:44
文章标签 详解 torch einsum

本文主要是介绍torch.einsum详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.einsum 是 PyTorch 中用于执行高效张量运算的函数,基于爱因斯坦求和约定(Einstein summation convention)。它能够处理复杂的张量操作,并简化代码书写。

基本语法

torch.einsum(subscripts, *operands)
  • subscripts:一个字符串,用于描述输入张量的维度如何结合。
  • *operands:待操作的张量。

爱因斯坦求和约定

爱因斯坦求和约定是一个简化张量运算的方式,省略了显式的求和符号。通过指定各维度的标签,可以直接描述复杂的张量运算。

语法结构

  • "nqhd,nkhd->nhqk": 这个字符串描述了如何对两个张量进行操作,并生成输出张量的维度。

    • n:批次大小(batch size)
    • q:查询序列长度(query length)
    • k:键序列长度(key length)
    • h:注意力头的数量(number of heads)
    • d:每个注意力头的维度(dimension per head)

示例代码

以下是使用 torch.einsum 计算多头注意力机制中点积相似性的示例代码:

import torch# 定义多头注意力机制的点积计算函数
def compute_attention_scores(queries, keys):# 计算点积相似性分数energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])return energy# 示例数据
N = 1            # 批次大小
q = 2            # 查询序列长度
k = 3            # 键序列长度
h = 2            # 注意力头数量
d = 4            # 每个注意力头的维度# 随机生成 queries 和 keys
queries = torch.rand((N, q, h, d))  # Shape (1, 2, 2, 4)
keys = torch.rand((N, k, h, d))    # Shape (1, 3, 2, 4)# 计算注意力分数
energy = compute_attention_scores(queries, keys)print("Energy shape:", energy.shape)
print(energy)

计算过程

  1. 维度解释

    • queries 的维度为 (1, 2, 2, 4)N = 1(批次大小),q = 2(查询序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
    • keys 的维度为 (1, 3, 2, 4)N = 1(批次大小),k = 3(键序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
  2. 点积计算

    • 对每个批次和每个头部,计算 querieskeysd 维度上的点积。
    • 结果的维度为 (N, h, q, k),其中:
      • N 是批次大小
      • h 是注意力头的数量
      • q 是查询序列的长度
      • k 是键序列的长度

    点积计算的实际操作是:

    • 对于每个批次(n)和每个头部(h),对 querieskeys 张量在 d 维度上进行点积运算,得到形状为 (q, k) 的张量。

简单计算示例

假设我们有如下示例数据:

queries = torch.tensor([[[[1.0, 0.5, 0.2, 1.5], [0.3, 0.7, 0.6, 0.8]], [[0.9, 0.4, 1.2, 0.5], [0.2, 0.6, 0.8, 0.7]]]])
keys = torch.tensor([[[[0.1, 1.0, 0.3, 0.5], [0.2, 0.4, 0.6, 0.7], [0.8, 1.0, 0.9, 0.5]], [[0.1, 0.5, 0.2, 0.8], [0.3, 0.4, 0.7, 0.9], [0.6, 0.8, 1.0, 0.2]]]])

点积计算

  • 对于第一个批次和第一个头部:

    • queries[0, :, 0, :]keys[0, :, 0, :] 的点积计算如下:

    计算:

    energy[0, 0, 0, 0] = (1.0*0.1 + 0.5*1.0 + 0.2*0.3 + 1.5*0.5) = 0.1 + 0.5 + 0.06 + 0.75 = 1.41
    energy[0, 0, 0, 1] = (1.0*0.2 + 0.5*0.4 + 0.2*0.6 + 1.5*0.7) = 0.2 + 0.2 + 0.12 + 1.05 = 1.59
    energy[0, 0, 0, 2] = (1.0*0.8 + 0.5*1.0 + 0.2*0.9 + 1.5*0.5) = 0.8 + 0.5 + 0.18 + 0.75 = 1.23
    energy[0, 0, 1, 0] = (0.3*0.1 + 0.7*1.0 + 0.6*0.3 + 0.8*0.5) = 0.03 + 0.7 + 0.18 + 0.4 = 1.31
    energy[0, 0, 1, 1] = (0.3*0.2 + 0.7*0.4 + 0.6*0.6 + 0.8*0.7) = 0.06 + 0.28 + 0.36 + 0.56 = 1.26
    energy[0, 0, 1, 2] = (0.3*0.8 + 0.7*1.0 + 0.6*0.9 + 0.8*0.5) = 0.24 + 0.7 + 0.54 + 0.4 = 1.88
    

总结

torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) 用于计算 querieskeys 张量在注意力机制中的点积,相似性得分。它通过爱因斯坦求和约定指定了如何在多维张量上执行这些操作,使得代码更简洁、效率更高。

Code

AI_With_NumPy
此项目汇集了很多AI相关的代码实现,供大家学习使用,欢迎点赞收藏👏🏻

备注

个人水平有限,有问题随时交流~

这篇关于torch.einsum详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1091516

相关文章

PHP轻松处理千万行数据的方法详解

《PHP轻松处理千万行数据的方法详解》说到处理大数据集,PHP通常不是第一个想到的语言,但如果你曾经需要处理数百万行数据而不让服务器崩溃或内存耗尽,你就会知道PHP用对了工具有多强大,下面小编就... 目录问题的本质php 中的数据流处理:为什么必不可少生成器:内存高效的迭代方式流量控制:避免系统过载一次性

MySQL的JDBC编程详解

《MySQL的JDBC编程详解》:本文主要介绍MySQL的JDBC编程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言一、前置知识1. 引入依赖2. 认识 url二、JDBC 操作流程1. JDBC 的写操作2. JDBC 的读操作总结前言本文介绍了mysq

Redis 的 SUBSCRIBE命令详解

《Redis的SUBSCRIBE命令详解》Redis的SUBSCRIBE命令用于订阅一个或多个频道,以便接收发送到这些频道的消息,本文给大家介绍Redis的SUBSCRIBE命令,感兴趣的朋友跟随... 目录基本语法工作原理示例消息格式相关命令python 示例Redis 的 SUBSCRIBE 命令用于订

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

SpringBoot日志级别与日志分组详解

《SpringBoot日志级别与日志分组详解》文章介绍了日志级别(ALL至OFF)及其作用,说明SpringBoot默认日志级别为INFO,可通过application.properties调整全局或... 目录日志级别1、级别内容2、调整日志级别调整默认日志级别调整指定类的日志级别项目开发过程中,利用日志

Java中的抽象类与abstract 关键字使用详解

《Java中的抽象类与abstract关键字使用详解》:本文主要介绍Java中的抽象类与abstract关键字使用详解,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、抽象类的概念二、使用 abstract2.1 修饰类 => 抽象类2.2 修饰方法 => 抽象方法,没有

MySQL8 密码强度评估与配置详解

《MySQL8密码强度评估与配置详解》MySQL8默认启用密码强度插件,实施MEDIUM策略(长度8、含数字/字母/特殊字符),支持动态调整与配置文件设置,推荐使用STRONG策略并定期更新密码以提... 目录一、mysql 8 密码强度评估机制1.核心插件:validate_password2.密码策略级

从入门到精通详解Python虚拟环境完全指南

《从入门到精通详解Python虚拟环境完全指南》Python虚拟环境是一个独立的Python运行环境,它允许你为不同的项目创建隔离的Python环境,下面小编就来和大家详细介绍一下吧... 目录什么是python虚拟环境一、使用venv创建和管理虚拟环境1.1 创建虚拟环境1.2 激活虚拟环境1.3 验证虚

详解python pycharm与cmd中制表符不一样

《详解pythonpycharm与cmd中制表符不一样》本文主要介绍了pythonpycharm与cmd中制表符不一样,这个问题通常是因为PyCharm和命令行(CMD)使用的制表符(tab)的宽... 这个问题通常是因为PyCharm和命令行(CMD)使用的制表符(tab)的宽度不同导致的。在PyChar