Hotdry.

Article

TorchTPU 架构解析:Google 如何实现 PyTorch 原生对接 TPU 硬件

深入解析 TorchTPU 项目架构,揭示 Google 如何在不使用 XLA 作为前端的前提下,实现 PyTorch 原生运行于 TPU 集群的工程路径与性能优化策略。

2026-04-23mlops

现代 AI 基础设施的构建挑战已发生根本性转变。当前机器学习前沿需要利用分布式系统,跨越数千台加速器运行。当模型扩展到在数万片芯片的集群上运行时,为这些模型提供动力的软件必须在性能、硬件可移植性和可靠性方面满足新的需求。在 Google,Tensor Processing Unit(TPU)是超级计算基础设施的核心。这些定制 ASIC 驱动着 Gemini 和 Veo 等 Google 自有 AI 平台的训练与服务,同时也支撑着 Cloud 客户的大规模工作负载。由于潜在用户中有许多使用 PyTorch 构建模型,因此提供一种允许 PyTorch 原生且高效地运行在 TPU 上的集成方案至关重要。

TorchTPU 正是为解决这一需求而诞生。作为一个工程团队,我们的使命是构建一个以可用性、可移植性和卓越性能为导向的技术栈。我们希望使开发者能够以最少的代码修改迁移现有 PyTorch 工作负载,同时为他们提供从硬件中提取每一丝算力的 API 和工具。本文将深入探讨驱动 TorchTPU 的工程原则、我们构建的技术架构,以及 2026 年的发展路线图。

TPU 硬件架构与软件挑战

要理解 TorchTPU,首先需要理解它所针对的硬件。TPU 系统不仅仅是一颗芯片,而是一个集成网络。主机连接到多颗芯片,每颗芯片通过芯片间互连(Inter-Chip Connect,ICI)连接到主机和其他芯片。这种 ICI 将芯片连接成高效的二维或三维环面拓扑(Torus Topology),允许大规模横向扩展而无需传统网络瓶颈。在每个芯片内,执行分为 TensorCore 和 SparseCore。TensorCore 是专用于密集矩阵运算的单线程单元,而 SparseCore 处理不规则内存访问模式,如嵌入表、Gather/Scatter 操作以及集合通信的卸载。

这些特性意味着 TPU 是机器学习的强大工具;我们的目标是提供专门支持,以充分利用这些独特能力。这正是 PyTorch 的用武之地:PyTorch 工具链已经在其他设备类型上创建了一致且广泛使用的接口。可用性的核心原则很简单:它应该让人感觉像原生 PyTorch。开发者应该能够获取现有的 PyTorch 脚本,将初始化改为「tpu」,然后运行训练循环,而无需修改任何核心逻辑。实现这一目标需要一种全新的方法来实现 PyTorch 与 TPU 编译器和运行时栈的交互。

TorchTPU 技术栈:工程实现路径

Eager First:无需妥协的灵活性

从概念到 TPU 上的原生 PyTorch 体验,意味着需要重新思考执行栈。我们确立了「Eager First」哲学。我们不要求开发者立即进入静态图编译,而是使用 PyTorch 的「PrivateUse1」接口实现 TorchTPU。无子类、无包装器;只有普通的、熟悉的 PyTorch Tensor 运行在 TPU 上。通过在这一深层级别集成,我们能够充分优先考虑开发者期望的 eager 执行体验。

我们设计了三種不同的 eager 模式来支持开发生命周期。第一种是 Debug Eager,它一次调度一个操作,并在每次执行后与 CPU 同步。虽然本质上很慢,但对于追踪形状不匹配、NaN 值和内存崩溃非常宝贵。第二种是 Strict Eager,它保持单操作调度,但异步执行,旨在镜像默认的 PyTorch 体验。这允许 CPU 和 TPU 同时运行,直到用户脚本中的同步点为止。然而,真正的突破是我们的 Fused Eager 模式。TorchTPU 使用自动化反射对操作流进行分析,在将其交给 TPU 之前动态地将步骤融合成更大、计算密集的块。通过最大化 TensorCore 利用率并最小化内存带宽开销,Fused Eager 能够持续提供比 Strict Eager 高 50% 到 100% 以上的性能提升,且无需用户进行任何设置。

这三种模式都共享一个编译缓存,该缓存可以在单个主机上运行,也可以配置为在多主机设置中持久化。这意味着随着 TorchTPU 学习你的工作负载,你花费在编译上的时间会减少,而花费在运行上的时间会增加。

静态编译: Dynamo、XLA 与 StableHLO

对于想要在 TPU 上释放峰值性能的用户,TorchTPU 原生集成 torch.compile 接口以进行完整图编译。我们使用 Torch Dynamo 捕获 FX 图。然而,我们不通过 Torch Inductor 路由,而是使用 XLA 作为主要后端编译器。这是一个经过深思熟虑的架构决策。XLA 为 TPU 拓扑经过了严格的实战检验。更重要的是,它原生理解如何优化通过 ICI 密集计算与集合通信之间的关键重叠。我们的翻译层将 PyTorch 的算子直接映射到 StableHLO,这是 XLA 的主要中间表示(IR),用于张量数学运算。这创建了从 PyTorch 到 XLA 核心降低路径的直接连接,使我们能够生成高度优化的 TPU 二进制文件,同时重用 eager 模式建立的执行路径。

对于编写自定义算子的开发者,我们确保可扩展性不会破坏性能。TorchTPU 原生支持使用 Pallas 和 JAX 编写的自定义内核。通过使用 @torch_tpu.pallas.custom_jax_kernel 装饰 JAX 函数,工程师可以编写低级硬件指令,直接与我们的降低路径对接。目前正在努力也支持 Helion 内核。

分布式训练与 MPMD 挑战

为了在规模化场景下保持 eager 和编译模式的灵活性和可用性,我们重点关注 PyTorch 的分布式 API。今天,TorchTPU 原生支持分布式数据并行(DDP)、完全分片数据并行 v2(FSDPv2)以及 PyTorch 的 DTensor。我们已经验证,许多基于 PyTorch 分布式 API 构建的第三方库在 TorchTPU 上可以不变地工作。

PyTorch/XLA 的一个主要限制是它仅支持纯 SPMD 代码。PyTorch 输入的现实情况是,不同 rank 上运行的代码经常存在轻微分歧:例如,rank 0 进程做一些额外的日志或分析工作是很常见的。这种输入对 TPU 栈构成了挑战,因为 TPU 栈针对 SPMD 优化进行了大量优化。XLA 在全局视图下效果最佳,但绕过它会增加开发者必须仔细去除不纯行为的开销。TorchTPU 经过精心设计以支持分歧执行(MPMD),并在必要时隔离通信原语以保持正确性,代价最小。这种方法有助于确保在 TPU 上使用 PyTorch 的体验对于现有 PyTorch 开发者来说尽可能自然,同时在可能的情况下保留 XLA 通过分布式 TPU 部署的全局视图重叠通信和计算的能力。

TPU 硬件感知与性能调优

TPU 可以实现非常高的性能和效率,但最佳模型设计可能与其他硬件略有不同。例如,我们经常看到模型将注意力头维度硬编码为 64,而当前一代 TPU 在维度为 128 或 256 时达到峰值矩阵乘法效率。将模型修改为针对 128 或 256 维度可以更好地利用 TPU 芯片上大型、密集且高效的 TensorCore。可移植性并不能消除硬件现实,因此 TorchTPU 促进了分层工作流:首先建立正确执行,然后使用我们即将推出的深度指南来识别和重构次优架构,或注入自定义内核,以实现最佳硬件利用率。

2026 年路线图与未来展望

今天,我们已经在训练和服务支持方面奠定了坚实基础,并且正在积极应对几个开放挑战,以使 TorchTPU 成为 PyTorch 生态系统中无摩擦的后端。我们的编译器团队的一个重点方向是减少由动态序列长度和批量大小变化触发的重编译。通过在 XLA 内实现高级有界动态性,我们的目标是在不产生编译开销的情况下处理形状变化。这对于某些工作负载(如迭代下一个 token 预测)来说可能是一个重要功能。

我们还在构建标准操作预编译 TPU 内核的综合库,以显著减少首次执行迭代的延迟。在 2026 年剩余时间里,我们正在努力推进以下工作:推出包含广泛文档和可复现架构教程的公共 GitHub 仓库;与 PyTorch 的 Helion DSL 集成以进一步扩展自定义内核能力;通过 torch.compile 直接提供一流动态形状支持;原生多队列支持,以缓解严重异步代码库的迁移,这些代码库具有解耦的内存和计算流;与 vLLM 和 TorchTitan 等生态系统支柱的深度集成,以及经过验证的扩展到完整 Pod 规模基础设施的线性扩展。

TorchTPU 代表了我们为在 TPU 硬件上提供无缝、高性能 PyTorch 体验而付出的专门工程努力。我们正在消除你喜爱的框架与下一代 AI 所需的 TPU 超级计算硬件之间的障碍和摩擦。


资料来源:本文主要参考 Google Developers Blog 发布的「TorchTPU: Running PyTorch Natively on TPUs at Google Scale」(2026 年 4 月 7 日)。

mlops