|
| 1 | +.. _lazy eval: |
| 2 | + |
| 3 | +Lazy Evaluation |
| 4 | +=============== |
| 5 | + |
| 6 | +.. currentmodule:: mlx.core |
| 7 | + |
| 8 | +Why Lazy Evaluation |
| 9 | +------------------- |
| 10 | + |
| 11 | +When you perform operations in MLX, no computation actually happens. Instead a |
| 12 | +compute graph is recorded. The actual computation only happens if an |
| 13 | +:func:`eval` is performed. |
| 14 | + |
| 15 | +MLX uses lazy evaluation because it has some nice features, some of which we |
| 16 | +describe below. |
| 17 | + |
| 18 | +Transforming Compute Graphs |
| 19 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 20 | + |
| 21 | +Lazy evaluation let's us record a compute graph without actually doing any |
| 22 | +computations. This is useful for function transformations like :func:`grad` and |
| 23 | +:func:`vmap` and graph optimizations like :func:`simplify`. |
| 24 | + |
| 25 | +Currently, MLX does not compile and rerun compute graphs. They are all |
| 26 | +generated dynamically. However, lazy evaluation makes it much easier to |
| 27 | +integrate compilation for future performance enhancements. |
| 28 | + |
| 29 | +Only Compute What You Use |
| 30 | +^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 31 | + |
| 32 | +In MLX you do not need to worry as much about computing outputs that are never |
| 33 | +used. For example: |
| 34 | + |
| 35 | +.. code-block:: python |
| 36 | +
|
| 37 | + def fun(x): |
| 38 | + a = fun1(x) |
| 39 | + b = expensive_fun(a) |
| 40 | + return a, b |
| 41 | +
|
| 42 | + y, _ = fun(x) |
| 43 | +
|
| 44 | +Here, we never actually compute the output of ``expensive_fun``. Use this |
| 45 | +pattern with care though, as the graph of ``expensive_fun`` is still built, and |
| 46 | +that has some cost associated to it. |
| 47 | + |
| 48 | +Similarly, lazy evaluation can be beneficial for saving memory while keeping |
| 49 | +code simple. Say you have a very large model ``Model`` derived from |
| 50 | +:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``. |
| 51 | +Typically, this will initialize all of the weights as ``float32``, but the |
| 52 | +initialization does not actually compute anything until you perform an |
| 53 | +:func:`eval`. If you update the model with ``float16`` weights, your maximum |
| 54 | +consumed memory will be half that required if eager computation was used |
| 55 | +instead. |
| 56 | + |
| 57 | +This pattern is simple to do in MLX thanks to lazy computation: |
| 58 | + |
| 59 | +.. code-block:: python |
| 60 | +
|
| 61 | + model = Model() # no memory used yet |
| 62 | + model.load_weights("weights_fp16.safetensors") |
| 63 | +
|
| 64 | +When to Evaluate |
| 65 | +---------------- |
| 66 | + |
| 67 | +A common question is when to use :func:`eval`. The trade-off is between |
| 68 | +letting graphs get too large and not batching enough useful work. |
| 69 | + |
| 70 | +For example: |
| 71 | + |
| 72 | +.. code-block:: python |
| 73 | +
|
| 74 | + for _ in range(100): |
| 75 | + a = a + b |
| 76 | + mx.eval(a) |
| 77 | + b = b * 2 |
| 78 | + mx.eval(b) |
| 79 | +
|
| 80 | +This is a bad idea because there is some fixed overhead with each graph |
| 81 | +evaluation. On the other hand, there is some slight overhead which grows with |
| 82 | +the compute graph size, so extremely large graphs (while computationally |
| 83 | +correct) can be costly. |
| 84 | + |
| 85 | +Luckily, a wide range of compute graph sizes work pretty well with MLX: |
| 86 | +anything from a few tens of operations to many thousands of operations per |
| 87 | +evaluation should be okay. |
| 88 | + |
| 89 | +Most numerical computations have an iterative outer loop (e.g. the iteration in |
| 90 | +stochastic gradient descent). A natural and usually efficient place to use |
| 91 | +:func:`eval` is at each iteration of this outer loop. |
| 92 | + |
| 93 | +Here is a concrete example: |
| 94 | + |
| 95 | +.. code-block:: python |
| 96 | +
|
| 97 | + for batch in dataset: |
| 98 | +
|
| 99 | + # Nothing has been evaluated yet |
| 100 | + loss, grad = value_and_grad_fn(model, batch) |
| 101 | +
|
| 102 | + # Still nothing has been evaluated |
| 103 | + optimizer.update(model, grad) |
| 104 | +
|
| 105 | + # Evaluate the loss and the new parameters which will |
| 106 | + # run the full gradient computation and optimizer update |
| 107 | + mx.eval(loss, model.parameters()) |
| 108 | +
|
| 109 | +
|
| 110 | +An important behavior to be aware of is when the graph will be implicitly |
| 111 | +evaluated. Anytime you ``print`` an array, convert it to an |
| 112 | +:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, |
| 113 | +the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX |
| 114 | +saving functions) will also evaluate the array. |
| 115 | + |
| 116 | + |
| 117 | +Calling :func:`array.item` on a scalar array will also evaluate it. In the |
| 118 | +example above, printing the loss (``print(loss)``) or adding the loss scalar to |
| 119 | +a list (``losses.append(loss.item())``) would cause a graph evaluation. If |
| 120 | +these lines are before ``mx.eval(loss, model.parameters())`` then this |
| 121 | +will be a partial evaluation, computing only the forward pass. |
| 122 | + |
| 123 | +Also, calling :func:`eval` on an array or set of arrays multiple times is |
| 124 | +perfectly fine. This is effectively a no-op. |
| 125 | + |
| 126 | +.. warning:: |
| 127 | + |
| 128 | + Using scalar arrays for control-flow will cause an evaluation. |
| 129 | + |
| 130 | +Here is an example: |
| 131 | + |
| 132 | +.. code-block:: python |
| 133 | +
|
| 134 | + def fun(x): |
| 135 | + h, y = first_layer(x) |
| 136 | + if y > 0: # An evaluation is done here! |
| 137 | + z = second_layer_a(h) |
| 138 | + else: |
| 139 | + z = second_layer_b(h) |
| 140 | + return z |
| 141 | +
|
| 142 | +Using arrays for control flow should be done with care. The above example works |
| 143 | +and can even be used with gradient transformations. However, this can be very |
| 144 | +inefficient if evaluations are done too frequently. |
0 commit comments