diff --git a/examples/medqa/rl.toml b/examples/medqa/rl.toml new file mode 100644 index 0000000..fda6777 --- /dev/null +++ b/examples/medqa/rl.toml @@ -0,0 +1,117 @@ +# Qwen3.5-9B GRPO on MedQA-USMLE-4-options. +# +# Notes specific to Qwen3.5 as a vision-language model in PRIME-RL: +# - max_async_level = 1 + filesystem weight broadcast: the NCCL broadcaster +# fails to initialize on the 9B VLM (init_broadcaster 500). Filesystem +# broadcast is reliable and `maybe_clean` keeps only async_level+1 = 2 +# broadcast dirs on disk during the run. +# - optim_cpu_offload = true: needed to fit FSDP + optimizer state into +# 2 train GPUs at seq_len=8192. Without it, OOMs on H100. +# - Do NOT enable [trainer.model.compile]: silently hangs on the 9B VLM +# during the first forward pass. +# - enable_thinking must be passed through chat_template_kwargs in BOTH +# the orchestrator sampling block and the eval sampling block — Qwen3.5 +# is an instruct model that requires the flag to emit tokens. +# +# 8 GPUs total: 2 train (FSDP) + 5 inference (dp=5) + 1 free for an external +# judge server (medqa itself doesn't need a judge; this layout matches the +# eval-time setup that adds judge-scored open-ended benches). + +max_steps = 100 +seq_len = 8192 +max_async_level = 1 + +[model] +name = "Qwen/Qwen3.5-9B" + +[weight_broadcast] +type = "filesystem" + +[orchestrator] +batch_size = 64 +rollouts_per_example = 8 +max_inflight_rollouts = 128 + +[orchestrator.sampling] +max_tokens = 8192 +extra_body = { chat_template_kwargs = { enable_thinking = true } } + +[orchestrator.buffer] +online_difficulty_filtering = true +easy_threshold = 1.0 +hard_threshold = 0.0 + +[[orchestrator.env]] +id = "medqa" +name = "medqa" +args = { shuffle_answers = true } + +[orchestrator.eval] +interval = 50 + +[orchestrator.eval.sampling] +max_tokens = 12288 +extra_body = { chat_template_kwargs = { enable_thinking = true } } + +[[orchestrator.eval.env]] +id = "medqa" +name = "medqa" +args = { shuffle_answers = false } +num_examples = 200 +rollouts_per_example = 1 + +[[orchestrator.eval.env]] +id = "medbullets" +name = "medbullets" +args = { shuffle_answers = false } +num_examples = 100 +rollouts_per_example = 1 + +[[orchestrator.eval.env]] +id = "metamedqa" +name = "metamedqa" +args = { shuffle_answers = false } +num_examples = 100 +rollouts_per_example = 1 + +[[orchestrator.eval.env]] +id = "head-qa" +name = "head-qa" +args = { shuffle_answers = false } +num_examples = 50 +rollouts_per_example = 1 + +[trainer.model] +name = "Qwen/Qwen3.5-9B" +attn = "flash_attention_3" +seq_len = 8192 +optimization_dtype = "bfloat16" +reduce_dtype = "bfloat16" +dp_replicate = 1 +reshard_after_forward = true +optim_cpu_offload = true + +[trainer.model.ac] +freq = 1 + +[trainer.optim] +lr = 1.5e-6 + +[ckpt] +interval = 50 +resume_step = -1 + +[deployment] +type = "single_node" +num_train_gpus = 2 +num_infer_gpus = 5 + +[inference] +gpu_memory_utilization = 0.90 + +[inference.parallel] +dp = 5 + +[inference.model] +max_model_len = 24576 +dtype = "bfloat16"