Skip to content

Commit 119be1f

Browse files
committed
Add CP annotation to sparse_matmul
1 parent e67180e commit 119be1f

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

MaxText/layers/moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
571571
output = output[: hs_shape[0]]
572572
return output
573573

574-
# Currently, we only support data and tensor parallelism with Megablox.
574+
# Currently, we support data, tensor, and expert parallelism with Megablox.
575575
# We all gather the input activations over tensor parallelism to follow strategy
576576
# in https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf.
577577

@@ -589,10 +589,10 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
589589
else:
590590
batch_logical_axis = "activation_batch_no_exp"
591591

592-
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, None, None))
593-
gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, None, None))
592+
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
593+
gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
594594
if self.config.model_name.startswith("deepseek3"):
595-
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, None, None))
595+
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
596596
else:
597597
# pre_bias_logits is None for non-DeepSeek v3 models
598598
pre_bias_logits_pspec = None
@@ -610,7 +610,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
610610
shard_map.shard_map,
611611
mesh=self.mesh,
612612
in_specs=(input_partition_pspec, gate_logits_pspec, pre_bias_logits_pspec, w0_pspec, w1_pspec, wo_pspec),
613-
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, None, "activation_embed"))),
613+
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed"))),
614614
check_rep=False,
615615
)
616616
def wrapper(x, logits, pre_bias_logits, w0, w1, wo):

MaxText/tests/moe_test.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,16 @@ def test_megablox(self):
348348
dtype="bfloat16",
349349
megablox=True,
350350
sparse_matmul=True,
351-
per_device_batch_size=4,
351+
per_device_batch_size=1,
352352
)
353353

354354
rng = jax.random.PRNGKey(1234)
355355
rng_model, rng_hidden_states = jax.random.split(rng)
356+
device_count = jax.device_count()
356357
hidden_states = jax.random.uniform(
357-
rng_hidden_states, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim), dtype=cfg.dtype
358+
rng_hidden_states,
359+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
360+
dtype=cfg.dtype,
358361
)
359362

360363
devices_array = maxtext_utils.create_device_mesh(cfg)
@@ -373,13 +376,16 @@ def test_ragged_dot(self):
373376
dtype="bfloat16",
374377
megablox=False,
375378
sparse_matmul=True,
376-
per_device_batch_size=4,
379+
per_device_batch_size=1,
377380
)
378381

379382
rng = jax.random.PRNGKey(1234)
380383
rng_model, rng_hidden_states = jax.random.split(rng)
384+
device_count = jax.device_count()
381385
hidden_states = jax.random.uniform(
382-
rng_hidden_states, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim), dtype=cfg.dtype
386+
rng_hidden_states,
387+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
388+
dtype=cfg.dtype,
383389
)
384390

385391
devices_array = maxtext_utils.create_device_mesh(cfg)
@@ -398,13 +404,16 @@ def test_dense(self):
398404
dtype="float32",
399405
megablox=False,
400406
sparse_matmul=False,
401-
per_device_batch_size=4,
407+
per_device_batch_size=1,
402408
)
403409

404410
rng = jax.random.PRNGKey(2345)
405411
rng_model, rng_hidden_states = jax.random.split(rng)
412+
device_count = jax.device_count()
406413
hidden_states = jax.random.uniform(
407-
rng_hidden_states, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim), dtype=cfg.dtype
414+
rng_hidden_states,
415+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
416+
dtype=cfg.dtype,
408417
)
409418

410419
devices_array = maxtext_utils.create_device_mesh(cfg)
@@ -423,14 +432,47 @@ def test_megablox_expert_parallelism(self):
423432
dtype="bfloat16",
424433
megablox=True,
425434
sparse_matmul=True,
426-
per_device_batch_size=4,
435+
per_device_batch_size=1,
427436
ici_expert_parallelism=4,
428437
)
429438

430439
rng = jax.random.PRNGKey(2345)
431440
rng_model, rng_hidden_states = jax.random.split(rng)
441+
device_count = jax.device_count()
442+
hidden_states = jax.random.uniform(
443+
rng_hidden_states,
444+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
445+
dtype=cfg.dtype,
446+
)
447+
448+
devices_array = maxtext_utils.create_device_mesh(cfg)
449+
mesh = Mesh(devices_array, cfg.mesh_axes)
450+
with nn_partitioning.axis_rules(cfg.logical_axis_rules):
451+
variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg)
452+
actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
453+
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
454+
455+
@pytest.mark.tpu_only
456+
def test_megablox_context_parallelism(self):
457+
cfg = pyconfig.initialize(
458+
[None, os.path.join(PKG_DIR, "configs", "base.yml")],
459+
run_name="moe_block_megablox_cp_test",
460+
enable_checkpointing=False,
461+
model_name="mixtral-8x7b",
462+
dtype="bfloat16",
463+
megablox=True,
464+
sparse_matmul=True,
465+
per_device_batch_size=1,
466+
ici_context_parallelism=4,
467+
)
468+
469+
rng = jax.random.PRNGKey(2345)
470+
rng_model, rng_hidden_states = jax.random.split(rng)
471+
device_count = jax.device_count()
432472
hidden_states = jax.random.uniform(
433-
rng_hidden_states, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim), dtype=cfg.dtype
473+
rng_hidden_states,
474+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
475+
dtype=cfg.dtype,
434476
)
435477

436478
devices_array = maxtext_utils.create_device_mesh(cfg)
@@ -541,7 +583,7 @@ def test_local_permute_no_offset(self):
541583

542584
def test_local_permute_offset(self):
543585
experts_per_group = 2
544-
expert_groups = 4 # aka number of expert shards.
586+
expert_groups = 4 # aka number of expert shards.
545587
num_experts = 8
546588

547589
# Global group sizes for each of the 8 experts

0 commit comments

Comments
 (0)