4040)
4141
4242
43- @requires_multi_gpu
44- def test_ep2_matches_ep1 (unused_tcp_port ):
45- """Test that EP=2 produces the same logits as EP=1."""
43+ def _run_torchrun (test_name : str , port : int ):
44+ """Run a named test worker via torchrun with 2 GPUs."""
4645 cmd = [
4746 "torchrun" ,
4847 "--nproc_per_node=2" ,
4948 "--rdzv-backend=c10d" ,
50- f"--rdzv-endpoint=localhost:{ unused_tcp_port } " ,
49+ f"--rdzv-endpoint=localhost:{ port } " ,
5150 str (Path (__file__ ).resolve ()),
51+ test_name ,
5252 ]
5353 result = subprocess .run (
5454 cmd ,
@@ -62,7 +62,19 @@ def test_ep2_matches_ep1(unused_tcp_port):
6262 if result .returncode != 0 :
6363 print (f"STDOUT:\n { result .stdout } " )
6464 print (f"STDERR:\n { result .stderr } " )
65- pytest .fail (f"EP equivalence test failed with exit code { result .returncode } " )
65+ pytest .fail (f"EP { test_name } test failed with exit code { result .returncode } " )
66+
67+
68+ @requires_multi_gpu
69+ def test_ep2_matches_ep1 (unused_tcp_port ):
70+ """Test that EP=2 produces the same logits as EP=1."""
71+ _run_torchrun ("forward" , unused_tcp_port )
72+
73+
74+ @requires_multi_gpu
75+ def test_ep2_backward_matches_ep1 (unused_tcp_port ):
76+ """Test that EP=2 backward pass produces the same gradients as EP=1."""
77+ _run_torchrun ("backward" , unused_tcp_port )
6678
6779
6880# ---------------------------------------------------------------------------
@@ -85,18 +97,22 @@ def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device
8597 from torch .distributed .tensor import DTensor , distribute_tensor
8698
8799 distributed_state : dict = {}
88- for key , value in model .state_dict ().items ():
89- if key not in full_state_dict :
100+ # model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel
101+ # override, so use nn.Module.state_dict to get the unfiltered dict that includes
102+ # TransformerEngine _extra_state entries required by load_state_dict(strict=True).
103+ for key , value in torch .nn .Module .state_dict (model ).items ():
104+ if key .endswith ("_extra_state" ):
105+ distributed_state [key ] = value
106+ elif key not in full_state_dict :
90107 continue
91- full_value = full_state_dict [key ]
92- if isinstance (value , DTensor ):
108+ elif isinstance (value , DTensor ):
93109 distributed_state [key ] = distribute_tensor (
94- full_value .to (device ),
110+ full_state_dict [ key ] .to (device ),
95111 value .device_mesh ,
96112 list (value .placements ),
97113 )
98114 else :
99- distributed_state [key ] = full_value
115+ distributed_state [key ] = full_state_dict [ key ]
100116 return distributed_state
101117
102118
@@ -183,5 +199,95 @@ def _run_ep_equivalence_test():
183199 torch .distributed .destroy_process_group ()
184200
185201
202+ def _run_ep_backward_test ():
203+ """Worker function for the EP backward equivalence test.
204+
205+ 1. Init distributed with 2 GPUs.
206+ 2. Every rank creates an EP=1 model, runs forward+backward, saves gradients.
207+ 3. Create EP=2 model with sharded weights, runs forward+backward, compares gradients.
208+ """
209+ from torch .distributed .tensor .device_mesh import DeviceMesh
210+
211+ dist_config = DistributedConfig ()
212+ device = torch .device (f"cuda:{ dist_config .local_rank } " )
213+ torch .cuda .set_device (device )
214+ torch .distributed .init_process_group (backend = "nccl" , device_id = device )
215+ ep_size = dist_config .world_size
216+
217+ # --- Phase 1: EP=1 reference (every rank computes independently) ---
218+ config_ep1 = create_small_mixtral_config (expert_parallel_size = 1 )
219+ torch .manual_seed (0 )
220+ model_ep1 = NVMixtralForCausalLM (config_ep1 ).to (dtype = torch .bfloat16 , device = device )
221+
222+ batch = get_dummy_batch (config_ep1 .vocab_size , seq_len = 32 , batch_size = 2 , device = device )
223+
224+ outputs_ep1 = model_ep1 (** batch )
225+ outputs_ep1 .loss .backward ()
226+
227+ ref_grads = {name : p .grad .detach ().clone ().cpu () for name , p in model_ep1 .named_parameters () if p .grad is not None }
228+ loss_ep1 = outputs_ep1 .loss .detach ().clone ().cpu ()
229+
230+ full_state_dict = {k : v .clone ().cpu () for k , v in model_ep1 .state_dict ().items ()}
231+ del model_ep1 , outputs_ep1
232+ torch .cuda .empty_cache ()
233+
234+ # --- Phase 2: EP=2 distributed run ---
235+ config_ep2 = create_small_mixtral_config (expert_parallel_size = ep_size )
236+ torch .manual_seed (0 )
237+ model_ep2 = NVMixtralForCausalLM (config_ep2 ).to (dtype = torch .bfloat16 , device = device )
238+
239+ ep_mesh = DeviceMesh ("cuda" , list (range (ep_size )))
240+ ep_group = ep_mesh .get_group ()
241+ model_ep2 .model .set_ep_groups (ep_group , ep_mesh )
242+
243+ distributed_state = _distribute_state_dict (full_state_dict , model_ep2 , device )
244+ model_ep2 .load_state_dict (distributed_state , strict = True )
245+
246+ batch_cuda = {k : v .to (device ) if isinstance (v , torch .Tensor ) else v for k , v in batch .items ()}
247+
248+ outputs_ep2 = model_ep2 (** batch_cuda )
249+ outputs_ep2 .loss .backward ()
250+
251+ test_grads = {}
252+ for name , p in model_ep2 .named_parameters ():
253+ if p .grad is not None :
254+ g = p .grad
255+ if hasattr (g , "full_tensor" ):
256+ g = g .full_tensor ()
257+ test_grads [name ] = g .detach ().clone ().cpu ()
258+ loss_ep2 = outputs_ep2 .loss .detach ().clone ().cpu ()
259+
260+ # --- Phase 3: Compare on rank 0 ---
261+ if dist_config .is_main_process ():
262+ torch .testing .assert_close (
263+ loss_ep2 ,
264+ loss_ep1 ,
265+ atol = 1e-3 ,
266+ rtol = 1e-3 ,
267+ msg = "EP=2 backward: loss does not match EP=1 loss" ,
268+ )
269+
270+ # All EP=1 parameters should have gradients in EP=2 as well
271+ for name in ref_grads :
272+ assert name in test_grads , f"EP=2 model missing gradient for { name } "
273+
274+ for name in ref_grads :
275+ torch .testing .assert_close (
276+ test_grads [name ],
277+ ref_grads [name ],
278+ atol = 3e-2 ,
279+ rtol = 3e-2 ,
280+ msg = f"EP=2 gradient mismatch for { name } " ,
281+ )
282+
283+ print ("EP backward test PASSED: EP=2 gradients match EP=1" )
284+
285+ torch .distributed .destroy_process_group ()
286+
287+
186288if __name__ == "__main__" :
187- _run_ep_equivalence_test ()
289+ test_name = sys .argv [1 ] if len (sys .argv ) > 1 else "forward"
290+ if test_name == "backward" :
291+ _run_ep_backward_test ()
292+ else :
293+ _run_ep_equivalence_test ()
0 commit comments