在 LLM 训练和推理的优化实践中,一个长期被忽视但日益凸显的瓶颈是:非 GEMM 操作(normalization、activation、residual update 等)虽然计算量小,却因频繁的显存访问而成为性能拖累。随着 FP8、FP4 等低精度格式加速矩阵乘法,GEMM 本身的成本持续下降,但中间张量的物化(materialization)成本并未同步改善。
CODA(Rewriting Transformer Blocks as GEMM-Epilogue Programs)提出了一种系统性的解决思路:将 Transformer 块重写为 GEMM-Epilogue 程序,把原本独立的内存受限算子融合进矩阵乘法的尾声阶段,在数据仍驻留于片上缓存时完成计算,避免额外的全局显存往返。
核心观察:Epilogue 是天然的融合点
高性能 GEMM 内核通常分为两个阶段:
- Mainloop:执行分块的矩阵乘累加计算,高度优化以压榨 Tensor Core 算力
- Epilogue:在输出分片(tile)仍驻留寄存器 / 共享内存时,对其进行后处理(scaling、bias addition、activation 等),然后写回全局显存
CODA 的关键洞察在于:许多 Transformer 操作可以通过代数重参数化,表达为 tile-local 的 epilogue 计算。当 GEMM 输出 tile 仍在芯片上时,直接在其上执行 residual addition、RMSNorm scaling、SwiGLU activation 等操作,无需将中间结果写回显存再读入。
五类 Epilogue 原语
CODA 保持 GEMM mainloop 固定,仅通过可组合的 epilogue 原语实现灵活表达:
- 逐元素与成对映射:residual update、activation、RoPE 旋转、SwiGLU gating
- 向量(秩 - 1 张量)加载 / 存储:加载行 / 列向量(如 RMSNorm weight)并在 tile 上广播
- Tile(秩 - 2 张量)加载 / 存储:读写 residual stream、保存激活值供反向传播使用
- Tile 归约:计算 tile 内的部分行 / 列归约结果,由轻量级辅助内核汇总
- 状态化转换:维护 tile 级状态,如 online log-sum-exp 和 cross-entropy 统计
这些原语刻意保持受限(仅操作 tile-local 数据),既确保编译生成高效代码,又足以覆盖标准 Transformer 块的前向与反向计算。
GEMM-Residual-RMSNorm-GEMM 重参数化
以预归一化 Transformer 中常见的 "GEMM → residual → RMSNorm → GEMM" 链为例:
y = RMSNorm(x @ W0 + z, γ) @ W1
传统实现中,RMSNorm 需要计算行级逆 RMS 因子 r,这引入了一个跨隐藏维度的归约,似乎必须在两个 GEMM 之间插入独立的归一化内核。
CODA 通过代数变换打破这一依赖:
y = r * ((x @ W0 + z) ⊙ γ) @ W1
由于 r 是行级共享的标量,它与矩阵乘法可交换。因此:
- 第一个 GEMM 的 epilogue:计算残差相加、RMSNorm weight 缩放,并输出 tile-local 的部分平方和(用于后续计算 r)
- 轻量级辅助归约:汇总各 tile 的部分结果得到行级 r
- 第二个 GEMM 的 epilogue:将 r 应用于输出 tile
原本独立的 RMSNorm 内核被消解,取而代之的是两个 GEMM epilogue 中的 tile-local 操作,加上一个读取少量部分值的轻量归约。
成对激活的寄存器级融合
对于 SwiGLU 这类成对激活(将相邻特征值配对处理),CODA 利用 Hopper Tensor Core 的累加器布局特性:每个线程在写回前持有一小簇相邻输出值。通过将配对特征安排在输出维度的相邻位置,epilogue 可直接在寄存器级应用激活函数,无需物化配对的中间张量。
这一技巧同样适用于:
- RoPE:维度保持的旋转操作
- SwiGLU:维度缩减的门控激活
- 反向传播:维度扩展的梯度计算
反向传播的 GEMM-Epilogue 结构
CODA 证明:若前向传播是 GEMM-epilogue 序列,则反向传播保持相同结构。关键观察是 tile-local 前向 epilogue 诱导出 tile-local 反向 epilogue,而周围的线性映射仍是 GEMM。
对于 RMSNorm 反向传播中需要的行级统计量,CODA 通过代数恒等式将其计算位置迁移到 GEMM 边界:
s = sum_cols(∇h2 ⊙ h2) = sum_cols(∇y ⊙ y)
这使得行级统计量可以在当前 GEMM 的 epilogue 中累积,供前一个 RMSNorm 反向传播使用,避免独立的 RMSNorm backward 内核读取完整激活张量。
实践要点
适用场景:
- Profiler 显示大量小型 post-GEMM 内核或层间显存流量过高
- 使用 FP8/FP4 等低精度格式时,GEMM 已非瓶颈,内存移动占比上升
- 需要极致单卡吞吐的推理场景
实施考量:
- CODA 当前专注于单 GPU 内核,分布式执行需额外工作
- 重参数化会模糊模块边界,与框架级抽象(如 PyTorch module)的集成需要权衡
- 代码基于 CuTeDSL 实现,需熟悉 CUTLASS 的编程模型
性能预期:
在 LLaMA-3 风格的 1B/7B/70B 模型配置下,CODA 的 kernel-level 加速显著,尤其在隐藏维度较大的层。完整的 block-level 基准测试(包含辅助归约和胶水操作)显示,重参数化的 Transformer 层在前后向传播中均能获得实质性吞吐提升。
与现有优化的关系
CODA 并非替代其他 LLM 优化技术,而是互补:
- FlashAttention:降低 attention 计算和显存开销
- KV-cache 优化:改善 decode 阶段的重用效率
- 量化:压缩权重和激活的位宽
- CODA:使 Transformer 块的线性代数流水线本身更高效
四者结合,可从不同层面共同压低 LLM 推理的延迟和成本。
参考来源
- Guo, H., et al. (2025). CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs. arXiv:2605.19269. https://arxiv.org/abs/2605.19269
内容声明:本文无广告投放、无付费植入。
如有事实性问题,欢迎发送勘误至 i@hotdrydog.com。