SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)

本文主要是介绍SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

地址

https://arxiv.org/pdf/2404.14469

核心

这篇论文介绍了一种名为SnapKV的创新方法,旨在提高大型语言模型处理长上下文时的效率和内存利用率。主要贡献包括: 1. 设计实验探索在输出生成过程中注意力特征的模式,发现注意力分配具有一致性,可以提取重要信息。 2. 提出了SnapKV算法,利用观察窗口和投票机制选择每个注意力头的重要键值对,并使用池化进行细粒度聚类。 3. 在多个模型和数据集上评估SnapKV,结果显示其可以大幅压缩键值对缓存,提高解码速度,同时保持模型性能。 总之,SnapKV为长序列输入提供了一种高效压缩键值对缓存的方法,有助于降低内存和计算成本,同时保持了生成质量。

import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math# perform qk calculation and get indices
# this version will not update in inference mode# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)"""batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class KVCluster():def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):# check if prefix phaseassert key_states.shape[-2] == query_states.shape[-2]bsz, num_heads, q_len, head_dim = query_states.shapeif q_len < self.max_capacity_prompt:return key_states, value_stateselse:attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(attn_weights.device)attention_mask = mask[None, None, :, :]attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)if self.pooling == 'avgpool':attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)elif self.pooling == 'maxpool':attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)else:raise ValueError('Pooling method not supported')indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indicesindices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)k_cur = key_states[:, :, -self.window_size:, :]v_cur = value_states[:, :, -self.window_size:, :]key_states = torch.cat([k_past_compress, k_cur], dim = 2)value_states = torch.cat([v_past_compress, v_cur], dim = 2)return key_states, value_states

这段代码实现了一个KVCluster类,用于更新键值对(key-value pairs)。该类具有以下方法:

  • __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):初始化KVCluster对象,可以设置窗口大小、最大容量、卷积核大小和池化方法。

  • reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):重置KVCluster对象的参数。

  • update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):根据输入的键、查询和值的状态以及注意力掩码更新键值对。在查询状态长度小于最大容量时,直接返回原始键值对。否则,根据注意力权重计算窗口内的聚合权重,并根据聚合权重选择top-k的索引。然后将过去的键值对和当前的键值对进行拼接,返回更新后的键值对。

该类主要用于处理键值对的更新,其中关键的部分是计算注意力权重、选择top-k索引和拼接过去和当前的键值对。

根据文档内容,使用SnapKV方法可以按照以下步骤进行:

  1. 确定观测窗口大小:选择输入序列末尾的一部分作为观测窗口,以捕获重要的注意力特征。通常选择窗口大小为32。
  2. 计算注意力权重:对观测窗口的查询和输入序列的前缀进行注意力计算,得到注意力权重矩阵。
  3. 进行投票:对每个注意力头,将观测窗口的注意力权重相加,选出权重最大的前缀位置作为重要特征。
  4. 聚类:对选出的重要特征进行聚类,以保留相邻特征。可以通过1D最大池化实现聚类。
  5. 更新KV缓存:将聚类后的特征与前缀特征拼接,形成新的Key-Value对,并更新KV缓存。这可以将KV缓存的大小压缩到指定值。
  6. 生成:使用更新后的KV缓存进行解码生成。由于KV缓存大小不再随输入序列增长,因此可以显著提高解码速度和内存效率。
  7. 调整参数:根据需要调整观测窗口大小、聚类核大小、KV缓存压缩目标值等参数,以平衡性能和效率。 总的来说,SnapKV通过自动识别输入序列中重要的注意力特征,并仅保留这些特征来压缩KV缓存,实现高效的长序列生成。该方法无需训练,可直接应用于现有模型中。

是的,根据文档中对SnapKV方法的描述,该步骤是在softmax之前进行的。 文档中提到SnapKV包含两个阶段:

  1. 投票阶段:计算观测窗口内查询和前缀的注意力权重,并进行投票,以选择出重要的前缀特征。
  2. 更新和存储阶段:根据投票结果,选择重要特征进行聚类,并拼接这些特征与前缀特征,形成新的KV对,以更新KV缓存。 这一过程发生在softmax之前,也就是在计算注意力权重时进行的。文档中并未明确指出是在softmax之前,但从上下文来看,这一过程发生在注意力权重计算阶段,因此是在softmax之前进行的。
    总之,SnapKV是在计算注意力权重时,通过压缩模型中提示的KV缓存来提高生成效率的,因此是在softmax之前进行的。

这篇关于SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/942919

相关文章

Java实现字节字符转bcd编码

《Java实现字节字符转bcd编码》BCD是一种将十进制数字编码为二进制的表示方式,常用于数字显示和存储,本文将介绍如何在Java中实现字节字符转BCD码的过程,需要的小伙伴可以了解下... 目录前言BCD码是什么Java实现字节转bcd编码方法补充总结前言BCD码(Binary-Coded Decima

python获取指定名字的程序的文件路径的两种方法

《python获取指定名字的程序的文件路径的两种方法》本文主要介绍了python获取指定名字的程序的文件路径的两种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 最近在做项目,需要用到给定一个程序名字就可以自动获取到这个程序在Windows系统下的绝对路径,以下

SpringBoot全局域名替换的实现

《SpringBoot全局域名替换的实现》本文主要介绍了SpringBoot全局域名替换的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录 项目结构⚙️ 配置文件application.yml️ 配置类AppProperties.Ja

JavaScript中的高级调试方法全攻略指南

《JavaScript中的高级调试方法全攻略指南》什么是高级JavaScript调试技巧,它比console.log有何优势,如何使用断点调试定位问题,通过本文,我们将深入解答这些问题,带您从理论到实... 目录观点与案例结合观点1观点2观点3观点4观点5高级调试技巧详解实战案例断点调试:定位变量错误性能分

Python实现批量CSV转Excel的高性能处理方案

《Python实现批量CSV转Excel的高性能处理方案》在日常办公中,我们经常需要将CSV格式的数据转换为Excel文件,本文将介绍一个基于Python的高性能解决方案,感兴趣的小伙伴可以跟随小编一... 目录一、场景需求二、技术方案三、核心代码四、批量处理方案五、性能优化六、使用示例完整代码七、小结一、

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

Java实现将HTML文件与字符串转换为图片

《Java实现将HTML文件与字符串转换为图片》在Java开发中,我们经常会遇到将HTML内容转换为图片的需求,本文小编就来和大家详细讲讲如何使用FreeSpire.DocforJava库来实现这一功... 目录前言核心实现:html 转图片完整代码场景 1:转换本地 HTML 文件为图片场景 2:转换 H

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

C#实现一键批量合并PDF文档

《C#实现一键批量合并PDF文档》这篇文章主要为大家详细介绍了如何使用C#实现一键批量合并PDF文档功能,文中的示例代码简洁易懂,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言效果展示功能实现1、添加文件2、文件分组(书签)3、定义页码范围4、自定义显示5、定义页面尺寸6、PDF批量合并7、其他方法

SpringBoot实现不同接口指定上传文件大小的具体步骤

《SpringBoot实现不同接口指定上传文件大小的具体步骤》:本文主要介绍在SpringBoot中通过自定义注解、AOP拦截和配置文件实现不同接口上传文件大小限制的方法,强调需设置全局阈值远大于... 目录一  springboot实现不同接口指定文件大小1.1 思路说明1.2 工程启动说明二 具体实施2