-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathjax_setup.py
More file actions
33 lines (26 loc) · 1.51 KB
/
Copy pathjax_setup.py
File metadata and controls
33 lines (26 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""JAX environment setup — must be importable without triggering JAX init."""
import os
def set_jax_flags_before_importing_jax(jax_platforms: str = None):
"""
Set JAX flags for better performance and reproducibility.
This should be called before importing JAX.
Args:
jax_platforms: Platform to use ("cpu", "cuda", or "tpu").
If None, JAX auto-detects available hardware.
"""
if jax_platforms is not None:
os.environ.setdefault("JAX_PLATFORMS", jax_platforms)
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") # Suppress XLA warnings
# Keep deterministic kernels and default to disabling Triton GEMM, which can
# trigger CUDA runtime errors on some GPUs for small/irregular matmuls.
# Triton tiling logic fails when it encounters certain fused operations where dimension bounds are not divisible by the tile size.
_xla_flags = os.environ.get("XLA_FLAGS", "")
if "--xla_gpu_deterministic_ops=true" not in _xla_flags:
_xla_flags = (_xla_flags + " --xla_gpu_deterministic_ops=true").strip()
if os.environ.get("FABRICPC_DISABLE_TRITON_GEMM", "1") == "1":
if "--xla_gpu_enable_triton_gemm=false" not in _xla_flags:
_xla_flags = (_xla_flags + " --xla_gpu_enable_triton_gemm=false").strip()
# Set XLA flags for good performance and reproducibility
_xla_flags = (_xla_flags + " --xla_gpu_autotune_level=1").strip()
os.environ["XLA_FLAGS"] = _xla_flags