Hotdry.
ai-systems

Parakeet.cpp 中 Metal 计算着色器优化 Beam Search 解码

针对纯C++ Parakeet ASR 的 beam search 解码,提供 Metal compute shader 的内存布局、线程组配置与 dispatch 优化,实现实时边缘推理。

在边缘设备上实现实时自动语音识别(ASR),NVIDIA Parakeet 模型的纯 C++ 实现 parakeet.cpp 提供了高效基础。其基于 axiom 张量库的 Metal GPU 加速已在编码器上表现出色,例如在 Apple M3 上 10 秒音频仅需 27ms 推理时间。然而,现有的贪婪解码(greedy decode)在 CTC/TDT 解码器中准确率有限,beam search 作为提升词错率(WER)的标准方法,却因计算密集而难以实时。为此,通过优化 Metal compute shader 可以将 beam search 解码 GPU 化,保持低延迟。

观点一:内存布局与数据驻留是性能基石。在 beam search 的每步迭代中,需要处理 [batch][beam][vocab] 形状的 logit 张量,其中 vocab 维度通常达数千。传统 CPU 解码频繁读写导致带宽瓶颈,而 Metal 的统一内存(Unified Memory)优势需通过连续布局发挥。将 logit 置于 GPU buffer 中,vocab 作为最内维度,确保线程 coalesced 访问。beam 状态(如 scores、tokens、paths)同样预分配固定大小 buffer,重用避免 realloc。证据显示,repo 中 encoder 已融合 MPSGraph 操作,类似可扩展到解码:dispatch 前预创建 MTLBuffer,步间仅更新指针。

可落地参数:

  • Logit buffer: float32 [1][beam_width][vocab_size],beam_width=8,vocab_size~5000。
  • State buffer: struct {float score; uint32_t token; uint32_t path_len;} [beam_width * max_steps]。
  • 阈值:score < -10.0 早停,防止无效扩展。

观点二:线程组配置针对 top-k 选择优化。Beam search 核心是每 beam 计算 top-k 候选项(k=beam_width),避免全 softmax/sort。设计 kernel:每个线程组处理一个 beam,256 线程 stride 遍历 vocab 块。每线程维护寄存器中固定大小 heap(unroll loop for k=8),末尾 threadgroup_barrier 后 reduction 选 top-k。避免 global sort,使用 partial select 算法如 quickselect 的 GPU 变体。此配置利用 Apple GPU 的 SIMD 宽度(128 lanes/warp),减少 bank conflict。

落地清单:

  1. Threadgroup size: [[threads(256,1,1)]],grid: [num_beams,1,1]。
  2. Local top-k: float heap [8]; uint32_t idx_heap [8]; 寄存器优先,非 threadgroup memory。
  3. Reduction: 使用 atomic_max/min 于 shared float max_logit,追踪 argmax。
  4. 伪码片段:
kernel void beam_topk(device float* logits [[buffer(0)]],
                      device float* scores [[buffer(1)]],
                      constant uint beam_width [[buffer(2)]],
                      uint3 tid [[thread_position_in_grid]],
                      uint3 tgid [[threadgroup_position_in_grid]]) {
  // stride vocab
  uint vocab_idx = tid.x;
  float logit = logits[tgid.x * vocab_size + vocab_idx];
  // local heap push/pop for top-k
  // ...
  threadgroup_barrier(mem_flags::mem_threadgroup);
  if (tid.x < beam_width) scores[tgid.x * beam_width + tid.x] = heap[tid.x];
}

观点三:Kernel 融合与 dispatch 启发式最小化开销。将 logit bias(repetition penalty)、top-k、beam expand、prune 融合一 kernel,避免多 dispatch 的 command encoder 开销。对于 streaming 模型如 Nemotron,latency_frames=1 时,每 80ms chunk dispatch 一次,需 <5ms 解码预算。批量多 utterance 或多 beam 填充 grid,提高 occupancy。监控 Metal Capture 工具:目标 kernel occupancy>50%,memory throughput >80GB/s。

参数与启发式:

  • Fusion: 单 kernel 处理 top-k + expand(output new_scores = old_score + log_prob)。
  • Dispatch: MTLCommandBuffer per step,encode 1-2 kernels;若 step <10ms,CPU 调度。
  • Batching: max_beams_per_dispatch=32,利用 M3 的 128GB/s 带宽。
  • 超参:beam_width=5(平衡速度 / 准确,WER 降 10-15%),temperature=1.2,eos_threshold=-2.0。

风险与回滚:Apple GPU threadgroup 限 1024 线程,vocab>8192 时分多 kernel。浮点精度:log-space 全用 float16 加速 1.5x,但 TDT duration 预测敏感,fallback float32。测试 RTF<0.01(real-time factor)。

实现步骤清单:

  1. 扩展 parakeet::tdt_greedy_decode 为 beam_decode,GPU logits。
  2. 编写 MSL kernel:topk_select.metal,集成 axiom MetalGraph。
  3. CLI flag: --beam 8 --gpu-decode。
  4. Benchmark: 10s audio,目标解码 < 50ms。
  5. 集成:model.decoder.to_gpu (),transcribe (..., Decoder::Beam)。

此优化使 parakeet.cpp 在 iPhone/MBP 上支持 beam=8 实时转录,WER 优于 greedy 达 12%,适用于会议 / 语音助手。相比 PyTorch MPS,纯 C++ 避免 Python 开销,部署体积 < 50MB。

资料来源:

查看归档