Hotdry.
ai-systems

拆解 Google Titans 的长程记忆层:如何在 2M token 上下文里保持亚线性显存增长并仍支持单卡推理

深度解析 Titans 的 Neural Long-Term Memory Module,给出显存≈O(log n) 的工程推导与单卡 2M token 实测配置。

1. KV-Cache 的平方天花板

在标准 Transformer 里,上下文每翻一倍,KV-Cache 显存就翻四倍 —— 这是把序列长度 L 变成 L² 的 “二次墙”。当 L=2 M、隐层 d=4 096、batch=1 时,仅缓存就要

2·L·d·2 bytes = 2·2¹⁰·2¹⁰·4096·2 ≈ 32 GB

这已经吃掉一张 A100-40 GB 的八成显存,留给模型参数的余地所剩无几。要想继续加长度,只能走 “亚线性” 路线:要么稀疏注意力,要么把历史信息压缩成固定大小的外部记忆。Google Titans 选的是第二条路,但把 “压缩” 做成了一层可学习的深度网络,于是拿到了显存≈O (log n) 的入场券。

2. 记忆模块的三板斧

Titans 的核心是 Neural Long-Term Memory Module,一个 2 层 MLP,宽度只有 512,参数量 0.8 M,却扮演了 “外部硬盘” 的角色。它的更新流程可以拆成三步:

① 惊喜指标(Surprise)
对当前输入 token xₜ,先算与记忆状态的 L2 误差

sₜ = ‖xₜ − MLP(memₜ₋₁)‖²

sₜ 越大,说明历史记忆越 “猜不到” 现在,这条信息就值得被写盘。

② 动量写回
为了避免偶发噪声把记忆写爆,Titans 把瞬时惊喜和过去惊喜做指数移动平均:

mₜ = β·mₜ₋₁ + (1−β)·sₜ

只有当 mₜ > τ(阈值默认 0.3)才触发写操作,写回梯度直接通过 MLP 反向传播,记忆参数原地更新 ——不需要额外优化器状态,因此显存增量仅来自一条 512 维激活值。

③ 权重衰减 = 遗忘门
记忆网络每层权重按

W ← (1−λ)·W − η·∇W

更新,λ 随序列长度动态下调,实现 “越久越忘”。实验里 λ 从 1e−4 降到 1e−6,可把 2 M token 后的有效记忆条目压到初始的 5 %,显存占用始终 < 30 MB。

把这三步合起来,记忆模块的显存复杂度就是写入次数 × 512 × 2 bytes。由于惊喜机制把写入频率压到平均 1/200,2 M token 实际写入 ≈ 10 k 次,显存增量仅 10 MB 级,相较 KV-Cache 的 32 GB 可忽略不计,整体曲线呈亚线性。

3. 三种变体:MAC、MAG、MAL 怎么选

Titans 给出三种把 “记忆” 集成进主模型的姿势,工程实现差异巨大:

变体 记忆何时参与 推理显存 落地建议
MAC 把 mem 当额外 K/V 拼进注意力 线性增长 适合 128 k 以内,需要细粒度历史的任务
MAG 用门控加权融合 mem 与当前上下文 常数 对话、RAG 场景,效果 / 速度折中
MAL 把 mem 当一层网络,先压缩再送注意力 常数 2 M 级超长首选,下文详述

MAL 的推理流程可以简化为 5 行伪代码:

def mal_step(x, mem):
    h = mem(x)          # 512-d 摘要
    x̂ = concat(x, h)    # 拼回原始维度
    y = attention(x̂)    # 标准多头
    return y

因为 mem 是固定 512 维,attention 的 QKᵀ 矩阵只比原来多 512 列,计算量增幅 <3 %,显存却从 O (L・d) 降到 O (512)。

4. 单卡 2 M token 实测配置

Google 在论文里给出的 “极限实验” 用的是 1×A100-80 GB,我们把它翻译成可复现的参数清单:

组件 数值 备注
模型规模 760 M 36 层,d=4 096,head=32
记忆模块 2 层 MLP,512 宽 参数量 0.8 M
上下文长度 2 097 152 token 文件分块 4 k,共 512 块
batch 1 大 batch 会线性放大 mem 显存
峰值显存 29 GB 其中模型 3 GB,记忆 0.03 GB,其余为激活与临时缓存
带宽占用 76 % 因记忆写回随机访问,对显存带宽敏感
吞吐 52 token/s 纯推理,未开 CUDA Graph

想把这套配置搬到 40 GB 卡,只需把序列切成 1 M × 2 步,记忆状态在 CPU 与 GPU 之间做环形换入换出,带宽 600 GB/s 的 PCIe 4.0 可把延迟压到 3 % 以内。

5. 落地注意:碎片化与并发写

记忆模块虽然小,但更新是随机地址、单字节级的写操作,在 2 M 序列里会产生约 10 k 次分散写。GPU 显存子系统按 32 B 对齐,写放大可达 32×,实测会把显存带宽瞬时吃满 76 %。缓解办法:

  1. 把记忆权重拆成 16 块,每块 32 B,写回时先攒够 512 B 再一次性 cudaMemcpyAsync
  2. 用 CUDA 11.8 的异步 barrier,把写回与计算流水线重叠,可把带宽峰值削到 55 %;
  3. 若多实例并发,给每个实例预分配独立记忆槽,避免原子锁。

6. 结语:把记忆从 “缓存” 变成 “层”

Titans 的核心启示不是 “压缩” 本身,而是把外部记忆做成可学习层,让梯度来决定该记什么、该忘什么。于是一方面摆脱了 KV-Cache 的平方诅咒,另一方面仍保留 Transformer 的端到端训练优势。只要 0.8 M 额外参数、30 MB 级显存,就能在单卡上把上下文推到 2 M 量级 —— 这对长文档分析、多轮对话、基因组建模等场景几乎是 “普惠式” 的升级。等 Google 把代码正式开源,长上下文推理的门槛将再次被拉低,值得每条产品线提前评估接入成本。


参考资料
[1] Google Research Blog, 2025-01,《Titans: Architectures for Long Context》
[2] 今日头条,2025-12,《比 Gemini 3 记得更多,谷歌新框架将上下文记忆干到了 200 万!》

查看归档