forked from NVIDIA-NeMo/RL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrpo.py
More file actions
3094 lines (2729 loc) · 137 KB
/
grpo.py
File metadata and controls
3094 lines (2729 loc) · 137 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from pathlib import Path
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
import numpy as np
import ray
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from nemo_rl.algorithms.advantage_estimator import (
GRPOAdvantageEstimator,
ReinforcePlusPlusAdvantageEstimator,
)
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import (
ClippedPGLossConfig,
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.reward_functions import (
RewardShapingConfig,
apply_reward_shaping,
)
from nemo_rl.algorithms.utils import (
calculate_baseline_and_std_per_prompt,
log_generation_metrics_to_wandb,
print_performance_metrics,
set_seed,
)
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import rl_collate_fn
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.interfaces import DatumSpec
from nemo_rl.data.llm_message_utils import (
batched_message_log_to_flat_message,
get_keys_from_message_log,
)
from nemo_rl.data.utils import extract_necessary_env_names
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.experience.rollouts import (
run_async_multi_turn_rollout,
run_async_nemo_gym_rollout,
run_multi_turn_rollout,
)
from nemo_rl.models.generation.interfaces import GenerationInterface
from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration
from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface
from nemo_rl.models.policy.lm_policy import Policy
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import (
Logger,
LoggerConfig,
print_message_log_samples,
)
from nemo_rl.utils.memory_tracker import MemoryTracker
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import TimeoutChecker, Timer
from nemo_rl.utils.venvs import create_local_venv_on_each_node
# ===============================================================================
# Configuration
# ===============================================================================
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)
class RewardScalingConfig(TypedDict):
"""Configure linear reward scaling with clamping.
When `enabled` is True, each reward is clamped to the source interval
[source_min, source_max] and linearly mapped to the target interval
[target_min, target_max]. Refer to the scale_rewards function for the implementation.
Defaults:
source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0
"""
enabled: bool
source_min: NotRequired[float]
source_max: NotRequired[float]
target_min: NotRequired[float]
target_max: NotRequired[float]
class AsyncGRPOConfig(TypedDict):
enabled: bool
# Maximum trajectory age in training steps for samples drawn from the
# async replay buffer. Trajectories older than this are excluded during
# sampling; buffer sizing also scales with this value.
max_trajectory_age_steps: int
# Does the weight synchronization as soon as the training is done
# without waiting for the pending generations to finish.
in_flight_weight_updates: NotRequired[bool]
# Recomputes the KV cache after the in-flight weight updates.
recompute_kv_cache_after_weight_updates: NotRequired[bool]
class AdvEstimatorConfig(TypedDict):
"""Configuration for advantage estimator (GRPO or Reinforce++)."""
name: str # "grpo" or "reinforce_plus_plus"
# GRPO specific
normalize_rewards: NotRequired[bool]
use_leave_one_out_baseline: NotRequired[bool]
# Reinforce++ specific
minus_baseline: NotRequired[bool]
class GRPOConfig(TypedDict):
num_prompts_per_step: int
num_generations_per_prompt: int
max_num_epochs: int
max_num_steps: int
max_rollout_turns: int
normalize_rewards: bool
use_leave_one_out_baseline: bool
val_period: int
val_batch_size: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
max_val_samples: int
skip_reference_policy_logprobs_calculation: NotRequired[bool]
seed: int
async_grpo: NotRequired[AsyncGRPOConfig]
overlong_filtering: NotRequired[bool]
# whether to enable dynamic sampling, i.e.
# whether to discard prompts whose rewards have zero standard deviation
use_dynamic_sampling: bool
# When using dynamic sampling, the maximum number of batches to generate
# before throwing an error
dynamic_sampling_max_gen_batches: NotRequired[int]
# When using dynamic sampling, generation prompt batch size will equal
# num_prompts_per_step * batch_multiplier
batch_multiplier: NotRequired[float]
reward_shaping: RewardShapingConfig
reward_scaling: RewardScalingConfig
# By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation.
calculate_advantages_on_gpu: NotRequired[bool]
# Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5)
# Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs.
seq_logprob_error_threshold: float | None
# Advantage estimator configuration (grpo or reinforce_plus_plus)
adv_estimator: NotRequired[AdvEstimatorConfig]
class GRPOSaveState(TypedDict):
consumed_samples: int
current_step: int
current_epoch: int
total_steps: int
total_valid_tokens: int # Track total number of non-padding tokens during training
val_reward: NotRequired[
float
] # Optional field - may not be present during training
def _default_grpo_save_state() -> GRPOSaveState:
return {
"consumed_samples": 0,
"current_step": 0,
"current_epoch": 0,
"total_steps": 0,
"total_valid_tokens": 0,
"val_reward": -99999999.0,
}
class GRPOLoggerConfig(LoggerConfig):
num_val_samples_to_print: int # number of val samples to print to stdout
class MasterConfig(TypedDict):
policy: PolicyConfig
loss_fn: ClippedPGLossConfig
env: dict[str, Any]
data: DataConfig
grpo: GRPOConfig
logger: GRPOLoggerConfig
cluster: ClusterConfig
checkpointing: CheckpointingConfig
# ===============================================================================
# Setup & Initialization
# ===============================================================================
def setup(
master_config: MasterConfig,
tokenizer: TokenizerType,
dataset: AllTaskProcessedDataset,
val_dataset: Optional[AllTaskProcessedDataset],
processor: Optional[AutoProcessor] = None,
) -> tuple[
ColocatablePolicyInterface,
Optional[GenerationInterface],
tuple[RayVirtualCluster, RayVirtualCluster],
StatefulDataLoader,
Optional[StatefulDataLoader],
ClippedPGLossFn,
Logger,
CheckpointManager,
GRPOSaveState,
MasterConfig,
]:
"""Main entry point for running GRPO algorithm.
Returns:
tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
"""
# Start timing the entire setup process
setup_start_time = time.perf_counter()
# Extract individual configs for easier access
policy_config = master_config["policy"]
generation_config = master_config["policy"]["generation"]
env_configs = master_config["env"]
loss_config = master_config["loss_fn"]
grpo_config = master_config["grpo"]
data_config = master_config["data"]
logger_config = master_config["logger"]
cluster_config = master_config["cluster"]
assert generation_config is not None, (
"A generation config in the PolicyConfig is required for GRPO"
)
# Set seed for all random number generators
set_seed(grpo_config["seed"])
# ==========================
# Logger
# ==========================
logger = Logger(logger_config)
logger.log_hyperparams(master_config)
# ==========================
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
grpo_save_state: Optional[GRPOSaveState] = cast(
Optional[GRPOSaveState], checkpointer.load_training_info(last_checkpoint_path)
)
if grpo_save_state is None:
grpo_save_state = _default_grpo_save_state()
# ==========================
# Data
# ==========================
# Validate batch_multiplier
batch_multiplier = grpo_config["batch_multiplier"]
dataloader_batch_size = grpo_config["num_prompts_per_step"]
if not grpo_config["use_dynamic_sampling"]:
assert batch_multiplier == 1, (
"batch_multiplier>1 can only be used if use_dynamic_sampling=True"
)
else:
dataloader_batch_size = int(dataloader_batch_size * batch_multiplier)
dataloader = StatefulDataLoader(
dataset,
batch_size=dataloader_batch_size,
shuffle=data_config["shuffle"],
collate_fn=rl_collate_fn,
drop_last=True,
num_workers=data_config["num_workers"],
)
if last_checkpoint_path is not None:
dataloader_state_dict = torch.load(
os.path.join(last_checkpoint_path, "train_dataloader.pt")
)
dataloader.load_state_dict(dataloader_state_dict)
print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True)
# Load validation dataset if provided
val_dataloader: Optional[StatefulDataLoader] = None
# If validation is enabled, load the validation dataloader
if (
grpo_config["val_period"] > 0
or grpo_config["val_at_start"]
or grpo_config["val_at_end"]
):
assert val_dataset is not None, (
"Validation dataset is required if validation is enabled"
)
val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=grpo_config["val_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
num_workers=data_config["num_workers"],
)
print(
f" ✓ Validation dataloader loaded with {len(val_dataset)} samples",
flush=True,
)
# ==========================
# Loss Function
# ==========================
loss_fn = ClippedPGLossFn(loss_config)
# Validate force_on_policy_ratio
if loss_config.get("force_on_policy_ratio", False):
assert (
grpo_config["num_prompts_per_step"]
* grpo_config["num_generations_per_prompt"]
== policy_config["train_global_batch_size"]
), (
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
)
os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1"
print(" ✓ force_on_policy_ratio enabled")
# ==========================
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...", flush=True)
colocated_inference = generation_config["colocated"]["enabled"]
env_name_list = extract_necessary_env_names(data_config)
rm_env_enabled = "reward_model" in env_name_list
total_nodes = cluster_config["num_nodes"]
if rm_env_enabled:
rm_resource = env_configs["reward_model"]["resources"]
rm_nodes = rm_resource["num_nodes"]
rm_gpus_per_node = rm_resource["gpus_per_node"]
else:
rm_nodes = 0
rm_gpus_per_node = 0
if total_nodes == 1:
policy_nodes = total_nodes
else:
policy_nodes = total_nodes - rm_nodes
assert policy_nodes > 0, (
"policy_nodes must be > 0, but got "
f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}"
)
if colocated_inference:
if total_nodes == 1:
policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node
assert policy_gpus_per_node > 0, (
"policy.generation.colocated.resources.gpus_per_node must be > 0 "
"when cluster.num_nodes = 1, "
f"but got {policy_gpus_per_node}."
)
else:
policy_gpus_per_node = cluster_config["gpus_per_node"]
cluster = RayVirtualCluster(
name="grpo_policy_cluster",
bundle_ct_per_node_list=[policy_gpus_per_node] * policy_nodes,
use_gpus=True,
num_gpus_per_node=policy_gpus_per_node,
max_colocated_worker_groups=1
if generation_config["backend"] == "megatron"
else 2,
)
train_cluster = cluster
inference_cluster = cluster
print(
f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes",
flush=True,
)
else:
assert generation_config["backend"] != "megatron", (
"Non-colocated inference is not supported for Megatron generation backends. "
"Please use vLLM backend for generation."
)
# train resources will be updated through overall and inference resources below
train_gpus_per_node = cluster_config["gpus_per_node"]
train_nodes = policy_nodes
inference_resources = generation_config["colocated"]["resources"]
inference_gpus_per_node = inference_resources["gpus_per_node"]
inference_nodes = inference_resources["num_nodes"]
# validate and configure resources
if policy_nodes == 1:
# When policy_nodes == 1, train and inference are on the same node
assert (
inference_gpus_per_node is not None and inference_gpus_per_node > 0
), (
"policy.generation.colocated.resources.gpus_per_node must be explicitly set to a value > 0 "
"when policy_nodes = 1 and inference is non-colocated, "
f"but got {inference_gpus_per_node}."
)
assert inference_nodes is None or inference_nodes == 1, (
"policy.generation.colocated.resources.num_nodes must be 1 or set to null "
"when policy_nodes = 1 and inference is non-colocated, "
f"but got {inference_nodes}."
)
inference_nodes = 1
# If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node
reward_gpus_to_subtract = (
rm_gpus_per_node if total_nodes == 1 and rm_env_enabled else 0
)
train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract
assert train_gpus_per_node > 0, (
"No enough GPUs for training, "
f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}"
+ (
f" - rm_gpus_per_node:{rm_gpus_per_node}"
if total_nodes == 1 and rm_env_enabled
else ""
)
)
else:
# train, inference, and reward model are all on different nodes
assert inference_nodes > 0, (
"policy.generation.colocated.resources.num_nodes must be > 0 "
"when cluster.num_nodes > 1 and inference is non-colocated, "
f"but got {inference_nodes}."
)
assert (
inference_gpus_per_node is not None
and inference_gpus_per_node == cluster_config["gpus_per_node"]
), (
"policy.generation.colocated.resources.gpus_per_node must be explicitly set and equal to cluster.gpus_per_node "
"when cluster.num_nodes > 1 and inference is non-colocated, "
f"but got inference_gpus_per_node={inference_gpus_per_node}, cluster.gpus_per_node={cluster_config['gpus_per_node']}."
)
train_nodes -= inference_nodes
# initialize train cluster
train_cluster = RayVirtualCluster(
name="grpo_train_cluster",
bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes,
use_gpus=True,
num_gpus_per_node=train_gpus_per_node,
max_colocated_worker_groups=1,
)
print(
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node",
flush=True,
)
# initialize inference cluster
inference_cluster = RayVirtualCluster(
name="grpo_inference_cluster",
bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes,
use_gpus=True,
num_gpus_per_node=inference_gpus_per_node,
max_colocated_worker_groups=1,
)
print(
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node",
flush=True,
)
# ==========================
# Training and Inference
# ==========================
print("\n▶ Setting up model and training...", flush=True)
# vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode
backend = generation_config["backend"]
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
# Dictionary to store worker initialization timing stats for logging
worker_init_timing_metrics = {}
# Prepare checkpoint paths
if last_checkpoint_path:
weights_path = Path(last_checkpoint_path) / "policy" / "weights"
optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer"
else:
weights_path = None
optimizer_path = None
if policy_config.get("megatron_cfg", {}).get("enabled", False):
## NOTE: this is equal to the total number of scheduler steps
total_train_iters = min(
grpo_config["max_num_steps"],
grpo_config["max_num_epochs"] * len(dataloader),
)
policy_config["megatron_cfg"]["train_iters"] = total_train_iters
# Define initialization functions that will be used in all paths
def init_policy():
"""Initialize policy training workers."""
t0 = time.perf_counter()
p = Policy(
cluster=train_cluster,
config=policy_config,
tokenizer=tokenizer,
processor=processor,
weights_path=weights_path,
optimizer_path=optimizer_path,
init_optimizer=True,
)
return p, time.perf_counter() - t0
def init_vllm():
"""Initialize vLLM generation workers."""
t0 = time.perf_counter()
pg = VllmGeneration(cluster=inference_cluster, config=generation_config)
pg.finish_generation()
return pg, time.perf_counter() - t0
def init_sglang():
"""Initialize SGLang generation workers."""
t0 = time.perf_counter()
pg = SGLangGeneration(cluster=inference_cluster, config=generation_config)
pg.finish_generation()
return pg, time.perf_counter() - t0
def initialize_generation_with_policy(
init_generation_fn,
generation_name: str,
init_time_key: str,
colocated_inference: bool,
worker_init_timing_metrics: dict,
):
"""Generic function to initialize a generation engine (vLLM or SGLang) along with policy.
Args:
init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang)
generation_name: Name of the generation engine ("vLLM" or "SGLang")
init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s")
colocated_inference: Whether inference is colocated with training
worker_init_timing_metrics: Dictionary to store timing metrics
Returns:
Tuple of (policy_generation, policy)
"""
# Determine if parallel initialization is possible (non-colocated mode)
use_parallel_init = not colocated_inference
if use_parallel_init:
# Parallel initialization: Generation engine and Policy can initialize simultaneously
print(
" ⚡ Using parallel worker initialization (non-colocated mode)",
flush=True,
)
# Execute both initializations in parallel
parallel_start_time = time.perf_counter()
with ThreadPoolExecutor(max_workers=2) as executor:
generation_future = executor.submit(init_generation_fn)
policy_future = executor.submit(init_policy)
policy_generation, generation_time = generation_future.result()
policy, policy_time = policy_future.result()
parallel_wall_time = time.perf_counter() - parallel_start_time
# Store timing metrics
worker_init_timing_metrics[init_time_key] = generation_time
worker_init_timing_metrics["policy_init_time_s"] = policy_time
worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time
worker_init_timing_metrics["parallel_init_enabled"] = True
else:
# Sequential initialization: colocated mode (GPU memory requires generation engine first)
print(
" ⚙️ Using sequential worker initialization (colocated mode)",
flush=True,
)
# Initialize generation engine first (clean GPU memory), then policy
policy_generation, generation_time = init_generation_fn()
worker_init_timing_metrics[init_time_key] = generation_time
policy, policy_time = init_policy()
worker_init_timing_metrics["policy_init_time_s"] = policy_time
worker_init_timing_metrics["parallel_init_enabled"] = 0.0
return policy_generation, policy
# Handle generation-specific setup
if backend == "megatron":
# Megatron generation: policy_generation is None, only initialize policy
policy_generation = None
print(
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}",
flush=True,
)
policy, policy_time = init_policy()
worker_init_timing_metrics["policy_init_time_s"] = policy_time
elif backend == "vllm":
# vLLM generation: setup config, then initialize with policy
generation_config = cast(VllmConfig, generation_config)
if generation_config["vllm_cfg"]["precision"] == "fp8":
assert loss_config["use_importance_sampling_correction"] is True, (
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
)
if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"):
# FP8 KV cache requires FP8 model precision
assert generation_config["vllm_cfg"]["precision"] == "fp8", (
f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. "
"FP8 KV cache can only be used together with FP8 model weights."
)
# FP8 KV cache compatibility checks
assert policy_config["dtensor_cfg"]["enabled"] == False, (
"DTensor backend is not supported with kv cache fp8 enabled."
)
assert not _should_use_async_rollouts(master_config), (
"Async rollouts is not supported with kv cache fp8 enabled."
)
assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, (
"Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future."
)
## make vllm hf overrides match the training policy
generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get(
"hf_config_overrides", {}
)
policy_generation, policy = initialize_generation_with_policy(
init_generation_fn=init_vllm,
generation_name="vLLM",
init_time_key="vllm_init_time_s",
colocated_inference=colocated_inference,
worker_init_timing_metrics=worker_init_timing_metrics,
)
print(
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}",
flush=True,
)
elif backend == "sglang":
generation_config = cast(SGLangConfig, generation_config)
# Set model_path if not already set
if "model_path" not in generation_config["sglang_cfg"]:
generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"]
policy_generation, policy = initialize_generation_with_policy(
init_generation_fn=init_sglang,
generation_name="SGLang",
init_time_key="sglang_init_time_s",
colocated_inference=colocated_inference,
worker_init_timing_metrics=worker_init_timing_metrics,
)
print(
f" ✓ Using SGLang backend for generation with {policy_config['model_name']}",
flush=True,
)
# Record when worker initialization completes (for calculating other setup time)
worker_init_complete_time = time.perf_counter() - setup_start_time
# print the node IP and GPU ID of the policy workers for debugging
policy.print_node_ip_and_gpu_id()
# if it is not colocated inference, initialize collective communication for update weights
if not colocated_inference:
t0 = time.perf_counter()
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
# world includes all training workers and all inference workers
train_world_size = train_cluster.world_size()
inference_world_size = inference_nodes * inference_gpus_per_node
world_size = train_world_size + inference_world_size
# init collective
futures_train = policy.init_collective(
ip, port, world_size, train_world_size=train_world_size
)
futures_inference = policy_generation.init_collective(
ip, port, world_size, train_world_size=train_world_size
) # type: ignore
# wait for all futures to complete
ray.get(futures_train + futures_inference)
worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0
# prepare refit info
state_dict_info = policy.prepare_refit_info()
if policy_generation is not None:
policy_generation.prepare_refit_info(state_dict_info)
# Calculate total setup time
total_setup_time = time.perf_counter() - setup_start_time
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
# Log worker initialization timing metrics to logger
if worker_init_timing_metrics:
print("\n▶ Worker Initialization Timing:")
vllm_time = worker_init_timing_metrics.get("vllm_init_time_s", 0)
policy_time = worker_init_timing_metrics.get("policy_init_time_s", 0)
total_setup = worker_init_timing_metrics.get("total_setup_time_s", 0)
if vllm_time:
print(f" vLLM init: {vllm_time:.1f}s")
if policy_time:
print(f" Policy init: {policy_time:.1f}s")
# Calculate "other" time (time after worker init completes)
other_time = total_setup - worker_init_complete_time
worker_init_timing_metrics["other_setup_time_s"] = other_time
print(f" Other setup: {other_time:.1f}s")
print(f" Total setup: {total_setup:.1f}s")
# Log all metrics to the logger for analysis
logger.log_metrics(worker_init_timing_metrics, step=0, prefix="timing/setup")
print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print(f" Total setup time: {total_setup_time:.1f}s")
print("=" * 60 + "\n", flush=True)
return (
policy,
policy_generation,
(train_cluster, inference_cluster),
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_save_state,
master_config,
)
# ===============================================================================
# Core Algorithm Functions
# ===============================================================================
def dynamic_sampling(
repeated_batch: BatchedDataDict[DatumSpec],
std: torch.Tensor,
baseline: torch.Tensor,
dynamic_sampling_num_gen_batches: int,
master_config: MasterConfig,
timer: Timer,
batch_cache: BatchedDataDict[DatumSpec] = None,
) -> BatchedDataDict[DatumSpec]:
"""Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.
This function filters the current batch to retain only those prompts that have a non-zero standard deviation.
If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt,
we store it in the batch_cache to be used in later iterations.
If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt,
the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt.
is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop
to continue sampling or proceed to training.
This approach is based on the dynamic sampling algorithm from the DAPO paper:
https://arxiv.org/pdf/2503.14476.
Args:
repeated_batch (BatchedDataDict[DatumSpec]): The current batch of data containing prompts, responses, rewards, baselines, and std.
std (torch.Tensor): Tensor representing the standard deviation for each prompt group.
baseline (torch.Tensor): Baseline values for each prompt group.
dynamic_sampling_num_gen_batches (int): Number of generation batches processed at the current step.
master_config (MasterConfig): Configuration containing GRPO and policy settings.
batch_cache (BatchedDataDict[DatumSpec], optional): Cache storing previously selected prompts with non-zero std.
Returns:
tuple: A tuple containing:
- repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts.
- is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training.
- batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations.
"""
# is_batch_complete is used to indicate if the current batch was able to generate enough prompts with non-zero std.
is_batch_complete = True
# Required batch size for training
train_prompts_size = (
master_config["grpo"]["num_prompts_per_step"]
* master_config["grpo"]["num_generations_per_prompt"]
)
# Store the baseline, std and total_reward for the current unfiltered batch.
repeated_batch["baseline"] = baseline
repeated_batch["std"] = std
total_rewards = repeated_batch["total_reward"]
dynamic_sampling_metrics = {}
# Dynamic sampling algorithm (used in DAPO algorithm)
# This block implements dynamic sampling by selecting prompt groups with non-zero std.
# If sampled prompts (with non-zero std) are fewer than num_prompts_per_step * num_generations_per_prompt, continue sampling until dynamic_sampling_max_gen_batches is reached.
if master_config["grpo"]["use_dynamic_sampling"]:
with timer.time("dynamic_sampling"):
# Get the prompt indices with non-zero std
non_zero_std_mask = std != 0.0
keep_prompt_indices = torch.arange(
len(non_zero_std_mask), device=std.device
)[non_zero_std_mask].tolist()
# Only select the inputs that have non-zero std
# total_reward is already a part of repeated_batch so we don't need to add it again
filtered_repeated_batch = repeated_batch.select_indices(keep_prompt_indices)
filtered_repeated_batch["std"] = std[keep_prompt_indices]
filtered_repeated_batch["baseline"] = baseline[keep_prompt_indices]
# Store filtered and total rewards to track them separately
filtered_rewards = filtered_repeated_batch["total_reward"]
filtered_repeated_batch["total_reward"] = total_rewards
filtered_repeated_batch["filtered_reward"] = filtered_rewards
# Store the total_reward for the current filtered batch.
# If none of the prompts in current batch have non-zero std, filtered_repeated_batch.size will be 0.
# In this case, the current batch will be ignored and the next batch will be processed and we generate responses for it.
if filtered_repeated_batch.size > 0:
# Concatenate the previous partially filled batch with the current batch. This serves as a cache to store and collect the prompts with non-zero std.
# This is used in the next iteration when the current batch is not enough to fill the buffer.
batch_cache = (
filtered_repeated_batch
if batch_cache is None
else BatchedDataDict.from_batches(
[batch_cache, filtered_repeated_batch]
)
)
filtered_repeated_batch = batch_cache
filtered_prompts_size = filtered_repeated_batch.size
print(
f"Detected {filtered_prompts_size} prompts with non-zero std; "
f"{train_prompts_size} are required and used for training."
)
# If the generation samples size is smaller than a fixed threshold (train_prompts_size), keep generating by processing the next batch
if filtered_prompts_size < train_prompts_size:
dynamic_sampling_max_gen_batches = master_config["grpo"][
"dynamic_sampling_max_gen_batches"
]
assert dynamic_sampling_max_gen_batches > 0, (
"When using grpo.use_dynamic_sampling, grpo.dynamic_sampling_max_gen_batches must be > 0"
)
if dynamic_sampling_num_gen_batches <= dynamic_sampling_max_gen_batches:
print(
f"Generation sample buffer size: {filtered_prompts_size} is smaller than train_prompts_size: {train_prompts_size}. Processed {dynamic_sampling_num_gen_batches} batches so far out of {dynamic_sampling_max_gen_batches}."
)
is_batch_complete = False
else:
raise ValueError(
f"Dynamic sampling has reached the maximum allowed number of batches ({dynamic_sampling_max_gen_batches}). Consider evaluating the complexity of your data or adjusting the num_prompts_per_step or num_generations_per_prompt parameters to enhance the diversity of the samples."
)
else:
num_discarded_valid_samples = filtered_prompts_size - train_prompts_size
dynamic_sampling_metrics[
"dynamic_sampling_num_discarded_valid_samples"
] = num_discarded_valid_samples
# Slice the batch, rewards, baselines and std to ensure batch size is train_prompts_size
filtered_repeated_batch = filtered_repeated_batch.slice(
0, train_prompts_size
)
batch_to_return = (
filtered_repeated_batch
if master_config["grpo"]["use_dynamic_sampling"]
else repeated_batch
)
return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics
def scale_rewards(
repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig
) -> BatchedDataDict[DatumSpec]:
"""Linearly scales rewards from a source range to a target range.
If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]`
is clamped to the configured source interval [source_min, source_max] and then
rescaled to the target interval [target_min, target_max].
Default configuration:
source_min = 0.0
source_max = 1.0
target_min = 0.0
target_max = 1.0
"""
if reward_scaling_cfg["enabled"]:
rewards = repeated_batch["total_reward"]
source_min = float(reward_scaling_cfg["source_min"])
source_max = float(reward_scaling_cfg["source_max"])
target_min = float(reward_scaling_cfg["target_min"])
target_max = float(reward_scaling_cfg["target_max"])
# Detect out-of-range values
out_of_range_mask = (rewards < source_min) | (rewards > source_max)
if torch.any(out_of_range_mask):
print(
f"[reward_scaling] WARNING: {int(out_of_range_mask.sum())} rewards "
f"are outside the configured source range [{source_min}, {source_max}]. "
f"Values will be clipped before scaling."
)
# Clamp and scale
rewards = torch.clamp(rewards, min=source_min, max=source_max)
scaled_rewards = target_min + (rewards - source_min) / (
source_max - source_min
) * (target_max - target_min)
repeated_batch["total_reward"] = scaled_rewards
return repeated_batch
def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
"""Determine if async rollouts should be used based on the configuration.
Returns True if vLLM backend is used with async_engine enabled.
"""
generation_config = master_config["policy"]["generation"]
if generation_config is None:
return False
backend = generation_config.get("backend", "")
if backend != "vllm":
return False
vllm_cfg = generation_config.get("vllm_cfg", {})
return vllm_cfg.get("async_engine", False)
def _should_use_nemo_gym(master_config: MasterConfig) -> bool:
"""Determine if NeMo-Gym should be used for rollouts and validation based on the configuration."""
env_config = master_config.get("env") or dict()
should_use_nemo_gym = bool(env_config.get("should_use_nemo_gym"))
if not should_use_nemo_gym:
return should_use_nemo_gym
# Validate the setup for training with NeMo-Gym
assert _should_use_async_rollouts(master_config), (
"❌ Error: In order to use NeMo-Gym, you must use vllm generation backend with `async_engine: true`!"
)
generation_config = master_config["policy"]["generation"]
# We piggyback off of `_should_use_async_rollouts` to guarantee the existence of these configs.
should_expose_http_server = generation_config["vllm_cfg"].get("expose_http_server")
assert should_expose_http_server, (
"In order to use NeMo-Gym, you must expose the vllm server via `expose_http_server: true`!"
)
return should_use_nemo_gym
def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool:
env_config = master_config.get("env") or dict()
should_log_nemo_gym_responses = bool(
env_config.get("should_log_nemo_gym_responses")
)
return should_log_nemo_gym_responses
def _create_advantage_estimator(master_config: MasterConfig):
"""Create and return an advantage estimator based on configuration.
Args:
master_config: The master configuration dictionary.
Returns:
An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator).
Raises:
ValueError: If the advantage estimator name is not recognized.
"""
grpo_config = master_config["grpo"]
loss_config = master_config["loss_fn"]
# Provide backward-compatible defaults when adv_estimator is not in config.
# Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline
# which older configs still use.
adv_estimator_config = grpo_config.get(
"adv_estimator",
{
"name": "grpo",