Skip to content

Commit 0163a8e

Browse files
authored
Add docs for the distributed namespace (#1184)
1 parent 5788429 commit 0163a8e

File tree

12 files changed

+202
-15
lines changed

12 files changed

+202
-15
lines changed

docs/src/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ are the CPU and GPU.
4343
usage/function_transforms
4444
usage/compile
4545
usage/numpy
46+
usage/distributed
4647
usage/using_streams
4748

4849
.. toctree::
@@ -69,6 +70,7 @@ are the CPU and GPU.
6970
python/metal
7071
python/nn
7172
python/optimizers
73+
python/distributed
7274
python/tree_utils
7375

7476
.. toctree::

docs/src/python/distributed.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.. _distributed:
2+
3+
.. currentmodule:: mlx.core.distributed
4+
5+
Distributed Communication
6+
==========================
7+
8+
MLX provides a distributed communication package using MPI. The MPI library is
9+
loaded at runtime; if MPI is available then distributed communication is also
10+
made available.
11+
12+
.. autosummary::
13+
:toctree: _autosummary
14+
15+
Group
16+
is_available
17+
init
18+
all_sum
19+
all_gather

docs/src/usage/distributed.rst

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
.. _usage_distributed:
2+
3+
Distributed Communication
4+
=========================
5+
6+
.. currentmodule:: mlx.core.distributed
7+
8+
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
9+
provide distributed communication operations that allow the computational cost
10+
of training or inference to be shared across many physical machines. You can
11+
see a list of the supported operations in the :ref:`API docs<distributed>`.
12+
13+
.. note::
14+
A lot of operations may not be supported or not as fast as they should be.
15+
We are adding more and tuning the ones we have as we are figuring out the
16+
best way to do distributed computing on Macs using MLX.
17+
18+
Getting Started
19+
---------------
20+
21+
MLX already comes with the ability to "talk" to MPI if it is installed on the
22+
machine. The minimal distributed program in MLX is as simple as:
23+
24+
.. code:: python
25+
26+
import mlx.core as mx
27+
28+
world = mx.distributed.init()
29+
x = mx.distributed.all_sum(mx.ones(10))
30+
print(world.rank(), x)
31+
32+
The program above sums the array ``mx.ones(10)`` across all
33+
distributed processes. If simply run with ``python``, however, only one
34+
process is launched and no distributed communication takes place.
35+
36+
To launch the program in distributed mode we need to use ``mpirun`` or
37+
``mpiexec`` depending on the MPI installation. The simplest possible way is the
38+
following:
39+
40+
.. code:: shell
41+
42+
$ mpirun -np 2 python test.py
43+
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
44+
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
45+
46+
The above launches two processes on the same (local) machine and we can see
47+
both standard output streams. The processes send the array of 1s to each other
48+
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
49+
print 4 etc.
50+
51+
Installing MPI
52+
---------------
53+
54+
MPI can be installed with Homebrew, using the Anaconda package manager or
55+
compiled from source. Most of our testing is done using ``openmpi`` installed
56+
with the Anaconda package manager as follows:
57+
58+
.. code:: shell
59+
60+
$ conda install openmpi
61+
62+
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
63+
so that MLX can find it and load it at runtime. This can simply be achieved by
64+
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
65+
66+
.. code:: shell
67+
68+
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
69+
70+
Setting up Remote Hosts
71+
-----------------------
72+
73+
MPI can automatically connect to remote hosts and set up the communication over
74+
the network if the remote hosts can be accessed via ssh. A good checklist to
75+
debug connectivity issues is the following:
76+
77+
* ``ssh hostname`` works from all machines to all machines without asking for
78+
password or host confirmation
79+
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
80+
full path to force all machines to use a specific path.
81+
* Ensure that the ``hostname`` used by MPI is the one that you have configured
82+
in the ``.ssh/config`` files on all machines.
83+
84+
.. note::
85+
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
86+
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
87+
88+
An easy way to pass the host names to MPI is using a host file. A host file
89+
looks like the following, where ``host1`` and ``host2`` should be the fully
90+
qualified domain names or IPs for these hosts.
91+
92+
.. code::
93+
94+
host1 slots=1
95+
host2 slots=1
96+
97+
When using MLX, it is very likely that you want to use 1 slot per host, ie one
98+
process per host. The hostfile also needs to contain the current
99+
host if you want to run on the local host. Passing the host file to
100+
``mpirun`` is simply done using the ``--hostfile`` command line argument.
101+
102+
Training Example
103+
----------------
104+
105+
In this section we will adapt an MLX training loop to support data parallel
106+
distributed training. Namely, we will average the gradients across a set of
107+
hosts before applying them to the model.
108+
109+
Our training loop looks like the following code snippet if we omit the model,
110+
dataset and optimizer initialization.
111+
112+
.. code:: python
113+
114+
model = ...
115+
optimizer = ...
116+
dataset = ...
117+
118+
def step(model, x, y):
119+
loss, grads = loss_grad_fn(model, x, y)
120+
optimizer.update(model, grads)
121+
return loss
122+
123+
for x, y in dataset:
124+
loss = step(model, x, y)
125+
mx.eval(loss, model.parameters())
126+
127+
All we have to do to average the gradients across machines is perform an
128+
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
129+
have to :func:`mlx.utils.tree_map` the gradients with following function.
130+
131+
.. code:: python
132+
133+
def all_avg(x):
134+
return mx.distributed.all_sum(x) / mx.distributed.init().size()
135+
136+
Putting everything together our training loop step looks as follows with
137+
everything else remaining the same.
138+
139+
.. code:: python
140+
141+
from mlx.utils import tree_map
142+
143+
def all_reduce_grads(grads):
144+
N = mx.distributed.init()
145+
if N == 1:
146+
return grads
147+
return tree_map(
148+
lambda x: mx.distributed.all_sum(x) / N,
149+
grads)
150+
151+
def step(model, x, y):
152+
loss, grads = loss_grad_fn(model, x, y)
153+
grads = all_reduce_grads(grads) # <--- This line was added
154+
optimizer.update(model, grads)
155+
return loss
156+
157+
Tuning All Reduce
158+
-----------------
159+
160+
We are working on improving the performance of all reduce on MLX but for now
161+
the two main things one can do to extract the most out of distributed training with MLX are:
162+
163+
1. Perform a few large reductions instead of many small ones to improve
164+
bandwidth and latency
165+
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
166+
connections between each host to improve bandwidth

examples/cpp/distributed.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int main() {
1616
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
1717

1818
array x = ones({10});
19-
array out = distributed::all_reduce_sum(x, global_group);
19+
array out = distributed::all_sum(x, global_group);
2020

2121
std::cout << out << std::endl;
2222
}

mlx/distributed/distributed.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ namespace detail {
5656
Stream communication_stream();
5757

5858
/* Perform an all reduce sum operation */
59-
void all_reduce_sum(Group group, const array& input, array& output);
59+
void all_sum(Group group, const array& input, array& output);
6060

6161
/* Perform an all reduce sum operation */
6262
void all_gather(Group group, const array& input, array& output);

mlx/distributed/mpi/mpi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ Stream communication_stream() {
260260
return comm_stream;
261261
}
262262

263-
void all_reduce_sum(Group group, const array& input_, array& output) {
263+
void all_sum(Group group, const array& input_, array& output) {
264264
array input = ensure_row_contiguous(input_);
265265
mpi().all_reduce(
266266
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE

mlx/distributed/no_distributed.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Stream communication_stream() {
3131
return comm_stream;
3232
}
3333

34-
void all_reduce_sum(Group group, const array& input, array& output) {}
34+
void all_sum(Group group, const array& input, array& output) {}
3535
void all_gather(Group group, const array& input, array& output) {}
3636

3737
} // namespace detail

mlx/distributed/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Group to_group(std::optional<Group> group) {
1717

1818
} // namespace
1919

20-
array all_reduce_sum(const array& x, std::optional<Group> group_) {
20+
array all_sum(const array& x, std::optional<Group> group_) {
2121
auto group = to_group(group_);
2222

2323
if (group.size() == 1) {

mlx/distributed/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace mlx::core::distributed {
1010

11-
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
11+
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
1212
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
1313

1414
} // namespace mlx::core::distributed

mlx/distributed/primitives.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void AllReduce::eval_cpu(
2424

2525
switch (reduce_type_) {
2626
case Sum:
27-
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
27+
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
2828
break;
2929
default:
3030
throw std::runtime_error("Only all reduce sum is supported for now");
@@ -36,7 +36,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
3636
const std::vector<int>& axes) {
3737
switch (reduce_type_) {
3838
case Sum:
39-
return {{all_reduce_sum(inputs[0], group())}, axes};
39+
return {{all_sum(inputs[0], group())}, axes};
4040
default:
4141
throw std::runtime_error("Only all reduce sum is supported for now");
4242
}
@@ -48,7 +48,7 @@ std::vector<array> AllReduce::jvp(
4848
const std::vector<int>& argnums) {
4949
switch (reduce_type_) {
5050
case Sum:
51-
return {all_reduce_sum(tangents[0], group())};
51+
return {all_sum(tangents[0], group())};
5252
default:
5353
throw std::runtime_error("Only all reduce sum is supported for now");
5454
}

0 commit comments

Comments
 (0)