Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os

import pytest
Expand All @@ -32,6 +33,28 @@
)


def _set_existing_attr(target: object, name: str, value: object) -> None:
if not hasattr(target, name):
raise ValueError(f"{type(target).__name__} has no field {name!r}")
setattr(target, name, value)


def _make_functional_test_model_small(model: object) -> None:
# Keep this checkpoint-loading functional test far below runner memory limits.
# The path under test is CP + sequence packing + pretrained checkpoint loading,
# not the full Llama 3.2 1B model shape.
for name, value in {
"num_layers": 2,
"hidden_size": 256,
"ffn_hidden_size": 1024,
"num_attention_heads": 4,
"num_query_groups": 4,
"kv_channels": 64,
"seq_length": 256,
}.items():
_set_existing_attr(model, name, value)


class TestPeftSftExample:
"""Run the PEFT SFT example as a functional test with packed sequences + CP."""

Expand All @@ -57,10 +80,10 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
torch.distributed.barrier()

pretrain_cfg = llama32_1b_pretrain_config()
_make_functional_test_model_small(pretrain_cfg.model)
pretrain_cfg.model.tensor_model_parallel_size = 1
pretrain_cfg.model.pipeline_model_parallel_size = 1
pretrain_cfg.model.context_parallel_size = 2
pretrain_cfg.model.seq_length = 256
pretrain_cfg.dataset.seq_length = 256
pretrain_cfg.train.train_iters = 1
pretrain_cfg.train.global_batch_size = 2
Expand All @@ -75,6 +98,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
pretrain_cfg.checkpoint.load = None

cfg = llama32_1b_sft_config()
_make_functional_test_model_small(cfg.model)
cfg.tokenizer.tokenizer_type = "HuggingFaceTokenizer"
cfg.tokenizer.tokenizer_model = "meta-llama/Llama-3.2-1B"
cfg.model.calculate_per_token_loss = True
Expand Down Expand Up @@ -115,7 +139,6 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
rewrite=False,
)

cfg.model.seq_length = 256
cfg.checkpoint.save_interval = cfg.train.train_iters
cfg.checkpoint.save = sft_checkpoint_dir
cfg.checkpoint.load = None
Expand All @@ -129,6 +152,9 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)
gc.collect()
torch.cuda.empty_cache()
torch.distributed.barrier()

finetune(cfg, forward_step)
verify_checkpoint_files(
Expand Down
Loading