@@ -732,6 +732,97 @@ def __init__(self, devices, axis_names, axis_types=None):
732732 role_to_mesh [rl_cluster_lib .Role .ACTOR ],
733733 )
734734
735+ def test_split_mesh_delegates_device_allocation_to_mesh_utils (self ):
736+ extra = """
737+ training_mode: "agentic_grpo"
738+ data_module: "tunix.cli.recipes.deepscaler_data"
739+ apply_chat_template_to_dataset: false
740+ data_config:
741+ train_data_path: "gs://fake/train.json"
742+ eval_data_path: "gs://fake/eval.parquet"
743+ prompt_key: "prompts"
744+ reward_functions: []
745+ verl_compatible: false
746+ chat_parser_config:
747+ type: "default"
748+ agent_class_path: null
749+ agent_kwargs: {}
750+ env_class_path: null
751+ env_kwargs: {}
752+ kubernetes_config: null
753+ agentic_grpo_config:
754+ num_generations: 2
755+ num_iterations: 1
756+ beta: 0.0
757+ epsilon: 0.2
758+ epsilon_high: 0.28
759+ system_prompt: ""
760+ max_concurrency: 1
761+ off_policy_steps: 0
762+ max_turns: 1
763+ context_ratio: 1
764+ sglang_jax_config:
765+ mem_fraction_static: 0.8
766+ vllm_config:
767+ hbm_utilization: 0.4
768+ """
769+ pipeline = _make_pipeline (extra )
770+ actor_model_config = pipeline .config ["actor_model_config" ]
771+ if isinstance (actor_model_config , omegaconf .dictconfig .DictConfig ):
772+ actor_model_config ["mesh" ] = {
773+ "shape" : "(1,2)" ,
774+ "axis_names" : "('fsdp','tp')" ,
775+ }
776+ pipeline .config ["reference_model_config" ] = {"same_mesh_as" : "actor" }
777+ rollout_model_config = pipeline .config ["rollout_model_config" ]
778+ if isinstance (rollout_model_config , omegaconf .dictconfig .DictConfig ):
779+ rollout_model_config ["mesh" ] = {
780+ "shape" : "(1,2)" ,
781+ "axis_names" : "('fsdp','tp')" ,
782+ }
783+
784+ fake_devices = ["a0" , "a1" , "r0" , "r1" ]
785+ allocated_devices = {
786+ "actor_model_config" : ["a0" , "a1" ],
787+ "rollout_model_config" : ["r0" , "r1" ],
788+ }
789+ created_mesh_devices = {}
790+
791+ def fake_create_mesh (model_key , devices = None ):
792+ created_mesh_devices [model_key ] = list (devices )
793+ return (model_key , tuple (devices ))
794+
795+ with mock .patch .object (grpo_main .jax , "devices" , return_value = fake_devices ):
796+ with mock .patch .object (
797+ grpo_main .mesh_lib ,
798+ "allocate_named_mesh_device_slices" ,
799+ return_value = allocated_devices ,
800+ ) as allocate_mock :
801+ with mock .patch .object (pipeline , "create_mesh" , side_effect = fake_create_mesh ):
802+ role_to_mesh = pipeline .create_role_to_mesh ()
803+
804+ allocate_mock .assert_called_once_with (
805+ [
806+ ("actor_model_config" , 2 ),
807+ ("rollout_model_config" , 2 ),
808+ ],
809+ devices = fake_devices ,
810+ )
811+ self .assertEqual (created_mesh_devices ["actor_model_config" ], ["a0" , "a1" ])
812+ self .assertEqual (created_mesh_devices ["rollout_model_config" ], ["r0" , "r1" ])
813+ self .assertEqual (
814+ role_to_mesh [rl_cluster_lib .Role .ACTOR ],
815+ ("actor_model_config" , ("a0" , "a1" )),
816+ )
817+ self .assertIs (
818+ role_to_mesh [rl_cluster_lib .Role .REFERENCE ],
819+ role_to_mesh [rl_cluster_lib .Role .ACTOR ],
820+ )
821+ self .assertEqual (
822+ role_to_mesh [rl_cluster_lib .Role .ROLLOUT ],
823+ ("rollout_model_config" , ("r0" , "r1" )),
824+ )
825+
735826
736827if __name__ == "__main__" :
737828 absltest .main ()
0 commit comments