Skip to content

Commit efdf89b

Browse files
committed
add ep test
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 2c3b1fc commit efdf89b

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for expert parallelism (EP) in the Mixtral MoE model.
17+
18+
Verifies that running with EP=2 (experts sharded across 2 GPUs) produces
19+
the same logits and loss as EP=1 (all experts on a single GPU).
20+
"""
21+
22+
import os
23+
import subprocess
24+
import sys
25+
from dataclasses import dataclass, field
26+
from pathlib import Path
27+
28+
29+
sys.path.insert(0, str(Path(__file__).parent.parent))
30+
31+
import pytest
32+
import torch
33+
34+
from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM
35+
36+
37+
requires_multi_gpu = pytest.mark.skipif(
38+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
39+
reason="Test requires at least 2 GPUs",
40+
)
41+
42+
43+
def _create_small_mixtral_config(**overrides) -> NVMixtralConfig:
44+
"""Create a small Mixtral config suitable for testing."""
45+
defaults = {
46+
"hidden_size": 128,
47+
"intermediate_size": 256,
48+
"num_hidden_layers": 2,
49+
"num_attention_heads": 4,
50+
"num_key_value_heads": 2,
51+
"num_local_experts": 4,
52+
"num_experts_per_tok": 2,
53+
"max_position_embeddings": 128,
54+
"vocab_size": 1000,
55+
"attn_input_format": "bshd",
56+
"self_attn_mask_type": "causal",
57+
"router_jitter_noise": 0.0,
58+
}
59+
defaults.update(overrides)
60+
return NVMixtralConfig(**defaults)
61+
62+
63+
def _get_dummy_batch(vocab_size: int, seq_len: int = 32, batch_size: int = 2, device: str = "cuda"):
64+
"""Create a simple dummy batch for testing."""
65+
torch.manual_seed(42)
66+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
67+
attention_mask = torch.ones_like(input_ids)
68+
labels = input_ids.clone()
69+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
70+
71+
72+
@requires_multi_gpu
73+
def test_ep2_matches_ep1(unused_tcp_port):
74+
"""Test that EP=2 produces the same logits as EP=1."""
75+
cmd = [
76+
"torchrun",
77+
"--nproc_per_node=2",
78+
"--rdzv-backend=c10d",
79+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
80+
os.path.relpath(__file__),
81+
]
82+
result = subprocess.run(
83+
cmd,
84+
check=False,
85+
text=True,
86+
cwd=str(Path(__file__).parent.parent),
87+
stdout=subprocess.PIPE,
88+
stderr=subprocess.PIPE,
89+
timeout=300,
90+
)
91+
if result.returncode != 0:
92+
print(f"STDOUT:\n{result.stdout}")
93+
print(f"STDERR:\n{result.stderr}")
94+
pytest.fail(f"EP equivalence test failed with exit code {result.returncode}")
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# Distributed worker executed via torchrun
99+
# ---------------------------------------------------------------------------
100+
101+
102+
@dataclass(frozen=True)
103+
class DistributedConfig:
104+
"""Distributed environment configuration."""
105+
106+
rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0")))
107+
local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0")))
108+
world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1")))
109+
_master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost"))
110+
_master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355"))
111+
112+
def is_main_process(self) -> bool:
113+
"""Return True if this is the global rank 0 process."""
114+
return self.rank == 0
115+
116+
117+
def _shard_expert_weights(full_state_dict: dict, ep_rank: int, ep_size: int, num_experts: int) -> dict:
118+
"""Shard expert weights from a full (EP=1) state dict for a given EP rank.
119+
120+
Expert weight keys follow the TE GroupedLinear naming convention:
121+
``...experts_gate_up.weight{i}`` and ``...experts_down.weight{i}``
122+
where ``i`` is the global expert index.
123+
124+
For EP, each rank keeps only its local slice of experts and renumbers
125+
the weight keys starting from 0.
126+
127+
Args:
128+
full_state_dict: Complete state dict from an EP=1 model (without _extra_state keys).
129+
ep_rank: This rank's index in the EP group.
130+
ep_size: Total number of EP ranks.
131+
num_experts: Total number of experts (before sharding).
132+
"""
133+
experts_per_rank = num_experts // ep_size
134+
start_expert = ep_rank * experts_per_rank
135+
end_expert = start_expert + experts_per_rank
136+
137+
new_state_dict = {}
138+
for key, value in full_state_dict.items():
139+
if "experts_gate_up.weight" in key or "experts_down.weight" in key:
140+
# Extract global expert index from key like "...weight3"
141+
prefix, weight_part = key.rsplit("weight", 1)
142+
global_idx = int(weight_part)
143+
if start_expert <= global_idx < end_expert:
144+
local_idx = global_idx - start_expert
145+
new_key = f"{prefix}weight{local_idx}"
146+
new_state_dict[new_key] = value
147+
else:
148+
# Non-expert weights are replicated
149+
new_state_dict[key] = value
150+
151+
return new_state_dict
152+
153+
154+
def _run_ep_equivalence_test():
155+
"""Main worker function for the EP equivalence test.
156+
157+
1. Set up each rank's device, init distributed.
158+
2. Every rank creates an EP=1 model on its own GPU, runs forward, saves reference.
159+
3. Create EP=2 model with sharded expert weights, run forward, compare.
160+
"""
161+
# --- Setup distributed first so each rank uses its own GPU ---
162+
dist_config = DistributedConfig()
163+
device = torch.device(f"cuda:{dist_config.local_rank}")
164+
torch.cuda.set_device(device)
165+
torch.distributed.init_process_group(backend="nccl", device_id=device)
166+
ep_rank = dist_config.rank
167+
ep_size = dist_config.world_size
168+
169+
# --- Phase 1: EP=1 reference (every rank computes independently) ---
170+
config_ep1 = _create_small_mixtral_config(expert_parallel_size=1)
171+
torch.manual_seed(0)
172+
model_ep1 = NVMixtralForCausalLM(config_ep1).to(dtype=torch.bfloat16, device=device)
173+
model_ep1.eval()
174+
175+
batch = _get_dummy_batch(config_ep1.vocab_size, seq_len=32, batch_size=2, device=device)
176+
177+
with torch.no_grad():
178+
outputs_ep1 = model_ep1(**batch)
179+
180+
logits_ep1 = outputs_ep1.logits.detach().clone().cpu()
181+
loss_ep1 = outputs_ep1.loss.detach().clone().cpu()
182+
183+
# Save EP=1 full state dict on CPU for sharding
184+
full_state_dict = {k: v.clone().cpu() for k, v in model_ep1.state_dict().items()}
185+
186+
del model_ep1, outputs_ep1
187+
torch.cuda.empty_cache()
188+
189+
# --- Phase 2: EP=2 distributed run ---
190+
config_ep2 = _create_small_mixtral_config(expert_parallel_size=ep_size)
191+
torch.manual_seed(0)
192+
model_ep2 = NVMixtralForCausalLM(config_ep2).to(dtype=torch.bfloat16, device=device)
193+
194+
# Load sharded expert weights (strict=False to skip TE _extra_state keys)
195+
sharded_state_dict = _shard_expert_weights(full_state_dict, ep_rank, ep_size, config_ep1.num_local_experts)
196+
model_ep2.load_state_dict(sharded_state_dict, strict=False)
197+
model_ep2.eval()
198+
199+
# Set EP process group on all MoE blocks
200+
ep_group = torch.distributed.group.WORLD
201+
model_ep2.model.set_ep_groups(ep_group)
202+
203+
# Same batch on all ranks (EP dispatches tokens, input is replicated)
204+
batch_cuda = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
205+
206+
with torch.no_grad():
207+
outputs_ep2 = model_ep2(**batch_cuda)
208+
209+
logits_ep2 = outputs_ep2.logits.detach().cpu()
210+
loss_ep2 = outputs_ep2.loss.detach().cpu()
211+
212+
# --- Phase 3: Compare on rank 0 ---
213+
if dist_config.is_main_process():
214+
torch.testing.assert_close(
215+
logits_ep2,
216+
logits_ep1,
217+
atol=1e-2,
218+
rtol=1e-2,
219+
msg="EP=2 logits do not match EP=1 logits",
220+
)
221+
222+
torch.testing.assert_close(
223+
loss_ep2,
224+
loss_ep1,
225+
atol=1e-3,
226+
rtol=1e-3,
227+
msg="EP=2 loss does not match EP=1 loss",
228+
)
229+
230+
print("EP equivalence test PASSED: EP=2 logits and loss match EP=1")
231+
232+
torch.distributed.destroy_process_group()
233+
234+
235+
if __name__ == "__main__":
236+
_run_ep_equivalence_test()

0 commit comments

Comments
 (0)