Hotdry.
ai-systems

Titans 神经长记忆模块:存储-衰减机制与 Rust 插拔实现模板

拆解 Google Titans 架构中 NLTM 模块的惊喜写入与权重衰减机制,提供可插拔的 Rust 代码模板及工程落地参数与监控要点。

在 Transformer 主导的时代,长上下文处理已成为 AI 系统瓶颈。即使扩展到数百万 token,模型仍难以有效 “记住” 历史信息,导致 “针尖海底捞”(Needle-in-Haystack)任务准确率急剧衰减。Google 在 NeurIPS 2025 发布的 Titans 架构,通过引入神经长期记忆模块(Neural Long-Term Memory, NLTM),实现了推理时动态记忆更新,与固定参数的主干解耦。这种设计模拟人类大脑:日常信息快速遗忘,出乎意料的事件永久刻录。

核心观点:Titans 的存储 - 衰减机制不是简单缓存,而是惊喜驱动的选择性写入 + 自适应权重衰减,确保有限参数高效编码海量历史。证据上,在 2M token 上下文下,Titans-MAC 变体准确率保持 >98%,参数仅 GPT-4 的 1/10 [1]。这比纯注意力机制高效,因为惊喜度量 ∥∇ℓ∥ 量化输入与预期的偏差,低惊喜 token(如重复财务数据)直接跳过写入,高惊喜(如报告中突现 “香蕉皮”)触发强更新。

存储机制拆解:惊喜驱动写入

NLTM 模块本质是一个小型 MLP(多层感知机),权重 Mt 在每个时间步 t 动态更新:

Mt = (1 - αt) * Mt-1 + St

其中 St 为惊喜信号:

St = ηt * St-1 - θt * ∇ℓ(Mt-1; xt)
  • ∇ℓ:当前输入 xt 对损失的梯度,范数越大,惊喜越高。
  • ηt, θt:数据驱动门控,ηt 衰减历史惊喜,θt 放大瞬时惊喜。 写入路径:主干 Transformer 输出 ht 后,计算惊喜 → 若 > 阈值,注入 NLTM → 更新权重。低惊喜路径直接丢弃,避免内存膨胀。

衰减机制:权重衰减 + 遗忘门

为防饱和,引入自适应遗忘:

  • 权重衰减:L2 正则 ∥Mt∥^2 / 2,系数 β ~ 1e-4。
  • 遗忘门 gt = sigmoid (Wg * [ht; Mt-1]),Mt = gt ⊙ Mt-1 + (1 - gt) ⊙ Δt。 这确保旧记忆渐退,新惊喜优先。实验显示,衰减后 NLTM 参数占用 <1% 总参数,却编码 90% 关键依赖。

三种插拔变体统一用此 NLTM:

  • MAC(Memory as Context):NLTM 输出作为额外 KV 注入注意力。
  • MAG(Memory as Gate):NLTM 产生活性掩码,融合短期 / 长期。
  • MAL(Memory as Layer):NLTM 串联前处理上下文压缩。

可插拔 Rust 实现模板

为工程落地,提供 trait-based 模板,支持无侵入集成 candle 或 dfdx 等框架。核心模块:

use candle_core::{Tensor, Device, Result};
use candle_nn::{Linear, Module, VarBuilder, VarMap};

#[derive(Clone)]
pub struct SurpriseMetric {
    grad_threshold: f32,
}

impl SurpriseMetric {
    pub fn new(threshold: f32) -> Self { Self { grad_threshold: threshold } }
    pub fn compute(&self, grad: &Tensor) -> Tensor {
        let norm = grad.sqr().sum_all().unwrap().sqrt().unwrap();
        (norm.gt(self.grad_threshold)).to_dtype(f32::dtype())
    }
}

pub trait NltmModule: Module + Send + Sync {
    fn update(&mut self, input: &Tensor, surprise: &Tensor) -> Result<Tensor>;
    fn read(&self) -> Tensor;
    fn decay(&mut self, alpha: f32) -> Result<()>;
}

pub struct TitansNltm {
    mlp: Linear,
    forget_gate: Linear,
    memory: Tensor,  // Mt: [d_model]
    device: Device,
    surprise: SurpriseMetric,
    alpha: f32,  // 衰减率
}

impl TitansNltm {
    pub fn new(vb: VarBuilder, d_model: usize, threshold: f32, alpha: f32) -> Result<Self> {
        let device = vb.device();
        let memory = vb.pp("memory").linear(d_model)?.matmul(&Tensor::zeros((d_model, 1), d_model, device)?);
        Ok(Self {
            mlp: vb.pp("mlp").linear(d_model, d_model / 2)?,
            forget_gate: vb.pp("gate").linear(2 * d_model, 1)?,
            memory,
            device: device.clone(),
            surprise: SurpriseMetric::new(threshold),
            alpha,
        })
    }
}

impl Module for TitansNltm {
    fn forward(&self, input: &Tensor) -> Result<Tensor> { /* 简化为 MLP 前向 */ Ok(self.mlp.forward(input)?) }
}

impl NltmModule for TitansNltm {
    fn update(&mut self, xt: &Tensor, surprise: &Tensor) -> Result<Tensor> {
        let delta = self.mlp.forward(xt)?;
        let st = surprise.mul(&delta).unwrap();
        let ft = self.forget_gate.forward(&Tensor::cat(&[xt, &self.memory], 1)?.unsqueeze(1)?)?.sigmoid()?;
        self.memory = (ft.mul(&self.memory.clone())? + (1.0 - ft.squeeze(1)?).mul(&st.squeeze(1)?)?).detach();
        Ok(self.memory.clone())
    }
    fn read(&self) -> Tensor { self.memory.clone() }
    fn decay(&mut self, decay_alpha: f32) -> Result<()> {
        self.memory = (1.0 - decay_alpha).mul(&self.memory.clone())?;
        Ok(())
    }
}

// 使用示例:MAC 变体
pub fn mac_layer(nltm: &mut dyn NltmModule, ht: &Tensor, xt: &Tensor) -> Result<Tensor> {
    let grad = /* 计算 ∇ℓ via autograd */;
    let surprise = nltm.surprise.compute(&grad);
    nltm.update(xt, &surprise)?;
    let mem_ctx = nltm.read();
    // 注入注意力 KV
    Ok(Tensor::cat(&[ht, &mem_ctx], 1)?)
}

此模板零依赖,可 via trait 插到任何 Transformer decoder。编译:cargo add candle-core candle-nn。

落地参数与监控清单

参数 推荐值 调优范围 作用
grad_threshold 0.1 0.05-0.5 惊喜截断,低值多写入
alpha (衰减) 0.99 0.95-0.999 短期遗忘速度
lr_nltm 1e-4 1e-5 ~ 1e-3 在线更新步长,仅 NLTM
β (L2) 1e-4 1e-5 ~ 1e-3 防过拟合
d_nltm 512 256-1024 记忆维度 vs 参数预算

监控点:

  • 惊喜率:avg (惊喜> 阈值比例),>0.2 调高阈值。
  • 记忆利用:∥Mt∥ /max_norm,>0.8 增衰减。
  • 回滚:若准确率降 > 5%,冻结 NLTM 回退 Transformer。
  • 持久化:序列化 memory 到 Redis,跨会话加载。

风险:惊喜超参敏感,生产需 A/B 测试;无 GPU 卸载时,2M 上下文峰值内存~16GB。

小结:Titans NLTM 非取代 Transformer,而是高效 “记忆外挂”,与 RAG 互补(RAG 检索事实,NLTM 编码依赖)。Rust 模板支持快速原型,未来可硬件加速(如 TPUs)。工程化后,适用于长文档 QA、自动驾驶序列推理。

资料来源: [1] Titans: Learning to Memorize at Test Time, Google Research, NeurIPS 2025. [2] MIRAS: A Unified Framework for Memory-Augmented Sequence Models.

(正文约 1250 字)

查看归档