Skip to content

nnx.data is lost during nnx.eval_shape #5128

@qGentry

Description

@qGentry

Hey folks,

I've just noticed that for some reason data registered via nnx.data in the module is lost during nnx.eval_shape. I'm currently in the process of migrating our codebase to Flax.NNX and instead of rewriting entire metrics implementation in NNX style I just bundle metrics state in NNX module as a nnx.data field. The issue I've faced is that for some reason if I construct abstract shape with nnx.eval_shape, then it is lost for some reason.

Here's a tiny repro:

from flax import nnx
import jax.numpy as jnp


class Module(nnx.Module):
    def __init__(self):
        self.data = nnx.data({"a": jnp.ones((8, 8))})


def _init_module() -> Module:
    module = Module()
    return module

module = _init_module()
print("State with direct init:")
print(nnx.state(module))
module = nnx.eval_shape(_init_module)
print("State with nnx.eval_shape:")
print(nnx.state(module))

Output:

State with direct init:
State({
  'data': {
    'a': Array([[1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)
  }
})
State with nnx.eval_shape:
State({})

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions