@@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self):
644644 )
645645 self .assertEqual (p .config ["agentic_grpo_config" ]["system_prompt" ], "" )
646646
647-
648647class SplitMeshConfigTest (absltest .TestCase ):
649648
650649 def test_split_mesh_uses_explicit_role_meshes (self ):
@@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self):
688687 "shape" : "(2,1)" ,
689688 "axis_names" : "('fsdp','tp')" ,
690689 }
691- pipeline .config ["reference_model_config" ] = {"same_mesh_as" : "actor" }
692690 rollout_model_config = pipeline .config ["rollout_model_config" ]
693691 if isinstance (rollout_model_config , omegaconf .dictconfig .DictConfig ):
694692 rollout_model_config ["mesh" ] = {
@@ -732,6 +730,128 @@ def __init__(self, devices, axis_names, axis_types=None):
732730 role_to_mesh [rl_cluster_lib .Role .ACTOR ],
733731 )
734732
733+ def test_colocate_with_reuses_device_slice_with_different_mesh (self ):
734+ extra = """
735+ training_mode: "agentic_grpo"
736+ data_module: "tunix.cli.recipes.deepscaler_data"
737+ apply_chat_template_to_dataset: false
738+ data_config:
739+ train_data_path: "gs://fake/train.json"
740+ eval_data_path: "gs://fake/eval.parquet"
741+ prompt_key: "prompts"
742+ reward_functions: []
743+ verl_compatible: false
744+ chat_parser_config:
745+ type: "default"
746+ agent_class_path: null
747+ agent_kwargs: {}
748+ env_class_path: null
749+ env_kwargs: {}
750+ kubernetes_config: null
751+ agentic_grpo_config:
752+ num_generations: 2
753+ num_iterations: 1
754+ beta: 0.0
755+ epsilon: 0.2
756+ epsilon_high: 0.28
757+ system_prompt: ""
758+ max_concurrency: 1
759+ off_policy_steps: 0
760+ max_turns: 1
761+ context_ratio: 1
762+ sglang_jax_config:
763+ mem_fraction_static: 0.8
764+ vllm_config:
765+ hbm_utilization: 0.4
766+ """
767+ pipeline = _make_pipeline (extra )
768+ actor_model_config = pipeline .config ["actor_model_config" ]
769+ if isinstance (actor_model_config , omegaconf .dictconfig .DictConfig ):
770+ actor_model_config ["mesh" ] = {
771+ "shape" : "(2,1)" ,
772+ "axis_names" : "('fsdp','tp')" ,
773+ }
774+ rollout_model_config = pipeline .config ["rollout_model_config" ]
775+ if isinstance (rollout_model_config , omegaconf .dictconfig .DictConfig ):
776+ rollout_model_config ["colocate_with" ] = "actor"
777+ rollout_model_config ["mesh" ] = {
778+ "shape" : "(1,2)" ,
779+ "axis_names" : "('fsdp','tp')" ,
780+ }
781+
782+ fake_devices = list (range (4 ))
783+
784+ class FakeMesh :
785+
786+ def __init__ (self , devices , axis_names , axis_types = None ):
787+ self .devices = devices
788+ self .axis_names = axis_names
789+ self .axis_types = axis_types
790+
791+ with mock .patch .object (grpo_main .jax , "devices" , return_value = fake_devices ):
792+ with mock .patch .object (
793+ grpo_main .jax .sharding , "Mesh" , side_effect = FakeMesh
794+ ):
795+ role_to_mesh = pipeline .create_role_to_mesh ()
796+
797+ self .assertSequenceEqual (
798+ role_to_mesh [rl_cluster_lib .Role .ACTOR ].devices .flatten ().tolist (),
799+ [0 , 1 ],
800+ )
801+ self .assertSequenceEqual (
802+ role_to_mesh [rl_cluster_lib .Role .ROLLOUT ].devices .flatten ().tolist (),
803+ [0 , 1 ],
804+ )
805+ self .assertEqual (
806+ role_to_mesh [rl_cluster_lib .Role .ACTOR ].devices .shape ,
807+ (2 , 1 ),
808+ )
809+ self .assertEqual (
810+ role_to_mesh [rl_cluster_lib .Role .ROLLOUT ].devices .shape ,
811+ (1 , 2 ),
812+ )
813+
814+ def test_empty_string_colocate_with_is_treated_as_unset (self ):
815+ extra = """
816+ training_mode: "agentic_grpo"
817+ data_module: "tunix.cli.recipes.deepscaler_data"
818+ apply_chat_template_to_dataset: false
819+ data_config:
820+ train_data_path: "gs://fake/train.json"
821+ eval_data_path: "gs://fake/eval.parquet"
822+ prompt_key: "prompts"
823+ reward_functions: []
824+ verl_compatible: false
825+ chat_parser_config:
826+ type: "default"
827+ agent_class_path: null
828+ agent_kwargs: {}
829+ env_class_path: null
830+ env_kwargs: {}
831+ kubernetes_config: null
832+ agentic_grpo_config:
833+ num_generations: 2
834+ num_iterations: 1
835+ beta: 0.0
836+ epsilon: 0.2
837+ epsilon_high: 0.28
838+ system_prompt: ""
839+ max_concurrency: 1
840+ off_policy_steps: 0
841+ max_turns: 1
842+ context_ratio: 1
843+ sglang_jax_config:
844+ mem_fraction_static: 0.8
845+ vllm_config:
846+ hbm_utilization: 0.4
847+ """
848+ pipeline = _make_pipeline (extra )
849+ rollout_model_config = pipeline .config ["rollout_model_config" ]
850+ if isinstance (rollout_model_config , omegaconf .dictconfig .DictConfig ):
851+ rollout_model_config ["colocate_with" ] = ""
852+
853+ self .assertEmpty (pipeline ._get_colocate_with_map ())
854+
735855
736856if __name__ == "__main__" :
737857 absltest .main ()
0 commit comments