在 Transformer 架构的深度堆叠过程中,残差连接(Residual Connection)扮演着至关重要的角色。然而,标准残差连接采用固定单位权重进行加法累积的做法,在层数急剧增长时暴露出明显的局限性。MoonshotAI 于近期开源的 Attention Residuals(以下简称 AttnRes)技术,提供了一种基于注意力机制的选择性残差聚合方案,为这一经典问题给出了系统性的解答。
传统残差连接的困境
标准残差连接的核心表达式为 h_{l+1} = h_l + F(h_l),其中每一层的输出都以固定权重 1 累加到后续层。这种设计的初衷是缓解深层网络的梯度消失问题,但在实际训练中会产生两个显著副作用。其一是表示稀释效应(Representation Dilution):随着层数加深,每个早期层的贡献被后续层层累积的加法运算所稀释,信息在传播过程中不断被衰减和覆盖。其二是 PreNorm 架构下的幅度膨胀:由于每一层都无条件地添加自身的输出,导致隐藏状态的幅度随深度线性增长,在极端情况下可能引发数值不稳定。
这些问题在大规模语言模型训练中尤为突出。当模型深度达到数十甚至上百层时,早期层学习到的关键表示往往难以有效传递到深层,限制了模型对长距离依赖关系的建模能力。
Attention Residuals 的核心机制
AttnRes 的创新之处在于用 softmax 注意力机制替代了固定权重的加法累积。每一层不再简单地将其输出追加到累积和中,而是通过一个可学习的伪查询向量(pseudo-query)来计算对所有前期表示的注意力权重。其数学表达式为:
$$\mathbf{h}l = \sum{i=0}^{l-1} \alpha_{i \to l} \cdot \mathbf{v}_i$$
其中权重 $\alpha_{i \to l}$ 由层特定的伪查询 $\mathbf {w}_l \in \mathbb {R}^d$ 与前期表示的键向量通过点积计算并经 softmax 归一化得到。这意味着每一层都能够动态地决定应该更多地关注早期哪一层的输出,实现了输入依赖的内容感知型聚合。与标准残差的均匀累积相比,AttnRes 让深层能够选择性回溯早期的重要表示,有效缓解了信息稀释问题。
Block AttnRes:工程落地的务实选择
原始的完整 AttnRes 在理论上优雅,但在实际大模型训练中面临显存瓶颈。直接对所有前期层输出计算注意力需要 O (Ld) 的显存开销,其中 L 为总层数,d 为隐藏维度,这在数百层模型中是不可接受的。
Block AttnRes 提出了分块聚合策略:将 transformer 层划分为 N 个块(block),块内仍使用标准残差连接进行累积,只在块边界处应用跨块的注意力聚合。当配置约 8 个块时,Block AttnRes 能够在仅引入边际开销的前提下,恢复完整 AttnRes 绝大部分的性能收益。这种设计使得该技术可以作为即插即用的替代方案,融入到现有的大模型训练流程中。
从工程实现角度来看,Block AttnRes 的关键超参数是块大小(block_size),建议值为 8 到 12 层一个块。在块内累积时,需要维护一个 partial_block 张量来记录当前块内的残差累积结果;当到达块边界时,将该张量压入块列表,并开启新块的累积。跨块注意力通过一个共享的线性投影和 RMSNorm 完成,伪查询权重仅与块数量相关,而非总层数,从而将显存复杂度从 O (Ld) 降低到 O (Nd)。
性能收益与实验数据
根据 MoonshotAI 在 Kimi Linear 48B 模型(激活参数 3B,训练 1.4T tokens)上的实验结果,AttnRes 在多个基准测试上实现了稳定提升。在通用理解任务上,MMLU 从 73.5 提升至 74.6,GPQA-Diamond 从 36.9 大幅提升至 44.4(增幅达 7.5 分),BBH 从 76.3 提升至 78.0。在数学与代码任务上,Math 从 53.5 提升至 57.1,HumanEval 从 59.1 提升至 62.2,MBPP 从 72.0 提升至 73.9。中文基准同样受益,CMMLU 从 82.0 提升至 82.9,C-Eval 从 79.6 提升至 82.5。
更值得关注的是 scaling law 实验结论:Block AttnRes 的 loss 曲线能够匹配基线模型在 1.25 倍计算预算下的表现。这意味着在相同的训练 compute 下,AttnRes 能够训练出更强的模型;或者说,要达到相同的模型质量,AttnRes 可以节省约 20% 的计算资源。
在训练动态层面,AttnRes 有效缓解了 PreNorm 稀释问题。实验数据显示,应用 AttnRes 后,输出幅度在整个深度范围内保持有界,不再随层数线性增长;梯度范数在各层之间的分布也更加均匀,消除了深层梯度消失或爆炸的隐患。
工程实践建议
对于希望在自有模型中尝试 AttnRes 的团队,以下参数可作为起点。块数量建议配置为 8 左右,可根据总层数灵活调整(总层数除以 8 得到每块层数)。伪查询投影使用单层线性变换,输出维度与隐藏维度一致。归一化方式推荐采用 RMSNorm 以保持与现代大模型架构的兼容性。集成方式为在每个 transformer 块的注意力层和 MLP 层之前分别插入一次 Block AttnRes 计算。
需要注意的是,AttnRes 引入了额外的跨块注意力计算,虽然 overhead 相对可控,但在极端规模(如数百层)时仍需监控其对训练速度的影响。此外,伪查询的初始化策略可能需要针对不同模型架构进行微调,建议从标准正态初始化开始,并根据收敛情况进行自适应调整。
资料来源
本文核心事实来源于 MoonshotAI 官方开源仓库 Attention-Residuals,该仓库同时提供了论文、PyTorch 伪代码实现示例及完整的下游任务评测结果。