在现代大语言模型的推理与训练中,注意力机制的计算瓶颈始终是性能优化的核心议题。传统的注意力实现需要实例化 $N \times N$ 大小的注意力矩阵 $S$ 和 $P$,对于长序列场景,这不仅带来了 $O (N^2)$ 的显存占用,更因频繁的片下内存(High-Bandwidth Memory, HBM)读写而成为性能杀手。FlashAttention 系列通过 “核融合” 与 “分块计算” 重新定义了高效注意力的标准,而近期提出的 FlashAttention-T(Towards Fully Tensorized Attention)则更进一步,试图将计算完全张量化(Fully Tensorized),以榨取硬件的最后一滴算力。本文将深入剖析 FlashAttention-T 的内核设计哲学,对比其与 v1/v2 的演进差异,并重点解析完全张量化场景下的内存布局优化实战。
从融合到张量化:FlashAttention 演进脉络
FlashAttention v1 的核心突破在于 ** 阻塞式算法(Blocking Algorithm)与核融合(Kernel Fusion)** 的结合。它将 $Q, K, V$ 切分为适合 SRAM 的小块(Tiles),在芯片内完成 $S_{tile} = Q_{tile} K^T$、$P_{tile} = softmax (S_{tile})$ 到 $O_{tile} = P_{tile} V$ 的全流程,从而避免了中间结果 $S$ 和 $P$ 对 HBM 的写入。反向传播时,通过存储输出和 Softmax 统计量进行重计算(Recomputation),大幅降低了显存占用。
FlashAttention v2 在此基础上进行了 ** 工作分区(Work Partitioning)和减少非矩阵乘法(Non-Matmul)** 的优化,进一步提升了 occupancy 和指令吞吐量。然而,无论是 v1 还是 v2,其核心计算路径中仍包含大量在向量寄存器上执行的标量操作,尤其是 softmax 中的求最大值、求和、指数运算等,这些操作未能充分利用 GPU 的张量核心(Tensor Core)资源。
FlashAttention-T 的出现标志着这一领域的范式转变。它的目标不仅仅是 “融合”,而是完全张量化—— 即利用张量核心的 MMA(Matrix Multiply-Accumulate)指令来执行原本只能在向量单元上运行的 softmax 原语。这要求对算法进行根本性的重设计,以适配 MMA 指令的约束(数据形状、对齐方式),并通过架构感知的调度技术并行化张量与向量单元的计算。这种深层次的硬件映射,使得在相同硬件条件下,理论算力利用率得到进一步提升。
完全张量化内核的工程实现剖析
FlashAttention-T 的完全张量化实现,意味着要重新利用那些专为矩阵乘法设计的 MMA 指令来处理 softmax 流程。这不仅是代码层面的改写,更涉及数据依赖关系和计算逻辑的深度重构。
传统 softmax 需要逐行扫描计算最大值 $m (x) = \max (x_j)$ 和指数和 $\ell (x) = \sum_j \exp (x_j)$,这在向量编程模型中非常直观。但在张量模型中,数据被切分为 16x16 或类似的矩阵块进行处理。因此,FlashAttention-T 采用了张量化在线 softmax 算法(Tensorized Online Softmax Algorithm)。该算法必须处理块内的局部统计量,并在块之间进行正确的同步与累积,确保最终结果的数学等价性。这意味着 MMA 指令不再只负责 $QK^T$ 和 $PV$ 的 GEMM 运算,还要穿插执行用于 softmax 状态维护的张量操作。
另一个关键挑战是调度。GPU 的张量核心(Tensor Core)和向量单元(Vector Unit)通常可以并行工作。FlashAttention-T 的调度策略会识别哪些操作可以卸载到张量核心(利用 MMA 指令),哪些操作必须或更适合在向量单元执行,从而实现指令级的双发射(Dual Issue)或流水线重叠,最大化硬件利用率。这是一种比 v1/v2 更高维度的优化,涉及到 PTX/SASS 层面的指令排布。
内存布局优化实战:三级架构与数据局部性
无论算法如何演进,对 GPU 内存层次结构的精细管理始终是高性能计算的核心。FlashAttention 内核通常遵循 GMEM(Global Memory) -> SMEM(Shared Memory) -> RF(Register File) 的三级数据移动路径。
在 SMEM(Shared Memory) 层面,为了最大化带宽利用,FlashAttention-T 延续了高效的存储策略:所有张量 ($Q, K, V, S, O$) 均以行主序(Row-major)存储在 SMEM 中。这种布局有利于从 HBM 批量加载数据,并在 SMEM 中进行快速的随机访问或转置操作。
而在 RF(Register File) 层面,由于 MMA 指令对输入形状有严格要求,布局变得更加复杂。一个关键的工程细节是:注意力分数矩阵 $S$ 在 RF 中通常以转置(Transposed)形式存储。在 SMEM -> RF 的拷贝过程中,代码需要显式地执行这一转置操作(通常结合 ldmatrix 指令族的 ldmatrix.transpose 变体)。这种布局使得后续的 MMA 操作(无论是 $Q \cdot K^T$ 还是 $P \cdot V$)都能以最优的数据形状匹配指令的列 / 行定义。
此外,FlashAttention-T 在 Warp 级别的协作(Cooperative Loading) 上也有独特考量。对于 $K$ 和 $V$ 张量,单个 warp 通常无法独立完成计算,它需要访问整个 $K$ 或 $V$ 块的数据。因此,CTA(线程块)内的多个 warp 必须协同工作:首先将数据从 GMEM 加载到 SMEM(可能由不同 warp 负责不同的行块),然后通过同步屏障(__syncthreads())确保数据就绪,最后每个 warp 再从 SMEM 拷贝其所需的子块到 RF。这种协作模式虽然增加了编程复杂度,但通过减少总的 HBM 访问次数,显著提升了能效比。
工程权衡与落地考量
完全张量化虽然代表了性能的前沿方向,但也带来了不可忽视的工程代价。首先是硬件依赖性。这种深度绑定硬件特性的实现(如特定架构的张量核心指令集、Shared Memory 的 banks 冲突规避)使得代码在不同代际的 GPU(Volta/Ampere vs. Hopper/Blackwell)上可能需要重写或调优。其次是编程与调试难度。手动管理寄存器压力、同步屏障以及非对齐内存访问的边界条件,需要对 GPU 架构有极深的理解,这增加了内核开发和维护的成本。
尽管如此,对于追求极致性能的推理引擎(如 vLLM, TensorRT-LLM)而言,FlashAttention-T 所代表的优化路径是必须攻克的阵地。它揭示了一个趋势:未来的性能优化将不再局限于算法层面(如稀疏注意力),而是深入到指令级并行与异构计算单元协同的深水区。
参考资料:
- Zenodo: FlashAttention-T: Towards Fully Tensorized Attention by Exploiting Tensor–Vector Parallelism (PPoPP '26).
- Sonny's Blog: Flash Attention from Scratch Part 3: Kernel 1 (Memory hierarchy and RF layouts analysis).