@@ -124,7 +124,7 @@ def device_getter(
124124import flax.nnx as nnx
125125
126126def _define_dunders(orig_method_name):
127- original_method = getattr(jaxlib._jax.ArrayImpl, orig_method_name)
127+ original_method = getattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl , orig_method_name)
128128 patched_method = {
129129 '__add__': jax___add___frnt_,
130130 '__sub__': jax___sub___frnt_,
@@ -140,7 +140,7 @@ def impl(self, rhs):
140140 except Exception as e:
141141 return patched_method(self, rhs)
142142
143- setattr(jaxlib._jax.ArrayImpl, orig_method_name, impl)
143+ setattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl , orig_method_name, impl)
144144
145145def _define_properties(orig_property_name):
146146 def device_getter(
@@ -174,14 +174,14 @@ def custom_getattr(self, name):
174174 # Attempt to retrieve the attribute from the wrapped object (`value`)
175175 return getattr(value, name)
176176 return object.__getattribute__(self, name)
177- original_property = getattr(jaxlib._jax.ArrayImpl, orig_property_name, None)
177+ original_property = getattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl , orig_property_name, None)
178178 patched_method = {
179179 'device': device_getter,
180180 'shape': shape_getter,
181181 'dtype': dtype_getter,
182182 }[orig_property_name]
183183
184- setattr(jaxlib._jax.ArrayImpl, orig_property_name, property(patched_method))
184+ setattr(jaxlib._jax.ArrayImpl if jax.__version__ >= '0.6.0' else jaxlib.xla_extension.ArrayImpl , orig_property_name, property(patched_method))
185185 setattr(nnx.Variable, orig_property_name, property(patched_method))
186186 setattr(nnx.Variable, '__getattr__', custom_getattr)
187187
0 commit comments