@@ -753,6 +753,97 @@ def __init__(self, devices, axis_names, axis_types=None):
753753 role_to_mesh [rl_cluster_lib .Role .ACTOR ],
754754 )
755755
756+ def test_split_mesh_delegates_device_allocation_to_mesh_utils (self ):
757+ extra = """
758+ training_mode: "agentic_grpo"
759+ data_module: "tunix.cli.recipes.deepscaler_data"
760+ apply_chat_template_to_dataset: false
761+ data_config:
762+ train_data_path: "gs://fake/train.json"
763+ eval_data_path: "gs://fake/eval.parquet"
764+ prompt_key: "prompts"
765+ reward_functions: []
766+ verl_compatible: false
767+ chat_parser_config:
768+ type: "default"
769+ agent_class_path: null
770+ agent_kwargs: {}
771+ env_class_path: null
772+ env_kwargs: {}
773+ kubernetes_config: null
774+ agentic_grpo_config:
775+ num_generations: 2
776+ num_iterations: 1
777+ beta: 0.0
778+ epsilon: 0.2
779+ epsilon_high: 0.28
780+ system_prompt: ""
781+ max_concurrency: 1
782+ off_policy_steps: 0
783+ max_turns: 1
784+ context_ratio: 1
785+ sglang_jax_config:
786+ mem_fraction_static: 0.8
787+ vllm_config:
788+ hbm_utilization: 0.4
789+ """
790+ pipeline = _make_pipeline (extra )
791+ actor_model_config = pipeline .config ["actor_model_config" ]
792+ if isinstance (actor_model_config , omegaconf .dictconfig .DictConfig ):
793+ actor_model_config ["mesh" ] = {
794+ "shape" : "(1,2)" ,
795+ "axis_names" : "('fsdp','tp')" ,
796+ }
797+ pipeline .config ["reference_model_config" ] = {"same_mesh_as" : "actor" }
798+ rollout_model_config = pipeline .config ["rollout_model_config" ]
799+ if isinstance (rollout_model_config , omegaconf .dictconfig .DictConfig ):
800+ rollout_model_config ["mesh" ] = {
801+ "shape" : "(1,2)" ,
802+ "axis_names" : "('fsdp','tp')" ,
803+ }
804+
805+ fake_devices = ["a0" , "a1" , "r0" , "r1" ]
806+ allocated_devices = {
807+ "actor_model_config" : ["a0" , "a1" ],
808+ "rollout_model_config" : ["r0" , "r1" ],
809+ }
810+ created_mesh_devices = {}
811+
812+ def fake_create_mesh (model_key , devices = None ):
813+ created_mesh_devices [model_key ] = list (devices )
814+ return (model_key , tuple (devices ))
815+
816+ with mock .patch .object (grpo_main .jax , "devices" , return_value = fake_devices ):
817+ with mock .patch .object (
818+ grpo_main .mesh_lib ,
819+ "allocate_named_mesh_device_slices" ,
820+ return_value = allocated_devices ,
821+ ) as allocate_mock :
822+ with mock .patch .object (pipeline , "create_mesh" , side_effect = fake_create_mesh ):
823+ role_to_mesh = pipeline .create_role_to_mesh ()
824+
825+ allocate_mock .assert_called_once_with (
826+ [
827+ ("actor_model_config" , 2 ),
828+ ("rollout_model_config" , 2 ),
829+ ],
830+ devices = fake_devices ,
831+ )
832+ self .assertEqual (created_mesh_devices ["actor_model_config" ], ["a0" , "a1" ])
833+ self .assertEqual (created_mesh_devices ["rollout_model_config" ], ["r0" , "r1" ])
834+ self .assertEqual (
835+ role_to_mesh [rl_cluster_lib .Role .ACTOR ],
836+ ("actor_model_config" , ("a0" , "a1" )),
837+ )
838+ self .assertIs (
839+ role_to_mesh [rl_cluster_lib .Role .REFERENCE ],
840+ role_to_mesh [rl_cluster_lib .Role .ACTOR ],
841+ )
842+ self .assertEqual (
843+ role_to_mesh [rl_cluster_lib .Role .ROLLOUT ],
844+ ("rollout_model_config" , ("r0" , "r1" )),
845+ )
846+
756847
757848if __name__ == "__main__" :
758849 absltest .main ()
0 commit comments