SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)

本文主要是介绍SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

地址

https://arxiv.org/pdf/2404.14469

核心

这篇论文介绍了一种名为SnapKV的创新方法,旨在提高大型语言模型处理长上下文时的效率和内存利用率。主要贡献包括: 1. 设计实验探索在输出生成过程中注意力特征的模式,发现注意力分配具有一致性,可以提取重要信息。 2. 提出了SnapKV算法,利用观察窗口和投票机制选择每个注意力头的重要键值对,并使用池化进行细粒度聚类。 3. 在多个模型和数据集上评估SnapKV,结果显示其可以大幅压缩键值对缓存,提高解码速度,同时保持模型性能。 总之,SnapKV为长序列输入提供了一种高效压缩键值对缓存的方法,有助于降低内存和计算成本,同时保持了生成质量。

import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math# perform qk calculation and get indices
# this version will not update in inference mode# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)"""batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class KVCluster():def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):# check if prefix phaseassert key_states.shape[-2] == query_states.shape[-2]bsz, num_heads, q_len, head_dim = query_states.shapeif q_len < self.max_capacity_prompt:return key_states, value_stateselse:attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(attn_weights.device)attention_mask = mask[None, None, :, :]attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)if self.pooling == 'avgpool':attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)elif self.pooling == 'maxpool':attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)else:raise ValueError('Pooling method not supported')indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indicesindices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)k_cur = key_states[:, :, -self.window_size:, :]v_cur = value_states[:, :, -self.window_size:, :]key_states = torch.cat([k_past_compress, k_cur], dim = 2)value_states = torch.cat([v_past_compress, v_cur], dim = 2)return key_states, value_states

这段代码实现了一个KVCluster类,用于更新键值对(key-value pairs)。该类具有以下方法:

  • __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):初始化KVCluster对象,可以设置窗口大小、最大容量、卷积核大小和池化方法。

  • reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):重置KVCluster对象的参数。

  • update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):根据输入的键、查询和值的状态以及注意力掩码更新键值对。在查询状态长度小于最大容量时,直接返回原始键值对。否则,根据注意力权重计算窗口内的聚合权重,并根据聚合权重选择top-k的索引。然后将过去的键值对和当前的键值对进行拼接,返回更新后的键值对。

该类主要用于处理键值对的更新,其中关键的部分是计算注意力权重、选择top-k索引和拼接过去和当前的键值对。

根据文档内容,使用SnapKV方法可以按照以下步骤进行:

  1. 确定观测窗口大小:选择输入序列末尾的一部分作为观测窗口,以捕获重要的注意力特征。通常选择窗口大小为32。
  2. 计算注意力权重:对观测窗口的查询和输入序列的前缀进行注意力计算,得到注意力权重矩阵。
  3. 进行投票:对每个注意力头,将观测窗口的注意力权重相加,选出权重最大的前缀位置作为重要特征。
  4. 聚类:对选出的重要特征进行聚类,以保留相邻特征。可以通过1D最大池化实现聚类。
  5. 更新KV缓存:将聚类后的特征与前缀特征拼接,形成新的Key-Value对,并更新KV缓存。这可以将KV缓存的大小压缩到指定值。
  6. 生成:使用更新后的KV缓存进行解码生成。由于KV缓存大小不再随输入序列增长,因此可以显著提高解码速度和内存效率。
  7. 调整参数:根据需要调整观测窗口大小、聚类核大小、KV缓存压缩目标值等参数,以平衡性能和效率。 总的来说,SnapKV通过自动识别输入序列中重要的注意力特征,并仅保留这些特征来压缩KV缓存,实现高效的长序列生成。该方法无需训练,可直接应用于现有模型中。

是的,根据文档中对SnapKV方法的描述,该步骤是在softmax之前进行的。 文档中提到SnapKV包含两个阶段:

  1. 投票阶段:计算观测窗口内查询和前缀的注意力权重,并进行投票,以选择出重要的前缀特征。
  2. 更新和存储阶段:根据投票结果,选择重要特征进行聚类,并拼接这些特征与前缀特征,形成新的KV对,以更新KV缓存。 这一过程发生在softmax之前,也就是在计算注意力权重时进行的。文档中并未明确指出是在softmax之前,但从上下文来看,这一过程发生在注意力权重计算阶段,因此是在softmax之前进行的。
    总之,SnapKV是在计算注意力权重时,通过压缩模型中提示的KV缓存来提高生成效率的,因此是在softmax之前进行的。

这篇关于SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/942919

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

Python实现微信自动锁定工具

《Python实现微信自动锁定工具》在数字化办公时代,微信已成为职场沟通的重要工具,但临时离开时忘记锁屏可能导致敏感信息泄露,下面我们就来看看如何使用Python打造一个微信自动锁定工具吧... 目录引言:当微信隐私遇到自动化守护效果展示核心功能全景图技术亮点深度解析1. 无操作检测引擎2. 微信路径智能获

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

如何在 Spring Boot 中实现 FreeMarker 模板

《如何在SpringBoot中实现FreeMarker模板》FreeMarker是一种功能强大、轻量级的模板引擎,用于在Java应用中生成动态文本输出(如HTML、XML、邮件内容等),本文... 目录什么是 FreeMarker 模板?在 Spring Boot 中实现 FreeMarker 模板1. 环

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll