diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index ba078d0b02594..6e12b60407e38 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -700,6 +700,11 @@ def _addindent(s_, numSpaces): class Module(nnx.Module, ModelHelpers, TorchModuleHelpers): + _v: nnx.Data[dict] + _buffers: nnx.Data[dict] + _module_dict: nnx.Data[dict] + _args: nnx.Data[tuple] + _kwargs: nnx.Data[dict] _build_mode = None _with_partial_v = None _store_vars = True diff --git a/ivy/transpiler/transformations/transformers/native_layers_transformer/__init__.py b/ivy/transpiler/transformations/transformers/native_layers_transformer/__init__.py index 979f68c70b1a8..c5f512bf39c8d 100644 --- a/ivy/transpiler/transformations/transformers/native_layers_transformer/__init__.py +++ b/ivy/transpiler/transformations/transformers/native_layers_transformer/__init__.py @@ -1266,6 +1266,10 @@ def torch_pad(input, pad, mode="constant", value=0): class FlaxConv(nnx.Conv): + _v: nnx.Data[dict] + _buffers: nnx.Data[dict] + pt_weight: nnx.Data[jax.Array] + pt_bias: nnx.Data[Optional[jax.Array]] def __init__(self, *args, **kwargs): self._previous_frame_info = None self._built = False @@ -1298,6 +1302,11 @@ def __init__(self, *args, **kwargs): **kwargs, ) + # For nnx >= 0.12, avoid DataAttr(None) on bias when bias is disabled. + # Ensure bias is a plain None rather than an nnx.data-wrapped value. + if not self.use_bias: + object.__setattr__(self, "bias", None) + # Compute self._reversed_padding_repeated_twice if isinstance(padding_, str): self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) @@ -1365,6 +1374,13 @@ def __setattr__(self, name, value): self.__dict__[name] = value return elif name in ["weight", "bias"] and hasattr(self, name): + # allow explicitly setting bias/weight to None without shape handling + if value is None: + new_native_name = "kernel" if name == "weight" else name + object.__setattr__(self, new_native_name, None) + new_pt_name = "pt_weight" if name == "weight" else "pt_bias" + object.__setattr__(self, new_pt_name, None) + return # Determine the transpose type based on the value shape if len(value.shape) > 4: # Conv3D case: PT [out_channels, in_channels, depth, height, width] @@ -1473,6 +1489,13 @@ def to(self, *args, **kwargs): return self._apply(lambda t: jax_to_frnt_(t, *args, **kwargs)) class FlaxBatchNorm(nnx.BatchNorm): + _v: nnx.Data[dict] + _buffers: nnx.Data[dict] + pt_weight: nnx.Data[Optional[jax.Array]] + pt_bias: nnx.Data[Optional[jax.Array]] + running_mean: nnx.Data[Optional[jax.Array]] + running_var: nnx.Data[Optional[jax.Array]] + num_batches_tracked: nnx.Data[Optional[jax.Array]] def __init__(self, *args, **kwargs): self._previous_frame_info = None self._built = False @@ -1640,6 +1663,10 @@ def to(self, *args, **kwargs): return self._apply(lambda t: jax_to_frnt_(t, *args, **kwargs)) class FlaxLinear(nnx.Linear): + _v: nnx.Data[dict] + _buffers: nnx.Data[dict] + pt_weight: nnx.Data[jax.Array] + pt_bias: nnx.Data[Optional[jax.Array]] def __init__(self, *args, **kwargs): self._previous_frame_info = None self._built = False