Skip to content

Commit c92a134

Browse files
authored
more docs (#421)
* more docs * fix link * nits + comments
1 parent 3b4f066 commit c92a134

File tree

6 files changed

+360
-1
lines changed

6 files changed

+360
-1
lines changed

docs/src/index.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ are the CPU and GPU.
3636
:maxdepth: 1
3737

3838
usage/quick_start
39+
usage/lazy_evaluation
3940
usage/unified_memory
40-
usage/using_streams
41+
usage/indexing
42+
usage/saving_and_loading
4143
usage/numpy
44+
usage/using_streams
4245

4346
.. toctree::
4447
:caption: Examples

docs/src/usage/indexing.rst

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
.. _indexing:
2+
3+
Indexing Arrays
4+
===============
5+
6+
.. currentmodule:: mlx.core
7+
8+
For the most part, indexing an MLX :obj:`array` works the same as indexing a
9+
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
10+
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
11+
how that works.
12+
13+
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
14+
15+
.. code-block:: shell
16+
17+
>>> arr = mx.arange(10)
18+
>>> arr[3]
19+
array(3, dtype=int32)
20+
>>> arr[-2] # negative indexing works
21+
array(8, dtype=int32)
22+
>>> arr[2:8:2] # start, stop, stride
23+
array([2, 4, 6], dtype=int32)
24+
25+
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
26+
27+
.. code-block:: shell
28+
29+
>>> arr = mx.arange(8).reshape(2, 2, 2)
30+
>>> arr[:, :, 0]
31+
array(3, dtype=int32)
32+
array([[0, 2],
33+
[4, 6]], dtype=int32
34+
>>> arr[..., 0]
35+
array([[0, 2],
36+
[4, 6]], dtype=int32
37+
38+
You can index with ``None`` to create a new axis:
39+
40+
.. code-block:: shell
41+
42+
>>> arr = mx.arange(8)
43+
>>> arr.shape
44+
[8]
45+
>>> arr[None].shape
46+
[1, 8]
47+
48+
49+
You can also use an :obj:`array` to index another :obj:`array`:
50+
51+
.. code-block:: shell
52+
53+
>>> arr = mx.arange(10)
54+
>>> idx = mx.array([5, 7])
55+
>>> arr[idx]
56+
array([5, 7], dtype=int32)
57+
58+
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
59+
works just as in NumPy.
60+
61+
Other functions which may be useful for indexing arrays are :func:`take` and
62+
:func:`take_along_axis`.
63+
64+
Differences from NumPy
65+
----------------------
66+
67+
.. Note::
68+
69+
MLX indexing is different from NumPy indexing in two important ways:
70+
71+
* Indexing does not perform bounds checking. Indexing out of bounds is
72+
undefined behavior.
73+
* Boolean mask based indexing is not yet supported.
74+
75+
The reason for the lack of bounds checking is that exceptions cannot propagate
76+
from the GPU. Performing bounds checking for array indices before launching the
77+
kernel would be extremely inefficient.
78+
79+
Indexing with boolean masks is something that MLX may support in the future. In
80+
general, MLX has limited support for operations for which outputs
81+
*shapes* are dependent on input *data*. Other examples of these types of
82+
operations which MLX does not yet support include :func:`numpy.nonzero` and the
83+
single input version of :func:`numpy.where`.
84+
85+
In Place Updates
86+
----------------
87+
88+
In place updates to indexed arrays are possible in MLX. For example:
89+
90+
.. code-block:: shell
91+
92+
>>> a = mx.array([1, 2, 3])
93+
>>> a[2] = 0
94+
>>> a
95+
array([1, 2, 0], dtype=int32)
96+
97+
Just as in NumPy, in place updates will be reflected in all references to the
98+
same array:
99+
100+
.. code-block:: shell
101+
102+
>>> a = mx.array([1, 2, 3])
103+
>>> b = a
104+
>>> b[2] = 0
105+
>>> b
106+
array([1, 2, 0], dtype=int32)
107+
>>> a
108+
array([1, 2, 0], dtype=int32)
109+
110+
Transformations of functions which use in-place updates are allowed and work as
111+
expected. For example:
112+
113+
.. code-block:: python
114+
115+
def fun(x, idx):
116+
x[idx] = 2.0
117+
return x.sum()
118+
119+
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
120+
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
121+
122+
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
123+
and ones elsewhere.

docs/src/usage/lazy_evaluation.rst

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
.. _lazy eval:
2+
3+
Lazy Evaluation
4+
===============
5+
6+
.. currentmodule:: mlx.core
7+
8+
Why Lazy Evaluation
9+
-------------------
10+
11+
When you perform operations in MLX, no computation actually happens. Instead a
12+
compute graph is recorded. The actual computation only happens if an
13+
:func:`eval` is performed.
14+
15+
MLX uses lazy evaluation because it has some nice features, some of which we
16+
describe below.
17+
18+
Transforming Compute Graphs
19+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
20+
21+
Lazy evaluation let's us record a compute graph without actually doing any
22+
computations. This is useful for function transformations like :func:`grad` and
23+
:func:`vmap` and graph optimizations like :func:`simplify`.
24+
25+
Currently, MLX does not compile and rerun compute graphs. They are all
26+
generated dynamically. However, lazy evaluation makes it much easier to
27+
integrate compilation for future performance enhancements.
28+
29+
Only Compute What You Use
30+
^^^^^^^^^^^^^^^^^^^^^^^^^
31+
32+
In MLX you do not need to worry as much about computing outputs that are never
33+
used. For example:
34+
35+
.. code-block:: python
36+
37+
def fun(x):
38+
a = fun1(x)
39+
b = expensive_fun(a)
40+
return a, b
41+
42+
y, _ = fun(x)
43+
44+
Here, we never actually compute the output of ``expensive_fun``. Use this
45+
pattern with care though, as the graph of ``expensive_fun`` is still built, and
46+
that has some cost associated to it.
47+
48+
Similarly, lazy evaluation can be beneficial for saving memory while keeping
49+
code simple. Say you have a very large model ``Model`` derived from
50+
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
51+
Typically, this will initialize all of the weights as ``float32``, but the
52+
initialization does not actually compute anything until you perform an
53+
:func:`eval`. If you update the model with ``float16`` weights, your maximum
54+
consumed memory will be half that required if eager computation was used
55+
instead.
56+
57+
This pattern is simple to do in MLX thanks to lazy computation:
58+
59+
.. code-block:: python
60+
61+
model = Model() # no memory used yet
62+
model.load_weights("weights_fp16.safetensors")
63+
64+
When to Evaluate
65+
----------------
66+
67+
A common question is when to use :func:`eval`. The trade-off is between
68+
letting graphs get too large and not batching enough useful work.
69+
70+
For example:
71+
72+
.. code-block:: python
73+
74+
for _ in range(100):
75+
a = a + b
76+
mx.eval(a)
77+
b = b * 2
78+
mx.eval(b)
79+
80+
This is a bad idea because there is some fixed overhead with each graph
81+
evaluation. On the other hand, there is some slight overhead which grows with
82+
the compute graph size, so extremely large graphs (while computationally
83+
correct) can be costly.
84+
85+
Luckily, a wide range of compute graph sizes work pretty well with MLX:
86+
anything from a few tens of operations to many thousands of operations per
87+
evaluation should be okay.
88+
89+
Most numerical computations have an iterative outer loop (e.g. the iteration in
90+
stochastic gradient descent). A natural and usually efficient place to use
91+
:func:`eval` is at each iteration of this outer loop.
92+
93+
Here is a concrete example:
94+
95+
.. code-block:: python
96+
97+
for batch in dataset:
98+
99+
# Nothing has been evaluated yet
100+
loss, grad = value_and_grad_fn(model, batch)
101+
102+
# Still nothing has been evaluated
103+
optimizer.update(model, grad)
104+
105+
# Evaluate the loss and the new parameters which will
106+
# run the full gradient computation and optimizer update
107+
mx.eval(loss, model.parameters())
108+
109+
110+
An important behavior to be aware of is when the graph will be implicitly
111+
evaluated. Anytime you ``print`` an array, convert it to an
112+
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
113+
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
114+
saving functions) will also evaluate the array.
115+
116+
117+
Calling :func:`array.item` on a scalar array will also evaluate it. In the
118+
example above, printing the loss (``print(loss)``) or adding the loss scalar to
119+
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
120+
these lines are before ``mx.eval(loss, model.parameters())`` then this
121+
will be a partial evaluation, computing only the forward pass.
122+
123+
Also, calling :func:`eval` on an array or set of arrays multiple times is
124+
perfectly fine. This is effectively a no-op.
125+
126+
.. warning::
127+
128+
Using scalar arrays for control-flow will cause an evaluation.
129+
130+
Here is an example:
131+
132+
.. code-block:: python
133+
134+
def fun(x):
135+
h, y = first_layer(x)
136+
if y > 0: # An evaluation is done here!
137+
z = second_layer_a(h)
138+
else:
139+
z = second_layer_b(h)
140+
return z
141+
142+
Using arrays for control flow should be done with care. The above example works
143+
and can even be used with gradient transformations. However, this can be very
144+
inefficient if evaluations are done too frequently.

docs/src/usage/numpy.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ even though no in-place operations on MLX memory are executed.
6262
PyTorch
6363
-------
6464

65+
.. warning::
66+
67+
PyTorch Support for :obj:`memoryview` is experimental and can break for
68+
multi-dimensional arrays. Casting to NumPy first is advised for now.
69+
6570
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
6671

6772
.. code-block:: python

docs/src/usage/quick_start.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ automatically evaluate the array.
4040
>> np.array(c) # Also evaluates c
4141
array([2., 4., 6., 8.], dtype=float32)
4242
43+
44+
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
45+
4346
Function and Graph Transformations
4447
----------------------------------
4548

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
.. _saving_and_loading:
2+
3+
Saving and Loading Arrays
4+
=========================
5+
6+
.. currentmodule:: mlx.core
7+
8+
MLX supports multiple array serialization formats.
9+
10+
.. list-table:: Serialization Formats
11+
:widths: 20 8 25 25
12+
:header-rows: 1
13+
14+
* - Format
15+
- Extension
16+
- Function
17+
- Notes
18+
* - NumPy
19+
- ``.npy``
20+
- :func:`save`
21+
- Single arrays only
22+
* - NumPy archive
23+
- ``.npz``
24+
- :func:`savez` and :func:`savez_compressed`
25+
- Multiple arrays
26+
* - Safetensors
27+
- ``.safetensors``
28+
- :func:`save_safetensors`
29+
- Multiple arrays
30+
* - GGUF
31+
- ``.gguf``
32+
- :func:`save_gguf`
33+
- Multiple arrays
34+
35+
The :func:`load` function will load any of the supported serialization
36+
formats. It determines the format from the extensions. The output of
37+
:func:`load` depends on the format.
38+
39+
Here's an example of saving a single array to a file:
40+
41+
.. code-block:: shell
42+
43+
>>> a = mx.array([1.0])
44+
>>> mx.save("array", a)
45+
46+
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
47+
is automatically added). Including the extension is optional; if it is missing
48+
it will be added. You can load the array with:
49+
50+
.. code-block:: shell
51+
52+
>>> mx.load("array.npy", a)
53+
array([1], dtype=float32)
54+
55+
Here's an example of saving several arrays to a single file:
56+
57+
.. code-block:: shell
58+
59+
>>> a = mx.array([1.0])
60+
>>> b = mx.array([2.0])
61+
>>> mx.savez("arrays", a, b=b)
62+
63+
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
64+
as arguments. If the keywords are missing, then default names will be
65+
provided. This can be loaded with:
66+
67+
.. code-block:: shell
68+
69+
>>> mx.load("arrays.npz")
70+
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
71+
72+
In this case :func:`load` returns a dictionary of names to arrays.
73+
74+
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
75+
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
76+
77+
.. code-block:: shell
78+
79+
>>> a = mx.array([1.0])
80+
>>> b = mx.array([2.0])
81+
>>> mx.save_safetensors("arrays", {"a": a, "b": b})

0 commit comments

Comments
 (0)