Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,11 @@ def _addindent(s_, numSpaces):


class Module(nnx.Module, ModelHelpers, TorchModuleHelpers):
_v: nnx.Data[dict]
_buffers: nnx.Data[dict]
_module_dict: nnx.Data[dict]
_args: nnx.Data[tuple]
_kwargs: nnx.Data[dict]
_build_mode = None
_with_partial_v = None
_store_vars = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,10 @@ def torch_pad(input, pad, mode="constant", value=0):


class FlaxConv(nnx.Conv):
_v: nnx.Data[dict]
_buffers: nnx.Data[dict]
pt_weight: nnx.Data[jax.Array]
pt_bias: nnx.Data[Optional[jax.Array]]
def __init__(self, *args, **kwargs):
self._previous_frame_info = None
self._built = False
Expand Down Expand Up @@ -1298,6 +1302,11 @@ def __init__(self, *args, **kwargs):
**kwargs,
)

# For nnx >= 0.12, avoid DataAttr(None) on bias when bias is disabled.
# Ensure bias is a plain None rather than an nnx.data-wrapped value.
if not self.use_bias:
object.__setattr__(self, "bias", None)

# Compute self._reversed_padding_repeated_twice
if isinstance(padding_, str):
self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size)
Expand Down Expand Up @@ -1365,6 +1374,13 @@ def __setattr__(self, name, value):
self.__dict__[name] = value
return
elif name in ["weight", "bias"] and hasattr(self, name):
# allow explicitly setting bias/weight to None without shape handling
if value is None:
new_native_name = "kernel" if name == "weight" else name
object.__setattr__(self, new_native_name, None)
new_pt_name = "pt_weight" if name == "weight" else "pt_bias"
object.__setattr__(self, new_pt_name, None)
return
# Determine the transpose type based on the value shape
if len(value.shape) > 4:
# Conv3D case: PT [out_channels, in_channels, depth, height, width]
Expand Down Expand Up @@ -1473,6 +1489,13 @@ def to(self, *args, **kwargs):
return self._apply(lambda t: jax_to_frnt_(t, *args, **kwargs))

class FlaxBatchNorm(nnx.BatchNorm):
_v: nnx.Data[dict]
_buffers: nnx.Data[dict]
pt_weight: nnx.Data[Optional[jax.Array]]
pt_bias: nnx.Data[Optional[jax.Array]]
running_mean: nnx.Data[Optional[jax.Array]]
running_var: nnx.Data[Optional[jax.Array]]
num_batches_tracked: nnx.Data[Optional[jax.Array]]
def __init__(self, *args, **kwargs):
self._previous_frame_info = None
self._built = False
Expand Down Expand Up @@ -1640,6 +1663,10 @@ def to(self, *args, **kwargs):
return self._apply(lambda t: jax_to_frnt_(t, *args, **kwargs))

class FlaxLinear(nnx.Linear):
_v: nnx.Data[dict]
_buffers: nnx.Data[dict]
pt_weight: nnx.Data[jax.Array]
pt_bias: nnx.Data[Optional[jax.Array]]
def __init__(self, *args, **kwargs):
self._previous_frame_info = None
self._built = False
Expand Down
Loading