|
7 | 7 | from mlx.utils import tree_flatten, tree_unflatten |
8 | 8 |
|
9 | 9 |
|
| 10 | +def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): |
| 11 | + if is_leaf_fn(model, value_key, value): |
| 12 | + return map_fn(value) |
| 13 | + |
| 14 | + elif isinstance(value, Module): |
| 15 | + return { |
| 16 | + k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn) |
| 17 | + for k, v in value.items() |
| 18 | + if filter_fn(value, k, v) |
| 19 | + } |
| 20 | + |
| 21 | + elif isinstance(value, dict): |
| 22 | + nd = {} |
| 23 | + for k, v in v.items(): |
| 24 | + tk = f"{value_key}.{k}" |
| 25 | + nd[k] = ( |
| 26 | + _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn) |
| 27 | + if filter_fn(model, tk, v) |
| 28 | + else {} |
| 29 | + ) |
| 30 | + return nd |
| 31 | + |
| 32 | + elif isinstance(value, list): |
| 33 | + nl = [] |
| 34 | + for i, vi in enumerate(value): |
| 35 | + tk = f"{value_key}.{i}" |
| 36 | + nl.append( |
| 37 | + _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn) |
| 38 | + if filter_fn(model, tk, vi) |
| 39 | + else {} |
| 40 | + ) |
| 41 | + return nl |
| 42 | + |
| 43 | + raise RuntimeError("Unexpected leaf found while traversing the module") |
| 44 | + |
| 45 | + |
10 | 46 | class Module(dict): |
11 | 47 | """Base class for building neural networks with MLX. |
12 | 48 |
|
@@ -98,10 +134,13 @@ def __getattr__(self, key: str): |
98 | 134 | if key in self: |
99 | 135 | return self[key] |
100 | 136 | else: |
101 | | - raise AttributeError(f"{type(self)!r} has no attribute {key!r}") |
| 137 | + super(Module, self).__getattr__(key, val) |
102 | 138 |
|
103 | 139 | def __setattr__(self, key: str, val: Any): |
104 | | - self[key] = val |
| 140 | + if isinstance(val, (mx.array, dict, list, tuple)): |
| 141 | + self[key] = val |
| 142 | + else: |
| 143 | + super(Module, self).__setattr__(key, val) |
105 | 144 |
|
106 | 145 | def load_weights( |
107 | 146 | self, |
@@ -245,31 +284,11 @@ def filter_and_map( |
245 | 284 | is_leaf_fn = is_leaf_fn or ( |
246 | 285 | lambda m, k, v: not isinstance(v, (Module, dict, list)) |
247 | 286 | ) |
248 | | - |
249 | | - def unwrap(vk, v): |
250 | | - if is_leaf_fn(self, vk, v): |
251 | | - return map_fn(v) |
252 | | - |
253 | | - if isinstance(v, Module): |
254 | | - return v.filter_and_map(filter_fn, map_fn, is_leaf_fn) |
255 | | - |
256 | | - if isinstance(v, dict): |
257 | | - nd = {} |
258 | | - for k, v in v.items(): |
259 | | - tk = f"{vk}.{k}" |
260 | | - nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {} |
261 | | - return nd |
262 | | - |
263 | | - if isinstance(v, list): |
264 | | - nl = [] |
265 | | - for i, vi in enumerate(v): |
266 | | - tk = f"{vk}.{i}" |
267 | | - nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {}) |
268 | | - return nl |
269 | | - |
270 | | - raise RuntimeError("Unexpected leaf found while traversing the module") |
271 | | - |
272 | | - return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)} |
| 287 | + return { |
| 288 | + k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn) |
| 289 | + for k, v in self.items() |
| 290 | + if filter_fn(self, k, v) |
| 291 | + } |
273 | 292 |
|
274 | 293 | def parameters(self): |
275 | 294 | """Recursively return all the :class:`mlx.core.array` members of this Module |
|
0 commit comments