Skip to content

Commit 4b317d5

Browse files
committed
fix: jax property patch for all jax versions
1 parent ae48973 commit 4b317d5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ivy/transpiler/utils/ast_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def device_getter(
124124
import flax.nnx as nnx
125125
126126
def _define_dunders(orig_method_name):
127-
original_method = getattr(jaxlib._jax.ArrayImpl, orig_method_name)
127+
original_method = getattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl, orig_method_name)
128128
patched_method = {
129129
'__add__': jax___add___frnt_,
130130
'__sub__': jax___sub___frnt_,
@@ -140,7 +140,7 @@ def impl(self, rhs):
140140
except Exception as e:
141141
return patched_method(self, rhs)
142142
143-
setattr(jaxlib._jax.ArrayImpl, orig_method_name, impl)
143+
setattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl, orig_method_name, impl)
144144
145145
def _define_properties(orig_property_name):
146146
def device_getter(
@@ -174,14 +174,14 @@ def custom_getattr(self, name):
174174
# Attempt to retrieve the attribute from the wrapped object (`value`)
175175
return getattr(value, name)
176176
return object.__getattribute__(self, name)
177-
original_property = getattr(jaxlib._jax.ArrayImpl, orig_property_name, None)
177+
original_property = getattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl, orig_property_name, None)
178178
patched_method = {
179179
'device': device_getter,
180180
'shape': shape_getter,
181181
'dtype': dtype_getter,
182182
}[orig_property_name]
183183
184-
setattr(jaxlib._jax.ArrayImpl, orig_property_name, property(patched_method))
184+
setattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl, orig_property_name, property(patched_method))
185185
setattr(nnx.Variable, orig_property_name, property(patched_method))
186186
setattr(nnx.Variable, '__getattr__', custom_getattr)
187187

0 commit comments

Comments
 (0)