11# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22
3- from megatron .core .models .bert .bert_model import BertModel
4- import pytest
5-
63import os
4+
5+ import pytest
76import torch
8- from torch .distributed ._tensor import DeviceMesh
97
10- from megatron .core .dist_checkpointing import save , load , load_plain_tensors
118from megatron .core import parallel_state as ps
12- from megatron .core .dist_checkpointing .dict_utils import diff
9+ from megatron .core .models .bert .bert_layer_specs import (
10+ bert_layer_local_spec ,
11+ bert_layer_with_transformer_engine_spec ,
12+ )
13+ from megatron .core .models .bert .bert_model import BertModel
14+ from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
1315from megatron .core .transformer .transformer_config import TransformerConfig
14- from tests .unit_tests .dist_checkpointing import TempNamedDir
15- from tests .unit_tests .dist_checkpointing .models .common import \
16- common_test_simple_sharded_state_dict_save_load , \
17- common_test_parallel_reconfiguration_e2e , common_test_state_dict_comparison , \
18- common_test_vocab_size_padding_change
16+ from tests .unit_tests .dist_checkpointing .models .common import (
17+ common_test_parallel_reconfiguration_e2e ,
18+ common_test_simple_sharded_state_dict_save_load ,
19+ common_test_state_dict_comparison ,
20+ common_test_vocab_size_padding_change ,
21+ )
1922from tests .unit_tests .test_utilities import Utils
20- from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
21- from megatron .core .models .bert .bert_layer_specs import bert_layer_local_spec , bert_layer_with_transformer_engine_spec
2223
2324
2425def initialize_bert_model (seed , layer_spec_fn = bert_layer_with_transformer_engine_spec , vocab_size = 128 , ** config_kwargs ):
@@ -52,6 +53,12 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt,
5253
5354
5455class TestBERTModelReconfiguration :
56+ def setup_method (self , method ):
57+ pass
58+
59+ def teardown_method (self , method ):
60+ Utils .destroy_model_parallel ()
61+
5562 @pytest .mark .parametrize (
5663 ('use_fpsl' , 'src_tp_pp' , 'dest_tp_pp' , 'src_layer_spec' , 'dst_layer_spec' ),
5764 [
@@ -67,6 +74,8 @@ class TestBERTModelReconfiguration:
6774 def test_parallel_reconfiguration_e2e (self , tmp_path_dist_ckpt , src_tp_pp , dest_tp_pp ,
6875 src_layer_spec , dst_layer_spec , use_fpsl ):
6976 """ Test model saving and loading with different TP/PP """
77+ Utils .initialize_model_parallel (src_tp_pp [0 ], src_tp_pp [1 ])
78+
7079 common_test_parallel_reconfiguration_e2e (initialize_bert_model , tmp_path_dist_ckpt , src_tp_pp ,
7180 dest_tp_pp , src_layer_spec , dst_layer_spec , use_fpsl )
7281
@@ -82,5 +91,6 @@ def test_state_dict_comparison(self, tmp_path_dist_ckpt):
8291 ])
8392 def test_vocab_size_padding_change (self , tmp_path_dist_ckpt , vocab_size_base , src_tp_pp , dest_tp_pp ):
8493 """ Test model loading with different vocab size (caused by TP padding). """
94+ Utils .initialize_model_parallel (src_tp_pp [0 ], src_tp_pp [1 ])
8595 common_test_vocab_size_padding_change (initialize_bert_model , tmp_path_dist_ckpt , vocab_size_base ,
8696 src_tp_pp , dest_tp_pp )
0 commit comments