-
Notifications
You must be signed in to change notification settings - Fork 123
Expand file tree
/
Copy pathenvs.py
More file actions
231 lines (202 loc) · 8.43 KB
/
envs.py
File metadata and controls
231 lines (202 loc) · 8.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
import functools
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
JAX_PLATFORMS: str = ""
TPU_ACCELERATOR_TYPE: str | None = None
TPU_NAME: str | None = None
TPU_WORKER_ID: str | None = None
TPU_MULTIHOST_BACKEND: str = ""
PREFILL_SLICES: str = ""
DECODE_SLICES: str = ""
SKIP_JAX_PRECOMPILE: bool = False
VLLM_XLA_CHECK_RECOMPILATION: bool = False
MODEL_IMPL_TYPE: str = "auto"
NEW_MODEL_DESIGN: bool = False
PHASED_PROFILING_DIR: str = ""
PYTHON_TRACER_LEVEL: int = 1
USE_MOE_EP_KERNEL: bool = False
USE_UNFUSED_MEGABLOCKS: bool = False
USE_DENSE_MOE: bool = False
NUM_SLICES: int = 1
RAY_USAGE_STATS_ENABLED: str = "0"
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
REQUANTIZE_BLOCK_SIZE: int | None = None
REQUANTIZE_WEIGHT_DTYPE: str = "float8_e4m3fn"
MOE_REQUANTIZE_BLOCK_SIZE: int | None = None
MOE_REQUANTIZE_WEIGHT_DTYPE: str = "float8_e4m3fn"
MOE_SKIP_REQUANTIZE: bool = False
LAYOUT_Q_PROJ_AS_NDH: bool = False
def env_with_choices(
env_name: str,
default: str | None,
choices: list[str] | Callable[[], list[str]],
case_sensitive: bool = True,
) -> Callable[[], str | None]:
"""
Create a lambda that validates environment variable against allowed choices
Args:
env_name: Name of the environment variable
default: Default value if not set (can be None)
choices: List of valid string options or callable that returns list
case_sensitive: Whether validation should be case sensitive
Returns:
Lambda function for environment_variables dict
"""
def _get_validated_env() -> str | None:
value = os.getenv(env_name)
if value is None:
return default
# Resolve choices if it's a callable (for lazy loading)
actual_choices = choices() if callable(choices) else choices
if not case_sensitive:
check_value = value.lower()
check_choices = [choice.lower() for choice in actual_choices]
else:
check_value = value
check_choices = actual_choices
if check_value not in check_choices:
raise ValueError(f"Invalid value '{value}' for {env_name}. "
f"Valid options: {actual_choices}.")
return value
return _get_validated_env
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
"""
Accepts both numeric strings ("0", "1") and boolean strings
("true", "false", "True", "False").
Args:
env_name: Name of the environment variable
default: Default boolean value if not set
"""
def _get_bool_env() -> bool:
value = os.getenv(env_name)
if value is None or value == "":
return default
value_lower = value.lower()
if value_lower in ("true", "1"):
return True
elif value_lower in ("false", "0"):
return False
else:
raise ValueError(
f"Invalid boolean value '{value}' for {env_name}. "
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
return _get_bool_env
environment_variables: dict[str, Callable[[], Any]] = {
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
"JAX_PLATFORMS":
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
"TPU_ACCELERATOR_TYPE":
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
# Name of the TPU resource
"TPU_NAME":
lambda: os.getenv("TPU_NAME", None),
# Worker ID for multi-host TPU setups
"TPU_WORKER_ID":
lambda: os.getenv("TPU_WORKER_ID", None),
# Backend for multi-host communication on TPU
"TPU_MULTIHOST_BACKEND":
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
# Slice configuration for disaggregated prefill workers
"PREFILL_SLICES":
lambda: os.getenv("PREFILL_SLICES", ""),
# Slice configuration for disaggregated decode workers
"DECODE_SLICES":
lambda: os.getenv("DECODE_SLICES", ""),
# Skip JAX precompilation step during initialization
"SKIP_JAX_PRECOMPILE":
env_bool("SKIP_JAX_PRECOMPILE", default=False),
# Check for XLA recompilation during execution
"VLLM_XLA_CHECK_RECOMPILATION":
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
# Model implementation type (e.g., "flax_nnx")
"MODEL_IMPL_TYPE":
env_with_choices("MODEL_IMPL_TYPE", "auto",
["auto", "vllm", "flax_nnx", "jetpack"]),
# Enable 2D tensor parallelism, shard attention heads across multiple axes
"USE_2D_TP":
env_bool("USE_2D_TP", default=False),
# Enable new experimental model design
"NEW_MODEL_DESIGN":
env_bool("NEW_MODEL_DESIGN", default=False),
# Directory to store phased profiling output
"PHASED_PROFILING_DIR":
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
# Python tracer level for profiling
"PYTHON_TRACER_LEVEL":
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
"USE_MOE_EP_KERNEL":
env_bool("USE_MOE_EP_KERNEL", default=False),
# Enable megablocks for JAX sparse matmul for MoE (Mixture of Experts)
# using Unfused weights
"USE_UNFUSED_MEGABLOCKS":
env_bool("USE_UNFUSED_MEGABLOCKS", default=False),
# Enable the dense backend for Jax MoE (Mixture of Experts)
# NOTE: this is a naive implementation and should not be used in production
"USE_DENSE_MOE":
env_bool("USE_DENSE_MOE", default=False),
# Number of TPU slices for multi-slice mesh
"NUM_SLICES":
lambda: int(os.getenv("NUM_SLICES") or "1"),
# Enable/disable Ray usage statistics collection
"RAY_USAGE_STATS_ENABLED":
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
# Ray compiled DAG channel type for TPU
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
"ENABLE_QUANTIZED_MATMUL_KERNEL":
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
# Specify block quantization size
"REQUANTIZE_BLOCK_SIZE":
lambda: int(block_size) if
(block_size := os.getenv("REQUANTIZE_BLOCK_SIZE")) is not None else None,
# Specify dtype for quantized linear weights
"REQUANTIZE_WEIGHT_DTYPE":
lambda: os.getenv("REQUANTIZE_WEIGHT_DTYPE", "float8_e4m3fn"),
# Specify dtype for quantized MoE weights
"MOE_REQUANTIZE_WEIGHT_DTYPE":
lambda: os.getenv("MOE_REQUANTIZE_WEIGHT_DTYPE", "float8_e4m3fn"),
# Specify requantization block size for MoE weights
"MOE_REQUANTIZE_BLOCK_SIZE":
lambda: int(block_size) if (block_size := os.getenv(
"MOE_REQUANTIZE_BLOCK_SIZE")) is not None else None,
# Skip MoE dequant→requant when weights are already in the target dtype
"MOE_SKIP_REQUANTIZE":
env_bool("MOE_SKIP_REQUANTIZE", default=False),
# dictates whether to layout q-proj as NDH (q-heads, model dim, head dim)
# or DNH (model dim, q-heads, head dim), which is the default (False)
"LAYOUT_Q_PROJ_AS_NDH":
lambda: bool(int(os.getenv("LAYOUT_Q_PROJ_AS_NDH") or "0")),
}
def __getattr__(name: str) -> Any:
"""
Gets environment variables lazily.
NOTE: After enable_envs_cache() invocation (which triggered after service
initialization), all environment variables will be cached.
"""
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def enable_envs_cache() -> None:
"""
Enables caching of environment variables by wrapping the module's __getattr__
function with functools.cache(). This improves performance by avoiding
repeated re-evaluation of environment variables.
NOTE: This should be called after service initialization. Once enabled,
environment variable values are cached and will not reflect changes to
os.environ until the process is restarted.
"""
# Tag __getattr__ with functools.cache
global __getattr__
__getattr__ = functools.cache(__getattr__)
# Cache all environment variables
for key in environment_variables:
__getattr__(key)
def __dir__() -> list[str]:
return list(environment_variables.keys())