11# Copyright © 2023 Apple Inc.
22
33import textwrap
4- from typing import Any , Callable , List , Optional , Union
4+ from typing import Any , Callable , List , Optional , Tuple , Union
55
66import mlx .core as mx
77from mlx .utils import tree_flatten , tree_unflatten
@@ -61,6 +61,7 @@ def __init__(self):
6161
6262 @property
6363 def training (self ):
64+ """Boolean indicating if the model is in training mode."""
6465 return self ._training
6566
6667 def _extra_repr (self ):
@@ -87,15 +88,83 @@ def __getattr__(self, key: str):
8788 def __setattr__ (self , key : str , val : Any ):
8889 self [key ] = val
8990
90- def load_weights (self , file : str ):
91+ def load_weights (
92+ self ,
93+ file_or_weights : Union [str , List [Tuple [str , mx .array ]]],
94+ strict : bool = True ,
95+ ):
9196 """
92- Load and update the model's weights from a `.npz` file.
97+ Update the model's weights from a ``.npz`` or a list.
98+
99+ Args:
100+ file_or_weights (str or list(tuple(str, mx.array))): The path to
101+ the weights ``.npz`` file or a list of pairs of parameter names
102+ and arrays.
103+ strict (bool, optional): If ``True`` then checks that the provided
104+ weights exactly match the parameters of the model. Otherwise,
105+ only the weights actually contained in the model are loaded and
106+ shapes are not checked. Default: ``True``.
107+
108+ Example:
109+
110+ .. code-block:: python
111+
112+ import mlx.core as mx
113+ import mlx.nn as nn
114+ model = nn.Linear(10, 10)
115+
116+ # Load from file
117+ model.load_weights("weights.npz")
118+
119+ # Load from list
120+ weights = [
121+ ("weight", mx.random.uniform(shape=(10, 10))),
122+ ("bias", mx.zeros((10,))),
123+ ]
124+ model.load_weights(weights)
125+
126+ # Missing weight
127+ weights = [
128+ ("weight", mx.random.uniform(shape=(10, 10))),
129+ ]
130+
131+ # Raises a ValueError exception
132+ model.load_weights(weights)
133+
134+ # Ok, only updates the weight but not the bias
135+ model.load_weights(weights, strict=False)
93136 """
94- self .update (tree_unflatten (list (mx .load (file ).items ())))
137+ weights = file_or_weights
138+ if isinstance (weights , str ):
139+ weights = list (mx .load (weights ).items ())
140+
141+ if strict :
142+ new_weights = dict (weights )
143+ curr_weights = dict (tree_flatten (self .parameters ()))
144+ if extras := (new_weights .keys () - curr_weights .keys ()):
145+ extras = " " .join (extras )
146+ raise ValueError (f"Received parameters not in model: { extras } ." )
147+ if missing := (curr_weights .keys () - new_weights .keys ()):
148+ missing = " " .join (missing )
149+ raise ValueError (f"Missing parameters: { missing } ." )
150+ for k , v in curr_weights .items ():
151+ v_new = new_weights [k ]
152+ if not isinstance (v_new , mx .array ):
153+ raise ValueError (
154+ "Expected mx.array but received "
155+ f"{ type (v_new )} for parameter { k } "
156+ )
157+ if v_new .shape != v .shape :
158+ raise ValueError (
159+ f"Expected shape { v .shape } but received "
160+ f" shape { v_new .shape } for parameter { k } "
161+ )
162+
163+ self .update (tree_unflatten (weights ))
95164
96165 def save_weights (self , file : str ):
97166 """
98- Save the model's weights to a `.npz` file.
167+ Save the model's weights to a `` .npz` ` file.
99168 """
100169 mx .savez (file , ** dict (tree_flatten (self .parameters ())))
101170
@@ -351,23 +420,26 @@ def freeze(
351420 """Freeze the Module's parameters or some of them. Freezing a parameter means not
352421 computing gradients for it.
353422
354- This function is idempotent ie freezing a frozen model is a noop.
423+ This function is idempotent i.e. freezing a frozen model is a no-op.
424+
425+ Example:
426+ For instance to only train the attention parameters from a Transformer:
355427
356- For instance to only train the attention parameters from a transformer:
428+ .. code-block:: python
357429
358- model = ...
359- model.freeze()
360- model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
430+ model = nn.Transformer()
431+ model.freeze()
432+ model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
361433
362434 Args:
363435 recurse (bool, optional): If True then freeze the parameters of the
364- submodules as well (default: True) .
436+ submodules as well. Default: `` True`` .
365437 keys (str or list[str], optional): If provided then only these
366438 parameters will be frozen otherwise all the parameters of a
367439 module. For instance freeze all biases by calling
368440 ``module.freeze(keys="bias")``.
369- strict (bool, optional): If set to True validate that the passed keys exist
370- (default: False) .
441+ strict (bool, optional): If set to `` True`` validate that the passed keys exist.
442+ Default: `` False`` .
371443 """
372444
373445 def _freeze_impl (_ , m ):
@@ -401,21 +473,25 @@ def unfreeze(
401473 This function is idempotent ie unfreezing a model that is not frozen is
402474 a noop.
403475
404- For instance to only train the biases one can do :
476+ Example :
405477
406- model = ...
407- model.freeze()
408- model.unfreeze(keys="bias")
478+ For instance to only train the biases of a Transformer one can do:
479+
480+ .. code-block:: python
481+
482+ model = nn.Transformer()
483+ model.freeze()
484+ model.unfreeze(keys="bias")
409485
410486 Args:
411487 recurse (bool, optional): If True then unfreeze the parameters of the
412- submodules as well (default: True) .
488+ submodules as well. Default: `` True`` .
413489 keys (str or list[str], optional): If provided then only these
414490 parameters will be unfrozen otherwise all the parameters of a
415491 module. For instance unfreeze all biases by calling
416492 ``module.unfreeze(keys="bias")``.
417- strict (bool, optional): If set to True validate that the passed keys exist
418- (default: False) .
493+ strict (bool, optional): If set to `` True`` validate that the passed keys exist.
494+ Default: `` False`` .
419495 """
420496
421497 def _unfreeze_impl (_ , m ):
@@ -432,10 +508,25 @@ def _unfreeze_impl(_, m):
432508 _unfreeze_impl ("" , self )
433509
434510 def train (self , mode : bool = True ):
511+ """Set the model in or out of training mode.
512+
513+ Training mode only applies to certain layers. For example
514+ :obj:`Dropout` applies a random mask in training mode, but is the
515+ identity in evaluation mode.
516+
517+ Args:
518+ mode (bool): Indicate if the model should be in training or
519+ evaluation mode. Default: ``True``.
520+ """
521+
435522 def _set_train (_ , m ):
436523 m ._training = mode
437524
438525 self .apply_to_modules (_set_train )
439526
440527 def eval (self ):
528+ """Set the model to evaluation mode.
529+
530+ See :func:`train`.
531+ """
441532 self .train (False )
0 commit comments