Skip to content

Commit 4494970

Browse files
authored
avoid nested closures in module (#759)
1 parent 776c3d2 commit 4494970

File tree

1 file changed

+46
-27
lines changed

1 file changed

+46
-27
lines changed

python/mlx/nn/layers/base.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,42 @@
77
from mlx.utils import tree_flatten, tree_unflatten
88

99

10+
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
11+
if is_leaf_fn(model, value_key, value):
12+
return map_fn(value)
13+
14+
elif isinstance(value, Module):
15+
return {
16+
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
17+
for k, v in value.items()
18+
if filter_fn(value, k, v)
19+
}
20+
21+
elif isinstance(value, dict):
22+
nd = {}
23+
for k, v in v.items():
24+
tk = f"{value_key}.{k}"
25+
nd[k] = (
26+
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
27+
if filter_fn(model, tk, v)
28+
else {}
29+
)
30+
return nd
31+
32+
elif isinstance(value, list):
33+
nl = []
34+
for i, vi in enumerate(value):
35+
tk = f"{value_key}.{i}"
36+
nl.append(
37+
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
38+
if filter_fn(model, tk, vi)
39+
else {}
40+
)
41+
return nl
42+
43+
raise RuntimeError("Unexpected leaf found while traversing the module")
44+
45+
1046
class Module(dict):
1147
"""Base class for building neural networks with MLX.
1248
@@ -98,10 +134,13 @@ def __getattr__(self, key: str):
98134
if key in self:
99135
return self[key]
100136
else:
101-
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
137+
super(Module, self).__getattr__(key, val)
102138

103139
def __setattr__(self, key: str, val: Any):
104-
self[key] = val
140+
if isinstance(val, (mx.array, dict, list, tuple)):
141+
self[key] = val
142+
else:
143+
super(Module, self).__setattr__(key, val)
105144

106145
def load_weights(
107146
self,
@@ -245,31 +284,11 @@ def filter_and_map(
245284
is_leaf_fn = is_leaf_fn or (
246285
lambda m, k, v: not isinstance(v, (Module, dict, list))
247286
)
248-
249-
def unwrap(vk, v):
250-
if is_leaf_fn(self, vk, v):
251-
return map_fn(v)
252-
253-
if isinstance(v, Module):
254-
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
255-
256-
if isinstance(v, dict):
257-
nd = {}
258-
for k, v in v.items():
259-
tk = f"{vk}.{k}"
260-
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
261-
return nd
262-
263-
if isinstance(v, list):
264-
nl = []
265-
for i, vi in enumerate(v):
266-
tk = f"{vk}.{i}"
267-
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
268-
return nl
269-
270-
raise RuntimeError("Unexpected leaf found while traversing the module")
271-
272-
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
287+
return {
288+
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
289+
for k, v in self.items()
290+
if filter_fn(self, k, v)
291+
}
273292

274293
def parameters(self):
275294
"""Recursively return all the :class:`mlx.core.array` members of this Module

0 commit comments

Comments
 (0)