From d7a5d9b258d06450e25caea9d72c8df232b7cd23 Mon Sep 17 00:00:00 2001 From: Pysith Vanuptikul Date: Mon, 9 Mar 2026 22:06:26 +0000 Subject: [PATCH 1/2] fix multihost pp errors Signed-off-by: Pysith Vanuptikul Signed-off-by: --- .../executors/ray_distributed_executor.py | 11 ++++++---- tpu_inference/runner/tpu_runner.py | 1 + tpu_inference/worker/tpu_worker.py | 21 ++++++++++++++++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index a358f5fa0b..b3fffb267f 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -167,9 +167,10 @@ def _initialize_ray_cluster(self) -> None: f"RayDistributedExecutor | placement_group_specs={placement_group_specs}" ) - # By default, Ray packs resources as much as possible. + # Use STRICT_SPREAD for PP to ensure each host participates in JAX initialization. + strategy = "STRICT_SPREAD" if pp_size > 1 else "PACK" current_placement_group = ray.util.placement_group( - placement_group_specs, strategy="PACK") + placement_group_specs, strategy=strategy) _wait_until_pg_ready(current_placement_group) assert current_placement_group is not None @@ -321,11 +322,13 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): additional_vars=set(current_platform.additional_env_vars), destination="workers") - # Copy existing env vars to each worker's args - for args in all_args_to_update_environment_variables: + # Copy existing env vars to each worker's args, but don't overwrite + # and don't copy global topology bounds from the driver. + for i, args in enumerate(all_args_to_update_environment_variables): for name in env_vars_to_copy: if name in os.environ: args[name] = os.environ[name] + logger.debug(f"Worker {i} environment variables: {args}") self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 2f3a5893ed..6e7e2a9828 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -403,6 +403,7 @@ def _init_phased_profiling(self) -> None: def _init_mm(self) -> None: self.is_multimodal_model = None self.uses_mrope = self.model_config.uses_mrope + self.supports_mm_inputs = True def _init_speculative_decoding(self) -> None: self.drafter = None diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index e33e309a2f..3dd84833da 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -167,7 +167,21 @@ def init_device(self, tpu_visible_chips=""): # set tpu visible devices for Jax runtime in single host PP. multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() - if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: + if self.parallel_config.pipeline_parallel_size > 1: + # Log environment variables for debugging + tpu_env_vars = [ + "TPU_PROCESS_ADDRESSES", + "TPU_PROCESS_PORT", + "CLOUD_TPU_TASK_ID", + "TPU_PROCESS_BOUNDS", + "TPU_CHIPS_PER_PROCESS_BOUNDS", + "TPU_VISIBLE_CHIPS", + ] + env_dump = {v: os.environ.get(v) for v in tpu_env_vars} + logger.debug( + f"Worker {self.rank} JAX/TPU environment before init_device: {env_dump}" + ) + tpu_ports = [ jax_parallel_state.BASE_JAX_PORT + i for i in range(self.pp_config.pp_world_size) @@ -199,6 +213,11 @@ def init_device(self, if tpu_visible_chips \ else self.pp_config.default_tpu_visible_chips + env_dump = {v: os.environ.get(v) for v in tpu_env_vars} + logger.debug( + f"Worker {self.rank} JAX/TPU environment after init_device: {env_dump}" + ) + if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config device_indexes = sharding_config.device_indexes From 544c2ef869163ce56401b873aa3339beee9469b2 Mon Sep 17 00:00:00 2001 From: Pysith Vanuptikul Date: Mon, 9 Mar 2026 22:41:28 +0000 Subject: [PATCH 2/2] Remove unecessary change Signed-off-by: Pysith Vanuptikul Signed-off-by: --- tpu_inference/worker/tpu_worker.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 3dd84833da..d3c686dd24 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -167,20 +167,20 @@ def init_device(self, tpu_visible_chips=""): # set tpu visible devices for Jax runtime in single host PP. multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() - if self.parallel_config.pipeline_parallel_size > 1: - # Log environment variables for debugging - tpu_env_vars = [ - "TPU_PROCESS_ADDRESSES", - "TPU_PROCESS_PORT", - "CLOUD_TPU_TASK_ID", - "TPU_PROCESS_BOUNDS", - "TPU_CHIPS_PER_PROCESS_BOUNDS", - "TPU_VISIBLE_CHIPS", - ] - env_dump = {v: os.environ.get(v) for v in tpu_env_vars} - logger.debug( - f"Worker {self.rank} JAX/TPU environment before init_device: {env_dump}" - ) + # Log environment variables for debugging + tpu_env_vars = [ + "TPU_PROCESS_ADDRESSES", + "TPU_PROCESS_PORT", + "CLOUD_TPU_TASK_ID", + "TPU_PROCESS_BOUNDS", + "TPU_CHIPS_PER_PROCESS_BOUNDS", + "TPU_VISIBLE_CHIPS", + ] + env_dump = {v: os.environ.get(v) for v in tpu_env_vars} + logger.debug( + f"Worker {self.rank} JAX/TPU environment init_device: {env_dump}") + + if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: tpu_ports = [ jax_parallel_state.BASE_JAX_PORT + i