@@ -348,13 +348,16 @@ def test_megablox(self):
348348 dtype = "bfloat16" ,
349349 megablox = True ,
350350 sparse_matmul = True ,
351- per_device_batch_size = 4 ,
351+ per_device_batch_size = 1 ,
352352 )
353353
354354 rng = jax .random .PRNGKey (1234 )
355355 rng_model , rng_hidden_states = jax .random .split (rng )
356+ device_count = jax .device_count ()
356357 hidden_states = jax .random .uniform (
357- rng_hidden_states , (int (cfg .per_device_batch_size ), cfg .max_target_length , cfg .base_emb_dim ), dtype = cfg .dtype
358+ rng_hidden_states ,
359+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
360+ dtype = cfg .dtype ,
358361 )
359362
360363 devices_array = maxtext_utils .create_device_mesh (cfg )
@@ -373,13 +376,16 @@ def test_ragged_dot(self):
373376 dtype = "bfloat16" ,
374377 megablox = False ,
375378 sparse_matmul = True ,
376- per_device_batch_size = 4 ,
379+ per_device_batch_size = 1 ,
377380 )
378381
379382 rng = jax .random .PRNGKey (1234 )
380383 rng_model , rng_hidden_states = jax .random .split (rng )
384+ device_count = jax .device_count ()
381385 hidden_states = jax .random .uniform (
382- rng_hidden_states , (int (cfg .per_device_batch_size ), cfg .max_target_length , cfg .base_emb_dim ), dtype = cfg .dtype
386+ rng_hidden_states ,
387+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
388+ dtype = cfg .dtype ,
383389 )
384390
385391 devices_array = maxtext_utils .create_device_mesh (cfg )
@@ -398,13 +404,16 @@ def test_dense(self):
398404 dtype = "float32" ,
399405 megablox = False ,
400406 sparse_matmul = False ,
401- per_device_batch_size = 4 ,
407+ per_device_batch_size = 1 ,
402408 )
403409
404410 rng = jax .random .PRNGKey (2345 )
405411 rng_model , rng_hidden_states = jax .random .split (rng )
412+ device_count = jax .device_count ()
406413 hidden_states = jax .random .uniform (
407- rng_hidden_states , (int (cfg .per_device_batch_size ), cfg .max_target_length , cfg .base_emb_dim ), dtype = cfg .dtype
414+ rng_hidden_states ,
415+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
416+ dtype = cfg .dtype ,
408417 )
409418
410419 devices_array = maxtext_utils .create_device_mesh (cfg )
@@ -423,14 +432,47 @@ def test_megablox_expert_parallelism(self):
423432 dtype = "bfloat16" ,
424433 megablox = True ,
425434 sparse_matmul = True ,
426- per_device_batch_size = 4 ,
435+ per_device_batch_size = 1 ,
427436 ici_expert_parallelism = 4 ,
428437 )
429438
430439 rng = jax .random .PRNGKey (2345 )
431440 rng_model , rng_hidden_states = jax .random .split (rng )
441+ device_count = jax .device_count ()
442+ hidden_states = jax .random .uniform (
443+ rng_hidden_states ,
444+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
445+ dtype = cfg .dtype ,
446+ )
447+
448+ devices_array = maxtext_utils .create_device_mesh (cfg )
449+ mesh = Mesh (devices_array , cfg .mesh_axes )
450+ with nn_partitioning .axis_rules (cfg .logical_axis_rules ):
451+ variables , expected_output = self .get_expected_output (rng_model , hidden_states , cfg )
452+ actual_output , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
453+ self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
454+
455+ @pytest .mark .tpu_only
456+ def test_megablox_context_parallelism (self ):
457+ cfg = pyconfig .initialize (
458+ [None , os .path .join (PKG_DIR , "configs" , "base.yml" )],
459+ run_name = "moe_block_megablox_cp_test" ,
460+ enable_checkpointing = False ,
461+ model_name = "mixtral-8x7b" ,
462+ dtype = "bfloat16" ,
463+ megablox = True ,
464+ sparse_matmul = True ,
465+ per_device_batch_size = 1 ,
466+ ici_context_parallelism = 4 ,
467+ )
468+
469+ rng = jax .random .PRNGKey (2345 )
470+ rng_model , rng_hidden_states = jax .random .split (rng )
471+ device_count = jax .device_count ()
432472 hidden_states = jax .random .uniform (
433- rng_hidden_states , (int (cfg .per_device_batch_size ), cfg .max_target_length , cfg .base_emb_dim ), dtype = cfg .dtype
473+ rng_hidden_states ,
474+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
475+ dtype = cfg .dtype ,
434476 )
435477
436478 devices_array = maxtext_utils .create_device_mesh (cfg )
@@ -541,7 +583,7 @@ def test_local_permute_no_offset(self):
541583
542584 def test_local_permute_offset (self ):
543585 experts_per_group = 2
544- expert_groups = 4 # aka number of expert shards.
586+ expert_groups = 4 # aka number of expert shards.
545587 num_experts = 8
546588
547589 # Global group sizes for each of the 8 experts
0 commit comments