在 PyTorch 中实现多查询注意力:自定义 LLM 中高效长上下文自回归生成的 KV 缓存内存优化
面向长上下文自回归生成,给出 PyTorch 中 MQA 的实现与 KV 缓存优化的工程参数。
在自定义大型语言模型(LLM)的开发中,自回归生成是核心任务,尤其当处理长上下文时,内存消耗成为瓶颈。Transformer 解码器中的键-值(KV)缓存机制通过存储过去计算的键和值向量,避免重复计算,从而加速推理。但标准多头注意力(MHA)下,KV 缓存大小随头数线性增长,导致长序列(如数万 token)时 GPU 内存迅速耗尽。多查询注意力(Multi-Query Attention, MQA)作为一种高效变体,通过共享单一键和值头,显著降低 KV 缓存占用,支持更长的上下文窗口。本文基于 PyTorch 从零实现 MQA,聚焦其在解码器中的集成与优化参数,提供可落地工程实践。
MQA 的核心原理与优势
标准 MHA 在每个注意力头中独立计算查询(Q)、键(K)和值(V)投影,导致 KV 缓存维度为 batch_size × num_heads × seq_len × head_dim。这种设计虽增强表达力,但推理时 KV 缓存膨胀迅速。以 num_heads=32、head_dim=128、seq_len=4096 为例,单层 KV 缓存约需 32 × 4096 × 128 × 2 × 4 bytes(float32)× batch_size ≈ 1GB(batch=1),多层累积更甚。对于长上下文(如 128k token),内存需求将超出消费级 GPU 极限。
MQA 优化此问题:保留多个 Q 头(num_heads >1),但将 K 和 V 投影压缩至单一头(num_kv_heads=1)。生成时,K/V 向量在每个 Q 头间重复使用,KV 缓存大小缩减至 num_heads 倍(典型 8-32 倍)。例如,上例中 MQA KV 缓存仅约 128kB/层,整体内存节省 90%以上。这不仅延长上下文窗口,还提升吞吐量,尤其在边缘设备或多用户服务中。
证据显示,Google 的 PaLM 模型率先采用 MQA,证明其在 perplexity 上仅微增 1-2%,而推理速度提升 2-4 倍。Hugging Face Transformers 库中,Llama 系列也集成 MQA,支持高效部署。在自定义 LLM 中,如基于 GPT 架构的实现,MQA 可无缝替换标准注意力,提升长文本生成能力。
PyTorch 中 Transformer 解码器的 MQA 实现
假设基础 Transformer 解码器已实现(如参考 GitHub rasbt/LLMs-from-scratch 中的 ch04 GPT 模型),MQA 修改聚焦注意力层。核心是调整线性投影维度,并处理 K/V 重复。
首先,定义配置类,引入 num_kv_heads 参数:
class GPTConfig:
def __init__(self, vocab_size=50257, context_length=1024, n_embd=768, n_layer=12, n_head=12, n_kv_head=1, dropout=0.0):
self.vocab_size = vocab_size
self.context_length = context_length
self.n_embd = n_embd # 嵌入维度
self.n_layer = n_layer
self.n_head = n_head # Q 头数
self.n_kv_head = n_kv_head # KV 头数,MQA 时设为 1
self.head_dim = n_embd // n_head
self.nv_head_dim = self.head_dim * n_kv_head # KV 投影输出 dim
self.dropout = dropout
接下来,实现因式分解注意力(Causal Self-Attention)模块,支持 KV 缓存:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.head_dim = config.head_dim
# Q 投影:全头维度
self.w_q = nn.Linear(config.n_embd, config.n_embd, bias=False)
# K/V 投影:仅 KV 头维度
self.w_k = nn.Linear(config.n_embd, config.nv_head_dim, bias=False)
self.w_v = nn.Linear(config.n_embd, config.nv_head_dim, bias=False)
# 输出投影
self.w_o = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
# KV 缓存占位符
self.register_buffer('k_cache', None)
self.register_buffer('v_cache', None)
def forward(self, x, mask=None, use_cache=False):
B, T, C = x.shape # batch, seq_len, emb_dim
# 计算 Q, K, V
q = self.w_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, nh, T, hd]
k = self.w_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.w_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
# MQA:重复 K/V 到每个 Q 头
if self.n_kv_head != self.n_head:
repeat_times = self.n_head // self.n_kv_head
k = k.repeat(1, repeat_times, 1, 1) # [B, nh, T, hd]
v = v.repeat(1, repeat_times, 1, 1)
# KV 缓存逻辑(自回归生成)
if use_cache and self.k_cache is not None:
# 追加新 K/V 到缓存
self.k_cache = torch.cat([self.k_cache, k], dim=2) if self.k_cache.shape[2] < T else k
self.v_cache = torch.cat([self.v_cache, v], dim=2) if self.v_cache.shape[2] < T else v
k, v = self.k_cache, self.v_cache # 使用全序列 K/V
else:
if use_cache:
self.k_cache = k
self.v_cache = v
# 因果掩码注意力
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask[:, None, None, :] == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, v) # [B, nh, T, T] @ [B, nh, T, hd] -> [B, nh, T, hd]
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.w_o(out)
return out, (k, v) if use_cache else None
在 GPT 模型中集成此注意力层:
class GPTBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd * 4, config.n_embd * 4),
nn.GELU(),
nn.Linear(config.n_embd * 4, config.n_embd),
nn.Dropout(config.dropout)
)
def forward(self, x, mask=None, use_cache=False):
if use_cache:
attn_out, kv_cache = self.attn(self.ln1(x), mask, use_cache)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x, kv_cache
else:
attn_out = self.attn(self.ln1(x), mask)[0]
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x
class GPTModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.context_length, config.n_embd)
self.blocks = nn.Sequential(*[GPTBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, idx, targets=None, use_cache=False):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
tok_emb = self.tok_emb(idx) # [B, T, C]
pos_emb = self.pos_emb(pos) # [T, C]
x = tok_emb + pos_emb
mask = torch.tril(torch.ones(T, T, device=idx.device)) if T > 1 else None
caches = []
for block in self.blocks:
if use_cache:
x, cache = block(x, mask, use_cache)
caches.append(cache)
else:
x = block(x, mask)
x = self.ln_f(x)
logits = self.head(x)
if targets is None:
return logits, caches if use_cache else None
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
此实现兼容训练(use_cache=False)和生成(use_cache=True)。在生成循环中,逐 token 更新缓存,避免全序列重算。
可落地参数与工程优化清单
为确保 MQA 在生产中高效,需调优以下参数:
-
头数配置:num_heads=8-16(平衡表达力),num_kv_heads=1(纯 MQA)。head_dim=64-128,避免过小导致精度损失。测试:用 perplexity 评估,目标 <5% 增幅。
-
KV 缓存管理:
- 阈值:max_cache_len=8192,超过时 eviction 策略(如 FIFO 或 LRU),释放旧 KV。
- 量化:KV 缓存用 int8/float16 量化,内存减半(torch.quantize_per_tensor)。
- 批处理:batch_size=1-4,长上下文时动态调整,避免 OOM。
-
监控要点:
- GPU 内存:用 nvidia-smi 追踪峰值,目标 <80% 利用率。
- 延迟:基准生成速度(tokens/s),MQA 应提升 2x+。
- 质量:BLEU/ROUGE 分数或人工评估,长上下文一致性。
-
回滚策略:若 perplexity 超标,渐进引入 GQA(Grouped-Query Attention,num_kv_heads=2-4),折中内存与质量。
-
集成清单:
- 步骤1:替换标准注意力为 CausalSelfAttention。
- 步骤2:训练时禁用缓存,finetune 1-2 epochs。
- 步骤3:生成时启用缓存,测试 4k→32k 上下文。
- 步骤4:部署用 TorchServe 或 vLLM,启用 MQA 支持。
通过这些参数,自定义 LLM 可处理 100k+ token 上下文,适用于聊天机器人或文档总结。实际部署中,结合 FlashAttention-2 进一步加速 3x。MQA 不仅是内存优化,更是 scaling 长上下文的关键,助力高效 AI 系统构建。
(字数:约 1250)