@@ -408,11 +408,12 @@ def _init_cluster(self):
408408 raise ValueError ("Rollout vllm model version or path is missing!" )
409409
410410 # TODO(linchai): maybe support offloading for vllm rollout.
411- # vLLM handles model initialization and loading internally, so we need to provide
412- # logical axis rules for vLLM to correctly shard the model on the rollout mesh.
413- # This is important for out-of-tree models in vLLM that are implemented with custom
414- # logical axis rules, like is the case for MaxText models.
415- with self ._get_logical_axis_rules_cm (Role .ROLLOUT ):
411+ with self ._get_mesh_and_logical_axis_rules_cm (Role .ROLLOUT ):
412+ # vLLM handles model initialization and loading internally, so we need
413+ # to provide logical axis rules for vLLM to correctly shard the model on
414+ # the rollout mesh. This is important for out-of-tree models in vLLM
415+ # that are implemented with custom logical axis rules, like is the case
416+ # for MaxText models.
416417 self ._rollout = vllm_rollout .VllmRollout (
417418 self .rollout_actor ,
418419 self .tokenizer ,
@@ -504,9 +505,7 @@ def _init_cluster(self):
504505 critic_config .checkpoint_root_directory = os .path .join (
505506 critic_config .checkpoint_root_directory , "critic"
506507 )
507- with self .cluster_config .role_to_mesh [
508- Role .CRITIC
509- ], self ._get_logical_axis_rules_cm (Role .CRITIC ):
508+ with self ._get_mesh_and_logical_axis_rules_cm (Role .CRITIC ):
510509 self ._critic_trainer = rl_trainer .Trainer (
511510 model = self .critic ,
512511 optimizer = self .cluster_config .training_config .critic_optimizer ,
@@ -528,9 +527,7 @@ def _init_cluster(self):
528527 actor_config .checkpoint_root_directory = os .path .join (
529528 actor_config .checkpoint_root_directory , "actor"
530529 )
531- with self .cluster_config .role_to_mesh [
532- Role .ACTOR
533- ], self ._get_logical_axis_rules_cm (Role .ACTOR ):
530+ with self ._get_mesh_and_logical_axis_rules_cm (Role .ACTOR ):
534531 self ._actor_trainer = rl_trainer .Trainer (
535532 model = self .train_actor ,
536533 optimizer = self .cluster_config .training_config .actor_optimizer ,
@@ -782,18 +779,14 @@ def buffer_metrics_async(
782779 self ._log_metrics (m )
783780
784781 def update_actor (self , train_ds , eval_ds , skip_jit = False ):
785- with self .cluster_config .role_to_mesh [
786- Role .ACTOR
787- ] as _ , self ._get_logical_axis_rules_cm (Role .ACTOR ):
782+ with self ._get_mesh_and_logical_axis_rules_cm (Role .ACTOR ):
788783 self ._maybe_load_model_from_cpu (self .actor_trainer .model , Role .ACTOR )
789784 with self ._perf .span_group ("actor_training" ):
790785 self .actor_trainer .train (train_ds , eval_ds , skip_jit )
791786 self ._maybe_offload_model_to_cpu (self .actor_trainer .model , Role .ACTOR )
792787
793788 def update_critic (self , train_ds , eval_ds , skip_jit = False ):
794- with self .cluster_config .role_to_mesh [
795- Role .CRITIC
796- ] as _ , self ._get_logical_axis_rules_cm (Role .CRITIC ):
789+ with self ._get_mesh_and_logical_axis_rules_cm (Role .CRITIC ):
797790 self ._maybe_load_model_from_cpu (self .critic_trainer .model , Role .CRITIC )
798791 with self ._perf .span_group ("critic_training" ):
799792 self ._critic_trainer .train (train_ds , eval_ds , skip_jit )
@@ -840,9 +833,7 @@ def generate(
840833 raise ValueError ("Cannot generate from an empty list of prompts." )
841834 micro_batch_size = micro_batch_size or len (string_prompts )
842835
843- with self .cluster_config .role_to_mesh [
844- Role .ROLLOUT
845- ] as mesh , self ._get_logical_axis_rules_cm (Role .ROLLOUT ):
836+ with self ._get_mesh_and_logical_axis_rules_cm (Role .ROLLOUT ) as (mesh , _ ):
846837 model = self .rollout .model ()
847838 self ._maybe_load_model_from_cpu (model , Role .ROLLOUT )
848839 if self .cluster_config .offload_to_cpu :
@@ -909,9 +900,7 @@ def get_ref_per_token_logps(
909900 )
910901 micro_batch_size = micro_batch_size or batch_size
911902
912- reference_mesh = self .cluster_config .role_to_mesh [Role .REFERENCE ]
913-
914- with reference_mesh , self ._get_logical_axis_rules_cm (Role .REFERENCE ):
903+ with self ._get_mesh_and_logical_axis_rules_cm (Role .REFERENCE ):
915904 # This assumes reference model shards same data sharding as actor, which
916905 # should be true as ref model and policy model shares same architecture.
917906 dest_prompt_tokens = sharding_utils .shard_input (
@@ -959,9 +948,7 @@ def get_old_per_token_logps(
959948 raise ValueError ("Cannot get old log probabilities from an empty batch." )
960949 micro_batch_size = micro_batch_size or batch_size
961950
962- with self .cluster_config .role_to_mesh [
963- Role .ROLLOUT
964- ], self ._get_logical_axis_rules_cm (Role .ROLLOUT ):
951+ with self ._get_mesh_and_logical_axis_rules_cm (Role .ROLLOUT ):
965952 model = self .rollout .model ()
966953 self ._maybe_load_model_from_cpu (model , Role .ROLLOUT )
967954 if self .cluster_config .offload_to_cpu :
@@ -1014,9 +1001,7 @@ def get_values(
10141001 eos_id : int ,
10151002 completion_mask : jax .Array | None = None ,
10161003 ) -> jax .Array :
1017- with self .cluster_config .role_to_mesh [
1018- Role .CRITIC
1019- ], self ._get_logical_axis_rules_cm (Role .CRITIC ):
1004+ with self ._get_mesh_and_logical_axis_rules_cm (Role .CRITIC ):
10201005 return self .inference_worker .get_values (
10211006 prompt_tokens ,
10221007 completion_tokens ,
@@ -1032,19 +1017,17 @@ def get_rewards(
10321017 pad_id : int ,
10331018 eos_id : int ,
10341019 ) -> jax .Array :
1035- with self .cluster_config .role_to_mesh [
1036- Role .REWARD
1037- ], self ._get_logical_axis_rules_cm (Role .REWARD ):
1020+ with self ._get_mesh_and_logical_axis_rules_cm (Role .REWARD ):
10381021 return self .inference_worker .get_rewards (
10391022 prompt_tokens ,
10401023 completion_tokens ,
10411024 pad_id ,
10421025 eos_id ,
10431026 )
10441027
1045- # TODO(b/487384811): Combine mesh and logical axis rules into a single context manager.
1046- def _get_logical_axis_rules_cm (self , role : Role ):
1047- """Returns a context manager for the logical axis rules.
1028+ @ contextlib . contextmanager
1029+ def _get_mesh_and_logical_axis_rules_cm (self , role : Role ):
1030+ """Returns a context manager for the mesh and logical axis rules.
10481031
10491032 This is used for models that uses logical sharding, so XLA can generate the
10501033 correct graph based on physical mesh.
@@ -1053,8 +1036,13 @@ def _get_logical_axis_rules_cm(self, role: Role):
10531036 role: The role of the model (e.g., ACTOR, CRITIC, REFERENCE, etc.).
10541037 """
10551038 role_logical_axis_rule = self .cluster_config .role_to_logical_axis_rule
1056- if role_logical_axis_rule is None or role not in role_logical_axis_rule :
1057- return contextlib .nullcontext ()
1058- cm = contextlib .ExitStack ()
1059- cm .enter_context (nn_partitioning .axis_rules (role_logical_axis_rule [role ]))
1060- return cm
1039+ logical_axis_rule_ctx = contextlib .nullcontext ()
1040+ if role_logical_axis_rule and role in role_logical_axis_rule :
1041+ logical_axis_rule_ctx = nn_partitioning .axis_rules (
1042+ role_logical_axis_rule [role ]
1043+ )
1044+ with contextlib .ExitStack () as stack :
1045+ yield (
1046+ stack .enter_context (self .cluster_config .role_to_mesh [role ]),
1047+ stack .enter_context (logical_axis_rule_ctx ),
1048+ )
0 commit comments