Hotdry.
ai-engineering

VERL中带KL正则化的多GPU PPO训练:可扩展离线RLHF工程参数

VERL框架下多GPU PPO训练的关键配置,包括KL系数、批次大小、FSDP并行策略,实现高效离线RLHF对齐大模型。

在 LLM 对齐管道中,离线 RLHF 是提升模型推理能力的关键步骤,而 PPO 算法结合 KL 正则化能有效防止策略崩溃,确保训练稳定。VERL 作为字节跳动 Seed 团队开源的 RLHF 框架,支持多 GPU 并行 PPO 训练,通过 HybridFlow 编程模型和 3D-HybridEngine 优化,实现 SOTA 吞吐量和 671B 模型规模扩展。这使得离线数据集(如 GSM8K 数学题)能高效转化为高性能对齐模型。

VERL 的 PPO 实现继承了经典 Proximal Policy Optimization 的核心:actor 生成轨迹、critic 评估价值、KL 散度控制策略偏离参考模型。KL 正则化通过自适应系数(如 kl_ctrl.kl_coef 初始 0.001)惩罚过度偏离,公式为 loss = pg_loss - entropy + kl_coef * KL (π_old || π),防止模式崩溃。证据显示,在 Qwen2.5-7B 上训练 GSM8K,KL 控制下准确率从基线提升 2-5 点,同时吞吐量达 12k+ tokens/s(FSDP+vLLM 后端)。“HybridFlow 框架实现了灵活高效的 RLHF 训练,支持 PPO 等算法在多 GPU 上的无缝扩展。”

多 GPU 并行是 VERL 的核心优势,支持 FSDP/FSDP2(训练)、vLLM/SGLang(rollout 生成),结合 tensor_model_parallel_size 和 sequence_parallel 优化通信。3D-HybridEngine 在 rollout-to-train 切换时重分片 actor 模型,消除内存冗余,通信开销降至传统 1/3。离线 RLHF 场景下,先预处理 Parquet 数据(prompt+ground_truth),用规则奖励(如 GSM8K 答案匹配)或 RM 模型评分,避免在线人类反馈瓶颈。

可落地配置清单

1. 环境与安装(单节点 8GPU 起步)

  • Docker: docker pull verlai/verl:app-verl0.5...(预装 PyTorch2.4+vLLM0.8+FlashAttention2.5)
  • 安装: pip install -e . && bash scripts/install_vllm_sglang_mcore.sh
  • 数据预处理: python examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k

2. 核心 PPO 配置(YAML 或 CLI 覆盖)

data:
  train_files: ~/data/gsm8k/train.parquet
  val_files: ~/data/gsm8k/test.parquet
  train_batch_size: 1024  # 全局批次,微批累积
  max_prompt_length: 512
  max_response_length: 256

actor_rollout_ref:
  model:
    path: Qwen/Qwen2.5-7B-Instruct  # HF模型
    use_remove_padding: true  # 序列打包,减填充20-30%
  actor:
    strategy: fsdp2  # FSDP2优先,内存-7%、吞吐+1.5%
    ulysses_sequence_parallel_size: 2  # 长序列优化
    ppo_micro_batch_size_per_gpu: 4-8  # H100:8, A100:4
    ppo_mini_batch_size: 256
    use_dynamic_bsz: true  # 动态批,最大化利用率+50-100%
    ppo_max_token_len_per_gpu: 3072  # 3x(prompt+response)
    clip_range: [0.2, 0.3]  # 裁剪比,防大步更新
    optim:
      lr: 1e-6
  rollout:
    name: vllm
    gpu_memory_utilization: 0.6  # 平衡OOM风险
    tensor_model_parallel_size: 2  # TP=2, DP=4 (8GPU)
  ref:
    log_prob_micro_batch_size_per_gpu: 4

critic:
  model:
    path: Qwen/Qwen2.5-7B-Instruct  # 共享或独立RM
  strategy: fsdp2
  ppo_micro_batch_size_per_gpu: 8  # Critic可大2x
  optim:
    lr: 1e-5

algorithm:
  kl_ctrl:
    type: fixed  # 或adaptive
    kl_coef: 0.001  # 起始值,监控ppo_kl<0.03早停
    target_kl: 0.1
    horizon: 10000
  gamma: 1.0
  lam: 0.95  # GAE lambda
  adv_estimator: gae

reward:  # 离线规则奖励
  gsm8k_rule_reward: true  # 匹配####后答案

trainer:
  n_gpus_per_node: 8
  nnodes: 1  # 多机: torchrun --nnodes=2
  total_epochs: 15
  save_freq: 10
  test_freq: 10
  logger: ["console", "wandb"]  # 项目: "verl-ppo-offline"

3. 启动命令(8xH100 示例)

torchrun -n 8 --nnodes=1 --nproc_per_node=8 \
  -m verl.trainer.main_ppo \
  +上述覆盖参数

4. 监控与调优要点

  • 指标阈值:
    指标 健康范围 异常行动
    ppo_kl 0.01-0.03 >0.05 增大 kl_coef,<0.005 减小
    grad_norm <10 >100 减 lr/clip
    val/test_score 递增 停滞 > 3epoch 早停
    throughput >10k t/s/GPU 查 Nsight: gen/ref 瓶颈
  • Nsight 剖析: global_profiler.steps=[1,5,10],分析 gen/ref/update 占比。
  • 常见风险与回滚:
    1. KL 崩溃: 轨迹重复→kl_coef*=2,clip_ratio=0.1。
    2. OOM: micro_bsz/2,启用 gradient_checkpointing+CPU_offload。
    3. 收敛慢: entropy_coeff=0.01 鼓励探索;多轮 PPO_epochs=4。
    4. 多机同步: NCCL_IB_DISABLE=0,torchrun+--master_addr。

5. 扩展到离线大规模

  • 数据集: Parquet 格式,prompt_key+reward_ground_truth;规模 10w + 样本。
  • LoRA RL: lora_rank=32/128,内存减半,8x80G 训 70B。
  • 离线变体: GRPO(无 critic,组相对优势),DAPO(数学 SOTA)。
  • Checkpoint 合并: python -m verl.model_merger merge --backend fsdp --local_dir=checkpoints/.../actor --target_dir=hf_model

实际部署中,从单 GPU 验证 KL 稳定性→8GPU 规模化→多节点(SkyPilot/KubeRay)。VERL 在 Doubao-1.5-pro 等生产中验证,AIME pass@1 达 70%,证明其工程可靠性。

资料来源:

查看归档