Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/medqa/rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Qwen3.5-9B GRPO on MedQA-USMLE-4-options.
#
# Notes specific to Qwen3.5 as a vision-language model in PRIME-RL:
# - max_async_level = 1 + filesystem weight broadcast: the NCCL broadcaster
# fails to initialize on the 9B VLM (init_broadcaster 500). Filesystem
# broadcast is reliable and `maybe_clean` keeps only async_level+1 = 2
# broadcast dirs on disk during the run.
# - optim_cpu_offload = true: needed to fit FSDP + optimizer state into
# 2 train GPUs at seq_len=8192. Without it, OOMs on H100.
# - Do NOT enable [trainer.model.compile]: silently hangs on the 9B VLM
# during the first forward pass.
# - enable_thinking must be passed through chat_template_kwargs in BOTH
# the orchestrator sampling block and the eval sampling block — Qwen3.5
# is an instruct model that requires the flag to emit <think> tokens.
#
# 8 GPUs total: 2 train (FSDP) + 5 inference (dp=5) + 1 free for an external
# judge server (medqa itself doesn't need a judge; this layout matches the
# eval-time setup that adds judge-scored open-ended benches).

max_steps = 100
seq_len = 8192
max_async_level = 1

[model]
name = "Qwen/Qwen3.5-9B"

[weight_broadcast]
type = "filesystem"

[orchestrator]
batch_size = 64
rollouts_per_example = 8
max_inflight_rollouts = 128

[orchestrator.sampling]
max_tokens = 8192
extra_body = { chat_template_kwargs = { enable_thinking = true } }

[orchestrator.buffer]
online_difficulty_filtering = true
easy_threshold = 1.0
hard_threshold = 0.0

[[orchestrator.env]]
id = "medqa"
name = "medqa"
args = { shuffle_answers = true }

[orchestrator.eval]
interval = 50

[orchestrator.eval.sampling]
max_tokens = 12288
extra_body = { chat_template_kwargs = { enable_thinking = true } }

[[orchestrator.eval.env]]
id = "medqa"
name = "medqa"
args = { shuffle_answers = false }
num_examples = 200
rollouts_per_example = 1

[[orchestrator.eval.env]]
id = "medbullets"
name = "medbullets"
args = { shuffle_answers = false }
num_examples = 100
rollouts_per_example = 1

[[orchestrator.eval.env]]
id = "metamedqa"
name = "metamedqa"
args = { shuffle_answers = false }
num_examples = 100
rollouts_per_example = 1

[[orchestrator.eval.env]]
id = "head-qa"
name = "head-qa"
args = { shuffle_answers = false }
num_examples = 50
rollouts_per_example = 1

[trainer.model]
name = "Qwen/Qwen3.5-9B"
attn = "flash_attention_3"
seq_len = 8192
optimization_dtype = "bfloat16"
reduce_dtype = "bfloat16"
dp_replicate = 1
reshard_after_forward = true
optim_cpu_offload = true

[trainer.model.ac]
freq = 1

[trainer.optim]
lr = 1.5e-6

[ckpt]
interval = 50
resume_step = -1

[deployment]
type = "single_node"
num_train_gpus = 2
num_infer_gpus = 5

[inference]
gpu_memory_utilization = 0.90

[inference.parallel]
dp = 5

[inference.model]
max_model_len = 24576
dtype = "bfloat16"