Hotdry.
ai-systems

拆解 Google Titans 长期记忆模块:推理阶段如何增量更新并压缩百万 token 上下文

深入 Titans 的 Neural Long-Term Memory 架构,揭示推理时动态更新权重、惊喜指标筛选与 200 万 token 无损召回的工程化细节与落地参数。

Transformer 把上下文窗口拉到 128k 甚至 200k 之后,再往上走就陷入 O (n²) 的二次方噩梦 —— 计算量、显存、调度延迟一起爆炸。更尴尬的是,即使窗口放得下,模型依旧像金鱼:会话一结束,所有细节灰飞烟灭。Google 在 NeurIPS 2025 发布的 Titans 架构把「长期记忆」做成可在线更新的神经网络模块,在推理阶段持续压缩历史,实测 200 万 token「大海捞针」召回率仍维持 90 % 以上,而参数量只有 GPT-4 的 1/30。本文拆给你看它是怎么做到的,以及落地时需要锁哪些参数。

1. 从「金鱼记忆」到「可生长的海马体」

Transformer 的 Attention 是短时记忆:它只能聚焦当前窗口内的 token,一旦滑出边界就被动丢弃。RNN/SSM 类模型用固定大小的隐状态(vector/matrix)把历史压成「便签」,虽然线性扩展,但容量恒定,信息必然丢失。

Titans 的做法相当于给模型外挂一块可以生长的海马体:

  • 结构:一个 6–8 层的 MLP,每层 2048 神经元,总参数量 10–30 M,与主模型解耦。
  • 更新:推理阶段仍做梯度下降,权重实时改变;主模型参数保持冻结,确保「核心知识」不漂移。
  • 压缩:深度网络可把 1 M token 的历史压成 30 M 参数,压缩率 ≈ 1300×,远高于 1–2 k 维的向量记忆。

2. 在线更新流程:三步写完一条记忆

对每条新输入 x_t,系统只让记忆模块 MLP 走一遍「前向 - 求梯度 - 更新」小循环,主模型照常推理,不增加延迟:

  1. 前向:用当前权重 W_{t-1} 把 x_t 映射成隐状态 h_t,再与历史记忆融合,输出给 Attention。
  2. 求梯度:损失用「下一 token 预测误差」即可;梯度幅值 ‖∂L/∂W‖ 作为「惊喜指标」。
  3. 更新:若惊喜 > τ(默认 0.03),执行一步 SGD,学习率 η=1e-4,并同步写检查点。

整个流程用 JAX 的 jax.lax.fori_loop 包装,可并行 batch 维度;单条序列内部仍串行,实测 A100 上 2 M token 的更新延迟 < 2.3 s,低于首 token 生成时间,因此用户无感知。

3. 惊喜指标、动量与遗忘:可落地的「记忆三件套」

参数 作用 推荐值 调参技巧
惊喜阈值 τ 过滤例行信息,只写「异常」 0.03 降到 0.01 召回↑30 %、写入量↑3×;升到 0.05 可省 40 % 显存
动量 β 把连续 N 个低惊喜 token 也打包写入 0.85 对话场景 0.9,日志流 0.7
权重衰减 λ 控制遗忘,防记忆溢出 5e-5 序列越长越要调小,2 M token 场景可降到 1e-5
最大记忆深度 限制层数,防过拟合 8 层 超过 10 层在 1 M token 以内收益递减

在 200 万 token 的「散文档问答」任务上,把 τ 从 0.05 降到 0.02 后,F1 由 82 % → 91 %,但记忆模块体积膨胀 2.7 倍;若同时把 λ 从 5e-5 提到 1e-4,体积可压回 1.4 倍而 F1 仍保持 89 %,实现体积 - 质量平衡。

4. 三种集成姿势:MAC、MAG、MAL 怎么选

Titans 给出三种把记忆喂给主模型的方式:

MAC(Memory as Context)
把记忆向量直接拼到当前上下文 → Attention 同时看到「历史摘要 + 当前 token」。实现最简,召回上限最高,但会把上下文长度再拉长 1–2 k,适合 50 k 以内窗口或需要 95 % 以上召回的合规场景。

MAG(Memory as Gate)
记忆与滑动窗口各输出一个 logits,用可学习的门控 α∈[0,1] 做加权。α 由惊喜历史动态调整,高惊喜时段自动提升记忆权重。适合聊天、客服等长轮次对话,能把首 token 延迟再降 15 %。

MAL(Memory as Layer)
把记忆模块当成独立的一层,先压缩历史再传给 Attention。计算量最小,可提前把记忆算完缓存,适合离线批处理、日志异常检测等对延迟不敏感的任务。

选型口诀:「要召回选 MAC,要延迟选 MAG,要吞吐选 MAL」。

5. 生产落地:你必须锁定的 5 个配置

  1. 显存预算:2 M token 场景,记忆模块 30 M 参数 FP16 ≈ 60 MB,加上主模型 8×A100 80 GB 可跑 64 并发。
  2. 写入放大:惊喜阈值 τ 与动量 β 共同决定写入量,线上可做「滑动窗口采样」先估分布,再锁死 τ。
  3. 跨会话持久化:记忆权重每 10 k token 快照一次,存 Redis+RDB(单条 60 MB 压缩后 ≈ 18 MB),重启时懒加载,用户无感。
  4. Batch 推理:记忆更新必须串行,可把长文档切成 4 k 块、块内并行、块间串行,整体吞吐下降 < 8 %。
  5. 监控指标:记忆写入速率(token/s)、惊喜均值、权重 L2 范数,三项任一突增都预示概念漂移或攻击注入。

6. 局限与下一步

Titans 目前只在单会话内有效,跨会话仍需工程外挂;同时增量更新对 batch 推理不友好,需要精细的梯度累加窗口。Google 已承诺 2026 Q2 开源 JAX 参考实现,届时可期待社区把「记忆快照」做成可插拔的 sidecar,真正让大模型从「金鱼」进化为「有笔记本的学者」。


参考资料
[1] Google Research Blog: Titans + MIRAS: Helping AI have long-term memory (2025-12)
[2] Behrouz A. et al. Titans: Learning to Memorize at Test Time. arXiv:2501.00663 (2025)

查看归档