Skip to content

Commit f997b96

Browse files
Support additional slurm features (priority, email notifications, dependency), update primerl pin (#6)
* Add SLURM priority/mail/resume/account options, switch to gpus-per-task with 16 CPU/GPU defaults, and sync docs/tests. * Add dependency and test-only flags to medarc_slurm * formatting fixes * group arguments, support passing primerl config as arg or option * update primerl, now uses IPO loss * update templates * match upstream config override behavior
1 parent 5aedf35 commit f997b96

11 files changed

Lines changed: 754 additions & 320 deletions

File tree

README.md

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ uv sync
3131
For flash attention support:
3232

3333
```bash
34-
uv sync --extra flash-attn # flash-attn v2
35-
uv sync --extra flash-attn-3 # flash-attn v2 + v3 (use for H100s)
34+
uv sync --extra flash-attn-2 # flash-attn 2
35+
uv sync --extra flash-attn-3 # flash-attn 2 + 3 (use for H100s)
36+
uv sync --extra flash-attn-4 # flash-attn 2, 3, & 4 (use for B200s)
3637
```
3738

3839
## medarc_slurm
@@ -41,20 +42,41 @@ uv sync --extra flash-attn-3 # flash-attn v2 + v3 (use for H100s)
4142

4243
```bash
4344
# SFT: single torchrun job
44-
medarc_slurm sft config.toml --output-dir runs/my-sft --gpus 2
45+
medarc_slurm sft --config config.toml --output-dir runs/my-sft --gpus 2
4546

4647
# RL: splits GPUs between vLLM inference and training
47-
medarc_slurm rl config.toml --output-dir runs/my-rl --train-gpus 1 --infer-gpus 2
48+
medarc_slurm rl --config config.toml --output-dir runs/my-rl --train-gpus 1 --infer-gpus 2
4849

4950
# RL: share a single GPU between inference and training
50-
medarc_slurm rl config.toml --output-dir runs/my-rl --single-gpu
51+
medarc_slurm rl --config config.toml --output-dir runs/my-rl --single-gpu
52+
53+
# SFT: low-priority queue + email notifications + resume from latest checkpoint
54+
medarc_slurm sft --config config.toml \
55+
--output-dir runs/my-sft \
56+
--gpus 2 \
57+
--priority low \
58+
--mail all \
59+
--mail-user email@domain.com \
60+
--slurm-resume
61+
62+
# Validate an RL submission (including dependency syntax) without creating a job
63+
medarc_slurm rl --config config.toml \
64+
--output-dir runs/my-rl \
65+
--train-gpus 1 \
66+
--infer-gpus 2 \
67+
--dependency afterok:123456 \
68+
--test-only
5169
```
5270

5371
Generated artifacts are written to `--output-dir`:
5472
- `sft.sh` or `rl.sh` — the SLURM batch script
5573
- `configs/` — resolved TOML subconfigs passed to each component
5674

57-
You can pass PRIME-RL config overrides directly as extra flags (for example `--wandb.project my-proj --wandb.name my-run`). You may also insert `--` before passthrough overrides for readability, but it is optional.
75+
You can pass PRIME-RL config overrides directly as extra flags (for example `--wandb.project my-proj --wandb.name my-run`). You may also insert `--` before passthrough overrides for readability, but it is optional. To layer multiple PRIME-RL configs, repeat `--config` with later files overriding earlier ones.
76+
77+
`medarc_slurm` now defaults `--account` to `training`. You can override it with `--account <name>`.
78+
Email mode is `--mail all` or `--mail begin_end` (with `--mail-user`).
79+
Use `--dependency "<expr>"` to pass SLURM dependencies and `--test-only` to run `sbatch` validation without submitting.
5880

5981
Run `medarc_slurm sft --help` or `medarc_slurm rl --help` for more details on available options.
6082

medarc_rl/launchers/rl_local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from subprocess import Popen
1717
from threading import Event, Thread
1818

19+
from pydantic_config import cli
1920
from prime_rl.configs.rl import RLConfig
2021
from prime_rl.entrypoints.rl import write_subconfigs
2122
from prime_rl.utils.logger import setup_logger
2223
from prime_rl.utils.pathing import get_log_dir
2324
from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process
24-
from prime_rl.utils.pydantic_config import parse_argv
2525
from prime_rl.utils.utils import get_free_port
2626

2727

@@ -315,7 +315,7 @@ def rl_local(config: RLConfig) -> None:
315315

316316

317317
def main() -> None:
318-
config = parse_argv(RLConfig)
318+
config = cli(RLConfig)
319319
rl_local(config)
320320

321321

medarc_rl/medarc_slurm.py

Lines changed: 197 additions & 60 deletions
Large diffs are not rendered by default.

medarc_rl/medarc_train.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import typer
99
from pydantic import ValidationError
10-
from typer import Argument, Option
10+
from typer import Option
1111

1212
from medarc_rl.utils import TYPER_PASSTHROUGH_CONTEXT, _load_settings_from_toml, _write_toml, extra_config_args
1313

@@ -25,6 +25,26 @@ def _gpu_ids(n: int) -> str:
2525
return ",".join(str(i) for i in range(n))
2626

2727

28+
def _enable_sft_resume(config, *, enabled: bool) -> None:
29+
if not enabled:
30+
return
31+
if config.ckpt is None:
32+
from prime_rl.configs.trainer import CheckpointConfig as TrainerCheckpointConfig
33+
34+
config.ckpt = TrainerCheckpointConfig()
35+
config.ckpt.resume_step = -1
36+
37+
38+
def _enable_rl_resume(config, *, enabled: bool) -> None:
39+
if not enabled:
40+
return
41+
if config.ckpt is None:
42+
from prime_rl.configs.rl import SharedCheckpointConfig
43+
44+
config.ckpt = SharedCheckpointConfig()
45+
config.ckpt.resume_step = -1
46+
47+
2848
@app.command(
2949
context_settings=TYPER_PASSTHROUGH_CONTEXT,
3050
help=(
@@ -33,19 +53,24 @@ def _gpu_ids(n: int) -> str:
3353
)
3454
def sft(
3555
ctx: typer.Context,
36-
config_toml: Annotated[Path, Argument(metavar="CONFIG_TOML", help="Path to the PRIME-RL SFT trainer TOML.")],
3756
output_dir: Annotated[Path, Option("--output-dir", file_okay=False, dir_okay=True, help="Directory to write resolved configs and checkpoints.")],
57+
config: Annotated[list[Path] | None, Option("--config", "--config-toml", help="One or more PRIME-RL SFT trainer TOMLs. Repeat `--config` to layer files with later files overriding earlier ones.")] = None,
3858
gpus: Annotated[int, Option("--gpus", min=1, max=8, help="Number of GPUs for SFT.")] = 1,
59+
resume: Annotated[bool, Option("--resume/--no-resume", help="Resume from the latest checkpoint (sets ckpt.resume_step=-1).")] = False,
3960
) -> None: # fmt: skip
4061
from prime_rl.configs.sft import SFTConfig
4162

63+
config_tomls = list(config or [])
64+
if not config_tomls:
65+
raise typer.BadParameter("Missing config path. Pass one or more --config values.", param_hint="--config")
4266
output_dir = output_dir.expanduser().resolve()
4367
config = _load_settings_from_toml(
4468
SFTConfig,
45-
config_toml.expanduser().resolve(),
69+
[config_toml.expanduser().resolve() for config_toml in config_tomls],
4670
output_dir=output_dir,
47-
extra_cli_args=extra_config_args(ctx),
71+
extra_cli_args=extra_config_args(ctx, positional_count=0),
4872
)
73+
_enable_sft_resume(config, enabled=resume)
4974

5075
config_dir = output_dir / "configs"
5176
config_dir.mkdir(parents=True, exist_ok=True)
@@ -82,37 +107,41 @@ def sft(
82107
)
83108
def rl(
84109
ctx: typer.Context,
85-
config_toml: Annotated[Path, Argument(metavar="CONFIG_TOML", help="Path to the PRIME-RL RL TOML.")],
86110
output_dir: Annotated[Path, Option("--output-dir", file_okay=False, dir_okay=True, help="Directory to write resolved configs and checkpoints.")],
111+
config: Annotated[list[Path] | None, Option("--config", "--config-toml", help="One or more PRIME-RL RL TOMLs. Repeat `--config` to layer files with later files overriding earlier ones.")] = None,
87112
train_gpus: Annotated[int, Option("--train-gpus", min=1, max=4, help="Number of GPUs for training.")] = 1,
88113
infer_gpus: Annotated[int, Option("--infer-gpus", min=1, max=7, help="Number of GPUs for inference.")] = 1,
89114
single_gpu: Annotated[bool, Option("--single-gpu", help="Share a single GPU between trainer and inference.")] = False,
115+
resume: Annotated[bool, Option("--resume/--no-resume", help="Resume from the latest checkpoint (sets ckpt.resume_step=-1).")] = False,
90116
) -> None: # fmt: skip
91117
from prime_rl.configs.rl import RLConfig
92118

93119
from medarc_rl.launchers.rl_local import rl_local
94120

121+
config_tomls = list(config or [])
122+
if not config_tomls:
123+
raise typer.BadParameter("Missing config path. Pass one or more --config values.", param_hint="--config")
95124
output_dir = output_dir.expanduser().resolve()
96125
train_gpus = 1 if single_gpu else train_gpus
97126
infer_gpus = 1 if single_gpu else infer_gpus
98-
total_gpus = 1 if single_gpu else (train_gpus + infer_gpus)
127+
gpus = 1 if single_gpu else (train_gpus + infer_gpus)
99128

100-
if not single_gpu and total_gpus < 2:
129+
if not single_gpu and gpus < 2:
101130
raise typer.BadParameter(
102-
f"Total GPUs must be at least 2, got train_gpus ({train_gpus}) + infer_gpus ({infer_gpus}) = {total_gpus}.",
131+
f"Total GPUs must be at least 2, got train_gpus ({train_gpus}) + infer_gpus ({infer_gpus}) = {gpus}.",
103132
param_hint="--train-gpus/--infer-gpus",
104133
)
105-
if total_gpus > 8:
134+
if gpus > 8:
106135
raise typer.BadParameter(
107-
f"Total GPUs must be at most 8, got train_gpus ({train_gpus}) + infer_gpus ({infer_gpus}) = {total_gpus}.",
136+
f"Total GPUs must be at most 8, got train_gpus ({train_gpus}) + infer_gpus ({infer_gpus}) = {gpus}.",
108137
param_hint="--train-gpus/--infer-gpus",
109138
)
110139

111140
try:
112141
config = _load_settings_from_toml(
113142
RLConfig,
114-
config_toml.expanduser().resolve(),
115-
extra_cli_args=extra_config_args(ctx),
143+
[config_toml.expanduser().resolve() for config_toml in config_tomls],
144+
extra_cli_args=extra_config_args(ctx, positional_count=0),
116145
output_dir=output_dir,
117146
deployment={"type": "single_node", "num_train_gpus": train_gpus, "num_infer_gpus": infer_gpus},
118147
)
@@ -121,6 +150,7 @@ def rl(
121150
f"RL config validation failed:\n{e}",
122151
param_hint="CONFIG_TOML/--train-gpus/--infer-gpus",
123152
) from e
153+
_enable_rl_resume(config, enabled=resume)
124154

125155
if single_gpu and getattr(config.trainer.weight_broadcast, "type", None) == "nccl":
126156
raise typer.BadParameter(
@@ -135,10 +165,10 @@ def rl(
135165
)
136166

137167
# Set env vars for rl_local
138-
os.environ["CUDA_VISIBLE_DEVICES"] = _gpu_ids(total_gpus)
168+
os.environ["CUDA_VISIBLE_DEVICES"] = _gpu_ids(gpus)
139169
os.environ["MEDARC_SINGLE_GPU"] = "1" if single_gpu else "0"
140170

141-
typer.echo(f"Starting RL on {total_gpus} GPU(s) (single_gpu={single_gpu})")
171+
typer.echo(f"Starting RL on {gpus} GPU(s) (single_gpu={single_gpu})")
142172
rl_local(config)
143173

144174

medarc_rl/slurm_templates/one_node_rl.j2

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
#SBATCH --job-name={{ job_name }}
44
#SBATCH --nodes=1
55
#SBATCH --ntasks=1
6-
#SBATCH --gres=gpu:{{ total_gpus }}
7-
{% if total_gpus == 8 %}
6+
#SBATCH --gpus-per-task={{ gpus }}
7+
{% if gpus == 8 %}
88
#SBATCH --exclusive
99
{% else %}
1010
#SBATCH --cpus-per-gpu={{ cpus_per_gpu }}
1111
{% endif %}
12+
{% if slurm_resume %}
13+
#SBATCH --requeue
14+
{% endif %}
15+
{% if qos %}
16+
#SBATCH --qos={{ qos }}
17+
{% endif %}
18+
{% if mail_type %}
19+
#SBATCH --mail-type={{ mail_type }}
20+
{% endif %}
21+
{% if mail_user %}
22+
#SBATCH --mail-user={{ mail_user }}
23+
{% endif %}
1224
#SBATCH --export=ALL
1325
#SBATCH --output="{{ output_dir }}/slurm/job_%j.log"
1426
#SBATCH --error="{{ output_dir }}/slurm/job_%j.log"

medarc_rl/slurm_templates/one_node_sft.j2

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
#SBATCH --job-name={{ job_name }}
44
#SBATCH --nodes=1
55
#SBATCH --ntasks=1
6-
#SBATCH --gres=gpu:{{ gpus }}
6+
#SBATCH --gpus-per-task={{ gpus }}
77
{% if gpus == 8 %}
88
#SBATCH --exclusive
99
{% else %}
1010
#SBATCH --cpus-per-gpu={{ cpus_per_gpu }}
1111
{% endif %}
12+
{% if slurm_resume %}
13+
#SBATCH --requeue
14+
{% endif %}
15+
{% if qos %}
16+
#SBATCH --qos={{ qos }}
17+
{% endif %}
18+
{% if mail_type %}
19+
#SBATCH --mail-type={{ mail_type }}
20+
{% endif %}
21+
{% if mail_user %}
22+
#SBATCH --mail-user={{ mail_user }}
23+
{% endif %}
1224
#SBATCH --export=ALL
1325
#SBATCH --output="{{ output_dir }}/slurm/job_%j.log"
1426
#SBATCH --error="{{ output_dir }}/slurm/job_%j.log"

medarc_rl/utils.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import json
34
from pathlib import Path
45
from typing import Any, TypeVar
56

67
import tomli_w
78
import typer
8-
from prime_rl.utils.pydantic_config import extract_toml_paths, to_kebab_case
9+
from pydantic import ValidationError
10+
from pydantic_config import ConfigFileError
11+
from prime_rl.utils.config import cli
912

1013
TYPER_PASSTHROUGH_CONTEXT = {"allow_extra_args": True, "ignore_unknown_options": True}
1114
T = TypeVar("T")
@@ -46,26 +49,30 @@ def _write_toml(path: Path, data: dict[str, Any]) -> None:
4649

4750
def _load_settings_from_toml(
4851
config_cls: type[T],
49-
config_path: Path,
52+
config_paths: list[Path],
5053
*,
5154
extra_cli_args: list[str] | None = None,
5255
**overrides: Any,
5356
) -> T:
54-
if not config_path.exists():
55-
raise typer.BadParameter(f"Config file does not exist: {config_path}", param_hint="CONFIG_TOML")
57+
if not config_paths:
58+
raise typer.BadParameter("At least one config file is required.", param_hint="CONFIG_TOML")
59+
for config_path in config_paths:
60+
if not config_path.exists():
61+
raise typer.BadParameter(f"Config file does not exist: {config_path}", param_hint="CONFIG_TOML")
5662

5763
reserved_roots = set(overrides)
5864
filtered_extra_args = filter_wrapper_owned_cli_args(extra_cli_args or [], override_roots=reserved_roots)
59-
args = ["@", str(config_path), *filtered_extra_args]
60-
toml_paths, cli_args = extract_toml_paths(args)
61-
if not toml_paths:
62-
raise typer.BadParameter(f"Failed to resolve TOML paths from {config_path}", param_hint="CONFIG_TOML")
63-
64-
config_cls.set_toml_files([str(path) for path in toml_paths])
6565
try:
66-
return config_cls(_cli_parse_args=to_kebab_case(cli_args), **overrides)
67-
finally:
68-
config_cls.clear_toml_files()
66+
return cli(
67+
config_cls,
68+
args=[
69+
*[item for config_path in config_paths for item in ("@", str(config_path))],
70+
*filtered_extra_args,
71+
*_overrides_to_cli_args(overrides),
72+
],
73+
)
74+
except (ConfigFileError, ValidationError, SystemExit) as e:
75+
raise typer.BadParameter(str(e), param_hint="CONFIG_TOML") from e
6976

7077

7178
def extra_config_args(ctx: typer.Context, *, positional_count: int = 1) -> list[str]:
@@ -129,3 +136,34 @@ def filter_wrapper_owned_cli_args(cli_args: list[str], *, override_roots: set[st
129136
i += 1
130137

131138
return filtered
139+
140+
141+
def _overrides_to_cli_args(overrides: dict[str, Any]) -> list[str]:
142+
args: list[str] = []
143+
for key, value in overrides.items():
144+
args.extend(_flatten_override(key, value))
145+
return args
146+
147+
148+
def _flatten_override(key: str, value: Any) -> list[str]:
149+
option = f"--{key.replace('_', '-')}"
150+
151+
if value is None:
152+
return []
153+
154+
if isinstance(value, dict):
155+
args: list[str] = []
156+
for subkey, subvalue in value.items():
157+
args.extend(_flatten_override(f"{key}.{subkey}", subvalue))
158+
return args
159+
160+
if isinstance(value, bool):
161+
return [option] if value else [f"--no-{key.replace('_', '-')}"]
162+
163+
if isinstance(value, Path):
164+
return [option, str(value)]
165+
166+
if isinstance(value, (list, tuple)):
167+
return [option, json.dumps(value)]
168+
169+
return [option, str(value)]

prime-rl

Submodule prime-rl updated 78 files

0 commit comments

Comments
 (0)