Skip to content

Commit b0981d8

Browse files
committed
Fix: Add Final Argument
1 parent 3007afb commit b0981d8

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

pretrain_gpt.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,16 @@ def get_batch(data_iterator, vp_stage=None):
119119
# Step 1b: merge sequences that are too short for CP
120120
_divisibility = 2 * cp_size
121121
_seq_lens = cu_seq[1:] - cu_seq[:-1]
122+
123+
_keep = _seq_lens >= _divisibility
124+
125+
# Expand to match cu_seq size and force first/last to stay
122126
_keep = torch.cat([
123-
torch.tensor([True], device=device),
124-
_seq_lens >= _divisibility,
127+
torch.tensor([True], device=device), # always keep first
128+
_keep
125129
])
130+
_keep[-1] = True # always keep last
131+
126132
cu_seq = cu_seq[_keep]
127133

128134
if cp_size > 1:
@@ -147,6 +153,7 @@ def get_batch(data_iterator, vp_stage=None):
147153
padding_token_id=tokenizer.eod,
148154
padding_label_id=-100,
149155
)
156+
150157
input_ids_padded = input_ids_padded.to(device)
151158
labels_padded = labels_padded.to(device)
152159
cu_seqlens_padded = cu_seqlens_padded.to(device=device, dtype=torch.int32)
@@ -499,6 +506,14 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa
499506
current_modality_weights=current_modality_weights,
500507
)
501508
else:
509+
510+
import remote_pdb
511+
import torch.distributed as dist
512+
513+
rank_t = dist.get_rank()
514+
if rank_t == 0:
515+
remote_pdb.set_trace(host = "0.0.0.0", port = 1234)
516+
502517
output_tensor = model(
503518
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask,
504519
packed_seq_params=packed_seq_params
@@ -592,7 +607,9 @@ def core_gpt_dataset_config_from_args(args):
592607
"sft_pack_samples": args.ap_sft_pack_samples,
593608
"sft_packing_strategy": args.ap_sft_packing_strategy,
594609
"sft_equalize_sample_loss": args.ap_sft_equalize_sample_loss,
595-
"sft_truncate_right": args.ap_sft_truncate_right
610+
"sft_truncate_right": args.ap_sft_truncate_right,
611+
"pretraining_packing_strategy": args.pretraining_packing_strategy,
612+
"max_docs_per_bin": args.max_docs_per_bin,
596613
}
597614

598615
# add FIM args to the config

0 commit comments

Comments
 (0)