|
| 1 | +.. _usage_distributed: |
| 2 | + |
| 3 | +Distributed Communication |
| 4 | +========================= |
| 5 | + |
| 6 | +.. currentmodule:: mlx.core.distributed |
| 7 | + |
| 8 | +MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to |
| 9 | +provide distributed communication operations that allow the computational cost |
| 10 | +of training or inference to be shared across many physical machines. You can |
| 11 | +see a list of the supported operations in the :ref:`API docs<distributed>`. |
| 12 | + |
| 13 | +.. note:: |
| 14 | + A lot of operations may not be supported or not as fast as they should be. |
| 15 | + We are adding more and tuning the ones we have as we are figuring out the |
| 16 | + best way to do distributed computing on Macs using MLX. |
| 17 | + |
| 18 | +Getting Started |
| 19 | +--------------- |
| 20 | + |
| 21 | +MLX already comes with the ability to "talk" to MPI if it is installed on the |
| 22 | +machine. The minimal distributed program in MLX is as simple as: |
| 23 | + |
| 24 | +.. code:: python |
| 25 | +
|
| 26 | + import mlx.core as mx |
| 27 | +
|
| 28 | + world = mx.distributed.init() |
| 29 | + x = mx.distributed.all_sum(mx.ones(10)) |
| 30 | + print(world.rank(), x) |
| 31 | +
|
| 32 | +The program above sums the array ``mx.ones(10)`` across all |
| 33 | +distributed processes. If simply run with ``python``, however, only one |
| 34 | +process is launched and no distributed communication takes place. |
| 35 | + |
| 36 | +To launch the program in distributed mode we need to use ``mpirun`` or |
| 37 | +``mpiexec`` depending on the MPI installation. The simplest possible way is the |
| 38 | +following: |
| 39 | + |
| 40 | +.. code:: shell |
| 41 | +
|
| 42 | + $ mpirun -np 2 python test.py |
| 43 | + 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) |
| 44 | + 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) |
| 45 | +
|
| 46 | +The above launches two processes on the same (local) machine and we can see |
| 47 | +both standard output streams. The processes send the array of 1s to each other |
| 48 | +and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would |
| 49 | +print 4 etc. |
| 50 | + |
| 51 | +Installing MPI |
| 52 | +--------------- |
| 53 | + |
| 54 | +MPI can be installed with Homebrew, using the Anaconda package manager or |
| 55 | +compiled from source. Most of our testing is done using ``openmpi`` installed |
| 56 | +with the Anaconda package manager as follows: |
| 57 | + |
| 58 | +.. code:: shell |
| 59 | +
|
| 60 | + $ conda install openmpi |
| 61 | +
|
| 62 | +Installing with Homebrew may require specifying the location of ``libmpi.dyld`` |
| 63 | +so that MLX can find it and load it at runtime. This can simply be achieved by |
| 64 | +passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``. |
| 65 | + |
| 66 | +.. code:: shell |
| 67 | +
|
| 68 | + $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py |
| 69 | +
|
| 70 | +Setting up Remote Hosts |
| 71 | +----------------------- |
| 72 | + |
| 73 | +MPI can automatically connect to remote hosts and set up the communication over |
| 74 | +the network if the remote hosts can be accessed via ssh. A good checklist to |
| 75 | +debug connectivity issues is the following: |
| 76 | + |
| 77 | +* ``ssh hostname`` works from all machines to all machines without asking for |
| 78 | + password or host confirmation |
| 79 | +* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its |
| 80 | + full path to force all machines to use a specific path. |
| 81 | +* Ensure that the ``hostname`` used by MPI is the one that you have configured |
| 82 | + in the ``.ssh/config`` files on all machines. |
| 83 | + |
| 84 | +.. note:: |
| 85 | + For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as |
| 86 | + the hostname passed to ssh if the current hostname matches ``*.bar.com``. |
| 87 | + |
| 88 | +An easy way to pass the host names to MPI is using a host file. A host file |
| 89 | +looks like the following, where ``host1`` and ``host2`` should be the fully |
| 90 | +qualified domain names or IPs for these hosts. |
| 91 | + |
| 92 | +.. code:: |
| 93 | +
|
| 94 | + host1 slots=1 |
| 95 | + host2 slots=1 |
| 96 | +
|
| 97 | +When using MLX, it is very likely that you want to use 1 slot per host, ie one |
| 98 | +process per host. The hostfile also needs to contain the current |
| 99 | +host if you want to run on the local host. Passing the host file to |
| 100 | +``mpirun`` is simply done using the ``--hostfile`` command line argument. |
| 101 | + |
| 102 | +Training Example |
| 103 | +---------------- |
| 104 | + |
| 105 | +In this section we will adapt an MLX training loop to support data parallel |
| 106 | +distributed training. Namely, we will average the gradients across a set of |
| 107 | +hosts before applying them to the model. |
| 108 | + |
| 109 | +Our training loop looks like the following code snippet if we omit the model, |
| 110 | +dataset and optimizer initialization. |
| 111 | + |
| 112 | +.. code:: python |
| 113 | +
|
| 114 | + model = ... |
| 115 | + optimizer = ... |
| 116 | + dataset = ... |
| 117 | +
|
| 118 | + def step(model, x, y): |
| 119 | + loss, grads = loss_grad_fn(model, x, y) |
| 120 | + optimizer.update(model, grads) |
| 121 | + return loss |
| 122 | +
|
| 123 | + for x, y in dataset: |
| 124 | + loss = step(model, x, y) |
| 125 | + mx.eval(loss, model.parameters()) |
| 126 | +
|
| 127 | +All we have to do to average the gradients across machines is perform an |
| 128 | +:func:`all_sum` and divide by the size of the :class:`Group`. Namely we |
| 129 | +have to :func:`mlx.utils.tree_map` the gradients with following function. |
| 130 | + |
| 131 | +.. code:: python |
| 132 | +
|
| 133 | + def all_avg(x): |
| 134 | + return mx.distributed.all_sum(x) / mx.distributed.init().size() |
| 135 | +
|
| 136 | +Putting everything together our training loop step looks as follows with |
| 137 | +everything else remaining the same. |
| 138 | + |
| 139 | +.. code:: python |
| 140 | +
|
| 141 | + from mlx.utils import tree_map |
| 142 | +
|
| 143 | + def all_reduce_grads(grads): |
| 144 | + N = mx.distributed.init() |
| 145 | + if N == 1: |
| 146 | + return grads |
| 147 | + return tree_map( |
| 148 | + lambda x: mx.distributed.all_sum(x) / N, |
| 149 | + grads) |
| 150 | +
|
| 151 | + def step(model, x, y): |
| 152 | + loss, grads = loss_grad_fn(model, x, y) |
| 153 | + grads = all_reduce_grads(grads) # <--- This line was added |
| 154 | + optimizer.update(model, grads) |
| 155 | + return loss |
| 156 | +
|
| 157 | +Tuning All Reduce |
| 158 | +----------------- |
| 159 | + |
| 160 | +We are working on improving the performance of all reduce on MLX but for now |
| 161 | +the two main things one can do to extract the most out of distributed training with MLX are: |
| 162 | + |
| 163 | +1. Perform a few large reductions instead of many small ones to improve |
| 164 | + bandwidth and latency |
| 165 | +2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp |
| 166 | + connections between each host to improve bandwidth |
0 commit comments