Skip to content

Commit 280d737

Browse files
committed
[nnx] refactor GraphDef
1 parent a1e3bbf commit 280d737

File tree

11 files changed

+283
-264
lines changed

11 files changed

+283
-264
lines changed

benchmarks/nnx_graph_overhead.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,9 @@ def main(argv):
9797
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
9898
pass
9999

100-
cached_step_nnx = nnx.cached_partial(step_nnx, model, optimizer)
101-
102100
t0 = time()
103101
for _ in range(total_steps):
104-
cached_step_nnx()
102+
step_nnx(model, optimizer)
105103

106104
total_time = time() - t0
107105
time_per_step = total_time / total_steps

flax/nnx/bridge/variables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
154154
elif isinstance(v, variablelib.VariableState):
155155
col_name = variablelib.variable_name_from_type(v.type)
156156
v = to_linen_var(v)
157-
elif isinstance(v, graph.NodeDef) or isinstance(v, graph.NodeRef):
157+
elif isinstance(v, graph.GraphDef):
158158
col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule
159159
else:
160160
raise ValueError(f'Cannot infer collection name from value: {v}')

flax/nnx/bridge/wrappers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@dataclasses.dataclass
3939
class Functional(tp.Generic[M]):
4040
module_type: tp.Type[M]
41-
graphdef: tp.Optional[graph.NodeDef[M]]
41+
graphdef: tp.Optional[graph.GraphDef[M]]
4242
args: tuple[tp.Any, ...]
4343
kwargs: dict[str, tp.Any]
4444

@@ -48,7 +48,6 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
4848
kwargs['rngs'] = rngs
4949
module = self.module_type(*self.args, **self.kwargs, **kwargs)
5050
graphdef, state = nnx.split(module)
51-
assert type(graphdef) is graph.NodeDef
5251
self.graphdef = graphdef
5352
return state # type: ignore
5453

@@ -217,7 +216,7 @@ class ToLinen(linen.Module):
217216
>>> variables.keys()
218217
dict_keys(['nnx', 'params'])
219218
>>> type(variables['nnx']['graphdef'])
220-
<class 'flax.nnx.graph.NodeDef'>
219+
<class 'flax.nnx.graph.GraphDef'>
221220
222221
Args:
223222
nnx_class: The NNX Module class (not instance!).

0 commit comments

Comments
 (0)