Skip to content

Commit b17315a

Browse files
committed
[test] fix: shrink CP packed SFT functional model
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent 2e4322c commit b17315a

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

tests/functional_tests/test_groups/training/test_seqpacking_cp_example.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import os
1617

1718
import pytest
@@ -32,6 +33,28 @@
3233
)
3334

3435

36+
def _set_existing_attr(target: object, name: str, value: object) -> None:
37+
if not hasattr(target, name):
38+
raise ValueError(f"{type(target).__name__} has no field {name!r}")
39+
setattr(target, name, value)
40+
41+
42+
def _make_functional_test_model_small(model: object) -> None:
43+
# Keep this checkpoint-loading functional test far below runner memory limits.
44+
# The path under test is CP + sequence packing + pretrained checkpoint loading,
45+
# not the full Llama 3.2 1B model shape.
46+
for name, value in {
47+
"num_layers": 2,
48+
"hidden_size": 256,
49+
"ffn_hidden_size": 1024,
50+
"num_attention_heads": 4,
51+
"num_query_groups": 4,
52+
"kv_channels": 64,
53+
"seq_length": 256,
54+
}.items():
55+
_set_existing_attr(model, name, value)
56+
57+
3558
class TestPeftSftExample:
3659
"""Run the PEFT SFT example as a functional test with packed sequences + CP."""
3760

@@ -57,10 +80,10 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
5780
torch.distributed.barrier()
5881

5982
pretrain_cfg = llama32_1b_pretrain_config()
83+
_make_functional_test_model_small(pretrain_cfg.model)
6084
pretrain_cfg.model.tensor_model_parallel_size = 1
6185
pretrain_cfg.model.pipeline_model_parallel_size = 1
6286
pretrain_cfg.model.context_parallel_size = 2
63-
pretrain_cfg.model.seq_length = 256
6487
pretrain_cfg.dataset.seq_length = 256
6588
pretrain_cfg.train.train_iters = 1
6689
pretrain_cfg.train.global_batch_size = 2
@@ -75,6 +98,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
7598
pretrain_cfg.checkpoint.load = None
7699

77100
cfg = llama32_1b_sft_config()
101+
_make_functional_test_model_small(cfg.model)
78102
cfg.tokenizer.tokenizer_type = "HuggingFaceTokenizer"
79103
cfg.tokenizer.tokenizer_model = "meta-llama/Llama-3.2-1B"
80104
cfg.model.calculate_per_token_loss = True
@@ -115,7 +139,6 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
115139
rewrite=False,
116140
)
117141

118-
cfg.model.seq_length = 256
119142
cfg.checkpoint.save_interval = cfg.train.train_iters
120143
cfg.checkpoint.save = sft_checkpoint_dir
121144
cfg.checkpoint.load = None
@@ -129,6 +152,9 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
129152
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
130153
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
131154
)
155+
gc.collect()
156+
torch.cuda.empty_cache()
157+
torch.distributed.barrier()
132158

133159
finetune(cfg, forward_step)
134160
verify_checkpoint_files(

0 commit comments

Comments
 (0)