Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tpu_inference/executors/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def _initialize_ray_cluster(self) -> None:
f"RayDistributedExecutor | placement_group_specs={placement_group_specs}"
)

# By default, Ray packs resources as much as possible.
# Use STRICT_SPREAD for PP to ensure each host participates in JAX initialization.
strategy = "STRICT_SPREAD" if pp_size > 1 else "PACK"
current_placement_group = ray.util.placement_group(
placement_group_specs, strategy="PACK")
placement_group_specs, strategy=strategy)
_wait_until_pg_ready(current_placement_group)

assert current_placement_group is not None
Expand Down Expand Up @@ -321,11 +322,13 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
additional_vars=set(current_platform.additional_env_vars),
destination="workers")

# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# Copy existing env vars to each worker's args, but don't overwrite
# and don't copy global topology bounds from the driver.
for i, args in enumerate(all_args_to_update_environment_variables):
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
logger.debug(f"Worker {i} environment variables: {args}")

self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
Expand Down
1 change: 1 addition & 0 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _init_phased_profiling(self) -> None:
def _init_mm(self) -> None:
self.is_multimodal_model = None
self.uses_mrope = self.model_config.uses_mrope
self.supports_mm_inputs = True

def _init_speculative_decoding(self) -> None:
self.drafter = None
Expand Down
21 changes: 20 additions & 1 deletion tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,21 @@ def init_device(self,
tpu_visible_chips=""):
# set tpu visible devices for Jax runtime in single host PP.
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
if self.parallel_config.pipeline_parallel_size > 1:
# Log environment variables for debugging
tpu_env_vars = [
"TPU_PROCESS_ADDRESSES",
Copy link

@ryanaoleary ryanaoleary Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe TPU_PROCESS_ADDRESSES is only introduced for tpu7x, otherwise it'll be TPU_WORKER_HOSTNAMES which should have the same values as the former, except without the TPU port appended to the address. Also it might be useful to dump the TPU_WORKER_ID as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with TPU_PROCESS_PORT - I don't think this is set automatically before tpu7x

"TPU_PROCESS_PORT",
"CLOUD_TPU_TASK_ID",
"TPU_PROCESS_BOUNDS",
"TPU_CHIPS_PER_PROCESS_BOUNDS",
"TPU_VISIBLE_CHIPS",
]
env_dump = {v: os.environ.get(v) for v in tpu_env_vars}
logger.debug(
f"Worker {self.rank} JAX/TPU environment before init_device: {env_dump}"
)

tpu_ports = [
jax_parallel_state.BASE_JAX_PORT + i
for i in range(self.pp_config.pp_world_size)
Expand Down Expand Up @@ -199,6 +213,11 @@ def init_device(self,
if tpu_visible_chips \
else self.pp_config.default_tpu_visible_chips

env_dump = {v: os.environ.get(v) for v in tpu_env_vars}
logger.debug(
f"Worker {self.rank} JAX/TPU environment after init_device: {env_dump}"
)

if not self.devices:
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
device_indexes = sharding_config.device_indexes
Expand Down
Loading