Skip to content

Commit 3fd84cc

Browse files
wedu-nvidiadgtm777
authored andcommitted
Update omr receipe (#973)
Signed-off-by: Wei Du <wedu@nvidia.com>
1 parent 388344e commit 3fd84cc

3 files changed

Lines changed: 114 additions & 66 deletions

File tree

recipes/openmathreasoning/scripts/simplified_recipe.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525

2626

27-
def prepare(workspace, cluster, num_gpus, expname_prefix, wandb_params):
27+
def prepare(workspace, cluster, expname_prefix):
2828
# data preparation needs to run locally without container, so not wrapping with run_cmd
2929
prepare_datasets(["aime24", "aime25"])
3030

@@ -90,7 +90,7 @@ def run_sdg(workspace, cluster, num_gpus, expname_prefix, wandb_params):
9090
)
9191

9292

93-
def run_training(workspace, cluster, num_gpus, expname_prefix, wandb_params):
93+
def run_training(workspace, cluster, num_gpus, expname_prefix, backend, wandb_params):
9494
# convert the generated solutions to a format that can be used for training
9595
run_cmd(
9696
ctx=wrap_arguments(
@@ -110,47 +110,54 @@ def run_training(workspace, cluster, num_gpus, expname_prefix, wandb_params):
110110
)
111111

112112
# train the model
113-
114-
sft_nemo_rl(
115-
ctx=wrap_arguments(
116-
"++policy.max_total_sequence_length=8192 "
117-
"++policy.train_global_batch_size=32 "
118-
"++policy.tensor_model_parallel_size=4 "
119-
"++policy.context_parallel_size=2 "
120-
"++policy.lr=1e-5 "
121-
"++sft.max_num_epochs=2 "
122-
),
123-
cluster=cluster,
124-
output_dir=f"{workspace}/training",
125-
hf_model="Qwen/Qwen2.5-14B-Instruct",
126-
backend="megatron",
127-
num_gpus=num_gpus,
128-
num_nodes=1,
129-
disable_wandb=wandb_params["disable_wandb"],
130-
wandb_project=wandb_params["wandb_project"],
131-
training_data=f"{workspace}/sft-data.jsonl",
132-
expname=f"{expname_prefix}-training",
133-
run_after=f"{expname_prefix}-prepare-training-data",
134-
final_hf_path=f"{workspace}/training/qwen2.5-14b-improved-hf",
135-
)
136-
137-
138-
def final_eval(workspace, cluster, num_gpus, expname_prefix, wandb_params):
113+
base_args = [
114+
"++policy.max_total_sequence_length=8192",
115+
"++policy.train_global_batch_size=32",
116+
"++policy.tensor_model_parallel_size=4",
117+
"++policy.context_parallel_size=2",
118+
"++policy.lr=1e-5",
119+
"++sft.max_num_epochs=2",
120+
]
121+
# For FSDP, sequence_packing cannot be used with context parallel
122+
for training_backend in backend:
123+
args = list(base_args)
124+
if training_backend == "fsdp":
125+
args.append("++policy.sequence_packing.enabled=False")
126+
127+
sft_nemo_rl(
128+
ctx=wrap_arguments(" ".join(args)),
129+
cluster=cluster,
130+
output_dir=f"{workspace}/training-{training_backend}",
131+
hf_model="Qwen/Qwen2.5-14B-Instruct",
132+
backend=training_backend,
133+
num_gpus=num_gpus,
134+
num_nodes=1,
135+
disable_wandb=wandb_params["disable_wandb"],
136+
wandb_project=wandb_params["wandb_project"],
137+
training_data=f"{workspace}/sft-data.jsonl",
138+
expname=f"{expname_prefix}-training-{training_backend}",
139+
run_after=f"{expname_prefix}-prepare-training-data",
140+
final_hf_path=f"{workspace}/training-{training_backend}/qwen2.5-14b-improved-hf",
141+
)
142+
143+
144+
def final_eval(workspace, cluster, num_gpus, expname_prefix, backend, wandb_params):
139145
# launching evaluation
140-
eval(
141-
ctx=wrap_arguments("++inference.tokens_to_generate=16384 ++parse_reasoning=True "),
142-
cluster=cluster,
143-
model=f"{workspace}/training/qwen2.5-14b-improved-hf",
144-
server_type="vllm",
145-
server_gpus=num_gpus,
146-
benchmarks="aime24:8,aime25:8",
147-
output_dir=f"{workspace}/evals/after-training",
148-
num_jobs=1,
149-
expname=f"{expname_prefix}-final-eval",
150-
run_after=f"{expname_prefix}-training",
151-
wandb_name=f"{expname_prefix}-final-eval" if not wandb_params["disable_wandb"] else None,
152-
wandb_project=wandb_params["wandb_project"],
153-
)
146+
for training_backend in backend:
147+
eval(
148+
ctx=wrap_arguments("++inference.tokens_to_generate=16384 ++parse_reasoning=True "),
149+
cluster=cluster,
150+
model=f"{workspace}/training-{training_backend}/qwen2.5-14b-improved-hf",
151+
server_type="vllm",
152+
server_gpus=num_gpus,
153+
benchmarks="aime24:8,aime25:8",
154+
output_dir=f"{workspace}/evals/after-training-{training_backend}",
155+
num_jobs=1,
156+
expname=f"{expname_prefix}-final-eval-{training_backend}",
157+
run_after=f"{expname_prefix}-training-{training_backend}",
158+
wandb_name=f"{expname_prefix}-final-eval" if not wandb_params["disable_wandb"] else None,
159+
wandb_project=wandb_params["wandb_project"],
160+
)
154161

155162

156163
def initial_eval(workspace, cluster, num_gpus, expname_prefix, wandb_params):
@@ -203,21 +210,42 @@ def initial_eval(workspace, cluster, num_gpus, expname_prefix, wandb_params):
203210
default="nemo-skills",
204211
help="WandB project name for tracking experiments.",
205212
)
213+
parser.add_argument(
214+
"--backend",
215+
type=str,
216+
nargs="+",
217+
choices=["megatron", "fsdp"],
218+
default=["megatron"],
219+
)
220+
206221
args = parser.parse_args()
207222

208223
wandb_params = {
209224
"disable_wandb": args.disable_wandb,
210225
"wandb_project": args.wandb_project,
211226
}
212-
args = (
227+
common_args = (
213228
args.workspace,
214229
args.cluster,
215230
args.num_gpus,
216231
args.expname_prefix,
232+
args.backend,
217233
wandb_params,
218234
)
219-
prepare(*args)
220-
initial_eval(*args)
221-
run_sdg(*args)
222-
run_training(*args)
223-
final_eval(*args)
235+
prepare(workspace=args.workspace, cluster=args.cluster, expname_prefix=args.expname_prefix)
236+
initial_eval(
237+
workspace=args.workspace,
238+
cluster=args.cluster,
239+
num_gpus=args.num_gpus,
240+
expname_prefix=args.expname_prefix,
241+
wandb_params=wandb_params,
242+
)
243+
run_sdg(
244+
workspace=args.workspace,
245+
cluster=args.cluster,
246+
num_gpus=args.num_gpus,
247+
expname_prefix=args.expname_prefix,
248+
wandb_params=wandb_params,
249+
)
250+
run_training(*common_args)
251+
final_eval(*common_args)

tests/slurm-tests/omr_simple_recipe/check_results.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
}
3333

3434

35-
def check_results(benchmark: str, baseline_results: dict, after_training_results: dict):
35+
def check_results(benchmark: str, baseline_results: dict, after_training_results: dict, backend: str):
3636
for metric in ["pass@1[avg-of-8]", "majority@8"]:
3737
baseline_acc = baseline_results[benchmark][metric]["symbolic_correct"]
3838
after_acc = after_training_results[benchmark][metric]["symbolic_correct"]
@@ -46,22 +46,33 @@ def check_results(benchmark: str, baseline_results: dict, after_training_results
4646
)
4747
soft_assert(
4848
lo_a <= after_acc <= hi_a,
49-
f"{benchmark}: after_training {after_acc}% out of range [{lo_a}%, {hi_a}%] for metric {metric}",
49+
f"{benchmark} for {backend}: after_training {after_acc}% out of range [{lo_a}%, {hi_a}%] for metric {metric}",
5050
)
5151

5252

5353
def main():
5454
ap = argparse.ArgumentParser()
5555
ap.add_argument("--workspace", required=True, help="Workspace directory containing eval results.")
56+
ap.add_argument(
57+
"--backend",
58+
type=str,
59+
nargs="+",
60+
choices=["megatron", "fsdp"],
61+
default=["megatron"],
62+
)
5663
args = ap.parse_args()
57-
58-
for benchmark in ("aime24", "aime25"):
59-
common_path = Path(args.workspace) / "evals"
60-
baseline_results = load_json(common_path / "baseline" / "eval-results" / benchmark / "metrics.json")
61-
after_training_results = load_json(
62-
common_path / "after-training" / "eval-results" / benchmark / "metrics.json"
63-
)
64-
check_results(benchmark, baseline_results, after_training_results)
64+
for training_backend in args.backend:
65+
for benchmark in ("aime24", "aime25"):
66+
common_path = Path(args.workspace) / "evals"
67+
baseline_results = load_json(common_path / "baseline" / "eval-results" / benchmark / "metrics.json")
68+
after_training_results = load_json(
69+
common_path
70+
/ "after-training-{}".format(training_backend)
71+
/ "eval-results"
72+
/ benchmark
73+
/ "metrics.json"
74+
)
75+
check_results(benchmark, baseline_results, after_training_results, training_backend)
6576

6677
assert_all()
6778

tests/slurm-tests/omr_simple_recipe/run_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,21 @@ def main():
2424
ap.add_argument("--wandb_project", default="nemo-skills-slurm-ci", help="W&B project name")
2525
ap.add_argument("--expname_prefix", required=True, help="Experiment name prefix used inside the recipe")
2626
ap.add_argument("--disable_wandb", action="store_true", help="Disable W&B logging in the recipe")
27+
ap.add_argument(
28+
"--backend",
29+
type=str,
30+
nargs="+",
31+
choices=["megatron", "fsdp"],
32+
default=["megatron"],
33+
)
2734
args = ap.parse_args()
2835

2936
cmd = (
3037
f"python -m recipes.openmathreasoning.scripts.simplified_recipe "
31-
f" --cluster {args.cluster} "
32-
f" --workspace {args.workspace} "
33-
f" --expname_prefix {args.expname_prefix} "
38+
f" --cluster {args.cluster} "
39+
f" --workspace {args.workspace} "
40+
f" --expname_prefix {args.expname_prefix} "
41+
f" --backend {' '.join(args.backend)} "
3442
)
3543

3644
if args.disable_wandb:
@@ -40,17 +48,18 @@ def main():
4048

4149
subprocess.run(cmd, shell=True, check=True)
4250

43-
checker_cmd = f"python tests/slurm-tests/omr_simple_recipe/check_results.py --workspace {args.workspace}"
51+
checker_cmd = f"python tests/slurm-tests/omr_simple_recipe/check_results.py --workspace {args.workspace} --backend {' '.join(args.backend)}"
52+
53+
final_eval_name = [f"{args.expname_prefix}-final-eval-{training_backend}" for training_backend in args.backend]
4454

4555
run_cmd(
4656
ctx=wrap_arguments(checker_cmd),
4757
cluster=args.cluster,
4858
expname=args.expname_prefix + "-check-results",
4959
log_dir=f"{args.workspace}/check-results-logs",
50-
run_after=[ # these are launched in simplified recipe
51-
f"{args.expname_prefix}-final-eval",
52-
f"{args.expname_prefix}-baseline-eval",
53-
],
60+
# these are launched in simplified recipe
61+
run_after=final_eval_name + [f"{args.expname_prefix}-baseline-eval"],
62+
reuse_code=True,
5463
)
5564

5665

0 commit comments

Comments
 (0)