Hotdry.
ai-systems

DeepSeek FlashMLA 共享内存分块与 NVIDIA Hopper/Blackwell 架构优化实践

深入分析 DeepSeek 开源的 FlashMLA CUDA kernel,聚焦共享内存分块策略与 TMA 预取流水线在 NVIDIA Hopper 与 Blackwell 架构上的工程化调优参数。

在大型语言模型推理系统中,注意力机制的计算效率直接决定了服务的吞吐量和延迟。DeepSeek 开源的 FlashMLA 库针对其独创的 Multi-head Latent Attention(MLA)架构提供了高度优化的 CUDA kernel 实现在 H800 SXM5 GPU 上实现了 3000 GB/s 的内存带宽利用率(内存受限场景)以及 660 TFlops 的计算吞吐量(计算受限场景)。本文将从共享内存分块策略与架构感知优化的角度,剖析其工程化实现中的关键技术决策与可复用的调优参数。

计算特性与架构约束分析

理解 MLA kernel 的计算特性是优化的前提。与传统的 Multi-Head Attention(MHA)不同,MLA 通过将 Key 和 Value 压缩到低维潜在空间来减少 KV Cache 的内存占用,但这也改变了计算访存比(FLOPs/Byte)。设查询头数为 $h_q$,每请求的查询 token 数为 $s_q$( speculative decoding 禁用时为 1),KV token 数为 $s_k$,头维度为 $d_k$ 和 $d_v$,则计算量为 $2 h_q s_q s_k (d_k + d_v)$,内存访问量约为 $2 s_k d_k$(BF16 格式),计算访存比约为 $2 h_q s_q (d_k + d_v) /d_k$。在 DeepSeek 的线上推理系统中,解码实例不使用 Tensor Parallelism,即 $h_q = 128$,因此该 kernel 属于典型的计算受限场景。

NVIDIA H800 SXM5 的峰值内存带宽为 3.35 TB/s,受限后的峰值计算能力约为 865 TFlops。根据计算,当 $h_q s_q \ge 128$ 时 kernel 进入计算受限状态。Blackwell 架构(B200)的 SM100 带来了新的硬件特性,包括增强的异步复制引擎和改进的 Tensor Core 调度能力,这为 kernel 优化提供了新的空间,同时也要求针对新架构重新审视内存布局与流水线设计。

共享内存分块与 TMA 预取流水线

FlashMLA 的核心优化策略是将 KV 数据块分块加载到共享内存(称为 tiling),并利用 Tensor Memory Accelerator(TMA)指令实现异步预取,从而隐藏内存访问延迟。对于 $64 \times 576$ 的 K block,kernel 会触发 9 次 TMA 复制操作,每次搬运 $64 \times 64$ 的数据块。这种细粒度的分块策略使得 GEMM 运算可以在第一块数据到达共享内存后立即启动,而剩余块仍在传输中,实现了内存访问与计算的重叠。

TMA 缓存提示的配置对 L2 Cache 命中率有显著影响。FlashMLA 使用 cute::TMA::CacheHintSm90::EVICT_FIRST 提示,这指示硬件在加载新数据时优先驱逐缓存中的旧数据,避免在长序列场景下缓存污染导致的有效带宽下降。在 Hopper 架构上,这一配置可将有效内存带宽提升约 5% 至 8%,尤其在 KV Cache 长度超过 4096 token 时效果更为明显。

分块大小的选择需要平衡多个因素。64x64 的 tile 尺寸使得每个 tile 恰好对应一个 warpgroup 的 MMA(Matrix Multiply-Accumulate)操作规模,充分利用 Tensor Core 的 64x64x4 调度粒度。同时,64 行的高度与 SM 的 L1 Cache 行大小对齐,避免了跨行的部分写放大问题。对于 Blackwell 架构,由于共享内存容量提升至 256 KB(SM100),可以考虑将 tile 高度增加到 128 行以减少 TMA 发起次数,但需要与寄存器压力权衡。

Seesaw 调度与双 Warpgroup 协作

计算受限场景下的关键挑战是如何让 Tensor Core 保持持续忙碌。FlashAttention 系列算法通过在线 softmax 和分块累积来避免全局归约,但 MLA kernel 面临的额外约束是 WGMMA(Warpgroup Matrix Multiply-Accumulate)指令要求输出矩阵必须驻留在寄存器中。每个 $64 \times 512$ 的输出矩阵需要 32,768 个 32 位寄存器,而 SM 的总寄存器容量仅为 65,536,这意味着同一 SM 上无法同时持有两个完整的输出矩阵进行交替计算。

FlashMLA 提出的 Seesaw 调度巧妙地解决了这一约束。其核心思想是将输出矩阵垂直拆分为 $O_L$ 和 $O_R$(各 $64 \times 256$),分别由两个 warpgroup 负责,并在时序上交错执行。具体而言,在一个调度周期内,两个 warpgroup 分别处理不同的 KV 子块:WG0 负责左侧子块 $K_0, V_0$,WG1 负责右侧子块 $K_1, V_1$。通过精心设计的同步点,两个 warpgroup 可以交替执行 GEMM 和 CUDA Core 操作(softmax 计算),从而实现计算资源的充分利用。

从时间线的角度分析,完整的 seesaw 调度包含以下步骤:第一步,WG0 计算 $P_0 = Q K_0^T$;第二步,WG1 计算 $P_1 = Q K_1^T$;第三步,WG0 执行 softmax 计算 $P_0 = \exp (P_0 - m_{new})$ 并更新 $O_L = O_L \cdot scale + P_0 V_{0L}$;第四步,WG1 执行 softmax 计算 $P_1 = \exp (P_1 - m_{new})$ 并更新 $O_R = O_R \cdot (scale_0 \cdot scale_1) + P_1 V_{1R}$;第五步和第六步,两个 warpgroup 分别完成剩余的矩阵向量乘法和累加。这种交错模式确保了 Tensor Core(执行 WGMMA)和 CUDA Core(执行 softmax 的指数与归约操作)在每个时钟周期都有工作可做。

Tile Scheduler 与负载均衡

在多请求并发的解码场景下,不同请求的序列长度差异可能导致 SM 负载不均。FlashMLA 实现了一个 Tile Scheduler 用于在运行时将请求和 KV Cache 块动态分配给空闲 SM。调度策略的核心是维护一个待处理块队列,每个 SM 完成当前块后从队列中获取下一个块,从而保证即使在序列长度分布不均匀时,各 SM 的利用率也能保持在较高水平。

Tile Scheduler 的另一个重要作用是处理 split KV 场景。当 KV Cache 长度超过单次 GEMM 能够处理的范围时,需要将计算拆分为多个子问题,并在最后合并结果。调度器会根据当前 SM 的负载情况和剩余工作量,动态决定是否将一个请求的多个子问题分配给同一个 SM(以减少中间结果传输)或分配给不同 SM(以增加并行度)。在 H800 SXM5 上,这一机制可将长序列场景下的端到端延迟降低约 12% 至 18%。

架构差异与参数适配

Hopper(SM90)和 Blackwell(SM100)架构在内存子系统上的差异要求 kernel 使用不同的调优参数。Hopper 的 L2 Cache 容量为 80 MB,采用两路组相联结构;Blackwell 增加到 120 MB 并改为四路组相联。这意味着在 Hopper 上有效的 TMA 缓存提示策略在 Blackwell 上需要调整。对于 Blackwell,建议将 CacheHintSm90::EVICT_FIRST 替换为 CacheHintSm100::EVICT_LAST,因为更大的缓存容量允许保留更多中间结果供后续块复用。

寄存器分配策略也需要针对架构调整。Blackwell 的寄存器文件带宽更高,但单线程可用寄存器数量限制更严格。在 H800 上,FlashMLA 为每个 warpgroup 分配约 24,000 个寄存器;而在 B200 上,建议将这一数字降低至 20,000 左右,以容纳更多的活跃线程束(从 8 个增加到 10 个),从而增加内存访问的并发度。

FP8 KV Cache 的使用是另一个需要根据架构权衡的选择。FlashMLA 支持 FP8 格式的 KV Cache,每 token 的存储空间从 BF16 的 1152 字节(576 字节 K + 576 字节 V)压缩至 656 字节。在 H800 上,FP8 解码的开销(约 3% 的延迟增加)换来了 40% 的 KV Cache 容量提升,这在长上下文场景下收益显著。而在 B200 上,由于新增的 FP8 Tensor Core 加速单元,解码开销可降低至 1% 以下,因此强烈建议在 B200 部署时启用 FP8 KV Cache。

工程化落地建议

将 FlashMLA 集成到生产推理系统时,建议采用以下配置参数。对于 Hopper 架构(H800),设置 TMA 预取深度为 4(即在计算当前块时异步预取 4 个后续块),使用 BF16 KV Cache(除非序列长度超过 8192),并将工作块大小设置为 $64 \times 576$。对于 Blackwell 架构(B200),建议启用 FP8 KV Cache,将 TMA 预取深度增加到 6,并将工作块高度调整为 128 行以利用更大的共享内存容量。

监控层面,建议跟踪以下指标以验证优化效果:SM 利用率应稳定在 85% 以上;Tensor Core 利用率应达到理论峰值的 70% 以上(B200)或 65% 以上(H800);L2 Cache 命中率在长序列场景下应高于 80%;内存带宽应接近 3 TB/s(内存受限场景)。若 SM 利用率持续偏低,首先检查是否存在 TMA 指令与 WGMMA 的结构 hazards;其次验证 Tile Scheduler 是否正确处理了变长序列的边界情况。

FlashMLA 的开源实现为高效部署 MLA 架构的模型提供了可直接复用的工程模板。其共享内存分块策略与 seesaw 调度机制不仅适用于 MLA,也可迁移到其他计算受限的注意力变体。通过针对目标架构调整分块尺寸、缓存提示和寄存器分配,可以在 NVIDIA Hopper 和 Blackwell GPU 上获得接近硬件峰值的性能表现。

资料来源:FlashMLA GitHub 仓库(https://github.com/deepseek-ai/FlashMLA)以及 DeepSeek 官方技术博客(https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md)。

查看归档