diff --git a/tests/functional_tests/test_groups/training/test_seqpacking_cp_example.py b/tests/functional_tests/test_groups/training/test_seqpacking_cp_example.py index adf5294cd6..5454d492a8 100644 --- a/tests/functional_tests/test_groups/training/test_seqpacking_cp_example.py +++ b/tests/functional_tests/test_groups/training/test_seqpacking_cp_example.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import pytest @@ -32,6 +33,28 @@ ) +def _set_existing_attr(target: object, name: str, value: object) -> None: + if not hasattr(target, name): + raise ValueError(f"{type(target).__name__} has no field {name!r}") + setattr(target, name, value) + + +def _make_functional_test_model_small(model: object) -> None: + # Keep this checkpoint-loading functional test far below runner memory limits. + # The path under test is CP + sequence packing + pretrained checkpoint loading, + # not the full Llama 3.2 1B model shape. + for name, value in { + "num_layers": 2, + "hidden_size": 256, + "ffn_hidden_size": 1024, + "num_attention_heads": 4, + "num_query_groups": 4, + "kv_channels": 64, + "seq_length": 256, + }.items(): + _set_existing_attr(model, name, value) + + class TestPeftSftExample: """Run the PEFT SFT example as a functional test with packed sequences + CP.""" @@ -57,10 +80,10 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path): torch.distributed.barrier() pretrain_cfg = llama32_1b_pretrain_config() + _make_functional_test_model_small(pretrain_cfg.model) pretrain_cfg.model.tensor_model_parallel_size = 1 pretrain_cfg.model.pipeline_model_parallel_size = 1 pretrain_cfg.model.context_parallel_size = 2 - pretrain_cfg.model.seq_length = 256 pretrain_cfg.dataset.seq_length = 256 pretrain_cfg.train.train_iters = 1 pretrain_cfg.train.global_batch_size = 2 @@ -75,6 +98,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path): pretrain_cfg.checkpoint.load = None cfg = llama32_1b_sft_config() + _make_functional_test_model_small(cfg.model) cfg.tokenizer.tokenizer_type = "HuggingFaceTokenizer" cfg.tokenizer.tokenizer_model = "meta-llama/Llama-3.2-1B" cfg.model.calculate_per_token_loss = True @@ -115,7 +139,6 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path): rewrite=False, ) - cfg.model.seq_length = 256 cfg.checkpoint.save_interval = cfg.train.train_iters cfg.checkpoint.save = sft_checkpoint_dir cfg.checkpoint.load = None @@ -129,6 +152,9 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path): ckpt_format=pretrain_cfg.checkpoint.ckpt_format, storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank, ) + gc.collect() + torch.cuda.empty_cache() + torch.distributed.barrier() finetune(cfg, forward_step) verify_checkpoint_files(