Hotdry.
ai-systems

权重稀疏 Transformer 实现可解释神经电路:剪枝与桥接工程参数

通过权重稀疏训练揭示 Transformer 内部电路,提供 L0 正则、mean ablation 验证及桥接对齐的具体参数与监控清单。

Transformer 模型的黑箱性质一直是机制可解释性的核心障碍,传统稠密权重导致激活叠加,难以逆向工程具体算法。权重稀疏训练提供了一种从源头解耦计算路径的范式,通过强制 L0 范数极小(典型目标 0.1%–1% 非零权重),模型被迫使用少数残差通道和神经元,形成解耦电路。这些电路不仅规模小,还对应人类可理解的概念,如 “引号检测器” 或 “嵌套深度计数器”,从而实现对模型行为的精确干预。

证据显示,在相同预训练损失下,稀疏模型的最小电路规模约为稠密模型的 1/16。例如,在字符串闭合任务中,电路仅需第 0 层 2 个 MLP 神经元(通道 985 检测任意引号,通道 460 区分单双引号)和第 10 层 1 个注意力头(QK 通道 1、V 通道 0),总计 12 个节点和 9 条边。通过 mean ablation 验证:保留电路节点平均损失不变,移除则损失激增 10 倍以上,证明电路必要且充分。

更复杂任务如括号嵌套深度计数,电路分为嵌入(左括号写入通道 759/826/1711)、计数(第 2 层头 125 平均激活写入通道 1249)和阈值化(第 4 层头 80 使用 sink 注意力阈值通道 1079)。此机制易受长上下文稀释影响,可预测稠密模型在序列 >512 时错误率升 20%。

桥接技术进一步扩展实用性:在每层插入线性 encoder/decoder(维度匹配隐藏大小 d_model=512–2048),以 NMSE loss 对齐稀疏与稠密表示。实验中,桥接后混合路径损失 <5% 退化,通过编辑稀疏电路(如修改引号通道激活),稠密输出概率偏移 15%–30%,实现可控干预。

工程落地参数清单如下:

1. 稀疏训练超参数

  • L0 目标:每矩阵非零比例 0.05%–0.5%(从 0.1% 起步,根据损失曲线调至前沿)。
  • 投影策略:forward/backward 后,top-k 保留绝对值最大权重(k = L0 * 矩阵大小),其余硬零。
  • 激活稀疏:AbsTopK 保留 top-2%–5% 激活,结合 L0 提升特征纯度。
  • 优化器:AdamW,lr=6e-4,warmup 10%,权重衰减 0.1;批次 512,序列长 1024。
  • 规模扩展:固定 L0 增 d_model(512→4096),电路大小降 2–4 倍,能力升 5–10x perplexity 改善。

2. 电路提取与验证

  • 剪枝阈值:迭代移除 <1e-4 贡献节点,直至损失升 2x;目标电路边数 < 总节点 20%。
  • Mean ablation:节点替换预训练分布均值,监控任务损失(电路内 <0.1 增,移除>2 增)。
  • 电路规模监控:目标 < 稠密 1/10;若超标,增稀疏度 20% 或换任务子集微调。
  • 定性检查:手动映射节点语义(e.g., 正激活对应单引号 token),覆盖率 >80%。

3. 桥接实现

  • 线性层:encoder/decoder 初始化 Xavier,冻结稠密权重,仅训桥接(lr=1e-4)。
  • Loss:NMSE = mean ((dense - bridge (sparse))^2 /dense_var) < 0.05;加 KL 散度对齐 logit。
  • 混合比例:50% 稀疏路径起步,渐增至 80%;验证集桥接一致性 >95%。
  • 干预协议:编辑电路节点(e.g., +0.5 激活深度通道),桥接回稠密,观察 logit 偏移 >10%。

监控与回滚策略

  • 指标仪表盘:电路大小、ablation delta loss、桥接 NMSE、预训练 perplexity。
  • 阈值警报:电路 >1/8 稠密 → 暂停增 L0 20%;桥接 NMSE >0.1 → 重训 encoder。
  • 回滚:若效率 < 稠密 100x,切换知识蒸馏(稀疏 teacher,稠密 student);或仅桥接关键层(前 30%)。
  • 部署清单:PyTorch 实现 gate/residual mask;HuggingFace 集成 sparsity hooks;GPU 利用率 >80% 需 sparse kernel(如 FlashAttention)。

风险控制:训练慢 100–1000 倍,主因稀疏矩阵乘法;缓解用自定义 CUDA kernel 或从稠密蒸馏初始化。当前限小规模(~10M 参数),但 scaling 定律显示更大模型电路更解耦。未来聚焦关键行为如安全拒绝电路,实现生产级审计。

此方法不复述新闻,而是提供可复制参数,推动从玩具模型向 GPT-3 规模可解释体演进。

资料来源

(正文字数:1268)

查看归档