当代人工智能服务的核心技术是 Transformer 模型中的自注意力机制。然而,标准自注意力的计算复杂度会随上下文长度线性增长,这一特性正在推动存储、计算和能源需求以超出社会供给能力的速度攀升。面对这一挑战,研究人员提出了一种全新的数学方法,通过对称感知泰勒近似(Symmetry-Aware Taylor Approximation)将自注意力的计算成本固定为常数级别,无论上下文扩展到多长,每个令牌的处理成本都保持不变。这项技术突破不仅为长上下文推理提供了新的可能性,更从根本上重新思考了注意力机制的计算范式。
标准自注意力的计算困境
理解对称感知泰勒近似的创新价值,首先需要认识标准自注意力机制的计算瓶颈。在传统形式中,自注意力的空间和时间复杂度均为 O (n),其中 n 表示序列中的令牌数量。每处理一个额外令牌,所需的内存和计算量都会按序列长度成比例增加。从整体来看,整个序列的处理复杂度达到 O (n²) 的二次方级别。这意味着当上下文长度从一千令牌扩展到一百万令牌时,单个令牌的处理成本将增长一千倍,内存占用也将呈现相同的增长趋势。
标准自注意力的核心操作是查询向量与键向量的点积结果经过缩放后送入指数函数,然后通过 Softmax 归一化与值向量加权求和得到输出。这一过程中,每个新令牌都需要与之前所有令牌重新计算注意力权重,导致 KV 缓存(Key-Value Cache)必须存储所有历史令牌的键和值向量。当上下文长度达到数百万令牌时,KV 缓存的大小可能超过数十吉字节,远超单个 GPU 的高速内存容量。这一限制已经成为长上下文应用的主要瓶颈,限制了模型处理长文档、复杂推理链和个性化历史的能力。
为了缓解这一问题,社区已经提出了多种修改方案,包括数据与参数复用策略、局部和结构化上下文窗口、低秩近似和稀疏注意力等。另一类重要的替代方案是循环神经网络,特别是可以通过并行扫描算法计算的线性循环状态空间模型。然而,这些方法大多通过牺牲注意力表达能力或引入近似误差来换取效率提升,在性能和效率之间存在固有的权衡关系。
泰勒近似的数学突破
对称感知泰勒近似方法的独特之处在于,它不是修改注意力的定义本身,而是通过重新组织计算方式来逼近原有的注意力模式。核心思想是将 Softmax 中的指数核函数展开为泰勒级数,然后利用对称张量的数学性质高效计算各级展开项。
对于查询向量 q 和键向量 k,以及缩放常数 c = √d_K(d_K 为键维度),指数核函数的泰勒展开为:exp (qᵀk/c) = Σ(p=0 到∞) _p (qᵀk)ᵖ,其中 _p = 1/(p!・cᵖ) 是第 p 项的缩放系数。这一展开将指数运算转化为多项式求和问题,而多项式计算可以利用线性代数工具高效实现。
然而,直接计算 (qᵀk)ᵖ 会遇到组合爆炸问题。对于阶数为 p 的项,完整展开会产生 d_Kᵖ 个单项式,即使对于中等规模的键维度(如 d_K = 64)和适度的阶数(如 p = 4),这个数字也会达到惊人的 16777216 项,远超任何可行的计算能力。这里正是对称感知方法发挥关键作用的地方。
研究人员观察到,(qᵀk)ᵖ 本质上是 q 和 k 对应元素乘积的 p 次幂求和,这可以表示为对称张量 q⊗p 和 k⊗p 的逐元素乘积之和。更重要的是,这些对称张量具有高度冗余性:大量元素是重复的单项式排列。例如,当 d_K = 2 且 p = 2 时,(q₁k₁ + q₂k₂)² 展开后包含四项,但 q₁q₂k₁k₂ 和 q₂q₁k₂k₁ 实际上是相同的单项式,可以合并。
通过识别这种对称性,可以只保留上三角区域中的唯一单项式作为最小基。基的大小由组合数 binom (d_K + p - 1, p) 给出,对于 d_K = 64 且 p = 4,这个数字从 16777216 降至 47905,减少超过 350 倍。更高的阶数带来更大的压缩收益,因为完整张量大小按指数增长,而最小基按多项式增长。
特征映射与线性注意力
泰勒展开的另一种理解角度是特征映射视角。对于每个阶数 p,可以定义一个特征映射函数 φ_p,将 d_K 维的查询或键向量映射到 m_p = binom (d_K + p - 1, p) 维的特征空间中的某个坐标。原始的 (qᵀk)ᵖ 可以表示为该特征空间中两个向量的加权内积:〈φ_p (q), φ_p (k)〉_C,其中 C_p 是对应对角权重矩阵,其元素表示每个唯一单项式在完整展开中出现的次数。
通过这种方式,指数核函数的泰勒近似转化为特征空间中的内积和近似。截断到 P 项后的近似为:exp (qᵀk/c) ≈ Σ(p=0 到 P-1) _p〈φ_p (q), φ_p (k)〉_C_p。这一表达式与线性注意力的形式非常相似,但关键区别在于这里需要计算 P 个独立的内积,每个内积对应泰勒级数的一个项。
在线性注意力的框架下,每个令牌的输出可以通过累积状态高效计算。定义累积特征状态 Z_{p,T} = Σ(t=1 到 T) φ_p (k_t) 和累积加权状态 S_{p,T} = Σ(t=1 到 T) φ_p (k_t)・v_tᵀ,其中 v_t 是第 t 个令牌的值向量。这些状态满足简单的线性递推关系:Z_{p,T} = Z_{p,T-1} + φ_p (k_T) 和 S_{p,T} = S_{p,T-1} + φ_p (k_T)・v_Tᵀ。这种递推关系允许通过顺序计算或并行前缀扫描算法高效更新状态。
第 T 个令牌的最终输出为 y_T ≈ S_T / Z_T,其中 S_T = Σ(p=0 到 P-1) S_{p,T}ᵀφ_p (q_T) 和 Z_T = Σ(p=0 到 P-1) 〈φ_p (q_T), Z_{p,T}〉_C_p。值得注意的是,所有这些计算的复杂度都只取决于头维度 d_K、值维度 d_V 和泰勒项数 P,与序列长度 n 完全无关。
隐藏状态大小与计算量分析
对称感知泰勒近似方法的核心优势可以通过具体的数学公式量化描述。对于单一泰勒项 p,隐藏状态大小为 (d_V + 1)・binom (d_K + p - 1, p)。这里的 +1 对应于常数项(p = 0)的贡献,它在数值稳定性中扮演重要角色。所有泰勒项的总体隐藏状态大小为 (d_V + 1)・binom (d_K + P - 1, P - 1),其中 P 是截断的泰勒项总数。
前向传播中每令牌的浮点运算次数(FLOPs)同样可以精确计算。对于单一泰勒项,FLOPs 为 (4d_V + 2p + 4)・binom (d_K + p - 1, p)。所有泰勒项的总体 FLOPs 为 (4d_V + 2 (P・d_K + 1)/(d_K + 1) + 2)・binom (d_K + P - 1, P - 1)。
以典型的多头注意力配置为例,假设头维度 d_K = 64、值维度 d_V = 64,使用 P = 4 项泰勒近似。隐藏状态大小约为 (64 + 1)・binom (64 + 4 - 1, 4 - 1) = 65・binom (67, 3) = 65・47905 ≈ 3.1 百万元素。对于 Float16 精度,这相当于约 6.2 MB 的状态大小。相比之下,标准注意力在 100 万令牌上下文下的 KV 缓存需要 1000000・(64 + 64) = 1.28 亿元素,约 256 MB(Float16)。当上下文继续增长时,标准注意力的缓存将迅速超过 GPU 内存容量,而泰勒近似方法的状态大小始终保持不变。
更重要的是,这些成本与头维度成反比关系。通过减小头维度并增加头数量,可以在固定资源预算下实现更强的注意力表达能力。在传统多头注意力中,由于每个头的成本随上下文长度线性增长,头数量的增加会成比例放大内存和计算压力。而泰勒近似方法将成本固定后,这一限制不复存在,为探索新型架构设计开辟了道路。
精度与效率的权衡
选择合适的泰勒截断项数 P 是在精度和效率之间权衡的关键决策。研究人员的实验表明,P = 4 通常足以使近似误差达到与 Float16 浮点精度相当的水平,这对于大多数推理场景是可接受的。这是因为泰勒系数的缩放因子 _p 会随 p 增大而迅速减小,第四项的系数已经接近 Float16 的分辨率极限。
具体而言,当 c = √d_K 时,第 p 项的缩放常数为 1/(p!・cᵖ)。对于 d_K = 64,c = 8,第四项的缩放常数约为 1/(24・4096) ≈ 1/98304,远小于 Float16 的机器精度(约 2⁻¹⁰ ≈ 1/1024)。这意味着更高阶项对最终结果的贡献在浮点精度下可以忽略不计。
然而,需要注意的是,近似精度与具体的输入统计特性有关。对于某些分布的查询和键向量,更高阶项可能仍然携带显著的信息。因此,在实际部署中,P 应该被视为可调的超参数,根据目标应用的具体精度要求进行选择。对于对精度要求极高的场景(如科学计算应用),可以增加 P 至 6 或 8;对于延迟敏感但对精度要求相对宽松的场景(如实时语音助手),P = 2 或 3 可能已经足够。
硬件优化实现路径
论文作者实现的当前版本是基于 PyTorch 的概念验证,尚未针对特定硬件进行深度优化。性能测试显示,在 NVIDIA GPU 上,当上下文长度从 1K 扩展到 1 亿令牌时,每令牌的峰值内存占用和运行时间均下降约三个数量级。然而,作者指出这些结果可能被当前实现的几个非最优因素所掩盖。
第一个优化空间是避免不必要的临时数据复制。当前实现使用高级索引从查询和键向量中提取元素来构建特征向量,但 PyTorch 返回的视图在底层会复制数据。在理想情况下,这些操作应该是零拷贝的索引访问,因为被索引的数据不会被修改。实现这一优化需要编写自定义 CUDA 内核,直接在全局内存中访问所需元素。
第二个优化方向是利用对称索引的层次结构。特征映射矩阵 M_p 的行按非递减顺序排列(i₁ ≤ i₂ ≤ ... ≤ i_p),这种层次结构可以用于优化内存访问模式和计算调度。例如,可以按索引的第一维度分组计算,利用 GPU 的并行处理能力同时计算多个共享前缀的单项式。
第三个重要优化是将泰勒项的评估从顺序执行改为并行执行。虽然当前实现在单个 CUDA 流上顺序处理四个泰勒项,但理论上这些项之间相互独立,可以同时在多个流上计算,或使用单个流中的并发内核。这将更好地利用现代 GPU 的大规模并行能力。
最根本的优化需要编写专门的设备端内核,精细管理数据在不同内存层级间的移动。现代 GPU 的高带宽内存(HBM)容量有限,而片上静态随机存取存储器(SRAM)速度更快但容量更小。精心设计的内核可以将经常访问的数据保持在更快的内存中,减少全局内存带宽瓶颈带来的性能损失。
实践应用建议
对于希望在实际系统中采用这一技术的开发者,以下几点建议可能有所帮助。首先,当前实现最适合的应用场景是长上下文推理,特别是那些需要处理数万甚至数十万令牌上下文的任务。对于短上下文(数千令牌以下),标准注意力的效率可能更高,因为泰勒近似引入的固定开销在小规模问题上可能占据主导地位。
其次,在模型架构设计方面,由于成本与头维度成反比,可以考虑使用更多的小型头替代少量的大型头。这种配置在传统注意力中会因为线性增长的开销而不切实际,但在泰勒近似框架下完全可行。更小的头维度还会降低特征空间的维度,进一步减少内存占用和计算量。
第三,位置编码方案需要特别关注。泰勒近似方法本身不改变位置编码的处理方式,但如果目标应用需要处理极长上下文(超过数十亿令牌),现有的位置编码方案可能需要进行相应调整。某些相对位置编码方案在极端长度下可能遇到数值不稳定问题,这在选择部署方案时需要考虑。
第四,数值稳定性是部署时需要仔细验证的方面。虽然 P = 4 的近似通常足够精确,但在某些输入分布下(如异常大的点积值),泰勒展开的截断误差可能累积。建议在部署前使用代表性数据样本进行全面的数值验证,确认输出与标准注意力的差异在可接受范围内。
与其他方法的比较
对称感知泰勒近似方法与现有的线性注意力方法既有相似之处,也有本质区别。与 Katharopoulos 等人提出的经典线性注意力相比,核心差异在于核函数的近似方式。传统线性注意力通常使用单个特征映射 φ,将指数核近似为某种单一的非线性变换。而泰勒近似使用 P 个独立的特征映射,每个映射对应泰勒级数的一项,然后将结果线性组合。
这种分解带来的好处是精度可控:通过增加 P,可以使近似误差任意小,直到达到浮点精度的极限。与之相对,大多数现有线性注意力方法引入的是系统性的偏差,无法通过简单的参数调整来消除。此外,P 个项可以独立累积,这意味着在计算资源充足时,可以并行处理所有项,进一步提升吞吐量。
与状态空间模型(如 Mamba)相比,泰勒近似方法保留了标准 Transformer 的整体架构,可以更容易地集成到现有训练和推理流程中。状态空间模型虽然也提供线性复杂度的序列处理,但需要专门的训练方法,且在某些任务上可能表现出不同的归纳偏置。泰勒近似方法可以直接复用预训练的 Transformer 权重,只需替换注意力实现,大幅降低了迁移成本。
未来发展方向
对称感知泰勒近似方法的提出开启了多个有价值的研究方向。首先是端到端训练的研究。目前的工作验证了注意力机制本身的正确性,但在完整模型上的训练动态和下游任务性能尚未得到充分探索。由于泰勒近似引入了可学习的精度参数(通过 P 的选择),研究如何自适应地调整 P 以平衡不同层的精度需求可能带来进一步的效率提升。
其次是更高阶特征空间的压缩技术。当前实现计算完整的最小基,但某些基向量的系数可能很小,对最终输出的贡献有限。探索随机采样、随机投影或低秩逼近等技术来压缩特征空间,在保持精度的同时进一步降低计算成本,是一个有前景的方向。
第三是将类似技术扩展到其他核函数。除了指数核(如 Softmax 中的核),许多深度学习操作涉及其他形式的核函数,如高斯核或拉普拉斯核。研究对称感知分解是否能为这些操作带来类似的效率提升,可能具有广泛的应用价值。
最后是与现有硬件加速器的协同设计。泰勒近似方法的可并行结构和规则的张量运算模式非常适合现代 AI 加速器的架构特征。与芯片厂商合作设计专门的支持指令或内核,可以进一步释放这一方法的性能潜力。
结论
对称感知泰勒近似代表了自注意力计算效率领域的重要进展。通过识别泰勒展开中的对称张量结构,并利用最小基表示来消除组合冗余,这一方法将原本随上下文长度线性增长的成本固定为常数级别。实验结果表明,在概念验证实现下,每令牌内存占用和运行时间均可实现三个数量级的缩减。随着更优化的工程实现出现,这些数字可能进一步提升。
对于长上下文推理应用,这一技术提供了保持 Transformer 架构不变性的同时获得线性复杂度注意力的途径。随着上下文窗口继续扩展成为 AI 系统的核心需求,这类高效注意力机制的重要性只会持续增加。对称感知泰勒近似的数学框架不仅解决了当前的效率瓶颈,更为未来更高效的序列模型设计提供了新的思路和工具。
资料来源:arXiv:2602.00294