1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import gc
1516import os
1617
1718import pytest
3233)
3334
3435
36+ def _set_existing_attr (target : object , name : str , value : object ) -> None :
37+ if not hasattr (target , name ):
38+ raise ValueError (f"{ type (target ).__name__ } has no field { name !r} " )
39+ setattr (target , name , value )
40+
41+
42+ def _make_functional_test_model_small (model : object ) -> None :
43+ # Keep this checkpoint-loading functional test far below runner memory limits.
44+ # The path under test is CP + sequence packing + pretrained checkpoint loading,
45+ # not the full Llama 3.2 1B model shape.
46+ for name , value in {
47+ "num_layers" : 2 ,
48+ "hidden_size" : 256 ,
49+ "ffn_hidden_size" : 1024 ,
50+ "num_attention_heads" : 4 ,
51+ "num_query_groups" : 4 ,
52+ "kv_channels" : 64 ,
53+ "seq_length" : 256 ,
54+ }.items ():
55+ _set_existing_attr (model , name , value )
56+
57+
3558class TestPeftSftExample :
3659 """Run the PEFT SFT example as a functional test with packed sequences + CP."""
3760
@@ -57,10 +80,10 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
5780 torch .distributed .barrier ()
5881
5982 pretrain_cfg = llama32_1b_pretrain_config ()
83+ _make_functional_test_model_small (pretrain_cfg .model )
6084 pretrain_cfg .model .tensor_model_parallel_size = 1
6185 pretrain_cfg .model .pipeline_model_parallel_size = 1
6286 pretrain_cfg .model .context_parallel_size = 2
63- pretrain_cfg .model .seq_length = 256
6487 pretrain_cfg .dataset .seq_length = 256
6588 pretrain_cfg .train .train_iters = 1
6689 pretrain_cfg .train .global_batch_size = 2
@@ -75,6 +98,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
7598 pretrain_cfg .checkpoint .load = None
7699
77100 cfg = llama32_1b_sft_config ()
101+ _make_functional_test_model_small (cfg .model )
78102 cfg .tokenizer .tokenizer_type = "HuggingFaceTokenizer"
79103 cfg .tokenizer .tokenizer_model = "meta-llama/Llama-3.2-1B"
80104 cfg .model .calculate_per_token_loss = True
@@ -115,7 +139,6 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
115139 rewrite = False ,
116140 )
117141
118- cfg .model .seq_length = 256
119142 cfg .checkpoint .save_interval = cfg .train .train_iters
120143 cfg .checkpoint .save = sft_checkpoint_dir
121144 cfg .checkpoint .load = None
@@ -129,6 +152,9 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
129152 ckpt_format = pretrain_cfg .checkpoint .ckpt_format ,
130153 storage_writers_per_rank = pretrain_cfg .checkpoint .storage_writers_per_rank ,
131154 )
155+ gc .collect ()
156+ torch .cuda .empty_cache ()
157+ torch .distributed .barrier ()
132158
133159 finetune (cfg , forward_step )
134160 verify_checkpoint_files (
0 commit comments