@@ -119,10 +119,16 @@ def get_batch(data_iterator, vp_stage=None):
119119 # Step 1b: merge sequences that are too short for CP
120120 _divisibility = 2 * cp_size
121121 _seq_lens = cu_seq [1 :] - cu_seq [:- 1 ]
122+
123+ _keep = _seq_lens >= _divisibility
124+
125+ # Expand to match cu_seq size and force first/last to stay
122126 _keep = torch .cat ([
123- torch .tensor ([True ], device = device ),
124- _seq_lens >= _divisibility ,
127+ torch .tensor ([True ], device = device ), # always keep first
128+ _keep
125129 ])
130+ _keep [- 1 ] = True # always keep last
131+
126132 cu_seq = cu_seq [_keep ]
127133
128134 if cp_size > 1 :
@@ -147,6 +153,7 @@ def get_batch(data_iterator, vp_stage=None):
147153 padding_token_id = tokenizer .eod ,
148154 padding_label_id = - 100 ,
149155 )
156+
150157 input_ids_padded = input_ids_padded .to (device )
151158 labels_padded = labels_padded .to (device )
152159 cu_seqlens_padded = cu_seqlens_padded .to (device = device , dtype = torch .int32 )
@@ -499,6 +506,14 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa
499506 current_modality_weights = current_modality_weights ,
500507 )
501508 else :
509+
510+ import remote_pdb
511+ import torch .distributed as dist
512+
513+ rank_t = dist .get_rank ()
514+ if rank_t == 0 :
515+ remote_pdb .set_trace (host = "0.0.0.0" , port = 1234 )
516+
502517 output_tensor = model (
503518 tokens , position_ids , attention_mask , labels = labels , loss_mask = loss_mask ,
504519 packed_seq_params = packed_seq_params
@@ -592,7 +607,9 @@ def core_gpt_dataset_config_from_args(args):
592607 "sft_pack_samples" : args .ap_sft_pack_samples ,
593608 "sft_packing_strategy" : args .ap_sft_packing_strategy ,
594609 "sft_equalize_sample_loss" : args .ap_sft_equalize_sample_loss ,
595- "sft_truncate_right" : args .ap_sft_truncate_right
610+ "sft_truncate_right" : args .ap_sft_truncate_right ,
611+ "pretraining_packing_strategy" : args .pretraining_packing_strategy ,
612+ "max_docs_per_bin" : args .max_docs_per_bin ,
596613 }
597614
598615 # add FIM args to the config
0 commit comments