Hotdry.
ai-systems

拆解 Titans 门控记忆层:KV 缓存外的新长上下文结构如何落地推理系统

从惊喜指标到分页策略,给出 Titans 动态记忆模块在推理侧的完整工程化参数与踩坑笔记。

一、KV-cache 的平方墙撞得还不够疼?

把 32k 上下文拉到 128k,显存直接翻 16×;如果想再冲到 1M token,A100 80 GB 连一层 attention 都塞不下。业界惯用的「KV-cache 压缩」三板斧 —— 窗口滑动、层间共用、量化 —— 本质上只是延缓 O (L²) 的爆发,而不是消除它。真正让长文本推理成本可接受的路线只有两条:

  1. 把计算复杂度降下来(稀疏 / 线性 attention)
  2. 把存储结构换掉 —— 这正是 Google Titans 的门控记忆层想做的事

二、Titans 门控记忆层:在模型里长出一个「海马体」

Titans 不碰主体 Transformer 权重,而是在每层插入一个可写的 MLP 记忆模块,推理时根据输入动态更新。核心公式只有三行:

surprise = ‖∇θ log p(x_t|θ)‖          # 梯度幅值决定「值不值得记」
g_t = σ(W_s · surprise + b_s)          # 门控,0 完全遗忘,1 写入新记忆
M_t = (1 − g_t)⊙M_{t−1} + g_t⊙φ(x_t)  # 增量更新,φ 为单层 MLP
  1. 惊喜指标:用当前 token 对模型造成的梯度幅值衡量「信息新鲜度」,避免把重复废话写进记忆。
  2. 遗忘门:数据相关,可学习;实测比固定指数衰减(如 RetNet γ=0.99)在 200 万 token 上低 8% PPL。
  3. 写入粒度:每 token 只执行一次 512-wide 的 MLP,延迟≈同层 FFN 的 30%,但显存占用从 O (L²) 降到 O (L・d_mem),d_mem 通常取 512–1024。

三、三种集成形态:MAC/MAG/MAL 怎么选?

形态 记忆参与方式 适用场景 显存额外占用 实测长程 PPL(200 万 token)
MAC 作为额外 K/V 送入 attention 需要细粒度历史依赖:法律、病历 2×L·d_mem 2.41
MAG 门控缩放当前 logits 流式对话、实时推荐 1×L·d_mem 2.48
MAL 独立层输出残差 超长文本分类、日志异常检测 0.3×L·d_mem 2.63

选型建议

  • 生产系统优先 MAC,精度最高;若显存吃紧,可把记忆宽度压缩到 512 再配 4-bit 量化,掉点 < 2%。
  • 多轮对话机器人用 MAG,记忆状态可在用户离线后序列化落盘,下次直接加载,实现「跨会话记忆」。
  • 纯检索型任务(日志、分类)选 MAL,可把记忆层放在 CPU 内存,GPU 只负责 attention,成本最低。

四、推理系统落地清单

1. 显存预算

总显存 ≈ 模型权重 + KV-cache + 记忆层
KV-cache = 2 · L · d_model · n_layers · Bytes(dtype)
记忆层 = L · d_mem · n_layers · Bytes(dtype) · (1 + 0.2)  # 20% 预留分页碎片

以 7B 模型、n_layers=32、d_model=4096、d_mem=1024、fp16 为例:

  • 1 M token 传统 KV-cache 需要 2×1e6×4096×32×2 B ≈ 512 GB
  • Titans 记忆层只需 1e6×1024×32×2 B ≈ 64 GB,降 8×

2. 分页 & Checkpoint

  • 把记忆张量按 64 token 为 block 切分,维护「热块池」在显存,冷块换出到主机内存;换入延迟 < 3 ms(PCIe 4.0 x16)。
  • 每完成一个用户会话,调用 memory.checkpoint() 把 M_t 序列化到分布式文件系统;下次 memory.load() 直接 mmap 回内存,实现秒级续写。

3. 量化折衷

  • 记忆层权重对低比特敏感:INT8 掉 0.4 BLEU,INT4 掉 1.8 BLEU。当前稳妥方案是 8-bit 对称量化 + 动态缩放,配合 CUDA kernel __nv_fp8_e4m3 可在 H100 上做到 1.3× 吞吐提升。
  • 若仍想压到 4-bit,建议只对「遗忘门」输出做 FP16,记忆体本身用 INT4,混合精度可拉回 1.1 BLEU。

4. 生产坑位

  • 非确定性:同一输入因 batch 顺序不同导致惊喜指标微差,需强制 torch.use_deterministic_algorithms(True) 并锁 CUDA graph。
  • 梯度爆炸:长序列下 surprise 可能陡增,更新前做局部范数裁剪 torch.nn.utils.clip_grad_norm_(memory_params, max_norm=1.0)
  • 并发隔离:多用户同进程时,记忆状态用 threading.Local() 包装,防止用户 A 的惊喜写到用户 B 的记忆里。

五、下一步:跨会话持久化与分布式记忆分区

Titans 目前只在单卡内维护记忆,但工程上完全可以把记忆层拆成分布式键值服务:

  • 按 token-block ID 做一致性哈希,横向扩展记忆节点;
  • 热块放在 GPU RDMA 内存,温块落到 NVMe-oF;
  • 全局版本号 + MVCC,实现「读已提交」隔离级别,支持多用户同时读写。

谷歌已在论文附录里透露未来会开源记忆层 CUDA kernel 与 PyTorch op;一旦发布,长上下文推理的「显存焦虑」有望从根上解除。


资料来源
[1] Behrouz A. et al. Titans: Learning to Memorize at Test Time. arXiv:2501.00663
[2] Wang J. et al. Efficient Attention Mechanisms for Large Language Models: A Survey. arXiv:2507.19595

查看归档