Skip to content

Commit 49ad5e1

Browse files
committed
Improved JAX detection
1 parent 3eb91da commit 49ad5e1

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

src/zeropybench/_jax.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def from_code(cls, code: str, globals: dict[str, Any]) -> Self:
2525
return cls(tree, globals)
2626

2727
def is_jax_context(self) -> bool:
28-
"""Returns true if a JAX variable or a jitted function is used."""
28+
"""Returns true if a JAX variable, module, or jitted function is used."""
2929
jax = sys.modules.get('jax')
3030
if jax is None:
3131
return False
@@ -39,8 +39,18 @@ def is_jax_context(self) -> bool:
3939
return True
4040
if self._contains_jax_arrays(obj, jax):
4141
return True
42+
if self._is_jax_module(obj):
43+
return True
4244
return False
4345

46+
@staticmethod
47+
def _is_jax_module(obj: Any) -> bool:
48+
"""Check if an object is a JAX module (jax, jax.numpy, etc.)."""
49+
module_name: str | None = getattr(obj, '__name__', None)
50+
if module_name is None:
51+
return False
52+
return module_name == 'jax' or module_name.startswith('jax.')
53+
4454
@staticmethod
4555
def _contains_jax_arrays(obj: Any, jax: Any) -> bool:
4656
"""Check if an object is a pytree containing JAX arrays."""

tests/test_jax.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,19 @@ class Config:
121121
assert parser.is_jax_context() is False
122122

123123

124+
def test_is_jax_context_jax_module() -> None:
125+
"""Test is_jax_context detects JAX modules (jnp.ones, jax.random, etc.)."""
126+
# jax.numpy module
127+
globals_ = {'jnp': jnp}
128+
parser = CodeASTParser.from_code('jnp.ones((10, 10))', globals_)
129+
assert parser.is_jax_context() is True
130+
131+
# jax module directly
132+
globals_ = {'jax': jax}
133+
parser = CodeASTParser.from_code('jax.numpy.ones((10, 10))', globals_)
134+
assert parser.is_jax_context() is True
135+
136+
124137
def test_is_jax_context_no_jax(mocker: MockerFixture) -> None:
125138
"""Test is_jax_context returns False when JAX is not imported."""
126139
mocker.patch.dict('sys.modules', {'jax': None})

0 commit comments

Comments
 (0)