Skip to content

Commit 7e03fe1

Browse files
committed
remove warmup callback
1 parent f17fb23 commit 7e03fe1

File tree

2 files changed

+0
-181
lines changed

2 files changed

+0
-181
lines changed

moe_pretraining/nemo/callback.py

Lines changed: 0 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import time
1818
from functools import wraps
19-
from pprint import pprint
2019
from typing import Any, Callable, List, Optional, Protocol, Union
2120

2221
import torch
@@ -40,10 +39,6 @@
4039
logger = logging.getLogger(__name__)
4140

4241

43-
# =============================================================================
44-
# Utility Functions
45-
# =============================================================================
46-
4742
def get_last_pp_rank():
4843
"""Check if current rank is the last pipeline parallel rank."""
4944
is_last_pp = mpu.is_pipeline_last_stage(ignore_virtual=True)
@@ -69,22 +64,9 @@ def broadcast_loss(loss_reduced):
6964
return loss_synced.item()
7065

7166

72-
def get_rank():
73-
"""Get current process rank for warmup callback."""
74-
return int(os.getenv("SLURM_PROCID", 0))
75-
76-
77-
# =============================================================================
78-
# MLPerf Logger
79-
# =============================================================================
80-
8167
mllogger = MLLoggerWrapper(PyTCommunicationHandler())
8268

8369

84-
# =============================================================================
85-
# Timer Utility
86-
# =============================================================================
87-
8870
class DeltaTimer:
8971
"""Timer for measuring time deltas."""
9072

@@ -378,10 +360,6 @@ def install_callbacks() -> None:
378360
pretrain_module.train = train_module.train
379361

380362

381-
# =============================================================================
382-
# MLPerf Logging Callback
383-
# =============================================================================
384-
385363
class MLPerfLoggingCallback:
386364
"""MLPerf logging callback for compliance logging."""
387365

@@ -575,10 +553,6 @@ def _get_samples_count(self, global_state):
575553
return self._get_step(global_state) * self.global_batch_size
576554

577555

578-
# =============================================================================
579-
# Delta Timing Callback
580-
# =============================================================================
581-
582556
class DeltaTimingCallback:
583557
"""Callback for tracking training step timing."""
584558

@@ -633,157 +607,3 @@ def on_validation_end(
633607
):
634608
"""Reset timer after validation to avoid including validation time in first train step."""
635609
self.t0 = time.time()
636-
637-
638-
# =============================================================================
639-
# Warmup Callback
640-
# =============================================================================
641-
642-
def get_mock_data(config):
643-
"""Get mock data configuration for warmup."""
644-
from megatron.bridge.training.config import MockGPTDatasetConfig
645-
646-
return MockGPTDatasetConfig(
647-
sequence_length=config.model.encoder_seq_length,
648-
random_seed=config.model.seed,
649-
dataloader_type="single",
650-
num_workers=config.model.data.num_workers,
651-
reset_position_ids=False,
652-
reset_attention_mask=False,
653-
eod_mask_loss=False,
654-
path_to_cache=None,
655-
split="900,50,50",
656-
)
657-
658-
659-
class WarmupCallback:
660-
"""Callback for performing training and validation warmup."""
661-
662-
def __init__(self, cfg):
663-
self.cfg = cfg
664-
self.train_steps = cfg.model.custom.warmup_train_steps
665-
self.val_steps = cfg.model.custom.warmup_validation_steps
666-
667-
def on_train_start(
668-
self,
669-
global_state,
670-
forward_step_func,
671-
model,
672-
optimizer,
673-
scheduler,
674-
):
675-
if get_rank() == 0:
676-
print("\nMCore config:", flush=True)
677-
pprint(model[0].config)
678-
679-
torch.distributed.barrier()
680-
enable_cuda_graph = int(os.getenv("MCORE_CUDA_GRAPH", "0")) == 1
681-
cuda_graph_scope = self.cfg.model.overwritten_attributes.cuda_graph_scope
682-
683-
forward_backward_func = get_forward_backward_func()
684-
if enable_cuda_graph and cuda_graph_scope == "full_iteration":
685-
forward_backward_func = FullCudaGraphWrapper(
686-
forward_backward_func,
687-
cuda_graph_warmup_steps=1,
688-
)
689-
690-
skip_val_warmup = enable_cuda_graph and cuda_graph_scope != "full_iteration"
691-
692-
warmup_mock_cfg = get_mock_data(self.cfg)
693-
694-
train_dataloader, val_dataloader, _ = warmup_mock_cfg.build_dataloaders()
695-
data_iterator = iter(train_dataloader)
696-
eval_data_iterator = iter(val_dataloader)
697-
698-
pp_group = mpu.get_pipeline_model_parallel_group()
699-
torch.distributed.barrier(pp_group)
700-
701-
for group in optimizer.param_groups:
702-
group["betas_"] = group["betas"]
703-
group["bias_correction_"] = group["bias_correction"]
704-
group["betas"] = [1.0, 1.0]
705-
group["bias_correction"] = False
706-
707-
if torch.distributed.get_rank() == 0:
708-
logger.info("Starting training warmup")
709-
start = time.time()
710-
for step_idx in range(self.train_steps):
711-
if torch.distributed.get_rank() == 0:
712-
logger.info(f" Starting warmup step {step_idx}")
713-
step_timer = time.time()
714-
torch.cuda.synchronize()
715-
torch.distributed.barrier()
716-
forward_backward_func(
717-
forward_step_func=forward_step_func,
718-
data_iterator=data_iterator,
719-
model=model,
720-
num_microbatches=get_num_microbatches(),
721-
seq_length=self.cfg.model.encoder_seq_length,
722-
micro_batch_size=self.cfg.model.micro_batch_size,
723-
decoder_seq_length=self.cfg.model.encoder_seq_length,
724-
forward_only=False,
725-
)
726-
optimizer.zero_grad()
727-
optimizer.step()
728-
optimizer.zero_grad()
729-
torch.cuda.synchronize()
730-
731-
for chunk in model:
732-
chunk.module.zero_grad_buffer()
733-
chunk.module.zero_grad()
734-
735-
if torch.distributed.get_rank() == 0:
736-
logger.info(f" Finished warmup step {step_idx}, takes {time.time() - step_timer} s")
737-
738-
torch.cuda.synchronize()
739-
torch.distributed.barrier()
740-
if torch.distributed.get_rank() == 0:
741-
logger.info(f"Finished training warmup: {time.time() - start} s. ")
742-
743-
for group in optimizer.param_groups:
744-
group["betas"] = group["betas_"]
745-
group["bias_correction"] = group["bias_correction_"]
746-
del group["betas_"]
747-
del group["bias_correction_"]
748-
if "step" in group:
749-
if isinstance(group["step"], torch.Tensor):
750-
group["step"].fill_(1)
751-
else:
752-
group["step"] = 1
753-
754-
if not skip_val_warmup:
755-
start = time.time()
756-
if torch.distributed.get_rank() == 0:
757-
logger.info("Starting validation warmups")
758-
for model_module in model:
759-
model_module.eval()
760-
with torch.no_grad():
761-
for _ in range(self.val_steps):
762-
torch.cuda.synchronize()
763-
torch.distributed.barrier()
764-
765-
forward_backward_func(
766-
forward_step_func=forward_step_func,
767-
data_iterator=eval_data_iterator,
768-
model=model,
769-
num_microbatches=get_num_microbatches(),
770-
seq_length=self.cfg.model.encoder_seq_length,
771-
micro_batch_size=self.cfg.model.micro_batch_size,
772-
forward_only=True,
773-
)
774-
torch.cuda.synchronize()
775-
776-
torch.distributed.barrier()
777-
if torch.distributed.get_rank() == 0:
778-
logger.info(f"Finished validation warmup: {time.time() - start} s. ")
779-
780-
for chunk in model:
781-
chunk.module.zero_grad_buffer()
782-
chunk.module.zero_grad()
783-
if torch.distributed.get_rank() == 0:
784-
logger.info(f"Finished training warmup: {time.time() - start} s. ")
785-
786-
torch.cuda.synchronize()
787-
torch.distributed.barrier()
788-
if torch.distributed.get_rank() == 0:
789-
logger.info(f"Time spent in run_training_warmup: {time.time() - start}s")

moe_pretraining/nemo/pretrain_deepseek_v3_671b.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from callback import (
2828
MLPerfLoggingCallback,
2929
DeltaTimingCallback,
30-
WarmupCallback,
3130
mllogger,
3231
install_callbacks,
3332
register_callback,

0 commit comments

Comments
 (0)