-
Notifications
You must be signed in to change notification settings - Fork 770
Open
Description
Hey folks,
I've just noticed that for some reason data registered via nnx.data in the module is lost during nnx.eval_shape. I'm currently in the process of migrating our codebase to Flax.NNX and instead of rewriting entire metrics implementation in NNX style I just bundle metrics state in NNX module as a nnx.data field. The issue I've faced is that for some reason if I construct abstract shape with nnx.eval_shape, then it is lost for some reason.
Here's a tiny repro:
from flax import nnx
import jax.numpy as jnp
class Module(nnx.Module):
def __init__(self):
self.data = nnx.data({"a": jnp.ones((8, 8))})
def _init_module() -> Module:
module = Module()
return module
module = _init_module()
print("State with direct init:")
print(nnx.state(module))
module = nnx.eval_shape(_init_module)
print("State with nnx.eval_shape:")
print(nnx.state(module))Output:
State with direct init:
State({
'data': {
'a': Array([[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)
}
})
State with nnx.eval_shape:
State({})
Metadata
Metadata
Assignees
Labels
No labels