Hotdry.
ai-systems

拆解 Titans 记忆模块:如何用「神经长期记忆+短期上下文」在推理阶段实现百万 token 级上下文无损召回

从惊奇度写入到动量遗忘,详解 Google Titans 如何在推理阶段动态维护一个可更新的 MLP 记忆体,把上下文窗口推至 200 万 token 仍保持 90%+ 召回,并给出可直接落地的超参卡与工程 checklist。

Transformer 把上下文做到 128 k 已让 GPU 喘不上气,但谷歌在 NeurIPS 2025 丢出的 Titans 架构直接官宣:200 万 token 上下文下「大海捞针」准确率仍维持 90% 以上。关键不是堆算力,而是让模型在推理阶段长出自己的「海马体」—— 一个可在测试时持续更新权重的神经长期记忆模块(Neural Long-Term Memory, LTM)。下面按「原理→实现→调参→落地」四层拆给你看。

一、Transformer 的 O (N²) 墙与 Titans 破墙思路

标准自注意力每增加 1 倍序列长度,计算量翻 4 倍,显存翻 2 倍;当长度 >100 k 时,batch-size 被迫压到 1,吞吐掉到不可接受。业界解法无非三派:

  1. 稀疏 / 线性注意力:把 O (N²) 压到 O (N),但精度掉得比显存还快;
  2. 分块 + 滑窗:把长序切成片段,每段内做局部注意力,跨段用状态传递,工程复杂且召回率随深度指数衰减;
  3. 外挂 RAG:用向量库存历史,推理时先检索再生成,延迟和一致性不可控。

Titans 走第四条路 ——在模型内部新增一个可学习的记忆体,推理阶段实时写、实时读,而主干参数完全不动。相当于给 Transformer 外挂一块「可刷新 SSD」,但读写接口与注意力原生融合,延迟只增加 5–12%。

二、神经长期记忆模块:动态 MLP + 惊奇度写入

1. 记忆体结构

  • 一个 4–8 层的 MLP,隐藏层维度与主模型一致(7 B 模型常用 4096),参数量≈ 1% 主模型;
  • 权重在训练后不完全冻结,而是在推理阶段根据输入持续做一步梯度下降;
  • 采用并行化训练技巧:把「记忆更新」拆成 batch×token 的矩阵运算,GPU 利用率 >85%。

2. 惊奇度(surprise metric)(写入触发器

对当前输入 x_t 计算:

surprise_t = ‖∇_θ_mem loss(x_t, y_t)‖₂

即记忆参数对当前 token 的梯度 L2 范数。梯度越大 → 模型「越意外」,越值得写入长期记忆。

工程实现里用「动量累积」防止偶发噪音触发:

S_t = β·S_{t-1} + (1-β)·surprise_t

当 S_t > τ(阈值)才执行写入,β 一般取 0.7–0.9。

3. 遗忘机制(容量管理

记忆 MLP 容量固定,写入无限长序列必然溢出。Titans 用自适应权重衰减

θ_mem ← θ_mem − η·λ_t·θ_mem

λ_t 与当前写入频率成反比:写入越频繁,λ_t 越小,保证新信息不被立即抹掉;写入稀疏时加大遗忘,清理旧知识。λ 初始值 1e-4,按验证集召回率自动微调。

三、三种集成范式:MAC、MAG、MAL 怎么选

范式 原理 延迟 召回率 适用场景
MAC (Memory as Context) 把记忆输出当额外 K,V 塞进注意力 +12% 最高 超长文档问答、 genomic 分析
MAG (Memory as Gate) 记忆与滑窗注意力做门控融合 +5% 中高 流式对话、实时字幕
MAL (Memory as Layer) 记忆模块当独立层,压缩后再给注意力 +8% 中等长度、批量大

生产建议

  • 200 k–2 M token 单任务召回 → 直接上 MAC;
  • 10 k–200 k 且延迟敏感 → 用 MAG,窗口 4 k 即可;
  • 想「零改图」复用现有推理框架 → 选 MAL,只插一层,不改注意力代码。

四、推理阶段参数更新:冻结主干 vs. 在线记忆

关键保证生产稳定性

  • 主干 transformer 权重完全冻结,杜绝「灾难性遗忘」;
  • 仅 θ_mem 在推理时做一步 SGD,学习率 1e-3–3e-4,单步更新耗时 <1 ms(7 B/4096 隐维,A100);
  • 每 token 更新一次太频繁,实际按「惊奇窗口」批处理:每 64 个 token 算一次累积惊奇,满足阈值才触发更新,GPU 利用率提升 30%。

五、MIRAS 框架:把「记忆」做成可插拔组件

如果你不想用 Titans 默认的 MLP 记忆,MIRAS 给出 4 个可替换维度:

  1. Memory architecture:向量、矩阵、深度网络随你挑;
  2. Attentional bias:用 Huber Loss 替代 MSE,对异常输入更鲁棒(YAAD 变体);
  3. Retention gate:把「遗忘」 reinterpret 成正则项,可换成 L1、L2、DropConnect;
  4. Memory algorithm:支持 Adam、AdaGrad、甚至二阶拟牛顿。

工程价值

  • 已存在的线性 RNN(Mamba、Gated DeltaNet)可被 MIRAS 统一成「不同正则化形式的记忆模块」,方便在统一 benchmark 里做 AB 测试;
  • 非欧目标函数(如 p - 范数、KL 散度)可直接塞进记忆更新,无需重写 CUDA kernel。

六、落地 checklist & 超参卡

参数 搜索区间 经验默认值 调节信号
记忆深度 2–12 层 4 层 层数↑ 召回↑,但延迟线性增长
惊奇阈值 τ 0.1–2.0 0.5 验证集召回率下降 2% 即调低
动量 β 0.5–0.95 0.8 噪音多→β 大,防止误写入
初始遗忘 λ₀ 1e-5–1e-3 1e-4 训练集遗忘过快→λ 减半
更新批大小 16–256 token 64 GPU 利用率 < 60% 时加大
最大上下文 ≤2 M 按显存设定 超过 2 M 需模型并行

部署注意

  1. 记忆权重 θ_mem 需会话级持久化(存 checkpoint),否则跨请求失忆;
  2. 多租户场景下每用户独立 θ_mem,可用 LRU 把冷用户换出 NVMe,热用户常驻显存;
  3. 监控指标:每千 token 更新次数、记忆参数 L2 范数、验证召回 @k,出现「更新爆炸」或「记忆饱和」及时下调学习率或加大 λ。

七、总结

Titans 用「可更新的深度记忆体」把长上下文问题从「算不起」变成「记得住」:

  • 推理阶段只动 1% 参数,稳定可回溯;
  • 惊奇度 + 动量 + 遗忘三件套,保证写得快、写得准、不溢出;
  • MAC 范式已验证 200 万 token 无损召回,可直接替换现有超长文本管线。

下一步,把 MIRAS 的「记忆即正则」思路搬到多模态、Agent 会话持久化,甚至让模型自己决定「要不要长新的记忆器官」,才是真正的「测试时学习」范式爆发点。


参考资料
[1] Google Research Blog, 2025-12: Titans + MIRAS: Helping AI have long-term memory
[2] Ali Behrouz et al., arXiv:2501.00663, Titans: Learning to Memorize at Test Time, 2025

查看归档