Hotdry.
ai-systems

TiDAR:扩散思考、自回归表达的混合生成架构

TiDAR在单一前向传播中融合扩散并行草稿生成与自回归拒绝采样,实现LLM生成吞吐量4.71~5.91倍提升,同时保持AR级质量,详解训练参数、推理优化与部署清单。

大型语言模型(LLM)的自回归(AR)生成范式虽质量可靠,但受内存带宽瓶颈制约,GPU 利用率低下,仅逐 token 输出导致实际吞吐低下。扩散模型(dLM)支持并行多 token 生成,利用 “免费 token 槽” 提速,却因破坏因果依赖而质量衰减。TiDAR(Think in Diffusion, Talk in Autoregression)提出序列级混合架构,在单模型单前向传播中并行执行扩散 “思考”(草稿生成)与 AR “表达”(质量验证),首次弥合效率 - 质量鸿沟,实现 1.5B 模型每步 7.45 token(4.71× AR)、8B 模型 8.25 token(5.91× AR),在 HumanEval 等任务匹配 AR 性能。

TiDAR 核心在于结构化注意力掩码(structured attention masks),将输入序列分为三段:前缀(prefix,已确认 token,因果注意力,支持 KV 缓存)、当前草稿(drafts,上步预生成 token,自回归拒绝采样验证)、预草稿(pre-drafts,下步扩散并行生成,双向块注意力)。全过程无需额外模型或多 pass,利用 GPU 内存带宽饱和时的免费计算槽(H100 ctx=4096 时~100 token 延迟不变)。

训练采用双模式联合损失:[\mathcal {L} = \lambda \mathcal {L}{AR} + (1-\lambda) \mathcal{L}{diff} ],其中 (\mathcal {L}{AR}) 为因果交叉熵,(\mathcal {L}{diff}) 为扩散区全掩码(full-mask)交叉熵。关键创新:扩散区全设 [mask],提供稠密监督信号,避免随机 mask 分布偏移,确保训练 - 推理一致。推荐参数:λ=1.0(均衡),序列扩展 2 倍(draft_len=8~16),学习率 1e-4,warmup 10%,基于 Qwen2.5 预训 32k 步。消融显示 full-mask 提升 HumanEval 43.29%(接近 Qwen2.5),Pareto 前沿优于 Block Diffusion。

推理流程无缝:1) 输入 prefix + 上 drafts(验证)+ 全 mask pre-drafts;2) 自回归采样 drafts(top-k/p=0.95 拒绝,accept 率 > 80% 阈值警报);3) 基于 accept 前缀扩散采样 pre-drafts(一步去噪,temp=1.0);4) KV 缓存精确切片复用(预 init mask 动态 slice,避免双向重算)。H100 基准:bs=1, ctx=4096, TiDAR-1.5B 吞吐 4.71× Qwen2.5,8B 达 5.91×,优于 EAGLE-3(draft 弱)和 Llada(质量低)。“TiDAR 在编码任务中表现优异,每前向生成 7.45 个词。”

工程落地参数清单:

  • Draft 配置:draft_len=8(1.5B)/16(8B),>16 mem 峰值升 20%,测试峰值 < 80% GPU mem 回滚 len=4。
  • 采样阈值:AR 拒绝 temp=1.0, nucleus p=0.95;扩散一步,β~cosine scheduler(T=1000 步预训)。
  • Mask 实现:FlexAttention/PyTorch 2.5+,预 gen 3x3 block mask(prefix causal, drafts causal, pre-drafts bidirectional)。
  • 优化内核:FlashAttention-2 KV 复用,动态 slice O (1);bs>1 扩展需 custom kernel 验证(当前 bs=1 最稳)。
  • 监控指标:accept_rate(>0.7 正常,<0.5 降 draft_len);throughput tokens/s;OOM fallback 纯 AR 模式(loss<1% 质量)。 部署脚本示例(伪码):
def tidar_step(prefix, kv_cache, draft_len=8):
    drafts = mask_pre_drafts(draft_len)  # full mask
    logits = model(prefix + prev_drafts + drafts, mask=structured_mask, kv_cache=kv_cache)
    accept_mask = ar_reject_sample(logits[:drafts], temp=1.0)
    new_prefix = prefix + accept_mask * prev_drafts
    new_kv = kv_cache.slice(new_prefix)
    pre_drafts = diffusion_sample(logits[drafts:], new_prefix)
    return new_prefix, pre_drafts, new_kv

回滚策略:若 accept_rate<0.6,切换纯 AR(质量保证,吞吐降);长 ctx>8k,chunked inference 分段。

TiDAR 虽聚焦 bs=1 autoreg-like 场景,长 ctx 训练翻倍易 OOM,大 bs 需并行 mask 优化,但为低延迟服务(如聊天 / 代码补全)提供即插即用加速,未来可扩展 VLA 代理多步规划。实际部署中,从 1.5B 起步,渐进 8B,结合 vLLM 包装,支持 trust-diffusion 模式(纯扩散 fallback)。

资料来源

查看归档