Skip to content

Commit dff4a38

Browse files
authored
Module checks the weight on load_weights (#337)
* update module to check weights on load, also fix docs and reorganize tests * nits + rebase * a few more docs updates for Module * use manual module file * comment
1 parent 0782a45 commit dff4a38

File tree

6 files changed

+575
-354
lines changed

6 files changed

+575
-354
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{{ fullname | escape | underline}}
2+
3+
.. currentmodule:: {{ module }}
4+
5+
.. add toctree option to make autodoc generate the pages
6+
7+
.. autoclass:: {{ objname }}
8+
9+
{% block attributes %}
10+
{% if attributes %}
11+
.. rubric:: Attributes
12+
13+
.. autosummary::
14+
:toctree: .
15+
{% for item in attributes %}
16+
~{{ fullname }}.{{ item }}
17+
{%- endfor %}
18+
{% endif %}
19+
{% endblock %}
20+
21+
{% block methods %}
22+
{% if methods %}
23+
.. rubric:: Methods
24+
25+
.. autosummary::
26+
:toctree: .
27+
{% for item in methods %}
28+
{%- if item not in inherited_members and item != '__init__' %}
29+
~{{ fullname }}.{{ item }}
30+
{%- endif -%}
31+
{%- endfor %}
32+
{% endif %}
33+
{% endblock %}

docs/src/python/nn.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,13 @@ In detail:
170170
:meth:`mlx.core.value_and_grad`
171171

172172
.. autosummary::
173-
:recursive:
174173
:toctree: _autosummary
175174

176175
value_and_grad
177-
Module
178176

179177
.. toctree::
180178

179+
nn/module
181180
nn/layers
182181
nn/functions
183182
nn/losses

docs/src/python/nn/module.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
Module
2+
======
3+
4+
.. currentmodule:: mlx.nn
5+
6+
.. autoclass:: Module
7+
8+
.. rubric:: Attributes
9+
10+
.. autosummary::
11+
:toctree: _autosummary
12+
13+
Module.training
14+
15+
.. rubric:: Methods
16+
17+
.. autosummary::
18+
:toctree: _autosummary
19+
20+
Module.apply
21+
Module.apply_to_modules
22+
Module.children
23+
Module.eval
24+
Module.filter_and_map
25+
Module.freeze
26+
Module.leaf_modules
27+
Module.load_weights
28+
Module.modules
29+
Module.named_modules
30+
Module.parameters
31+
Module.save_weights
32+
Module.train
33+
Module.trainable_parameters
34+
Module.unfreeze
35+
Module.update
36+
Module.update_modules

python/mlx/nn/layers/base.py

Lines changed: 111 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright © 2023 Apple Inc.
22

33
import textwrap
4-
from typing import Any, Callable, List, Optional, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import mlx.core as mx
77
from mlx.utils import tree_flatten, tree_unflatten
@@ -61,6 +61,7 @@ def __init__(self):
6161

6262
@property
6363
def training(self):
64+
"""Boolean indicating if the model is in training mode."""
6465
return self._training
6566

6667
def _extra_repr(self):
@@ -87,15 +88,83 @@ def __getattr__(self, key: str):
8788
def __setattr__(self, key: str, val: Any):
8889
self[key] = val
8990

90-
def load_weights(self, file: str):
91+
def load_weights(
92+
self,
93+
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
94+
strict: bool = True,
95+
):
9196
"""
92-
Load and update the model's weights from a `.npz` file.
97+
Update the model's weights from a ``.npz`` or a list.
98+
99+
Args:
100+
file_or_weights (str or list(tuple(str, mx.array))): The path to
101+
the weights ``.npz`` file or a list of pairs of parameter names
102+
and arrays.
103+
strict (bool, optional): If ``True`` then checks that the provided
104+
weights exactly match the parameters of the model. Otherwise,
105+
only the weights actually contained in the model are loaded and
106+
shapes are not checked. Default: ``True``.
107+
108+
Example:
109+
110+
.. code-block:: python
111+
112+
import mlx.core as mx
113+
import mlx.nn as nn
114+
model = nn.Linear(10, 10)
115+
116+
# Load from file
117+
model.load_weights("weights.npz")
118+
119+
# Load from list
120+
weights = [
121+
("weight", mx.random.uniform(shape=(10, 10))),
122+
("bias", mx.zeros((10,))),
123+
]
124+
model.load_weights(weights)
125+
126+
# Missing weight
127+
weights = [
128+
("weight", mx.random.uniform(shape=(10, 10))),
129+
]
130+
131+
# Raises a ValueError exception
132+
model.load_weights(weights)
133+
134+
# Ok, only updates the weight but not the bias
135+
model.load_weights(weights, strict=False)
93136
"""
94-
self.update(tree_unflatten(list(mx.load(file).items())))
137+
weights = file_or_weights
138+
if isinstance(weights, str):
139+
weights = list(mx.load(weights).items())
140+
141+
if strict:
142+
new_weights = dict(weights)
143+
curr_weights = dict(tree_flatten(self.parameters()))
144+
if extras := (new_weights.keys() - curr_weights.keys()):
145+
extras = " ".join(extras)
146+
raise ValueError(f"Received parameters not in model: {extras}.")
147+
if missing := (curr_weights.keys() - new_weights.keys()):
148+
missing = " ".join(missing)
149+
raise ValueError(f"Missing parameters: {missing}.")
150+
for k, v in curr_weights.items():
151+
v_new = new_weights[k]
152+
if not isinstance(v_new, mx.array):
153+
raise ValueError(
154+
"Expected mx.array but received "
155+
f"{type(v_new)} for parameter {k}"
156+
)
157+
if v_new.shape != v.shape:
158+
raise ValueError(
159+
f"Expected shape {v.shape} but received "
160+
f" shape {v_new.shape} for parameter {k}"
161+
)
162+
163+
self.update(tree_unflatten(weights))
95164

96165
def save_weights(self, file: str):
97166
"""
98-
Save the model's weights to a `.npz` file.
167+
Save the model's weights to a ``.npz`` file.
99168
"""
100169
mx.savez(file, **dict(tree_flatten(self.parameters())))
101170

@@ -351,23 +420,26 @@ def freeze(
351420
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
352421
computing gradients for it.
353422
354-
This function is idempotent ie freezing a frozen model is a noop.
423+
This function is idempotent i.e. freezing a frozen model is a no-op.
424+
425+
Example:
426+
For instance to only train the attention parameters from a Transformer:
355427
356-
For instance to only train the attention parameters from a transformer:
428+
.. code-block:: python
357429
358-
model = ...
359-
model.freeze()
360-
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
430+
model = nn.Transformer()
431+
model.freeze()
432+
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
361433
362434
Args:
363435
recurse (bool, optional): If True then freeze the parameters of the
364-
submodules as well (default: True).
436+
submodules as well. Default: ``True``.
365437
keys (str or list[str], optional): If provided then only these
366438
parameters will be frozen otherwise all the parameters of a
367439
module. For instance freeze all biases by calling
368440
``module.freeze(keys="bias")``.
369-
strict (bool, optional): If set to True validate that the passed keys exist
370-
(default: False).
441+
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
442+
Default: ``False``.
371443
"""
372444

373445
def _freeze_impl(_, m):
@@ -401,21 +473,25 @@ def unfreeze(
401473
This function is idempotent ie unfreezing a model that is not frozen is
402474
a noop.
403475
404-
For instance to only train the biases one can do:
476+
Example:
405477
406-
model = ...
407-
model.freeze()
408-
model.unfreeze(keys="bias")
478+
For instance to only train the biases of a Transformer one can do:
479+
480+
.. code-block:: python
481+
482+
model = nn.Transformer()
483+
model.freeze()
484+
model.unfreeze(keys="bias")
409485
410486
Args:
411487
recurse (bool, optional): If True then unfreeze the parameters of the
412-
submodules as well (default: True).
488+
submodules as well. Default: ``True``.
413489
keys (str or list[str], optional): If provided then only these
414490
parameters will be unfrozen otherwise all the parameters of a
415491
module. For instance unfreeze all biases by calling
416492
``module.unfreeze(keys="bias")``.
417-
strict (bool, optional): If set to True validate that the passed keys exist
418-
(default: False).
493+
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
494+
Default: ``False``.
419495
"""
420496

421497
def _unfreeze_impl(_, m):
@@ -432,10 +508,25 @@ def _unfreeze_impl(_, m):
432508
_unfreeze_impl("", self)
433509

434510
def train(self, mode: bool = True):
511+
"""Set the model in or out of training mode.
512+
513+
Training mode only applies to certain layers. For example
514+
:obj:`Dropout` applies a random mask in training mode, but is the
515+
identity in evaluation mode.
516+
517+
Args:
518+
mode (bool): Indicate if the model should be in training or
519+
evaluation mode. Default: ``True``.
520+
"""
521+
435522
def _set_train(_, m):
436523
m._training = mode
437524

438525
self.apply_to_modules(_set_train)
439526

440527
def eval(self):
528+
"""Set the model to evaluation mode.
529+
530+
See :func:`train`.
531+
"""
441532
self.train(False)

0 commit comments

Comments
 (0)