Skip to content

Commit efd3e32

Browse files
[triton_kernels] fix test case for distributed routing kernels (#9258)
1 parent 4e466d8 commit efd3e32

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

python/triton_kernels/tests/test_distributed.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.multiprocessing as mp
88
import triton
99
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, SymmetricMemoryPool
10+
from triton_kernels.distributed_details.mesh import Mesh
1011
from triton_kernels.reduce import reduce
1112
from triton_kernels.topk import topk
1213
from triton_kernels.matmul import matmul
@@ -197,16 +198,14 @@ def _run_expert_sharding(rank, world_size, *, n_tokens, d_model, n_expts_tot, n_
197198
y_indx=y_indx_global,
198199
)
199200

200-
symm_mem_pool = SymmetricMemoryPool()
201+
symm_mem_pool = SymmetricMemoryPool(Mesh(dist.group.WORLD))
201202
symm_mem_pool.initialize_matmul(
202203
n_tokens_global=n_tokens_global,
203204
d_input=d_model,
204205
d_model=d_model,
205206
n_expts_act=n_expts_act,
206207
n_expts_tot=n_expts_tot,
207208
dtype=torch.bfloat16,
208-
n_ranks=world_size,
209-
group=dist.group.WORLD,
210209
device=dev,
211210
)
212211

@@ -239,7 +238,6 @@ def run_moe():
239238
g.replay()
240239
dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri_graph)
241240
triton.testing.assert_close(y_global_ref, y_global_tri)
242-
symm_mem_pool.release()
243241

244242

245243
@pytest.mark.parametrize("distributed_launcher", [2, 4], indirect=True)

0 commit comments

Comments
 (0)