Skip to content

Commit eba6a9d

Browse files
authored
Compatibility with pip-installed openmpi (#2741)
1 parent be9e2ae commit eba6a9d

File tree

3 files changed

+52
-17
lines changed

3 files changed

+52
-17
lines changed

docs/src/usage/distributed.rst

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ Distributed Communication
77

88
MLX supports distributed communication operations that allow the computational cost
99
of training or inference to be shared across many physical machines. At the
10-
moment we support two different communication backends:
10+
moment we support three different communication backends:
1111

1212
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
1313
full-featured and mature distributed communications library
14-
* A **ring** backend of our own that uses native TCP sockets and should be
15-
faster for thunderbolt connections.
14+
* A **ring** backend of our own that uses native TCP sockets. It should be
15+
faster for thunderbolt connections, but it also works over Ethernet.
16+
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
1617

1718
The list of all currently supported operations and their documentation can be
1819
seen in the :ref:`API docs<distributed>`.
@@ -84,9 +85,8 @@ Selecting Backend
8485
^^^^^^^^^^^^^^^^^
8586

8687
You can select the backend you want to use when calling :func:`init` by passing
87-
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
88-
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
89-
both fail then a singleton group is created.
88+
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
89+
available backends. If they all fail then a singleton group is created.
9090

9191
.. note::
9292
After a distributed backend is successfully initialized :func:`init` will
@@ -220,22 +220,24 @@ print 4 etc.
220220
Installing MPI
221221
^^^^^^^^^^^^^^
222222

223-
MPI can be installed with Homebrew, using the Anaconda package manager or
223+
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
224224
compiled from source. Most of our testing is done using ``openmpi`` installed
225225
with the Anaconda package manager as follows:
226226

227227
.. code:: shell
228228
229229
$ conda install conda-forge::openmpi
230230
231-
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
231+
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
232232
so that MLX can find it and load it at runtime. This can simply be achieved by
233233
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
234-
done automatically by ``mlx.launch``.
234+
done automatically by ``mlx.launch``. Some environments use a non-standard
235+
library filename that can be specified using the ``MPI_LIBNAME`` environment
236+
variable. This is automatically taken care of by ``mlx.launch`` as well.
235237

236238
.. code:: shell
237239
238-
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
240+
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
239241
$ # or simply
240242
$ mlx.launch -n 2 test.py
241243

mlx/distributed/mpi/mpi.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2024 Apple Inc.
22

33
#include <dlfcn.h>
4+
#include <cstdlib>
45
#include <iostream>
56

67
#include "mlx/backend/cpu/encoder.h"
@@ -19,11 +20,17 @@
1920
} \
2021
}
2122

23+
static const char* get_libmpi_name() {
24+
const char* libname = std::getenv("MLX_MPI_LIBNAME");
25+
if (libname != nullptr) {
26+
return libname;
27+
}
2228
#ifdef __APPLE__
23-
static constexpr const char* libmpi_name = "libmpi.dylib";
29+
return "libmpi.dylib";
2430
#else
25-
static constexpr const char* libmpi_name = "libmpi.so";
31+
return "libmpi.so";
2632
#endif
33+
}
2734

2835
namespace mlx::core::distributed::mpi {
2936

@@ -94,7 +101,7 @@ struct MPIWrapper {
94101
MPIWrapper() {
95102
initialized_ = false;
96103

97-
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
104+
libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL);
98105
if (libmpi_handle_ == nullptr) {
99106
return;
100107
}

python/mlx/distributed_run.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ipaddress
66
import json
77
import os
8+
import platform
89
import shlex
910
import shutil
1011
import sys
@@ -386,15 +387,40 @@ def node_thread(rank, host, hostfile, input_queue):
386387
t.join()
387388

388389

390+
def get_mpi_libname():
391+
try:
392+
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
393+
ompi_info = ompi_info.stdout.strip().decode()
394+
395+
if platform.system() == "Darwin":
396+
otool_output = run(
397+
["otool", "-L", ompi_info], check=True, capture_output=True
398+
)
399+
else:
400+
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
401+
otool_output = otool_output.stdout.decode()
402+
403+
# StopIteration if not found
404+
libmpi_line = next(
405+
filter(lambda line: "libmpi" in line, otool_output.splitlines())
406+
)
407+
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
408+
except:
409+
return None
410+
411+
389412
def launch_mpi(parser, hosts, args, command):
390413
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
391414
mpirun = mpirun.stdout.strip().decode()
392415

393-
# Homebrew libmpi doesn't work with anaconda python out of the box.
394-
# TODO: Check if we should do this with every mpirun
395-
if "homebrew" in mpirun:
416+
# Compatibility with homebrew and pip installs
417+
mpi_libname = get_mpi_libname()
418+
if mpi_libname is not None:
396419
dyld = Path(mpirun).parent.parent / "lib"
397-
args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env
420+
args.env = [
421+
f"DYLD_LIBRARY_PATH={str(dyld)}",
422+
f"MLX_MPI_LIBNAME={mpi_libname}",
423+
] + args.env
398424

399425
log(args.verbose, f"Using '{mpirun}'")
400426
with tempfile.NamedTemporaryFile(mode="w") as f:

0 commit comments

Comments
 (0)