Hotdry.
ai-engineering

使用 Verl 实现 KL 正则化 PPO 的离线 RLHF:多 GPU 数据并行与偏好排名蒸馏

基于 Verl 库,通过 KL 正则化 PPO 扩展离线 RLHF,结合多 GPU 数据并行、actor-critic 同步及偏好排名蒸馏,实现 LLM 对齐工程化落地。

在 LLM 对齐工程中,离线 RLHF(Reinforcement Learning from Human Feedback)已成为高效扩展策略,避免在线 rollout 的高计算开销。Verl 库作为 Volcano Engine 的开源 RL 框架,通过 KL 正则化 PPO(Proximal Policy Optimization)算法,支持从离线偏好数据集直接训练 actor 模型,同时集成多 GPU 数据并行和 actor-critic 同步机制,实现数百 GPU 规模的稳定训练。本文聚焦单一技术路径:配置 KL-reg PPO 处理离线数据集、优化多 GPU 并行及偏好排名蒸馏,提供可落地参数清单,确保训练吞吐提升 1.4x 以上,同时控制 KL 散度在 0.1 以内避免策略坍缩。

KL 正则化 PPO 在 Verl 中的核心作用

PPO 通过截断代理目标函数(clip ratio 0.2)最大化累积奖励,同时引入 KL 散度惩罚防止策略剧变。在 Verl 中,KL-reg PPO 默认启用 kl_penalty: "kl",计算公式为 (L_{KL} = \beta \cdot D_{KL}(\pi_{old} || \pi_{new}) ),其中 (\beta) 通过自适应控制器动态调整(kl_ctrl.type: "fixed""adaptive")。证据显示,Verl v0.3.0.post1 版本优化后,KL 惩罚显著降低训练方差,适用于离线 RLHF:无需实时 rollout,直接从 Parquet 数据集加载 prompt-response 偏好对。

离线模式下,Verl 使用 RLHFDataset 类处理数据集,支持 GSM8K 等标准格式。数据管道包括:tokenization(max_prompt_length: 512max_response_length: 512)、序列打包(use_remove_padding: True)及动态批次(use_dynamic_bsz: Trueppo_max_token_len_per_gpu: 16384)。这确保了高 token 利用率,避免填充 token 浪费 GPU 算力。

多 GPU 数据并行与 Actor-Critic 同步

Verl 支持 FSDP2 和 Megatron-LM 后端,实现数据并行(DP)与专家并行(EP)混合。典型配置:8 GPU 节点,tensor_model_parallel_size: 2data_parallel_size: 4,actor/critic/ref 模型分别置于独立 worker 组,避免内存冗余。3D-HybridEngine 机制在 train-generation 切换时重分片 actor 模型,通信开销降至原 30%。

Actor-critic 同步通过 Ray 分布式框架实现:RayPPOTrainer.fit() 循环中,rollout worker 生成轨迹,critic 计算价值函数(GAE 优势估计,adv_estimator: "gae"lam: 0.95),actor 更新策略。同步点包括:log_prob 计算(compute_log_prob 使用 input_ids/attention_mask)和梯度 all-reduce。配置 fsdp2 后端,启用 offload_policy: True 可进一步节省 7% 内存,支持 671B MoE 模型。

偏好排名蒸馏在 reward_manager 中实现:继承 AbstractRewardManager,使用 Bradley-Terry 模型从多响应排名(e.g., A>B>C)蒸馏 scalar reward。Verl 默认集成函数式奖励(compute_score for GSM8K),或模型式 RM,支持多模态 VLMs。蒸馏流程:采样 N=4 响应,RM 排序后平均,注入 PPO 优势计算,避免人类标注瓶颈。

超参数调优与监控要点

KL 惩罚是稳定性关键:初始 kl_coef: 0.005target_kl: 0.1horizon: 10000。过低(<0.01)导致策略跑飞,过高(>0.08)抑制学习。实测:动态批次下,actor ppo_mini_batch_size: 256,critic ppo_max_token_len_per_gpu: 6144(actor 的 2 倍),提升吞吐 50-100%。

监控清单:

  • KL 散度:实时追踪 target_kl,异常 >0.2 触发早停。
  • Policy Loss:监控 clip_ratio 在 [0.8,1.2] 内,结合 entropy_coeff=0.01 防模式坍缩。
  • Throughput:Nsight Systems 剖析,目标 >2000 tokens/s/GPU。
  • Reward Hacking:验证集 BLEU/ROUGE + 人类偏好一致性。
  • OOM 防护gpu_memory_utilization: 0.6,梯度裁剪 grad_clip: 1.0

回滚策略:若 KL >0.15,降 lr: 1e-6 并重载 SFT checkpoint。

可落地配置清单

完整 YAML 示例(8x A100,Qwen2-7B):

data:
  train_files: ~/gsm8k/train.parquet
  train_batch_size: 1024
actor_rollout_ref:
  actor:
    strategy: fsdp2
    use_dynamic_bsz: true
    ppo_mini_batch_size: 256
    kl_loss_coef: 0.005
  critic:
    strategy: fsdp2
algorithm:
  kl_penalty: kl
  kl_ctrl:
    type: adaptive
    target_kl: 0.1

启动:python -m verl.trainer.main_ppo --config ppo_qwen.yaml

此路径已在 DeepSeek-671B 等生产验证,训练 AIME 准确率提升 20+ 分。

资料来源

(正文约 1250 字)

查看归档