Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs
from megatron.bridge.data.hf_processors.squad import process_squad_example
from megatron.bridge.recipes.llama.llama3 import llama32_1b_sft_config
from megatron.bridge.recipes.llama.llama3 import llama32_1b_pretrain_config, llama32_1b_sft_config
from megatron.bridge.training.finetune import finetune
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from tests.functional_tests.utils import (
broadcast_path,
clear_directories,
Expand All @@ -43,14 +44,36 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
pytest.skip("requires >=2 GPUs for context_parallel_size=2")

shared_dir = broadcast_path(tmp_path)
checkpoint_dir = os.path.join(shared_dir, "checkpoints")
tensorboard_dir = os.path.join(shared_dir, "tensorboard")
pretrain_checkpoint_dir = os.path.join(shared_dir, "pretrain_checkpoints")
pretrain_tensorboard_dir = os.path.join(shared_dir, "pretrain_tensorboard")
sft_checkpoint_dir = os.path.join(shared_dir, "sft_checkpoints")
sft_tensorboard_dir = os.path.join(shared_dir, "sft_tensorboard")

if torch.distributed.get_rank() == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(tensorboard_dir, exist_ok=True)
os.makedirs(pretrain_checkpoint_dir, exist_ok=True)
os.makedirs(pretrain_tensorboard_dir, exist_ok=True)
os.makedirs(sft_checkpoint_dir, exist_ok=True)
os.makedirs(sft_tensorboard_dir, exist_ok=True)
torch.distributed.barrier()

pretrain_cfg = llama32_1b_pretrain_config()
pretrain_cfg.model.tensor_model_parallel_size = 1
pretrain_cfg.model.pipeline_model_parallel_size = 1
pretrain_cfg.model.context_parallel_size = 2
pretrain_cfg.model.seq_length = 256
pretrain_cfg.dataset.seq_length = 256
pretrain_cfg.train.train_iters = 1
pretrain_cfg.train.global_batch_size = 2
pretrain_cfg.train.micro_batch_size = 1
pretrain_cfg.validation.eval_interval = 1
pretrain_cfg.validation.eval_iters = 0
pretrain_cfg.scheduler.lr_warmup_iters = 0
pretrain_cfg.logger.log_interval = 1
pretrain_cfg.logger.tensorboard_dir = pretrain_tensorboard_dir
pretrain_cfg.checkpoint.save_interval = pretrain_cfg.train.train_iters
pretrain_cfg.checkpoint.save = pretrain_checkpoint_dir
pretrain_cfg.checkpoint.load = None

cfg = llama32_1b_sft_config()
cfg.tokenizer.tokenizer_type = "HuggingFaceTokenizer"
cfg.tokenizer.tokenizer_model = "meta-llama/Llama-3.2-1B"
Expand All @@ -70,7 +93,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
cfg.validation.eval_iters = 0
cfg.scheduler.lr_warmup_iters = 0
cfg.logger.log_interval = 1
cfg.logger.tensorboard_dir = tensorboard_dir
cfg.logger.tensorboard_dir = sft_tensorboard_dir

# Use a small packed SQuAD dataset to exercise THD/context-parallel slicing
cfg.dataset = HFDatasetConfig(
Expand All @@ -94,13 +117,22 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):

cfg.model.seq_length = 256
cfg.checkpoint.save_interval = cfg.train.train_iters
cfg.checkpoint.save = checkpoint_dir
cfg.checkpoint.pretrained_checkpoint = None
cfg.checkpoint.save = sft_checkpoint_dir
cfg.checkpoint.load = None
cfg.checkpoint.pretrained_checkpoint = pretrain_checkpoint_dir

try:
pretrain(pretrain_cfg, forward_step)
verify_checkpoint_files(
pretrain_checkpoint_dir,
pretrain_cfg.train.train_iters,
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)

finetune(cfg, forward_step)
verify_checkpoint_files(
checkpoint_dir,
sft_checkpoint_dir,
cfg.train.train_iters,
ckpt_format=cfg.checkpoint.ckpt_format,
storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank,
Expand Down
Loading