Skip to content

Commit 11d3c8e

Browse files
committed
omit jax backend warnings
1 parent 66eec1b commit 11d3c8e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/palantir/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919

2020
import importlib.metadata
2121
import warnings
22+
import logging
2223

23-
# Filter JAX warnings about CUDA when GPU support is not available
24+
# Filter JAX warnings and errors about CUDA when GPU support is not available
2425
warnings.filterwarnings("ignore", message=".*CUDA.*", module="jax.*")
2526
warnings.filterwarnings("ignore", message=".*cuSPARSE.*")
2627
warnings.filterwarnings("ignore", message=".*NVIDIA GPU.*")
2728

29+
# Suppress JAX logging errors for CUDA plugin failures
30+
# These are harmless - JAX falls back to CPU automatically
31+
logging.getLogger("jax._src.xla_bridge").setLevel(logging.CRITICAL)
32+
2833
from . import config
2934

3035
# Import modules in a specific order to avoid circular imports

0 commit comments

Comments
 (0)