大型语言模型(LLM)的自回归(AR)生成范式虽质量可靠,但受内存带宽瓶颈制约,GPU利用率低下,仅逐token输出导致实际吞吐低下。扩散模型(dLM)支持并行多token生成,利用“免费token槽”提速,却因破坏因果依赖而质量衰减。TiDAR(Think in Diffusion, Talk in Autoregression)提出序列级混合架构,在单模型单前向传播中并行执行扩散“思考”(草稿生成)与AR“表达”(质量验证),首次弥合效率-质量鸿沟,实现1.5B模型每步7.45 token(4.71× AR)、8B模型8.25 token(5.91× AR),在HumanEval等任务匹配AR性能。
TiDAR核心在于结构化注意力掩码(structured attention masks),将输入序列分为三段:前缀(prefix,已确认token,因果注意力,支持KV缓存)、当前草稿(drafts,上步预生成token,自回归拒绝采样验证)、预草稿(pre-drafts,下步扩散并行生成,双向块注意力)。全过程无需额外模型或多pass,利用GPU内存带宽饱和时的免费计算槽(H100 ctx=4096时~100 token延迟不变)。
训练采用双模式联合损失:[ \mathcal{L} = \lambda \mathcal{L}{AR} + (1-\lambda) \mathcal{L}{diff} ],其中(\mathcal{L}{AR})为因果交叉熵,(\mathcal{L}{diff})为扩散区全掩码(full-mask)交叉熵。关键创新:扩散区全设[mask],提供稠密监督信号,避免随机mask分布偏移,确保训练-推理一致。推荐参数:λ=1.0(均衡),序列扩展2倍(draft_len=8~16),学习率1e-4,warmup 10%,基于Qwen2.5预训32k步。消融显示full-mask提升HumanEval 43.29%(接近Qwen2.5),Pareto前沿优于Block Diffusion。
推理流程无缝:1) 输入prefix + 上drafts(验证)+ 全mask pre-drafts;2) 自回归采样drafts(top-k/p=0.95拒绝,accept率>80%阈值警报);3) 基于accept前缀扩散采样pre-drafts(一步去噪,temp=1.0);4) KV缓存精确切片复用(预init mask动态slice,避免双向重算)。H100基准:bs=1, ctx=4096, TiDAR-1.5B吞吐4.71× Qwen2.5,8B达5.91×,优于EAGLE-3(draft弱)和Llada(质量低)。“TiDAR在编码任务中表现优异,每前向生成7.45个词。”
工程落地参数清单:
- Draft配置:draft_len=8(1.5B)/16(8B),>16 mem峰值升20%,测试峰值<80% GPU mem回滚len=4。
- 采样阈值:AR拒绝temp=1.0, nucleus p=0.95;扩散一步,β~cosine scheduler(T=1000步预训)。
- Mask实现:FlexAttention/PyTorch 2.5+,预gen 3x3 block mask(prefix causal, drafts causal, pre-drafts bidirectional)。
- 优化内核:FlashAttention-2 KV复用,动态slice O(1);bs>1扩展需custom kernel验证(当前bs=1最稳)。
- 监控指标:accept_rate(>0.7正常,<0.5降draft_len);throughput tokens/s;OOM fallback纯AR模式(loss<1%质量)。
部署脚本示例(伪码):
def tidar_step(prefix, kv_cache, draft_len=8):
drafts = mask_pre_drafts(draft_len) # full mask
logits = model(prefix + prev_drafts + drafts, mask=structured_mask, kv_cache=kv_cache)
accept_mask = ar_reject_sample(logits[:drafts], temp=1.0)
new_prefix = prefix + accept_mask * prev_drafts
new_kv = kv_cache.slice(new_prefix)
pre_drafts = diffusion_sample(logits[drafts:], new_prefix)
return new_prefix, pre_drafts, new_kv
回滚策略:若accept_rate<0.6,切换纯AR(质量保证,吞吐降);长ctx>8k,chunked inference分段。
TiDAR虽聚焦bs=1 autoreg-like场景,长ctx训练翻倍易OOM,大bs需并行mask优化,但为低延迟服务(如聊天/代码补全)提供即插即用加速,未来可扩展VLA代理多步规划。实际部署中,从1.5B起步,渐进8B,结合vLLM包装,支持trust-diffusion模式(纯扩散fallback)。
资料来源: