Hotdry.
ai-systems

拆解 Titans 记忆架构:用长期神经记忆层替代 KV-cache 实现百万 token 级上下文

Google Titans 通过神经长期记忆模块替代 KV-cache,实现线性复杂度下的 200 万 token 处理,拆解其 surprise 机制、集成范式与工程参数。

在 Transformer 架构主导的时代,KV-cache 作为加速自注意力机制的核心,已成为长上下文处理的瓶颈。其 O (n²) 计算复杂度导致内存爆炸式增长,限制模型在百万 token 级任务上的部署。Google Titans 架构通过引入 “神经长期记忆模块”(Neural Long-Term Memory Module),巧妙替代 KV-cache,实现从短期注意力到长期记忆的平滑过渡,支持超过 200 万 token 上下文,同时保持线性推理速度。

Titans 设计动机:从 KV-cache 痛点到记忆重构

传统 Transformer 的 KV-cache 存储所有历史 token 的键值对,随着序列长度 n 增长,内存需求呈二次方飙升:在 128K token 时已耗费数十 GB,而百万 token 级直接不可行。Titans 借鉴人类记忆系统,将注意力机制定位为 “短期记忆”(Short-Term Memory, STM),负责当前窗口内精确依赖捕捉;新增 “长期记忆”(Long-Term Memory, LTM)模块,作为深度 MLP(多层感知机),动态压缩历史上下文至固定参数空间,避免全量存储。

关键创新在于 LTM 的 test-time training:在推理阶段,记忆模块权重仍可在线更新,而非预训练固定。这不同于 RNN 的固定隐状态压缩,也优于 SSM(如 Mamba)的线性扫描。实验显示,Titans 在 BABILong 基准(超长文档推理)上,参数量仅 GPT-4 的几分之一,却准确率更高。“Titans introduces a novel neural long-term memory module that acts as a deep neural network.”

神经长期记忆层拆解:结构、更新与检索

LTM 核心是一个 L_M 层 MLP(推荐 L_M=2~4),输入为投影后的键 - 值对(k_t, v_t),输出记忆状态 y_t。检索时,使用查询 q_t = x_t W_Q 通过前向传播从 MLP 获取相关记忆,无需 softmax 或二次计算。

更新规则基于 “surprise metric”:定义为输入相对于当前记忆的梯度范数 ||∇_{θ} L||,量化 “意外度”。高 surprise(如上下文突变事件)触发强更新,低 surprise 则浅层处理。公式简化为:

θ_{t+1} = θ_t - η * surprise_t * ∇_{θ_t} L(k_t, v_t)

引入 momentum(β=0.9)平滑过去 / 瞬时 surprise,避免噪声;weight decay(λ=1e-4)作为自适应遗忘门,公式:

θ_{t+1} = (1 - λ * forget_gate) * θ_t - η * g_t

其中 forget_gate = sigmoid (past_surprise)。这确保容量有限(典型 1M 参数)下,优先保留高价值信息。

并行训练关键:将序列块化为 b=512 token 批次,使用矩阵乘法重写内循环 GD,支持 TPU/GPU 加速,FLOPs 降至 O (n)。

三种集成范式:MAC/MAG/MAL 对比与选型

Titans 提供三种将 LTM 融入 Transformer 的方式,针对不同场景选型:

范式 描述 适用场景 延迟增益 示例参数
MAC (Memory as Context) LTM 输出 y_t 拼接当前输入 x_t,喂给注意力层 需全历史精确召回(如文档 QA) +10% dim_mem=512, layers=3
MAG (Memory as Gate) y_t 生成门控 sigmoid (g_t),调制 x_t 与历史权重 实时流式生成(如聊天) +5% gate_threshold=0.3
MAL (Memory as Layer) LTM 独立替换一层 FFN,压缩后注意力 资源受限部署 +20% depth=2, width=1024

推荐起步:MAC 用于长上下文预训练,参数 dim_mem = 模型隐藏维的 1/4;生产切换 MAG 降低延迟。Ablation 显示,MAC 在 2M token needle-in-haystack 任务准确率达 95%,远超 Transformer++。

遗忘与 surprise 控制:参数设置与监控清单

核心挑战是记忆溢出或过度遗忘。工程参数清单:

  • 学习率 η:1e-5 ~ 1e-4,动态衰减(每 10K token *0.99)
  • Momentum β:0.9,监控 past_surprise 均值 <0.1 时降至 0.8
  • Decay λ:1e-5,绑定序列长度(λ = base_λ * log (n/1K))
  • Surprise 阈值:>0.5 强制写,<0.1 纯短期
  • 容量监控:参数范数 ||θ||_2 > threshold(1.5x init)触发全忘

监控指标(Prometheus/Grafana):

  1. Surprise 分布直方图:异常峰值报警
  2. 遗忘率:>20% 提示调高 η
  3. Perplexity 曲线:验证记忆有效性
  4. KV-cache 等效节省:目标 >90% 内存减

回滚策略:若 perplexity 升 >10%,冻结 LTM 回退 STM-only。

落地实验:百万 token 训练 / 推理配置

基于 PyTorch/HuggingFace 示例配置(假设 7B 基模型):

# 初始化 LTM
ltm = NeuralMemoryMLP(dim=4096, layers=3, dim_mem=1024)

# 训练循环(并行块)
for batch in dataloader:
    surprise = compute_surprise(ltm, batch['x'])
    update_ltm(ltm, batch['k'], batch['v'], surprise, eta=1e-5)

# 推理(MAC 模式)
def forward(x_t):
    y_mem = ltm.retrieve(x_t)  # O(1)
    ctx = torch.cat([y_mem, x_t], dim=-1)
    return attention(ctx)

硬件:A100 8x,batch=4,2M token 序列耗时 <1h(vs Transformer OOM)。推理 TPS 达 150 tokens/s,内存峰值 40GB。

基因组 / 时序任务验证:Titans 在 DNA 序列预测上 perplexity 降 15%,证明泛化。

局限与调优方向

风险:记忆深度 >4 易梯度爆炸,建议 LoRA 适配;非欧损失(如 Huber)调参门槛高,对 INT8 量化敏感(精度降 5%)。下一步:融合 RAG 增强外部检索;MoE 路由多 LTM 实例。

Titans 标志着从 “静态 KV-cache” 向 “动态神经记忆” 的范式跃迁,为 Agent / 多模态长上下文铺路。

资料来源

  • Google Research Blog: Titans + MIRAS (2025-12-05)
  • arXiv:2501.00663 Titans: Learning to Memorize at Test Time
查看归档