-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathsequence_packing_utils.py
More file actions
1175 lines (1011 loc) · 44.4 KB
/
sequence_packing_utils.py
File metadata and controls
1175 lines (1011 loc) · 44.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import torch
import math
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from torch.utils.data import DataLoader, TensorDataset
from dataclasses import dataclass, field
from megatron.core.utils import log_single_rank
from megatron.training.global_vars import get_args, get_tokenizer
from megatron.training.utils import get_nvtx_range
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core import mpu
import logging
import typing
from megatron.core.num_microbatches_calculator import (
get_num_microbatches,
reconfigure_num_microbatches_calculator,
)
logger = logging.getLogger(__name__)
@dataclass
class PackingInfo:
"""Information about how sequences are packed into bins.
Attributes:
bin_seq_indices: List where each element contains the global sequence indices in that bin
seq_starts: Dict mapping bin index to list of start positions for each sequence in that bin
seq_lengths: List of all original sequence lengths (indexed by global sequence index)
seq_to_bin_idx: List mapping each global sequence index to its bin index
packing_algo: Algorithm used for distributing bins ('fifo' or 'round-robin')
"""
bin_seq_indices: List[List[int]]
seq_starts: Dict[int, List[int]]
seq_lengths: List[int]
seq_to_bin_idx: List[Optional[int]]
packing_algo: typing.Literal['fifo', 'round-robin']
@dataclass
class PackingContext:
"""Context containing all information needed for sequence packing during training.
Attributes:
bin_size: Maximum size of each bin (in tokens)
packer: 'SequencePacker' instance used for packing
packing_info: PackingInfo object with bin assignments and metadata
original_generation_masks: Generation masks for all sequences before packing
original_trajs: All trajectories before packing
packed_trajs: Packed trajectories tensor [num_bins, bin_size]
packed_position_ids: Position IDs for packed sequences [num_bins, bin_size]
packed_attention_mask: Attention mask for packed sequences [num_bins, 1, bin_size, bin_size]
packed_loss_mask: Loss mask for packed sequences [num_bins, bin_size]
original_inference_logprobs: Inference logprobs for all sequences before packing (optional)
bin_advantages: List of advantage tensors for each bin
cached_packed_seq_params: Pre-computed PackedSeqParams for each bin
"""
bin_size: int
packer: 'SequencePacker'
packing_info: PackingInfo
original_generation_masks: torch.Tensor
original_trajs: torch.Tensor
packed_trajs: torch.Tensor
packed_position_ids: torch.Tensor
packed_attention_mask: torch.Tensor
packed_loss_mask: torch.Tensor
original_inference_logprobs: Optional[torch.Tensor] = None
bin_advantages: List[torch.Tensor] = field(default_factory=list)
cached_packed_seq_params: List[Optional[PackedSeqParams]] = field(default_factory=list)
def load_packed_data_by_index(bin_idx: int, packing_context: PackingContext, logprobs_is_correction: bool):
"""Load packed data by index.
Args:
bin_idx: Index of the bin to load.
"""
# Get packing context (should always be available in packed mode)
idx = slice(bin_idx, bin_idx + 1)
# Get cached PackedSeqParams for proper attention masking in Transformer Engine
# These were pre-computed in prepare_data_for_update to avoid repeated tensor allocations
packed_seq_params = packing_context.cached_packed_seq_params[bin_idx]
# Extract packed data for this bin (already on GPU)
tokens = packing_context.packed_trajs[idx]
position_ids = packing_context.packed_position_ids[idx]
# Check if we have old_logprobs and ref_logprobs as attributes
# These are set after logprobs computation, so they may not exist during initial forward pass
old_logprobs = getattr(packing_context, 'old_logprobs', None)
if old_logprobs is not None:
old_logprobs = old_logprobs[idx]
ref_logprobs = getattr(packing_context, 'ref_logprobs', None)
if ref_logprobs is not None:
ref_logprobs = ref_logprobs[idx]
# Slice from position 1 because logprobs predict the next token, so they are
# shifted by 1 relative to the input tokens (logprobs has shape [batch, seq_len-1])
loss_mask = packing_context.packed_loss_mask[idx, 1:]
# Get sequence-level data for this bin
packing_info = packing_context.packing_info
seq_starts = packing_info.seq_starts[bin_idx]
seq_indices = packing_info.bin_seq_indices[bin_idx]
# Handle empty bins (used for padding to ensure all ranks have same iterations)
if not seq_indices:
seq_lengths = []
advantages = torch.tensor([], device='cuda')
else:
seq_lengths = [packing_info.seq_lengths[idx] for idx in seq_indices]
advantages = packing_context.bin_advantages[bin_idx]
# Extract packed inference_logprobs if available
packed_inference_logprobs = getattr(packing_context, 'packed_inference_logprobs', None)
if packed_inference_logprobs is not None and logprobs_is_correction:
inference_logprobs = packed_inference_logprobs[idx]
else:
inference_logprobs = None
return (
tokens,
advantages,
old_logprobs,
loss_mask,
position_ids,
ref_logprobs,
inference_logprobs,
seq_starts,
seq_lengths,
seq_indices,
packed_seq_params,
)
def log_packing_efficiency(packing_context: PackingContext):
# Log packing efficiency (for this rank's bins)
packing_info = packing_context.packing_info
packed_trajs = packing_context.packed_trajs
my_bin_seq_indices = packing_info.bin_seq_indices
num_bins = len(packing_info.bin_seq_indices)
total_tokens = sum(packing_info.seq_lengths) # All sequences
my_sequences = sum(len(indices) for indices in my_bin_seq_indices)
my_tokens = sum(
packing_info.seq_lengths[idx]
for indices in my_bin_seq_indices
for idx in indices
)
total_capacity = packed_trajs.shape[0] * packed_trajs.shape[1]
packing_efficiency = my_tokens / total_capacity if total_capacity > 0 else 0
avg_seq_length = total_tokens / len(packing_info.seq_lengths)
rank = mpu.get_data_parallel_rank()
log_single_rank(logger, logging.INFO, "[Sequence Packing] Statistics:")
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Total sequences: {len(packing_info.seq_lengths)}",
)
log_single_rank(
logger, logging.INFO, f"[Sequence Packing] - Total bins: {num_bins}"
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Bin size: {packed_trajs.shape[1]} tokens",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Average sequence length: {avg_seq_length:.1f} tokens",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - This rank: {my_sequences} sequences in {packed_trajs.shape[0]} bins",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Packing efficiency: {packing_efficiency:.1%} ({my_tokens:,} / {total_capacity:,} tokens)",
)
# Add detailed per-rank sequence distribution analysis
if torch.distributed.is_initialized():
# Gather sequence counts from all ranks
seq_counts_per_bin = [len(indices) for indices in my_bin_seq_indices]
non_empty_bins = [c for c in seq_counts_per_bin if c > 0]
# Create tensor with rank statistics
rank_stats = torch.tensor(
[
float(rank),
float(len(my_bin_seq_indices)), # total bins
float(len(non_empty_bins)), # non-empty bins
float(my_sequences), # total sequences
(
float(min(non_empty_bins)) if non_empty_bins else 0.0
), # min sequences per bin
(
float(max(non_empty_bins)) if non_empty_bins else 0.0
), # max sequences per bin
(
float(my_sequences / len(non_empty_bins)) if non_empty_bins else 0.0
), # avg sequences per non-empty bin
],
device='cuda',
)
# Gather from all ranks
world_size = mpu.get_data_parallel_world_size()
all_rank_stats = [torch.zeros_like(rank_stats) for _ in range(world_size)]
torch.distributed.all_gather(
all_rank_stats, rank_stats, group=mpu.get_data_parallel_group()
)
# Print detailed statistics for each rank
if rank == 0:
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Per-rank distribution ({packing_info.packing_algo} algorithm):",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] Rank | Total Bins | Non-empty | Sequences | Min/Bin | Max/Bin | Avg/Bin",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] -----|------------|-----------|-----------|---------|---------|--------",
)
for stats in all_rank_stats:
r = int(stats[0].item())
total_bins = int(stats[1].item())
non_empty = int(stats[2].item())
sequences = int(stats[3].item())
min_seq = int(stats[4].item())
max_seq = int(stats[5].item())
avg_seq = stats[6].item()
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] {r:3d} | {total_bins:10d} | {non_empty:9d} | {sequences:9d} | {min_seq:7d} | {max_seq:7d} | {avg_seq:6.1f}",
)
# Also show first few bins for rank 0 as example
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Example (Rank 0 first 10 bins): {seq_counts_per_bin[:10]}",
)
# Show the improvement from round-robin
total_seqs_all_ranks = sum(int(stats[3].item()) for stats in all_rank_stats)
avg_seqs_per_rank = total_seqs_all_ranks / world_size
max_deviation = max(
abs(int(stats[3].item()) - avg_seqs_per_rank)
for stats in all_rank_stats
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] Round-robin distribution quality:",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Average sequences per rank: {avg_seqs_per_rank:.1f}",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Max deviation from average: {max_deviation:.0f} sequences ({max_deviation/avg_seqs_per_rank*100:.1f}%)",
)
def get_actual_sequence_lengths(sequences: torch.Tensor, pad_token: int) -> List[int]:
"""Get actual sequence lengths for pre-padded sequences.
Args:
sequences: Tensor of shape [batch_size, seq_len] with pre-padded sequences
pad_token: The padding token ID
Returns:
List of actual sequence lengths (excluding padding)
"""
if len(sequences.shape) != 2:
raise ValueError(f"Expected 2D tensor, got shape {sequences.shape}")
actual_lengths = []
# Find actual length of each sequence by locating where padding starts
for seq in sequences:
# Find the last non-padding token
non_pad_mask = seq != pad_token
if non_pad_mask.any():
# Get the position of the last non-padding token
actual_length = non_pad_mask.nonzero(as_tuple=True)[0][-1].item() + 1
else:
actual_length = 0 # All padding
actual_lengths.append(actual_length)
return actual_lengths
def create_empty_bins(
num_empty_bins : int,
bin_size : int,
packed_trajs : torch.Tensor,
packed_position_ids : torch.Tensor,
packed_loss_mask : torch.Tensor,
packed_attention_mask : torch.Tensor,
tokenizer,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
"""Create empty bins for padding to ensure all ranks have the same number of bins.
Args:
num_empty_bins: Number of empty bins to create
bin_size: Size of each bin
packed_trajs: Packed trajectories tensor (for dtype/device reference)
packed_position_ids: Packed position IDs tensor (for dtype/device reference)
packed_loss_mask: Packed loss mask tensor (for dtype/device reference)
packed_attention_mask: Packed attention mask tensor (can be None)
tokenizer: Tokenizer for pad token
Returns:
Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info_entries)
"""
device = packed_trajs.device
# Create empty bins with proper shape
empty_bins = []
empty_position_ids_list = []
empty_loss_mask_list = []
empty_attention_mask_list = []
empty_packing_info_entries = []
for i in range(num_empty_bins):
# Trajectories filled with pad tokens
empty_bin = torch.full(
(1, bin_size), tokenizer.pad, dtype=packed_trajs.dtype, device=device
)
empty_bins.append(empty_bin)
# Zero position IDs
empty_pos_ids = torch.zeros(1, bin_size, dtype=packed_position_ids.dtype, device=device)
empty_position_ids_list.append(empty_pos_ids)
# Zero loss mask (so no loss contribution)
empty_loss = torch.zeros(1, bin_size, dtype=packed_loss_mask.dtype, device=device)
empty_loss_mask_list.append(empty_loss)
# Zero attention mask if needed
if packed_attention_mask is not None:
# Attention mask is always 4D: [num_bins, 1, bin_size, bin_size]
empty_attn = torch.zeros(
1, 1, bin_size, bin_size, dtype=packed_attention_mask.dtype, device=device
)
empty_attention_mask_list.append(empty_attn)
# Empty packing info entries
empty_packing_info_entries.append(
{
'bin_seq_indices': [], # No sequences in empty bin
'seq_starts': [], # No sequence starts
}
)
# Concatenate all empty bins
if num_empty_bins > 0:
empty_trajs = torch.cat(empty_bins, dim=0)
empty_position_ids = torch.cat(empty_position_ids_list, dim=0)
empty_loss_mask = torch.cat(empty_loss_mask_list, dim=0)
empty_attention_mask = (
torch.cat(empty_attention_mask_list, dim=0)
if packed_attention_mask is not None
else None
)
else:
empty_trajs = None
empty_position_ids = None
empty_loss_mask = None
empty_attention_mask = None
return (
empty_trajs,
empty_position_ids,
empty_loss_mask,
empty_attention_mask,
empty_packing_info_entries,
)
def get_default_packed_seq_params(seq_length: int, max_sequences_per_bin: int, device: torch.device) -> PackedSeqParams:
"""Create a default PackedSeqParams that acts as no-op for a single sequence.
This ensures CUDA graph signature consistency when packed_seq_params
would otherwise be None. A single sequence spanning the full length
means no actual packing boundaries
Args:
seq_length: The sequence length
max_sequences_per_bin: Max sequences to pack in a bin.
device: Device to create tensors on.
Returns:
PackedSeqParams configured as a single unpacked sequence.
"""
args = get_args()
# Pad to the maximum number of sequences in the bin for the attention kernel.
# We add 2 to account for the initial 0 and the final bin_size.
cu_seqlens = torch.full(
(max_sequences_per_bin + 2,), seq_length, dtype=torch.int32, device=device,
)
cu_seqlens[0] = 0
return PackedSeqParams(
qkv_format='thd',
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=None,
cu_seqlens_kv_padded=None,
max_seqlen_q=seq_length,
max_seqlen_kv=seq_length,
total_tokens=seq_length,
)
def create_packed_seq_params(packing_context: PackingContext):
cached_packed_seq_params = []
packing_info = packing_context.packing_info
bin_size = packing_context.bin_size
max_sequences_per_bin = packing_context.packer.max_sequences_per_bin
device = packing_context.packed_trajs.device
for bin_idx in range(len(packing_context.packed_trajs)):
params = create_packed_seq_params_for_bin(
packing_info=packing_info,
bin_idx=bin_idx,
bin_size=bin_size,
max_sequences_per_bin=max_sequences_per_bin,
device=device,
)
cached_packed_seq_params.append(params)
return cached_packed_seq_params
def create_packed_seq_params_for_bin(
packing_info: PackingInfo,
bin_idx: int,
bin_size: int,
max_sequences_per_bin: int,
device: torch.device
) -> Optional[PackedSeqParams]:
"""Create PackedSeqParams for a single bin to enable proper attention masking in TE.
When using Transformer Engine with sequence packing, we need to provide cu_seqlens
(cumulative sequence lengths) so that TE knows the boundaries between sequences
within a packed bin. This prevents attention leakage between unrelated sequences.
Args:
packing_info: PackingInfo object containing packing metadata from SequencePacker
bin_idx: Index of the bin to create params for
bin_size: Size of the bin (padded sequence length)
max_sequences_per_bin: Maximum number of sequences per bin
device: Device to create tensors on
Returns:
PackedSeqParams with cu_seqlens set for proper attention masking (or None if empty)
"""
seq_indices = packing_info.bin_seq_indices[bin_idx]
# Handle empty bins (padding bins with no sequences)
if not seq_indices:
return None
# Get actual sequence lengths for sequences in this bin
seq_lengths_in_bin = [packing_info.seq_lengths[idx] for idx in seq_indices]
# Build cumulative sequence lengths for actual sequences
# cu_seqlens should be [0, len(seq1), len(seq1)+len(seq2), ..., total_actual_len]
cu_seqlens_list = np.append(np.cumsum([0] + seq_lengths_in_bin), bin_size)
cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device)
# Pad cu_seqlens to bin_size by repeating the last value (creates zero-length ghost sequences)
# This ensures a fixed tensor size for CUDA graph compatibility
# We add 2 to account for the initial 0 and the final bin_size.
if len(cu_seqlens) < max_sequences_per_bin + 2:
out = cu_seqlens.new_full((max_sequences_per_bin + 2,), bin_size)
out[:len(cu_seqlens)] = cu_seqlens
cu_seqlens = out
max_seqlen = bin_size
return PackedSeqParams(
qkv_format='thd',
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=None,
cu_seqlens_kv_padded=None,
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
total_tokens=bin_size,
)
def pack_inference_logprobs(
inference_logprobs: List[torch.Tensor],
packing_info: PackingInfo,
generation_masks: torch.Tensor,
bin_size: int,
) -> torch.Tensor:
"""Pack inference logprobs into bins aligned with packed sequences.
Args:
inference_logprobs: List of inference logprobs tensors for each sequence
packing_info: PackingInfo object containing bin assignments and sequence positions
generation_masks: Tensor indicating which tokens were generated
bin_size: Size of each bin
Returns:
Packed inference logprobs tensor of shape [num_bins, bin_size - 1]
"""
num_bins = len(packing_info.bin_seq_indices)
# Create packed inference logprobs tensor (logprobs are 1 token shorter than sequences)
packed_inference_logprobs = torch.zeros(
(num_bins, bin_size - 1), dtype=torch.float32, device='cpu'
)
# Create mapping from global sequence index to local bin index
# This is needed because seq_to_bin_idx uses global bin indices,
# but after distribution each rank only has a subset of bins
seq_to_local_bin = {}
for local_bin_idx, seq_indices in enumerate(packing_info.bin_seq_indices):
for seq_idx in seq_indices:
seq_to_local_bin[seq_idx] = local_bin_idx
# Align and pack inference logprobs based on generation masks
for seq_idx in range(len(inference_logprobs)):
if seq_idx not in seq_to_local_bin:
continue # Skip sequences not on this rank
local_bin_idx = seq_to_local_bin[seq_idx]
# Get the position of this sequence within the bin
seq_positions = packing_info.bin_seq_indices[local_bin_idx]
seq_pos_in_bin = seq_positions.index(seq_idx)
seq_start = packing_info.seq_starts[local_bin_idx][seq_pos_in_bin]
# Get generation mask for this sequence to find where generation starts
gen_mask = generation_masks[seq_idx]
# Find first generation token (accounting for the shift in get_logprobs)
first_gen_idx = gen_mask.int().argmax().item() - 1
# Get the inference logprobs for this sequence
if isinstance(inference_logprobs[seq_idx], torch.Tensor):
seq_inf_logprobs = inference_logprobs[seq_idx]
else:
continue # Skip if no inference logprobs
# Calculate where to place inference logprobs in the packed tensor
# The inference logprobs start at the first generated token position
pack_start = seq_start + first_gen_idx
pack_end = min(
pack_start + len(seq_inf_logprobs), seq_start + packing_info.seq_lengths[seq_idx] - 1
)
actual_len = pack_end - pack_start
if actual_len > 0 and pack_end <= bin_size - 1:
packed_inference_logprobs[local_bin_idx, pack_start:pack_end] = seq_inf_logprobs[
:actual_len
]
return packed_inference_logprobs
def compute_packed_inference_logprobs_stats(
old_logprobs: torch.Tensor,
packed_inference_logprobs: torch.Tensor,
packed_loss_mask: torch.Tensor,
group_stats: Any,
) -> None:
"""Compute statistics for packed inference logprobs for logging purposes.
Compares packed inference logprobs with old logprobs using the packed loss mask
to identify valid positions. Updates group_stats with computed metrics.
Args:
old_logprobs: Old logprobs tensor in packed format [num_bins, seq_len-1]
packed_inference_logprobs: Packed inference logprobs [num_bins, seq_len-1]
packed_loss_mask: Loss mask indicating valid positions [num_bins, seq_len]
group_stats: Statistics object to update with computed metrics
"""
# Lazy import to avoid circular dependency (rl_utils imports from this module)
from megatron.rl.rl_utils import update_inference_logprobs_group_stats
# Ensure all tensors are on the same device (CPU for stats computation)
old_logprobs = old_logprobs.cpu()
packed_inference_logprobs = packed_inference_logprobs.cpu()
packed_loss_mask = packed_loss_mask.cpu()
# Use packed_loss_mask to identify valid positions for stats (shift by 1 for logprobs)
mask = packed_loss_mask[:, 1:].bool()
# Ensure shapes match
if mask.shape != old_logprobs.shape:
return
# Update group statistics using common helper
update_inference_logprobs_group_stats(
old_logprobs=old_logprobs,
inference_logprobs=packed_inference_logprobs,
mask=mask,
group_stats=group_stats,
)
class SequencePacker:
"""Packs multiple sequences into bins to minimize padding and improve GPU utilization."""
def __init__(self, bin_size: int, pad_token: int, max_sequences_per_bin: int = 16):
self.bin_size = bin_size
self.pad_token = pad_token
self.max_sequences_per_bin = max_sequences_per_bin
def pack_sequences(
self, trajs: torch.Tensor, generation_masks: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
"""Pack sequences into bins using a greedy first-fit algorithm."""
# Convert trajectories to list for packing
sequences = [trajs[i] for i in range(trajs.shape[0])]
sequences_tensor = torch.stack(sequences)
seq_lengths = get_actual_sequence_lengths(sequences_tensor, self.pad_token)
# Trim sequences to actual lengths
sequences = [sequences_tensor[i, :length] for i, length in enumerate(seq_lengths)]
sorted_indices = sorted(range(len(sequences)), key=lambda i: seq_lengths[i], reverse=True)
bins = []
bin_seq_indices = [] # Track which sequences are in each bin
current_bin = []
current_bin_indices = []
current_bin_length = 0
# Pack sequences into bins
sequences_per_bin = []
for idx in sorted_indices:
seq = sequences[idx]
seq_len = len(seq)
if (
current_bin_length + seq_len <= self.bin_size
and len(current_bin) < self.max_sequences_per_bin
):
current_bin.append(seq)
current_bin_indices.append(idx)
current_bin_length += seq_len
else:
# Start a new bin
if current_bin:
bins.append(current_bin)
bin_seq_indices.append(current_bin_indices)
sequences_per_bin.append(len(current_bin))
current_bin = [seq]
current_bin_indices = [idx]
current_bin_length = seq_len
# Don't forget the last bin
if current_bin:
bins.append(current_bin)
bin_seq_indices.append(current_bin_indices)
sequences_per_bin.append(len(current_bin))
# Create packed tensors
num_bins = len(bins)
device = sequences[0].device
dtype = sequences[0].dtype
# Log packing distribution
if sequences_per_bin:
avg_seqs_per_bin = sum(sequences_per_bin) / len(sequences_per_bin)
min_seqs = min(sequences_per_bin)
max_seqs = max(sequences_per_bin)
log_single_rank(
logger,
logging.INFO,
(
f"[SequencePacker] Packing distribution: {num_bins} bins, "
f"avg {avg_seqs_per_bin:.1f} seqs/bin, "
f"min {min_seqs}, max {max_seqs} seqs/bin "
f"(limit: {self.max_sequences_per_bin})"
),
)
# Store for later use
self.last_avg_seqs_per_bin = avg_seqs_per_bin
packed_sequences = torch.full(
(num_bins, self.bin_size), self.pad_token, dtype=dtype, device=device
)
position_ids = torch.zeros(
(num_bins, self.bin_size), dtype=torch.long, device=device, requires_grad=False
)
attention_mask = torch.zeros(
(num_bins, 1, self.bin_size, self.bin_size), dtype=torch.bool, device=device
)
loss_mask = torch.zeros((num_bins, self.bin_size), dtype=torch.float, device=device)
# Track packing information for unpacking later
seq_starts_dict: Dict[int, List[int]] = {}
seq_to_bin_idx: List[Optional[int]] = [None] * len(sequences)
# Build seq_to_bin_idx mapping
for bin_idx, seq_indices in enumerate(bin_seq_indices):
for seq_idx in seq_indices:
seq_to_bin_idx[seq_idx] = bin_idx
# Fill bins
for bin_idx, (bin_seqs, seq_indices) in enumerate(zip(bins, bin_seq_indices)):
seq_starts = []
current_pos = 0
for seq_idx, seq in enumerate(bin_seqs):
start = current_pos
end = start + len(seq)
seq_starts.append(start)
current_pos = end
# Pack sequence
packed_sequences[bin_idx, start:end] = seq
# Position IDs reset for each sequence
position_ids[bin_idx, start:end] = torch.arange(
len(seq), device=device, requires_grad=False
)
# Causal attention mask within each sequence
seq_len = end - start
attention_mask[bin_idx, 0, start:end, start:end] = torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
)
# Loss mask (excluding padding)
loss_mask[bin_idx, start:end] = 1.0
# Apply generation mask if provided
if generation_masks is not None:
orig_idx = seq_indices[seq_idx]
gen_mask = generation_masks[orig_idx][
: len(seq)
] # Truncate to actual seq length
loss_mask[bin_idx, start:end] *= gen_mask.float()
seq_starts.append(current_pos)
seq_starts_dict[bin_idx] = seq_starts
# Note: We'll store the actual padded length later when we know it
# (it depends on the original trajectories passed to pack_sequences)
# Invert attention mask, before inversion: (True = attend, False = mask)
attention_mask.bitwise_not_()
# Create the PackingInfo dataclass
packing_info = PackingInfo(
bin_seq_indices=bin_seq_indices,
seq_starts=seq_starts_dict,
seq_lengths=seq_lengths,
seq_to_bin_idx=seq_to_bin_idx,
packing_algo='fifo'
)
seq_per_bin = [len(indices) for indices in packing_info.bin_seq_indices]
log_single_rank(
logger, logging.DEBUG, ("Initial packing output (before distribution):")
)
log_single_rank(
logger,
logging.DEBUG,
f" - Total bins created: {len(packing_info.bin_seq_indices)}",
)
log_single_rank(
logger, logging.DEBUG, f" - Total sequences packed: {sum(seq_per_bin)}"
)
log_single_rank(
logger,
logging.DEBUG,
f" - Sequences per bin: min={min(seq_per_bin)}, max={max(seq_per_bin)}, avg={sum(seq_per_bin)/len(seq_per_bin):.1f}",
)
log_single_rank(logger, logging.DEBUG, f" - First 20 bins: {seq_per_bin[:20]}")
return packed_sequences, position_ids, attention_mask, loss_mask, packing_info
def distribute_packed_bins(
packed_trajs: torch.Tensor,
packed_position_ids: torch.Tensor,
packed_attention_mask: torch.Tensor,
packed_loss_mask: torch.Tensor,
packing_info: PackingInfo,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
"""Distribute packed bins across the data parallel ranks."""
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
tokenizer = get_tokenizer()
# Distribute packed bins across data parallel ranks
num_bins, bin_size = packed_trajs.shape
packing_algo = packing_info.packing_algo
if packing_algo == 'round-robin':
# Round-robin assignment: rank i gets bins [i, i+world_size, i+2*world_size, ...]
my_bin_indices = list(range(rank, num_bins, world_size))
else: # fifo (default)
world_size = world_size if world_size > 0 else 1
# FIFO assignment: divide bins sequentially across ranks
bins_per_rank = num_bins // world_size
extra_bins = num_bins % world_size
# Calculate start and end indices for this rank
if rank < extra_bins:
# Ranks with extra bins
start_idx = rank * (bins_per_rank + 1)
end_idx = start_idx + bins_per_rank + 1
else:
# Ranks without extra bins
start_idx = rank * bins_per_rank + extra_bins
end_idx = start_idx + bins_per_rank
my_bin_indices = list(range(start_idx, end_idx))
# Calculate the maximum bins any rank has (for synchronization)
max_bins_per_rank = (num_bins + world_size - 1) // world_size
# Extract this rank's bins
my_packed_trajs = []
my_packed_position_ids = []
my_packed_attention_mask = []
my_packed_loss_mask = []
my_bin_seq_indices = []
my_seq_starts = {}
# Build the local data from the global indices
for new_idx, old_idx in enumerate(my_bin_indices):
my_packed_trajs.append(packed_trajs[old_idx])
my_packed_position_ids.append(packed_position_ids[old_idx])
if packed_attention_mask is not None:
my_packed_attention_mask.append(packed_attention_mask[old_idx])
my_packed_loss_mask.append(packed_loss_mask[old_idx])
my_bin_seq_indices.append(packing_info.bin_seq_indices[old_idx])
my_seq_starts[new_idx] = packing_info.seq_starts[old_idx]
# Stack the selected bins
packed_trajs = (
torch.stack(my_packed_trajs)
if my_packed_trajs
else torch.empty(
0,
packed_trajs.shape[1],
dtype=packed_trajs.dtype,
device=packed_trajs.device,
)
)
packed_position_ids = (
torch.stack(my_packed_position_ids)
if my_packed_position_ids
else torch.empty(
0,
packed_position_ids.shape[1],
dtype=packed_position_ids.dtype,
device=packed_position_ids.device,
)
)
packed_attention_mask = (
torch.stack(my_packed_attention_mask) if my_packed_attention_mask else None
)
packed_loss_mask = (
torch.stack(my_packed_loss_mask)
if my_packed_loss_mask
else torch.empty(
0,
packed_loss_mask.shape[1],
dtype=packed_loss_mask.dtype,
device=packed_loss_mask.device,
)
)
# Debug: Check what we're extracting
log_single_rank(logger, logging.DEBUG, (f"Rank 0 {packing_algo} bin assignment:"))
log_single_rank(
logger, logging.DEBUG, f" - Total bins before distribution: {num_bins}"
)
log_single_rank(
logger,
logging.DEBUG,
f" - Bins assigned to rank 0: {my_bin_indices[:10]}... (showing first 10)",
)
log_single_rank(
logger,
logging.DEBUG,
f" - Number of bins for this rank: {len(my_bin_indices)}",
)
log_single_rank(
logger,
logging.DEBUG,
f" - Length of my_bin_seq_indices: {len(my_bin_seq_indices)}",
)
if len(my_bin_seq_indices) > 0:
log_single_rank(
logger,
logging.DEBUG,
f" - Sequences in first 5 bins: {[len(indices) for indices in my_bin_seq_indices[:5]]}",
)
# Create updated packing info for this rank
new_packing_info = PackingInfo(
bin_seq_indices=my_bin_seq_indices,
seq_starts=my_seq_starts,
seq_lengths=packing_info.seq_lengths, # Keep all sequence lengths
seq_to_bin_idx=packing_info.seq_to_bin_idx, # Keep mapping
packing_algo=packing_algo,
)
# Add empty bins if this rank has fewer than max_bins_per_rank
current_bins = len(my_bin_indices)
if current_bins < max_bins_per_rank:
num_empty_bins = max_bins_per_rank - current_bins
# Create empty bins using the helper function
(
empty_trajs,
empty_position_ids,
empty_loss_mask,
empty_attention_mask,
empty_packing_entries,
) = create_empty_bins(
num_empty_bins,
bin_size,
packed_trajs,
packed_position_ids,
packed_loss_mask,
packed_attention_mask,
tokenizer,
)
# Append empty bins to packed tensors
packed_trajs = torch.cat([packed_trajs, empty_trajs], dim=0)
packed_position_ids = torch.cat(
[packed_position_ids, empty_position_ids], dim=0
)
packed_loss_mask = torch.cat([packed_loss_mask, empty_loss_mask], dim=0)
if packed_attention_mask is not None and empty_attention_mask is not None:
packed_attention_mask = torch.cat(
[packed_attention_mask, empty_attention_mask], dim=0
)
# Add empty entries to packing_info
for i, entry in enumerate(empty_packing_entries):
bin_idx = current_bins + i
new_packing_info.bin_seq_indices.append(entry['bin_seq_indices'])
new_packing_info.seq_starts[bin_idx] = entry['seq_starts']
return packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, new_packing_info
def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_advantages, bin_size, max_sequences_per_bin, packing_algo):
tokenizer = get_tokenizer()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
nvtx_range = get_nvtx_range()
with nvtx_range("rl/regather-trajectories", time=True):
def _gather(data):
data = data.cuda()
data_list = [torch.empty_like(data) for _ in range(data_parallel_world_size)]
torch.distributed.all_gather(data_list, data, group=data_parallel_group)
return torch.cat(data_list, dim=0)
trajs = _gather(trajs)
generation_masks = _gather(generation_masks)
if inference_logprobs is not None:
inference_logprobs = _gather(inference_logprobs)
with nvtx_range("rl/pack-sequences", time=True):
# Create packer with max sequences per bin limit to prevent extreme imbalance
packer = SequencePacker(
bin_size=bin_size,
pad_token=tokenizer.pad,
max_sequences_per_bin=max_sequences_per_bin,
)
# Pack sequences with generation masks
(