diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index a358f5fa0b..b3fffb267f 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -167,9 +167,10 @@ def _initialize_ray_cluster(self) -> None: f"RayDistributedExecutor | placement_group_specs={placement_group_specs}" ) - # By default, Ray packs resources as much as possible. + # Use STRICT_SPREAD for PP to ensure each host participates in JAX initialization. + strategy = "STRICT_SPREAD" if pp_size > 1 else "PACK" current_placement_group = ray.util.placement_group( - placement_group_specs, strategy="PACK") + placement_group_specs, strategy=strategy) _wait_until_pg_ready(current_placement_group) assert current_placement_group is not None @@ -321,11 +322,13 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): additional_vars=set(current_platform.additional_env_vars), destination="workers") - # Copy existing env vars to each worker's args - for args in all_args_to_update_environment_variables: + # Copy existing env vars to each worker's args, but don't overwrite + # and don't copy global topology bounds from the driver. + for i, args in enumerate(all_args_to_update_environment_variables): for name in env_vars_to_copy: if name in os.environ: args[name] = os.environ[name] + logger.debug(f"Worker {i} environment variables: {args}") self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 2f3a5893ed..6e7e2a9828 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -403,6 +403,7 @@ def _init_phased_profiling(self) -> None: def _init_mm(self) -> None: self.is_multimodal_model = None self.uses_mrope = self.model_config.uses_mrope + self.supports_mm_inputs = True def _init_speculative_decoding(self) -> None: self.drafter = None diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index e33e309a2f..d3c686dd24 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -167,7 +167,21 @@ def init_device(self, tpu_visible_chips=""): # set tpu visible devices for Jax runtime in single host PP. multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + # Log environment variables for debugging + tpu_env_vars = [ + "TPU_PROCESS_ADDRESSES", + "TPU_PROCESS_PORT", + "CLOUD_TPU_TASK_ID", + "TPU_PROCESS_BOUNDS", + "TPU_CHIPS_PER_PROCESS_BOUNDS", + "TPU_VISIBLE_CHIPS", + ] + env_dump = {v: os.environ.get(v) for v in tpu_env_vars} + logger.debug( + f"Worker {self.rank} JAX/TPU environment init_device: {env_dump}") + if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: + tpu_ports = [ jax_parallel_state.BASE_JAX_PORT + i for i in range(self.pp_config.pp_world_size) @@ -199,6 +213,11 @@ def init_device(self, if tpu_visible_chips \ else self.pp_config.default_tpu_visible_chips + env_dump = {v: os.environ.get(v) for v in tpu_env_vars} + logger.debug( + f"Worker {self.rank} JAX/TPU environment after init_device: {env_dump}" + ) + if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config device_indexes = sharding_config.device_indexes