Skip to content

Commit 385e4a1

Browse files
committed
feat: add token_dispatcher tests
1 parent 484c4c6 commit 385e4a1

File tree

8 files changed

+542
-0
lines changed

8 files changed

+542
-0
lines changed

tests/backends/__init__.py

Whitespace-only changes.

tests/backends/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
6+
def pytest_configure(config):
7+
megatron_path = os.environ.get("MEGATRON_PATH")
8+
if megatron_path is None or not os.path.exists(megatron_path):
9+
megatron_path = Path(__file__).resolve().parent.parent.parent / "third_party" / "Megatron-LM"
10+
sys.path.insert(0, str(megatron_path))
11+
print(f"[Primus] sys.path.insert: {megatron_path}")

tests/backends/megatron/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch._dynamo
2+
3+
torch._dynamo.config.suppress_errors = True
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import os
2+
from datetime import timedelta
3+
4+
import megatron.core.parallel_state as ps
5+
import torch
6+
from torch._C._distributed_c10d import PrefixStore
7+
from torch.distributed import rendezvous
8+
9+
10+
class TestModel(torch.nn.Module):
11+
def __init__(
12+
self,
13+
input_dim: int,
14+
output_dim: int,
15+
num_layers: int,
16+
bias: bool,
17+
shared_embedding: bool = False,
18+
):
19+
super().__init__()
20+
self.layers = torch.nn.ModuleList(
21+
[torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_layers)]
22+
)
23+
if shared_embedding:
24+
self.layers[-1].weight.shared_embedding = True
25+
26+
27+
class Utils:
28+
29+
world_size = int(os.getenv("WORLD_SIZE", 1))
30+
rank = int(os.getenv("LOCAL_RANK", 0))
31+
inited = False
32+
store = None
33+
34+
@staticmethod
35+
def initialize_distributed():
36+
37+
os.environ.pop("NVTE_FLASH_ATTN", None)
38+
os.environ.pop("NVTE_FUSED_ATTN", None)
39+
os.environ.pop("NVTE_UNFUSED_ATTN", None)
40+
41+
if not torch.distributed.is_initialized() and Utils.rank >= 0:
42+
print(
43+
f"Initializing torch.distributed with rank: {Utils.rank}, " f"world_size: {Utils.world_size}"
44+
)
45+
torch.cuda.set_device(Utils.rank % torch.cuda.device_count())
46+
init_method = "tcp://"
47+
master_ip = os.getenv("MASTER_ADDR", "localhost")
48+
master_port = os.getenv("MASTER_PORT", "6000")
49+
init_method += master_ip + ":" + master_port
50+
rendezvous_iterator = rendezvous(
51+
init_method, Utils.rank, Utils.world_size, timeout=timedelta(minutes=1)
52+
)
53+
store, rank, world_size = next(rendezvous_iterator)
54+
store.set_timeout(timedelta(minutes=1))
55+
56+
# Use a PrefixStore to avoid accidental overrides of keys used by
57+
# different systems (e.g. RPC) in case the store is multi-tenant.
58+
store = PrefixStore("default_pg", store)
59+
Utils.store = store
60+
61+
torch.distributed.init_process_group(
62+
backend="nccl", world_size=Utils.world_size, rank=Utils.rank, store=store
63+
)
64+
65+
torch.distributed.barrier()
66+
Utils.inited = True
67+
68+
@staticmethod
69+
def set_world_size(world_size=None, rank=None):
70+
Utils.world_size = torch.cuda.device_count() if world_size is None else world_size
71+
if torch.distributed.is_initialized() and Utils.world_size != torch.distributed.get_world_size():
72+
torch.distributed.destroy_process_group()
73+
74+
if rank is None:
75+
Utils.rank = int(os.environ["LOCAL_RANK"])
76+
if Utils.rank >= Utils.world_size:
77+
Utils.rank = -1
78+
else:
79+
Utils.rank = rank
80+
81+
@staticmethod
82+
def destroy_model_parallel():
83+
os.environ.pop("NVTE_FLASH_ATTN", None)
84+
os.environ.pop("NVTE_FUSED_ATTN", None)
85+
os.environ.pop("NVTE_UNFUSED_ATTN", None)
86+
if not Utils.inited:
87+
return
88+
torch.distributed.barrier()
89+
ps.destroy_model_parallel()
90+
Utils.inited = False
91+
92+
@staticmethod
93+
def initialize_model_parallel(
94+
tensor_model_parallel_size=1,
95+
pipeline_model_parallel_size=1,
96+
virtual_pipeline_model_parallel_size=None,
97+
**kwargs,
98+
):
99+
# Need to unset these variables to make sure previous
100+
# tests setting them doesn't interfere current test.
101+
os.environ.pop("NVTE_FLASH_ATTN", None)
102+
os.environ.pop("NVTE_FUSED_ATTN", None)
103+
os.environ.pop("NVTE_UNFUSED_ATTN", None)
104+
105+
ps.destroy_model_parallel()
106+
Utils.initialize_distributed()
107+
ps.initialize_model_parallel(
108+
tensor_model_parallel_size,
109+
pipeline_model_parallel_size,
110+
virtual_pipeline_model_parallel_size,
111+
**kwargs,
112+
)
113+
Utils.inited = True
114+
115+
@staticmethod
116+
def fake_initialize_model_parallel(
117+
tensor_model_parallel_size=1,
118+
pipeline_model_parallel_size=1,
119+
virtual_pipeline_model_parallel_size=None,
120+
expert_model_parallel_size=1,
121+
):
122+
"""Used for layer-wise UT as a proxy for NeMo-style intialization."""
123+
ps.set_tensor_model_parallel_world_size(tensor_model_parallel_size)
124+
ps.set_tensor_model_parallel_rank(0)
125+
126+
ps.set_expert_model_parallel_world_size(expert_model_parallel_size)
127+
ps.set_expert_model_parallel_rank(0)
128+
if virtual_pipeline_model_parallel_size is not None:
129+
ps.set_virtual_pipeline_model_parallel_world_size(virtual_pipeline_model_parallel_size)
130+
ps.set_virtual_pipeline_model_parallel_rank(0)
131+
132+
ps.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size)

tests/backends/megatron/unit_tests/transformer/__init__.py

Whitespace-only changes.

tests/backends/megatron/unit_tests/transformer/moe/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)