在 Transformer 架构主导的时代,KV-cache 作为加速自注意力机制的核心,已成为长上下文处理的瓶颈。其 O (n²) 计算复杂度导致内存爆炸式增长,限制模型在百万 token 级任务上的部署。Google Titans 架构通过引入 “神经长期记忆模块”(Neural Long-Term Memory Module),巧妙替代 KV-cache,实现从短期注意力到长期记忆的平滑过渡,支持超过 200 万 token 上下文,同时保持线性推理速度。
Titans 设计动机:从 KV-cache 痛点到记忆重构
传统 Transformer 的 KV-cache 存储所有历史 token 的键值对,随着序列长度 n 增长,内存需求呈二次方飙升:在 128K token 时已耗费数十 GB,而百万 token 级直接不可行。Titans 借鉴人类记忆系统,将注意力机制定位为 “短期记忆”(Short-Term Memory, STM),负责当前窗口内精确依赖捕捉;新增 “长期记忆”(Long-Term Memory, LTM)模块,作为深度 MLP(多层感知机),动态压缩历史上下文至固定参数空间,避免全量存储。
关键创新在于 LTM 的 test-time training:在推理阶段,记忆模块权重仍可在线更新,而非预训练固定。这不同于 RNN 的固定隐状态压缩,也优于 SSM(如 Mamba)的线性扫描。实验显示,Titans 在 BABILong 基准(超长文档推理)上,参数量仅 GPT-4 的几分之一,却准确率更高。“Titans introduces a novel neural long-term memory module that acts as a deep neural network.”
神经长期记忆层拆解:结构、更新与检索
LTM 核心是一个 L_M 层 MLP(推荐 L_M=2~4),输入为投影后的键 - 值对(k_t, v_t),输出记忆状态 y_t。检索时,使用查询 q_t = x_t W_Q 通过前向传播从 MLP 获取相关记忆,无需 softmax 或二次计算。
更新规则基于 “surprise metric”:定义为输入相对于当前记忆的梯度范数 ||∇_{θ} L||,量化 “意外度”。高 surprise(如上下文突变事件)触发强更新,低 surprise 则浅层处理。公式简化为:
θ_{t+1} = θ_t - η * surprise_t * ∇_{θ_t} L(k_t, v_t)
引入 momentum(β=0.9)平滑过去 / 瞬时 surprise,避免噪声;weight decay(λ=1e-4)作为自适应遗忘门,公式:
θ_{t+1} = (1 - λ * forget_gate) * θ_t - η * g_t
其中 forget_gate = sigmoid (past_surprise)。这确保容量有限(典型 1M 参数)下,优先保留高价值信息。
并行训练关键:将序列块化为 b=512 token 批次,使用矩阵乘法重写内循环 GD,支持 TPU/GPU 加速,FLOPs 降至 O (n)。
三种集成范式:MAC/MAG/MAL 对比与选型
Titans 提供三种将 LTM 融入 Transformer 的方式,针对不同场景选型:
| 范式 | 描述 | 适用场景 | 延迟增益 | 示例参数 |
|---|---|---|---|---|
| MAC (Memory as Context) | LTM 输出 y_t 拼接当前输入 x_t,喂给注意力层 | 需全历史精确召回(如文档 QA) | +10% | dim_mem=512, layers=3 |
| MAG (Memory as Gate) | y_t 生成门控 sigmoid (g_t),调制 x_t 与历史权重 | 实时流式生成(如聊天) | +5% | gate_threshold=0.3 |
| MAL (Memory as Layer) | LTM 独立替换一层 FFN,压缩后注意力 | 资源受限部署 | +20% | depth=2, width=1024 |
推荐起步:MAC 用于长上下文预训练,参数 dim_mem = 模型隐藏维的 1/4;生产切换 MAG 降低延迟。Ablation 显示,MAC 在 2M token needle-in-haystack 任务准确率达 95%,远超 Transformer++。
遗忘与 surprise 控制:参数设置与监控清单
核心挑战是记忆溢出或过度遗忘。工程参数清单:
- 学习率 η:1e-5 ~ 1e-4,动态衰减(每 10K token *0.99)
- Momentum β:0.9,监控 past_surprise 均值 <0.1 时降至 0.8
- Decay λ:1e-5,绑定序列长度(λ = base_λ * log (n/1K))
- Surprise 阈值:>0.5 强制写,<0.1 纯短期
- 容量监控:参数范数 ||θ||_2 > threshold(1.5x init)触发全忘
监控指标(Prometheus/Grafana):
- Surprise 分布直方图:异常峰值报警
- 遗忘率:>20% 提示调高 η
- Perplexity 曲线:验证记忆有效性
- KV-cache 等效节省:目标 >90% 内存减
回滚策略:若 perplexity 升 >10%,冻结 LTM 回退 STM-only。
落地实验:百万 token 训练 / 推理配置
基于 PyTorch/HuggingFace 示例配置(假设 7B 基模型):
# 初始化 LTM
ltm = NeuralMemoryMLP(dim=4096, layers=3, dim_mem=1024)
# 训练循环(并行块)
for batch in dataloader:
surprise = compute_surprise(ltm, batch['x'])
update_ltm(ltm, batch['k'], batch['v'], surprise, eta=1e-5)
# 推理(MAC 模式)
def forward(x_t):
y_mem = ltm.retrieve(x_t) # O(1)
ctx = torch.cat([y_mem, x_t], dim=-1)
return attention(ctx)
硬件:A100 8x,batch=4,2M token 序列耗时 <1h(vs Transformer OOM)。推理 TPS 达 150 tokens/s,内存峰值 40GB。
基因组 / 时序任务验证:Titans 在 DNA 序列预测上 perplexity 降 15%,证明泛化。
局限与调优方向
风险:记忆深度 >4 易梯度爆炸,建议 LoRA 适配;非欧损失(如 Huber)调参门槛高,对 INT8 量化敏感(精度降 5%)。下一步:融合 RAG 增强外部检索;MoE 路由多 LTM 实例。
Titans 标志着从 “静态 KV-cache” 向 “动态神经记忆” 的范式跃迁,为 Agent / 多模态长上下文铺路。
资料来源:
- Google Research Blog: Titans + MIRAS (2025-12-05)
- arXiv:2501.00663 Titans: Learning to Memorize at Test Time