-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathrl_utils.py
More file actions
2023 lines (1753 loc) · 86.6 KB
/
rl_utils.py
File metadata and controls
2023 lines (1753 loc) · 86.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import gc
import copy
from functools import partial
# Keep this to make the env registered.
import itertools
import math
import logging
import json
import os
from collections import Counter, defaultdict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional
import numpy as np
import torch
import torch.distributed as dist
import yaml
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from megatron.core import mpu
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator
from megatron.core.optimizer import MegatronOptimizer
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.pipeline_parallel.utils import is_pp_last_stage, get_pp_last_rank
from megatron.core.rerun_state_machine import RerunDataIterator
from megatron.core.tokenizers import MegatronTokenizer
from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import HuggingFaceTokenizer
from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.utils import (
toggle_cuda_graphs,
transition_moe_cudagraphs,
)
from megatron.core.inference.utils import set_decode_expert_padding
from megatron.core.resharding.refit import swap_model_weights
from megatron.core.inference.unified_memory import (
advise_managed_module_parameters_preferred_location,
prefetch_managed_module_parameters,
)
from megatron.core.inference.utils import device_memory_summary
from megatron.core.utils import get_asyncio_loop, log_single_rank
from megatron.rl.sequence_packing_utils import (
get_microbatch_dataloader,
pack_inference_logprobs,
compute_packed_inference_logprobs_stats,
pack_all_trajectories,
load_packed_data_by_index,
get_sequence_packing_tensorboard_metrics,
get_sequence_packing_log_info,
get_default_packed_seq_params,
update_microbatch_calculator,
)
from megatron.rl.agent.api import (
EvaluationRequest,
EvaluationResponse,
GroupedRolloutRequest,
RewardEvaluationResult,
Rollout,
TokenRollout,
)
from megatron.rl.agent.weighted_multi_task import WeightedMultiTask
from megatron.rl.inference.megatron import MegatronLocal
from megatron.rl.logging import LOG_DIR as lang_rl_log_dir
from megatron.rl.logging import log as lang_rl_log
from megatron.rl.server.inference.inference_interface_server import InferenceInterfaceServer
from megatron.training.global_vars import (
get_args,
get_tensorboard_writer,
get_tokenizer,
get_wandb_writer,
)
from megatron.training.utils import (
get_ltor_masks_and_position_ids,
get_nvtx_range,
print_rank_0,
unwrap_model,
)
from megatron.core.utils import get_pg_rank, get_pg_size, get_attr_wrapped_model
from megatron.core.process_groups_config import ProcessGroupCollection
from wandb import wandb_run
from megatron.core.transformer.custom_layers.batch_invariant_kernels import (
is_batch_invariant_mode_enabled,
)
from megatron.core.inference.contexts.dynamic_context import HAVE_TORCH_MEMORY_SAVER
if HAVE_TORCH_MEMORY_SAVER:
from torch_memory_saver import torch_memory_saver
logger = logging.getLogger(__name__)
# Global variable to store packing context for forward_step
_GLOBAL_PACKING_CONTEXT = None
# Track whether the inference model is currently paused (offloaded to CPU).
# Model starts on GPU after creation and is used immediately, so starts as False.
_INFERENCE_MODEL_IS_PAUSED = False
def _torch_saver_swap_inference_model(*, to_cpu: bool) -> None:
"""Swap RL inference model weights between CPU and GPU using torch_memory_saver.
Uses torch_memory_saver.pause()/resume() to transfer inference model weights
that were allocated within a torch_memory_saver.region() context.
Args:
to_cpu: If True, move weights to CPU (pause). If False, restore weights to GPU (resume).
"""
global _INFERENCE_MODEL_IS_PAUSED
if not HAVE_TORCH_MEMORY_SAVER:
raise RuntimeError(
"torch_memory_saver is required for inference model offloading when not using UVM. "
"Please install it: pip install torch_memory_saver "
"(see https://github.com/fzyzcjy/torch_memory_saver)"
)
tag = "rl_inference_model"
if to_cpu:
if not _INFERENCE_MODEL_IS_PAUSED:
print_rank_0(f"torch_memory_saver: pausing {tag}, before: {device_memory_summary()}")
torch_memory_saver.pause(tag)
_INFERENCE_MODEL_IS_PAUSED = True
print_rank_0(f"torch_memory_saver: paused {tag}, after: {device_memory_summary()}")
else:
if _INFERENCE_MODEL_IS_PAUSED:
print_rank_0(f"torch_memory_saver: resuming {tag}, before: {device_memory_summary()}")
torch_memory_saver.resume(tag)
_INFERENCE_MODEL_IS_PAUSED = False
print_rank_0(f"torch_memory_saver: resumed {tag}, after: {device_memory_summary()}")
def _maybe_prefetch_separate_inference_model_weights(model_core, *, to_cpu: bool) -> None:
"""Prefetch RL *separate inference model* weights to CPU/GPU.
Supports two modes:
1. UVM-based offloading (when --rl-inference-model-unified-memory-level=1)
2. torch_memory_saver-based offloading (when offloading is enabled but UVM is not)
Gated by user args; this assumes the separate inference model was allocated
with UVM or torch_memory_saver when enabled.
"""
args = get_args()
if not args.rl_offload_inference_model_weights_when_idle:
return
# Check for torch_memory_saver path (when offloading is enabled but UVM is not)
if args.rl_inference_model_unified_memory_level != 1:
_torch_saver_swap_inference_model(to_cpu=to_cpu)
return
# UVM-based path (when UVM level is 1)
device = -1 if to_cpu else int(torch.cuda.current_device())
# Note: include_buffers=False because buffers created with explicit device= in register_buffer()
# are not allocated via the UVM mempool and will fail UVM operations. Only parameters are UVM-allocated.
advise_managed_module_parameters_preferred_location(model_core, device=device, include_buffers=False)
nbytes = prefetch_managed_module_parameters(model_core, device=device, include_buffers=False)
# Ensure pages are resident before we enter CUDA-graph capture / inference, or before training continues.
torch.cuda.synchronize()
if to_cpu:
print_rank_0(f"[Rank 0] offloaded {nbytes / 1024**2:.2f} MB of separate RL inference model weights to CPU (other ranks may vary)")
else:
print_rank_0(f"[Rank 0] prefetched {nbytes / 1024**2:.2f} MB of separate RL inference model weights to GPU (other ranks may vary)")
def verify_model_weights_swap(
train_model: LanguageModule,
inference_model: LanguageModule,
seq_len: int = 8,
batch_size: int = 2,
atol: float = 1e-4,
rtol: float = 1e-4,
) -> None:
"""Verify that the inference model produces the same forward pass outputs
as the training model after the weights have been swapped.
This function should be called after swap_model_weights to ensure the weight
transfer was successful. It runs a forward pass on both models and asserts
the outputs match. This is meant for debugging purposes only.
Args:
train_model: The training model (source of weights).
inference_model: The inference model (target of weights).
seq_len: Sequence length for test input.
batch_size: Batch size for test input.
atol: Absolute tolerance for comparing outputs.
rtol: Relative tolerance for comparing outputs.
Raises:
AssertionError: If forward pass outputs do not match within tolerance.
"""
args = get_args()
# Unwrap models to get the core module
train_lm = train_model[0] if isinstance(train_model, (list, tuple)) else train_model
inf_lm = inference_model[0] if isinstance(inference_model, (list, tuple)) else inference_model
train_core = unwrap_model(train_lm)
inf_core = unwrap_model(inf_lm)
actual_vocab_size = getattr(args, 'padded_vocab_size', 128256)
actual_seq_len = min(seq_len, getattr(args, 'seq_length', seq_len))
device = torch.device(f"cuda:{torch.cuda.current_device()}")
# Generate deterministic test input - same across ALL ranks
torch.manual_seed(1234)
test_tokens = torch.randint(
low=0, high=actual_vocab_size, size=(batch_size, actual_seq_len),
device=device, dtype=torch.long
)
test_position_ids = (
torch.arange(actual_seq_len, device=device, dtype=torch.long)
.unsqueeze(0)
.expand(batch_size, -1)
)
test_attention_mask = torch.ones(
(batch_size, 1, actual_seq_len, actual_seq_len), device=device, dtype=torch.bool
)
# Save and restore training state
train_was_training = train_core.training
inf_was_training = inf_core.training
train_core.eval()
inf_core.eval()
try:
with torch.no_grad():
train_output = train_lm(
test_tokens, test_position_ids, test_attention_mask,
runtime_gather_output=True
)
inf_output = inf_lm(
test_tokens, test_position_ids, test_attention_mask,
runtime_gather_output=True
)
# Only check on ranks that have output (last PP stage)
if train_output is not None and inf_output is not None:
assert train_output.shape == inf_output.shape, (
f"Output shape mismatch: train={train_output.shape}, infer={inf_output.shape}"
)
max_diff = (train_output - inf_output).abs().max().item()
assert torch.allclose(train_output, inf_output, atol=atol, rtol=rtol), (
f"Forward pass outputs do not match: max_diff={max_diff:.6e}, atol={atol}, rtol={rtol}"
)
finally:
# Restore training state
if train_was_training:
train_core.train()
if inf_was_training:
inf_core.train()
Rollouts = list[TokenRollout | Rollout]
GroupedRollouts = list[Rollouts]
@dataclass(slots=True)
class RolloutStats:
rewards: list[list[float]] # inner list is for a group
env_ids: list[str] # same length as len(rewards)
turn_lens: list[list[int]] # token lengths of turns, grouped.
traj_lens: list[list[int]] # all turns comprise one trajectory.
num_turns: None | list[list[int]] # num_turns per traj
advantages: None | list[list[float]]
min_piold_to_inf_prob: None | float
max_piold_to_inf_prob: None | float
mean_piold_to_inf_prob: None | float
min_inf_train_prob_abs_diff: None | float
max_inf_train_prob_abs_diff: None | float
mean_inf_train_prob_abs_diff: None | float
min_inf_prob: None | float
max_inf_prob: None | float
mean_inf_prob: None | float
policy_staleness: list[list[int]]
kv_cache_staleness: list[list[int]]
completed_at_steps: list[list[int]]
num_evictions: list[list[int]]
# Runtime state container for RL-specific data that shouldn't be checkpointed
class RLRuntimeState:
"""Container for runtime state that is not checkpointed, tracking state between rollout collections"""
def __init__(self):
self.packing_context = None
self.last_collection_iteration = 0
self.sequences_this_iteration_on_rank = 0
self.latest_batch_num_sequences = 0
def reset_iteration_counters(self, iteration):
"""Reset per-iteration counters."""
self.sequences_this_iteration_on_rank = 0
self.last_collection_iteration = iteration
def increment_sequences(self, count):
"""Increment the sequence counter."""
self.sequences_this_iteration_on_rank += count
self.latest_batch_num_sequences = count
# Global runtime state instance
_rl_runtime_state = RLRuntimeState()
def get_rl_runtime_state():
"""Get the global RL runtime state."""
return _rl_runtime_state
def update_inference_logprobs_group_stats(
old_logprobs: torch.Tensor,
inference_logprobs: torch.Tensor,
mask: torch.Tensor,
group_stats: Any,
) -> None:
"""Update group statistics with inference/train logprobs comparison metrics.
This is the common statistics computation used by both packed and unpacked cases.
Args:
old_logprobs: Old logprobs tensor (train side)
inference_logprobs: Inference logprobs tensor (aligned to match old_logprobs shape)
mask: Boolean mask indicating valid positions for statistics
group_stats: Statistics object to update with computed metrics
"""
n_elems = mask.sum()
if n_elems > 0:
ratios = (old_logprobs - inference_logprobs).exp()[mask]
abs_diffs = (old_logprobs.exp() - inference_logprobs.exp()).abs()[mask]
group_stats.min_piold_to_inf_prob = ratios.min().item()
group_stats.max_piold_to_inf_prob = ratios.max().item()
group_stats.mean_piold_to_inf_prob = (ratios.sum() / n_elems).item()
group_stats.min_inf_train_prob_abs_diff = abs_diffs.min().item()
group_stats.max_inf_train_prob_abs_diff = abs_diffs.max().item()
group_stats.mean_inf_train_prob_abs_diff = (abs_diffs.sum() / n_elems).item()
inf_probs = inference_logprobs.exp()[mask]
group_stats.min_inf_prob = inf_probs.min().item()
group_stats.max_inf_prob = inf_probs.max().item()
group_stats.mean_inf_prob = inf_probs.mean().item()
def align_unpacked_inference_logprobs(
inference_logprobs: List[torch.Tensor],
old_logprobs_for_data: torch.Tensor,
generation_masks: torch.Tensor,
group_stats: Any,
) -> torch.Tensor:
"""Align inference logprobs with old_logprobs for unpacked sequences and compute statistics.
Args:
inference_logprobs: List of inference logprobs tensors for each sequence
old_logprobs_for_data: Template tensor with correct shape for alignment
generation_masks: Tensor indicating which tokens were generated
group_stats: Statistics object to update with computed metrics
Returns:
Aligned inference logprobs tensor
"""
# Get first occurrence of a generation token
# In get_logprobs() we chop off the first token -> the generation mask is shifted by one
gen_masks_for_alignment = generation_masks
first_gen_tok = gen_masks_for_alignment.int().argmax(dim=1) - 1
# Align inference logprobs with old_logprobs
# Note: We use old_logprobs_for_data as template since it has correct shape
padded_inference_logprobs = old_logprobs_for_data.clone()
# We need to align old_logprobs and inference logprobs as the latter are only for generations
for i, inf_logprobs in enumerate(inference_logprobs):
first_gen_idx = first_gen_tok[i]
# We subtract -1 here because we append eod token on the train side, and we do not
# get it from the inference. For the eod token, we reuse old_logprobs value.
end_idx = min(first_gen_idx + len(inf_logprobs), padded_inference_logprobs.shape[1])
actual_len = end_idx - first_gen_idx
if actual_len > 0:
padded_inference_logprobs[i, first_gen_idx:end_idx] = inf_logprobs[:actual_len]
# Create truncated mask for statistics
if old_logprobs_for_data.shape[1] + 1 < gen_masks_for_alignment.shape[1]:
gen_masks_for_alignment = gen_masks_for_alignment[:, : old_logprobs_for_data.shape[1] + 1]
truncated_mask = gen_masks_for_alignment[:, 1:].bool()
# Final safety check
if truncated_mask.shape != old_logprobs_for_data.shape:
if truncated_mask.shape[1] > old_logprobs_for_data.shape[1]:
truncated_mask = truncated_mask[:, : old_logprobs_for_data.shape[1]]
elif truncated_mask.shape[1] < old_logprobs_for_data.shape[1]:
pad_size = old_logprobs_for_data.shape[1] - truncated_mask.shape[1]
truncated_mask = torch.nn.functional.pad(truncated_mask, (0, pad_size), value=False)
# Sanity check: Two probability values cannot be more than 1.0 apart
abs_diffs = (old_logprobs_for_data.exp() - padded_inference_logprobs.exp()).abs()[truncated_mask]
assert all(abs_diffs <= 1.0)
# Update group statistics using common helper
update_inference_logprobs_group_stats(
old_logprobs=old_logprobs_for_data,
inference_logprobs=padded_inference_logprobs,
mask=truncated_mask,
group_stats=group_stats,
)
return padded_inference_logprobs
def get_agent(args, parallel_generation_tasks: int | None = None):
"""Get an agent based on environment configuration.
If args.langrl_env_config is provided, uses weighted environment selection.
Otherwise falls back to legacy single environment selection.
"""
with open(args.langrl_env_config, 'r') as f:
config = yaml.safe_load(f)
return WeightedMultiTask.from_config(
config,
parallel_generation_tasks=parallel_generation_tasks,
)
_INFERENCE_INTERFACE = None
def get_inference_interface(args, loop, model):
global _INFERENCE_INTERFACE
if _INFERENCE_INTERFACE is None:
_INFERENCE_INTERFACE = loop.run_until_complete(
MegatronLocal.launch(
model[0],
host='0.0.0.0',
port=8294,
verbose=args.inference_text_gen_server_logging)
)
return _INFERENCE_INTERFACE
_ROLLOUT_GENERATOR = None
def get_rollout_generator(args, inference_interface, n_prompts, samples_per_group):
global _ROLLOUT_GENERATOR
if not args.rl_partial_rollouts or _ROLLOUT_GENERATOR is None:
agent = get_agent(args, parallel_generation_tasks=args.rl_parallel_generation_tasks)
# Collect Rollouts
request = GroupedRolloutRequest(
num_groups=-1 if args.rl_partial_rollouts else n_prompts,
rollouts_per_group=samples_per_group,
inference_interface=inference_interface,
generation_args={
'temperature': args.rl_default_temperature,
'max_tokens': args.inference_max_seq_length,
'top_p': args.rl_default_top_p,
'top_k': args.rl_default_top_k,
},
filter_groups_with_same_reward=args.grpo_filter_groups_with_same_reward,
)
_ROLLOUT_GENERATOR = agent.get_grouped_rollouts(request)
return _ROLLOUT_GENERATOR
def get_environment_rollouts(
model: LanguageModule, inference_model: LanguageModule, optimizer: MegatronOptimizer, n_prompts: int, samples_per_group: int
):
"""Sample environment rollouts from an LLM.
Args:
model: Model to sample from.
inference_model: Inference model to use for inference.
n_prompts: Number of prompts to sample for across *all* data parallel workers.
samples_per_group: Amount of trajectories per prompt.
Returns:
GroupedRollouts object which is a nested list with each element being a list of rollouts of a group.
"""
args = get_args()
nvtx_range = get_nvtx_range()
if args.rl_offload_optimizer_during_inference:
with nvtx_range("rl/offload-optimizer-before-inference", time=True):
if not args.rl_training_cuda_graphs:
with nvtx_range("rl/offload/grad-buffers", time=True):
model[0].offload_grad_buffers()
else:
logger.warning(
"Gradient buffers will not be offloaded when training cudagraphs are enabled!")
with nvtx_range("rl/offload/optimizer-state", time=True):
optimizer.offload_to_cpu()
# If we have separate training and inference models we to refit weights from the training model to the inference model.
has_separate_inference_model = inference_model is not None
if has_separate_inference_model:
# If the separate inference model weights were prefetched to CPU while idle, bring them
# back to GPU before refit/copy and before any CUDA-graph'd inference.
with nvtx_range("rl/prefetch-weights-to-gpu", time=True):
inf_core = unwrap_model(inference_model[0])
_maybe_prefetch_separate_inference_model_weights(inf_core, to_cpu=False)
swap_model_weights(model, inference_model, args.refit_method)
if args.rl_verify_model_weights_swap:
verify_model_weights_swap(
train_model=model,
inference_model=inference_model,
atol=.1,
rtol=5e-4,
)
else:
inference_model = model
inference_pg_collection = get_attr_wrapped_model(inference_model[0], "pg_collection")
pg_size = get_pg_size(inference_pg_collection.ep)
assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}"
with nvtx_range("rl/rollout-collection", time=True):
loop = get_asyncio_loop()
with megatron_rl_inference_mode(
inference_model,
optimizer,
args.cuda_graph_impl,
False, # offload optimizer during rollout collection is handled above
training_model=model if has_separate_inference_model else None,
increment_staleness_on_suspend=True,
) as inference_interface:
with nvtx_range("rl/inference-setup", time=True):
# Asyncronously run inference and rollout collection
rollout_generator = get_rollout_generator(
args, inference_interface, n_prompts, samples_per_group
)
# NOTE(jbarker): we need to double check this when using PP>1
rank = torch.distributed.get_rank()
with nvtx_range("rl/collect-rollouts", time=True):
if rank == 0:
log_single_rank(
logger,
logging.INFO,
f"Collecting rollouts, Iteration {args.curr_iteration}...",
)
rollouts = [
loop.run_until_complete(anext(rollout_generator)) for _ in range(n_prompts)
]
# In deterministic mode, sort rollouts by problem_id for consistent ordering
# regardless of completion order due to system timing jitter.
if torch.are_deterministic_algorithms_enabled():
rollouts.sort(key=lambda group: group[0].problem_id if group and group[0].problem_id else "")
if not args.rl_partial_rollouts:
while True:
try:
loop.run_until_complete(anext(rollout_generator))
assert False, "Unexpected group left in generator."
except StopAsyncIteration:
break
else:
# Just set up space to collect the rollouts
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]
with nvtx_range("rl/sync-rollouts", time=True):
# Wait for Rollouts to be collected
# TODO(jbarker): double check why this isn't causing rank 0 memory allocations
torch.distributed.broadcast_object_list(rollouts, src=0)
logger.debug(f"Got rollouts on rank {rank}")
if args.rl_offload_optimizer_during_inference:
with nvtx_range("rl/restore-optimizer-after-inference", time=True):
with nvtx_range("rl/restore/grad-buffers", time=True):
model[0].restore_grad_buffers()
with nvtx_range("rl/restore/optimizer-state", time=True):
optimizer.restore_from_cpu()
if lang_rl_log_dir and rank == get_pg_rank(inference_pg_collection.tp):
with open(
lang_rl_log_dir
+ f'/rollouts_rank{rank}_iteration{args.curr_iteration}_'
+ f'{Path(args.langrl_env_config).stem}.json',
'w',
) as f:
json.dump([[r.model_dump() for r in group] for group in rollouts], f)
return rollouts
def selective_log_softmax(logits, index):
"""Taken from: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/utils.py#L1659.
A memory-efficient implementation of the common `log_softmax -> gather` operation.
This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
use_bik_logsoftmax = is_batch_invariant_mode_enabled()
if logits.dtype in [torch.float32, torch.float64] and not use_bik_logsoftmax:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = (
selected_logits - logsumexp_values
) # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = torch.nn.functional.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(
-1
)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def get_logprobs(model, tokens, position_ids, no_grad=False, sequence_packing=False, packed_seq_params=None):
"""Get sequence logprobs from their token ids.
Args:
model: model to predict with.
tokens: inputs for which we want to get logprobs.
position_ids: position ids that come with tokens.
attention_mask: attention mask that comes with tokens.
no_grad: whether to run in no_grad mode.
packed_seq_params: Optional PackedSeqParams for sequence packing with TE.
When provided with qkv_format='thd', the input tokens are sliced to
remove padding before the forward pass, and outputs are padded back.
packed_seq_len: Optional length of the packed sequence (excluding padding).
Required when packed_seq_params is provided to avoid CPU-GPU synchronization.
Returns:
Logprobs of input sequences.
"""
args = get_args()
# Ensure packed_seq_params is always provided for CUDA graph signature consistency.
# When sequence_packing is enabled, construct from packing config (max_sequences_per_bin).
# When sequence_packing is disabled, construct a single-sequence default so the CUDA
# graph signature matches the training forward_step in train_rl.py.
# This is necessary because reference logprobs steps will reuse the training forward graph.
if packed_seq_params is None:
if sequence_packing:
packed_seq_params = get_default_packed_seq_params(
seq_length=tokens.shape[1],
max_sequences_per_bin=args.rl_sequence_packing_max_sequences_per_bin,
device=tokens.device,
)
else:
cu_seqlens = torch.tensor([0, tokens.shape[1]], dtype=torch.int32, device=tokens.device)
packed_seq_params = PackedSeqParams(
qkv_format='thd',
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=tokens.shape[1],
max_seqlen_kv=tokens.shape[1],
total_tokens=tokens.shape[1],
)
nvtx_range = get_nvtx_range()
with nvtx_range("rl/get-logprobs", time=True):
with nvtx_range("rl/forward-pass", time=True):
# TODO(vitalyk): use fp16/bf16 as a function argument. Do not use args.
attention_mask_for_forward = None
# This is a hack to fix megatron's behaviour when flash-decode affects the training code flow.
flash_decode = model.config.flash_decode
model.config.flash_decode = False
fp32_output = not (args.fp16 or args.bf16)
with torch.no_grad() if no_grad else nullcontext():
logits_or_hidden_states = model(
tokens,
position_ids,
attention_mask_for_forward,
packed_seq_params=packed_seq_params,
runtime_gather_output=True,
fp32_output=fp32_output,
)
model.config.flash_decode = flash_decode
pg_collection = get_attr_wrapped_model(model, "pg_collection")
pp_group = pg_collection.pp
if not is_pp_last_stage(pp_group):
return logits_or_hidden_states
else:
logits = logits_or_hidden_states
with nvtx_range("rl/log-softmax", time=True):
# We do not need logprobs for the n+1 token.
logprobs = selective_log_softmax(logits[:, :-1, :], tokens[:, 1:])
return logprobs
def calculate_grpo_advantages(rewards: list[list[float]], num_turns: list[list[int]]) -> np.ndarray:
"""Calculate GRPO advantages from rewards/num_turns.
For multiturn rollouts, the logic is a bit more involved.
# For training, we'll be turning each turn into a trajectory with the same reward
# within a trajectory, e.g. if [[a,b],[c,d,e]] trajectory has reward 1.0, we will
# get [a,b] with 1.0 and [c,d,e] with 1.0 when doing updates.
"""
rewards = np.array(rewards)
num_turns = np.array(num_turns)
# Each outer dimension of num_turns is a group. Sum of those gives total num_turns per group.
# Let's use this to calculate advantage.
# mean/std should be repeated based on group lens
group_turns = num_turns.sum(axis=-1)
reward_means = rewards.mean(axis=1, keepdims=True).repeat(group_turns)
reward_stds = rewards.std(axis=1, keepdims=True).repeat(group_turns)
# rewards are originally [g, group_size]
# Making an assumption that all groups are of the same size!
# @vitalyk: this will go away when we start sending env-based sample reqs.
rewards = rewards.flatten().repeat(num_turns.flatten())
return ((rewards - reward_means) / (1e-4 + reward_stds)).tolist()
def compute_group_stats(
rollouts: GroupedRollouts, tokenizer: MegatronTokenizer, seq_len: int,
) -> RolloutStats:
"""Add group-based rollout stats for logging.
Args:
rollouts: Rollouts to generate the stats for. Each inner list is a group (as in GRPO group), i.e. all rollouts are for the same prompt.
tokenizer: Tokenizer to tokenize the rollouts in case they are raw strings.
seq_len: Maximum sequence length.
Returns:
RolloutStats object containing all the stats.
"""
# TODO (rkirby) Maybe do some of this after the tensor building
group_reward_means = []
group_reward_stds = []
turn_lens = []
traj_lens = []
rewards = []
env_ids = []
group_reward_ids = []
num_turns = [] # num_turns per traj
all_policy_staleness = []
all_kv_cache_staleness = []
all_completed_at_steps = []
all_num_evictions = []
for group in rollouts:
group_rewards = []
group_traj_lengths = []
group_turn_lengths = []
group_num_turns = []
group_policy_staleness = []
group_kv_staleness = []
group_completed_at_steps = []
group_num_evictions = []
for rollout in group:
if isinstance(rollout, TokenRollout):
for turn_traj in rollout.trajectory:
detokenized_traj = tokenizer.detokenize(turn_traj)
lang_rl_log(
f"Rollout: [{rollout.env_id}] [{rollout.reward} : {len(rollout.trajectory)} tokens] {detokenized_traj}"
)
# TODO(vitalyk): how does multiturn change EOD/EOT?
assert (len(turn_traj) == seq_len) or (
turn_traj[-1] == tokenizer.eod
), f"Rollout is not the correct length: {len(turn_traj)} {turn_traj[-1]}\n{detokenized_traj}"
else:
lang_rl_log(
f"Rollout: [{rollout.env_id}] [{rollout.reward} : {len(rollout.trajectory)} chars] {rollout.trajectory}"
)
group_num_turns.append(len(rollout.trajectory))
group_rewards.append(rollout.reward)
roll_turn_lens = [len(t) for t in rollout.trajectory]
group_turn_lengths.extend(roll_turn_lens)
group_traj_lengths.append(sum(roll_turn_lens))
group_policy_staleness.extend(s for turn in rollout.policy_staleness for s in turn)
group_kv_staleness.extend(s for turn in rollout.kv_cache_staleness for s in turn)
group_completed_at_steps.extend(rollout.completed_at_step)
group_num_evictions.append(sum(rollout.num_evictions))
all_policy_staleness.append(group_policy_staleness)
all_kv_cache_staleness.append(group_kv_staleness)
all_completed_at_steps.append(group_completed_at_steps)
all_num_evictions.append(group_num_evictions)
traj_lens.append(group_traj_lengths)
turn_lens.append(group_turn_lengths)
env_ids.append(group[0].env_id) # All rollouts in a group share the env_id by design.
rewards.append(group_rewards)
# https://arxiv.org/abs/2504.21233 reports that lens variance hurts.
# Let's track this.
num_turns.append(group_num_turns)
stats = RolloutStats(
traj_lens=traj_lens,
turn_lens=turn_lens,
rewards=rewards,
# --------
# Everything above is per-group, i.e. it is a list of lists,
# with the inner list being the group data.
env_ids=env_ids,
num_turns=num_turns,
advantages=calculate_grpo_advantages(rewards, num_turns),
min_piold_to_inf_prob=None,
max_piold_to_inf_prob=None,
mean_piold_to_inf_prob=None,
min_inf_train_prob_abs_diff=None,
max_inf_train_prob_abs_diff=None,
mean_inf_train_prob_abs_diff=None,
min_inf_prob=None,
max_inf_prob=None,
mean_inf_prob=None,
policy_staleness=all_policy_staleness,
kv_cache_staleness=all_kv_cache_staleness,
completed_at_steps=all_completed_at_steps,
num_evictions=all_num_evictions,
)
return stats
def compute_true_staleness(
per_token_staleness: list[list[int]],
completed_at_steps: list[list[int]],
turn_lens: list[list[int]],
current_iteration: int,
) -> list[int]:
"""Compute true per-token staleness by adding the completion gap.
Args:
per_token_staleness: Grouped flat list of per-token raw staleness values.
completed_at_steps: Grouped list of per-turn completion steps.
turn_lens: Grouped list of per-turn token counts.
current_iteration: Current training iteration.
Returns:
Flat list of true staleness values (one per token across all groups).
"""
result = []
for group_staleness, group_completed, group_turn_lens in zip(
per_token_staleness, completed_at_steps, turn_lens
):
token_idx = 0
for completed_at, num_tokens in zip(group_completed, group_turn_lens):
gap = current_iteration - completed_at
for _ in range(num_tokens):
result.append(group_staleness[token_idx] + gap)
token_idx += 1
return result
def prep_wandb_metrics(
wandb_writer: wandb_run.Run,
traj_lens: List[List[int]],
turn_lens: List[List[int]],
rewards: List[List[float]],
num_turns: List[List[int]],
advantages: List[float],
policy_staleness: List[List[int]],
kv_cache_staleness: List[List[int]],
num_evictions: List[List[int]],
completed_at_steps: List[List[int]],
current_iteration: int,
example_group: list[TokenRollout | Rollout] | None = None,
tokenizer: MegatronTokenizer | None = None,
):
"""Make a wandb-parseable dictionary of metrics for logging.
Args:
wandb_writer: Wandb run to log to.
traj_lens: Grouped list of trajectory lengths.
turn_lens: Grouped list of turn lengths.
rewards: Grouped list of rewards.
num_turns: Grouped list of number of turns in the trajectories.
advantages: Flattened list of advantages.
policy_staleness: Grouped list of per-token policy staleness.
kv_cache_staleness: Grouped list of per-token KV cache staleness.
num_evictions: Grouped list of per-rollout number of evictions.
completed_at_steps: Grouped list of per-turn completed at steps.
current_iteration: Current training iteration.
example_group: A list of rollouts of one group to log examples of trajectories.
tokenizer: Tokenizer to untokenize trajectories for logging.
"""
group_table = wandb_writer.Table(
columns=['group_means', 'group_stds'],
data=[[np.mean(g), np.std(g)] for g in rewards],
)
true_policy_staleness = compute_true_staleness(
policy_staleness, completed_at_steps, turn_lens, current_iteration)
true_kv_staleness = compute_true_staleness(
kv_cache_staleness, completed_at_steps, turn_lens, current_iteration)
metrics = {
'group_means_hist': wandb_writer.plot.histogram(
group_table, 'group_means', 'Group Means'
),
'group_stds_hist': wandb_writer.plot.histogram(
group_table, 'group_stds', 'Group STDs'
),
'rewards_hist': wandb_writer.plot.histogram(
wandb_writer.Table(
columns=['reward'], data=[[r] for g in rewards for r in g]
),
'reward', 'All Rewards'
),
'advantages_hist': wandb_writer.plot.histogram(
wandb_writer.Table(
columns=['advantages'], data=[[x] for x in advantages]
),
'advantages', 'Advantages'
),
'rollout_table': wandb_writer.Table(
columns=['reward', 'traj_length', 'num_evictions'],
data=list(zip(
[r for g in rewards for r in g],
[l for g in traj_lens for l in g],
[e for g in num_evictions for e in g],
)),
),
'mean_turn_length': np.mean([np.mean(g) for g in turn_lens]),
'mean_turn_length_std': np.mean([np.std(g) for g in turn_lens]),
'max_turn_length': max([max(g) for g in turn_lens]),
'min_turn_length': min([min(g) for g in turn_lens]),
'mean_traj_length': np.mean([np.mean(g) for g in traj_lens]),
'mean_traj_length_std': np.mean([np.std(g) for g in traj_lens]),
'max_traj_length': max([max(g) for g in traj_lens]),
'min_traj_length': min([min(g) for g in traj_lens]),
'mean_num_turns': np.mean([np.mean(g) for g in num_turns]),
'max_num_turns': max([max(g) for g in num_turns]),
'min_num_turns': min([min(g) for g in num_turns]),
'mean_reward': np.mean([np.mean(g) for g in rewards]),
'mean_advantage': np.mean(advantages),
'nonzero_groups_ratio': np.count_nonzero(advantages)
/ len(advantages),
'mean_policy_staleness': np.mean(true_policy_staleness),
'max_policy_staleness': max(true_policy_staleness),
'min_policy_staleness': min(true_policy_staleness),
'mean_kv_cache_staleness': np.mean(true_kv_staleness),
'max_kv_cache_staleness': max(true_kv_staleness),
'min_kv_cache_staleness': min(true_kv_staleness),
'total_eviction_count': sum([sum(g) for g in num_evictions]),
'max_num_evictions': max([max(g) for g in num_evictions]),
'mean_completion_gap': np.mean([current_iteration - s for g in completed_at_steps for s in g]),
}
if example_group:
if tokenizer is None:
raise ValueError("If you provide an example group to log, you need to provide a tokenizer too.")
metrics['rollouts'] = wandb_writer.Table(
columns=['Trajectories', 'Tokens', 'Rewards'],
rows=[
[
tokenizer.detokenize(turn) if isinstance(r, TokenRollout) else turn,
r.trajectory,
r.reward,
]
for r in example_group for turn in r.trajectory
],
)
return metrics
def maybe_log_training_metrics(
group_stats: RolloutStats,
current_iteration: int,
tokenizer: MegatronTokenizer,
example_groups: dict[str, list[TokenRollout | Rollout]],
):
"""Log training metrics if writers are available.
Args:
group_stats: RolloutStats object to pass to writers.
current_iteration: Current training iteration.
tokenizer: Tokenizer to untokenize trajectories for logging.
example_groups: A dict with values as list of rollouts of one group to log examples of trajectories. Keys are env names.
"""