Skip to content

Commit 50dfb66

Browse files
authored
Comms (#1097)
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
1 parent 0189ab6 commit 50dfb66

File tree

19 files changed

+913
-1
lines changed

19 files changed

+913
-1
lines changed

.circleci/config.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ jobs:
7171
name: Install dependencies
7272
command: |
7373
brew install [email protected]
74+
brew install openmpi
7475
python3.8 -m venv env
7576
source env/bin/activate
7677
pip install --upgrade pip
@@ -96,6 +97,7 @@ jobs:
9697
source env/bin/activate
9798
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
9899
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
100+
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
99101
- run:
100102
name: Build example extension
101103
command: |

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ else()
167167
set(MLX_BUILD_ACCELERATE OFF)
168168
endif()
169169

170+
find_package(MPI)
171+
if (MPI_FOUND)
172+
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
173+
endif()
174+
170175
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
171176

172177
target_include_directories(

examples/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
99
build_example(linear_regression.cpp)
1010
build_example(logistic_regression.cpp)
1111
build_example(metal_capture.cpp)
12+
build_example(distributed.cpp)

examples/cpp/distributed.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#include <iostream>
4+
5+
#include "mlx/mlx.h"
6+
7+
using namespace mlx::core;
8+
9+
int main() {
10+
if (!distributed::is_available()) {
11+
std::cout << "No communication backend found" << std::endl;
12+
return 1;
13+
}
14+
15+
auto global_group = distributed::init();
16+
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
17+
18+
array x = ones({10});
19+
array out = distributed::all_reduce_sum(x, global_group);
20+
21+
std::cout << out << std::endl;
22+
}

mlx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else()
2525
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
2626
endif()
2727

28+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
2829
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
2930
if (MLX_BUILD_ACCELERATE)
3031
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

mlx/distributed/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
target_sources(
2+
mlx
3+
PRIVATE
4+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
5+
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
6+
)
7+
8+
if (MPI_FOUND AND MLX_BUILD_CPU)
9+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
10+
else()
11+
target_sources(
12+
mlx
13+
PRIVATE
14+
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
15+
)
16+
endif()

mlx/distributed/distributed.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#pragma once
4+
5+
#include <memory>
6+
7+
#include "mlx/array.h"
8+
9+
namespace mlx::core::distributed {
10+
11+
/* Check if a communication backend is available */
12+
bool is_available();
13+
14+
/**
15+
* A distributed::Group represents a group of independent mlx processes that
16+
* can communicate. We must also be able to create sub-groups from a group in
17+
* order to define more granular communication.
18+
*/
19+
struct Group {
20+
Group(std::shared_ptr<void> group) : group_(group) {}
21+
22+
int rank();
23+
int size();
24+
25+
/**
26+
* Split the group according to the provided color. Namely processes that use
27+
* the same color will go to the same group.
28+
*
29+
* The key defines the rank of the processes in the new group. The smaller
30+
* the key the smaller the rank. If the provided key is negative, then the
31+
* rank in the current group is used.
32+
*/
33+
Group split(int color, int key = -1);
34+
35+
const std::shared_ptr<void>& raw_group() {
36+
return group_;
37+
}
38+
39+
private:
40+
std::shared_ptr<void> group_{nullptr};
41+
};
42+
43+
/**
44+
* Initialize the distributed backend and return the group containing all
45+
* discoverable processes.
46+
*/
47+
Group init();
48+
49+
namespace detail {
50+
51+
/* Return the communication stream. */
52+
Stream communication_stream();
53+
54+
/* Perform an all reduce sum operation */
55+
void all_reduce_sum(Group group, const array& input, array& output);
56+
57+
/* Perform an all reduce sum operation */
58+
void all_gather(Group group, const array& input, array& output);
59+
60+
} // namespace detail
61+
62+
} // namespace mlx::core::distributed

mlx/distributed/mpi/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
target_sources(
2+
mlx
3+
PRIVATE
4+
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
5+
)

0 commit comments

Comments
 (0)