Hotdry.

Article

CODA:将 Transformer 块重写为 GEMM-Epilogue 程序以消除显存瓶颈

通过将 Transformer 块重新参数化为 GEMM-Epilogue 程序,CODA 将内存受限的归一化、激活等操作融合到矩阵乘法的尾声中,显著减少显存往返并提升推理效率。

2026-05-22ai-systems

在 LLM 训练和推理的优化实践中,一个长期被忽视但日益凸显的瓶颈是:非 GEMM 操作(normalization、activation、residual update 等)虽然计算量小,却因频繁的显存访问而成为性能拖累。随着 FP8、FP4 等低精度格式加速矩阵乘法,GEMM 本身的成本持续下降,但中间张量的物化(materialization)成本并未同步改善。

CODA(Rewriting Transformer Blocks as GEMM-Epilogue Programs)提出了一种系统性的解决思路:将 Transformer 块重写为 GEMM-Epilogue 程序,把原本独立的内存受限算子融合进矩阵乘法的尾声阶段,在数据仍驻留于片上缓存时完成计算,避免额外的全局显存往返。

核心观察:Epilogue 是天然的融合点

高性能 GEMM 内核通常分为两个阶段:

  • Mainloop:执行分块的矩阵乘累加计算,高度优化以压榨 Tensor Core 算力
  • Epilogue:在输出分片(tile)仍驻留寄存器 / 共享内存时,对其进行后处理(scaling、bias addition、activation 等),然后写回全局显存

CODA 的关键洞察在于:许多 Transformer 操作可以通过代数重参数化,表达为 tile-local 的 epilogue 计算。当 GEMM 输出 tile 仍在芯片上时,直接在其上执行 residual addition、RMSNorm scaling、SwiGLU activation 等操作,无需将中间结果写回显存再读入。

五类 Epilogue 原语

CODA 保持 GEMM mainloop 固定,仅通过可组合的 epilogue 原语实现灵活表达:

  1. 逐元素与成对映射:residual update、activation、RoPE 旋转、SwiGLU gating
  2. 向量(秩 - 1 张量)加载 / 存储:加载行 / 列向量(如 RMSNorm weight)并在 tile 上广播
  3. Tile(秩 - 2 张量)加载 / 存储:读写 residual stream、保存激活值供反向传播使用
  4. Tile 归约:计算 tile 内的部分行 / 列归约结果,由轻量级辅助内核汇总
  5. 状态化转换:维护 tile 级状态,如 online log-sum-exp 和 cross-entropy 统计

这些原语刻意保持受限(仅操作 tile-local 数据),既确保编译生成高效代码,又足以覆盖标准 Transformer 块的前向与反向计算。

GEMM-Residual-RMSNorm-GEMM 重参数化

以预归一化 Transformer 中常见的 "GEMM → residual → RMSNorm → GEMM" 链为例:

y = RMSNorm(x @ W0 + z, γ) @ W1

传统实现中,RMSNorm 需要计算行级逆 RMS 因子 r,这引入了一个跨隐藏维度的归约,似乎必须在两个 GEMM 之间插入独立的归一化内核。

CODA 通过代数变换打破这一依赖:

y = r * ((x @ W0 + z) ⊙ γ) @ W1

由于 r 是行级共享的标量,它与矩阵乘法可交换。因此:

  1. 第一个 GEMM 的 epilogue:计算残差相加、RMSNorm weight 缩放,并输出 tile-local 的部分平方和(用于后续计算 r)
  2. 轻量级辅助归约:汇总各 tile 的部分结果得到行级 r
  3. 第二个 GEMM 的 epilogue:将 r 应用于输出 tile

原本独立的 RMSNorm 内核被消解,取而代之的是两个 GEMM epilogue 中的 tile-local 操作,加上一个读取少量部分值的轻量归约。

成对激活的寄存器级融合

对于 SwiGLU 这类成对激活(将相邻特征值配对处理),CODA 利用 Hopper Tensor Core 的累加器布局特性:每个线程在写回前持有一小簇相邻输出值。通过将配对特征安排在输出维度的相邻位置,epilogue 可直接在寄存器级应用激活函数,无需物化配对的中间张量。

这一技巧同样适用于:

  • RoPE:维度保持的旋转操作
  • SwiGLU:维度缩减的门控激活
  • 反向传播:维度扩展的梯度计算

反向传播的 GEMM-Epilogue 结构

CODA 证明:若前向传播是 GEMM-epilogue 序列,则反向传播保持相同结构。关键观察是 tile-local 前向 epilogue 诱导出 tile-local 反向 epilogue,而周围的线性映射仍是 GEMM。

对于 RMSNorm 反向传播中需要的行级统计量,CODA 通过代数恒等式将其计算位置迁移到 GEMM 边界:

s = sum_cols(∇h2 ⊙ h2) = sum_cols(∇y ⊙ y)

这使得行级统计量可以在当前 GEMM 的 epilogue 中累积,供前一个 RMSNorm 反向传播使用,避免独立的 RMSNorm backward 内核读取完整激活张量。

实践要点

适用场景

  • Profiler 显示大量小型 post-GEMM 内核或层间显存流量过高
  • 使用 FP8/FP4 等低精度格式时,GEMM 已非瓶颈,内存移动占比上升
  • 需要极致单卡吞吐的推理场景

实施考量

  • CODA 当前专注于单 GPU 内核,分布式执行需额外工作
  • 重参数化会模糊模块边界,与框架级抽象(如 PyTorch module)的集成需要权衡
  • 代码基于 CuTeDSL 实现,需熟悉 CUTLASS 的编程模型

性能预期

在 LLaMA-3 风格的 1B/7B/70B 模型配置下,CODA 的 kernel-level 加速显著,尤其在隐藏维度较大的层。完整的 block-level 基准测试(包含辅助归约和胶水操作)显示,重参数化的 Transformer 层在前后向传播中均能获得实质性吞吐提升。

与现有优化的关系

CODA 并非替代其他 LLM 优化技术,而是互补:

  • FlashAttention:降低 attention 计算和显存开销
  • KV-cache 优化:改善 decode 阶段的重用效率
  • 量化:压缩权重和激活的位宽
  • CODA:使 Transformer 块的线性代数流水线本身更高效

四者结合,可从不同层面共同压低 LLM 推理的延迟和成本。


参考来源

ai-systems

内容声明:本文无广告投放、无付费植入。

如有事实性问题,欢迎发送勘误至 i@hotdrydog.com