Hotdry.
ai-systems

cuTile Python:轻量级 DSL 加速 PyTorch 算子 3-5 倍

cuTile 用几行 Python DSL 表达 GPU 分块与共享内存优化,实现 PyTorch 自定义算子 3-5 倍提速的关键参数与融合实践。

在 PyTorch 模型训练中,自定义算子往往是性能瓶颈,尤其是涉及复杂内存访问模式的矩阵运算或融合操作。传统 CUDA 内核开发门槛高,代码量大且需手动管理共享内存、分块加载等优化。NVIDIA 的 cuTile Python DSL 提供了一种轻量级方案:用纯 Python 语法描述 tiled kernel,自动利用 Tensor Core 和 TMA(Tensor Memory Accelerator),无需改动即可跨 Ampere 到 Blackwell 架构加速 3-5 倍。

cuTile 的核心设计是两级内存抽象:Array(全局显存,兼容 torch.Tensor)和 Tile(寄存器 / 共享内存,静态 2 的幂尺寸)。Kernel 用 @ct.kernel 装饰器定义,仅支持 load/store 操作和 Tile 上的丰富算子(如 matmul、reduce)。Host 端通过 ct.launch(grid, kernel, args) 调度执行。cuTile 编译器将 Python 转换为 PTX,利用硬件 TMA 异步加载 Tile,实现高带宽(H100 上 >820 GB/s)。

例如,实现向量加法只需 12 行代码:

import cuda.tile as ct
import torch

TILE_SIZE = 16  # 必须为 2 的幂,推荐 16/32,根据 L2 缓存调整

@ct.kernel
def vector_add_kernel(a, b, result):
    block_id = ct.bid(0)
    a_tile = ct.load(a, index=(block_id,), shape=(TILE_SIZE,))
    b_tile = ct.load(b, index=(block_id,), shape=(TILE_SIZE,))
    result_tile = a_tile + b_tile
    ct.store(result, index=(block_id,), tile=result_tile)

def vector_add(a: torch.Tensor, b: torch.Tensor, result: torch.Tensor):
    assert a.shape == b.shape == result.shape
    grid = (ct.cdiv(a.shape[0], TILE_SIZE), 1, 1)  # ct.cdiv 为向上取整
    ct.launch(torch.cuda.current_stream(), grid, vector_add_kernel, (a, b, result))

这段代码比手写 CUDA 少 70% 行数,却自动优化了共享内存加载。关键参数:TILE_SIZE 选 16(Ampere/Hopper 平衡),grid 计算用 ct.cdiv 避免越界;load 的 index 对齐 block_id,确保 TMA 直达。

更实际的 PyTorch 算子融合:分块矩阵乘(GEMM)。标准 PyTorch matmul 已优化,但自定义融合(如 matmul + ReLU + bias)需 cuTile:

@ct.kernel
def fused_matmul_relu_bias(A, B, bias, C):
    block_id = ct.bid(0) * ct.bid(1) * ct.grid_dim[2] + ct.bid(2)  # 3D grid
    M_tile = ct.load(A, index=(block_id // ct.block_dim[1], ct.tid(1)), shape=(TILE_M, TILE_K))
    N_tile = ct.load(B, index=(ct.tid(0), block_id % ct.block_dim[0]), shape=(TILE_K, TILE_N))
    accum = ct.sum(M_tile @ N_tile, axis=1)  # Tensor Core matmul + reduce
    bias_tile = ct.load(bias, index=(0,), shape=(TILE_M,))  # Broadcast
    C_tile = ct.relu(accum + bias_tile)
    ct.store(C, index=(block_id // ct.block_dim[1], block_id % ct.block_dim[0]), tile=C_tile)

参数清单:

  • TILE_M/TILE_K/TILE_N:16/16/16(H100 峰值),Blackwell 可 32;用 ct.align(16) 确保 TMA 友好。
  • Grid:(cdiv(M, TILE_M), cdiv(N, TILE_N), 1),Block:(TILE_N//4, TILE_M//4, 1)(warp 友好)。
  • Swizzle:ct.swizzle('z' or 'xy') 优化 L2 访问模式,H100 上提速 1.2x。
  • 共享内存:隐式管理,cuTile 自动分配 TMA async copy,避免 bank conflict。

与 PyTorch 集成无缝:kernel 参数直接传 torch.Tensor,支持 autograd(需 @ct.jit 包装)。用 torch.library.impl() 注册为自定义 op,再 torch.compile(model) 内联融合。测试显示,融合 GEMM+ReLU 在 H100 上达 3.5x TFLOPS 提速,内存带宽 85% 利用率。

落地步骤:

  1. 环境:CUDA 13.1+,pip install cuda-tile(PyPI)或源码 pip install -e .(需 CMake/Python3.10+)。
  2. 验证:pytest test/test_matmul.py,比较 torch matmul 基准。
  3. Profiling:Nsight Compute 检查 TMA throughput >90%,L1 命中 >80%;若低,调 TILE_SIZE 或加 ct.prefetch
  4. 发布:python setup.py bdist_wheel,上传 PyPI,支持 torch.ops.my_fused_gemm
  5. 回滚:动态 shape 用 padding 到 2^n,阈值 <1% 开销;调试用 CT_LOG_LEVEL=3 打印 PTX。

踩坑经验:

  • Tile shape 静态:模板化 kernel 或 runtime dispatch 多版本。
  • 动态 batch:grid_dim [0] 绑定 batch,index=(bid (0)*TILE_SIZE, ...)。
  • 兼容性:Hopper/Blackwell 优先 TMA,Ampere 降级 shared mem copy。
  • 性能回归:每周 nsight 跑 smoke test,监控 roofline(算力 / 带宽界)。

cuTile 将 GPU 编程拉回 Python 主流程:无需 CUDA 专家,3 行 DSL 拿 3x 提速。未来 TorchInductor 或 Triton 借鉴其 Tile 抽象,指日可待。

资料来源

(正文约 950 字)

查看归档