Skip to content
Open

env fix #2077

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project

import importlib
import os

import pytest

import tpu_inference.env_override
import tpu_inference.envs as envs
from tpu_inference.envs import enable_envs_cache, environment_variables

Expand Down Expand Up @@ -306,3 +310,25 @@ def test_cache_preserves_values_across_env_changes(

# Now it should reflect the new value
assert envs.JAX_PLATFORMS == "cpu"


def test_env_libtpu_default(monkeypatch: pytest.MonkeyPatch):
LIBTPU_INIT_ARGS = os.environ.get("LIBTPU_INIT_ARGS", "")

assert "xla_tpu_use_tc_device_shape_on_sc=true" in LIBTPU_INIT_ARGS
assert "xla_tpu_scheduler_percent_shared_memory_limit=1000" in LIBTPU_INIT_ARGS


def test_env_libtpu_overwrite(monkeypatch: pytest.MonkeyPatch):
os.environ[
"LIBTPU_INIT_ARGS"] = "--xla_tpu_use_tc_device_shape_on_sc=false --xla_tpu_scheduler_percent_shared_memory_limit=3000"

importlib.reload(tpu_inference.env_override)

actual_args = os.environ.get("LIBTPU_INIT_ARGS", "")

assert "--xla_tpu_use_tc_device_shape_on_sc=false" in actual_args
assert "--xla_tpu_scheduler_percent_shared_memory_limit=3000" in actual_args

assert "--xla_tpu_use_tc_device_shape_on_sc=true" not in actual_args
assert "--xla_tpu_scheduler_percent_shared_memory_limit=1000" not in actual_args
26 changes: 24 additions & 2 deletions tpu_inference/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,29 @@
# This prevents errors when trying to create CUDA streams on TPU hardware
# The issue was introduced by vllm-project/vllm#26440
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_use_tc_device_shape_on_sc=true"

# Fetch the existing args, default to an empty string if it doesn't exist
current_args = os.environ.get("LIBTPU_INIT_ARGS", "")

# Keep a list of missing arguments we need to add
args_to_append = []

# Check for the first flag
if "xla_tpu_use_tc_device_shape_on_sc" not in current_args:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you make this into a separate function? like, maybe_append_arg?

args_to_append.append("--xla_tpu_use_tc_device_shape_on_sc=true")

# Check for the second flag
if "xla_tpu_scheduler_percent_shared_memory_limit" not in current_args:
args_to_append.append(
"--xla_tpu_scheduler_percent_shared_memory_limit=1000")

# If we found any missing flags, append them to the environment variable
if args_to_append:
# Join our new args with a space
new_args_str = " ".join(args_to_append)

# Combine the old and new args
os.environ["LIBTPU_INIT_ARGS"] = f"{current_args} {new_args_str}"

# Monkeypatch vLLM to avoid ImportError: cannot import name 'SamplingParams' from 'vllm'
# in vllm/v1/... submodules due to circular imports or lazy loading failures.
Expand All @@ -22,4 +44,4 @@
from vllm.sampling_params import RequestOutputKind
vllm.RequestOutputKind = RequestOutputKind
except ImportError:
pass
pass
Loading