Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions heat/core/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,27 @@ class MPILibrary(Enum):
class MPILibraryInfo:
name: MPILibrary
version: str
incompatible_operations: dict[str, list[str]] = dataclasses.field(default_factory=dict)


def _get_mpi_library() -> MPILibraryInfo:
library = mpi4py.MPI.Get_library_version().split()
match library:
case ["Open", "MPI", *_]:
return MPILibraryInfo(MPILibrary.OpenMPI, library[2])
version = library[2]
if version.startswith("v5.0."):
incompatibilities = INCOMPATIBILITIES[MPILibrary.OpenMPI].get("5.0.x", {})
elif version.startswith("v4.1."):
incompatibilities = INCOMPATIBILITIES[MPILibrary.OpenMPI].get("4.1.x", {})
return MPILibraryInfo(MPILibrary.OpenMPI, library[2], incompatibilities)
case ["Intel(R)", "MPI", *_]:
return MPILibraryInfo(MPILibrary.IntelMPI, library[3])
return MPILibraryInfo(MPILibrary.IntelMPI, library[3], {})
case ["MPICH", "Version:", *_]:
return MPILibraryInfo(MPILibrary.MPICH, library[2])
return MPILibraryInfo(MPILibrary.MPICH, library[2], {})
case ["MVAPICH", "Version:", *_]:
return MPILibraryInfo(MPILibrary.MVAPICH, library[2])
return MPILibraryInfo(MPILibrary.MVAPICH, library[2], {})
case ["===", "ParaStation", "MPI", *_]:
return MPILibraryInfo(MPILibrary.ParaStationMPI, library[3])
return MPILibraryInfo(MPILibrary.ParaStationMPI, library[3], {})
case _:
return MPILibraryInfo(MPILibrary.Other, "unknown")

Expand All @@ -68,7 +74,7 @@ def _check_gpu_aware_mpi(library: MPILibraryInfo) -> tuple[bool, bool]:
rocm = "rocm" in extensions or "hip" in extensions
# Seems to be broken, disabled by default for now
# return cuda, rocm
return False, False
return cuda, rocm
except Exception as e: # noqa E722
return False, False
case MPILibrary.IntelMPI:
Expand All @@ -93,6 +99,35 @@ def _check_gpu_aware_mpi(library: MPILibraryInfo) -> tuple[bool, bool]:
return False, False


# Library / version / device
INCOMPATIBILITIES: dict[MPILibrary, dict[str, dict[str, list[str]]]] = {
MPILibrary.IntelMPI: {},
MPILibrary.OpenMPI: {
"5.0.x": {
"cuda": [
"Accumulate",
"Compare_and_swap",
"Fetch_and_op",
"Get_Accumulate",
"Iallgather",
"Iallgatherv",
"Iallreduce",
"Ialltoall",
"Ialltoallv",
"Ialltoallw",
"Ibcast",
"Iscan",
"Iexscan",
"Rget",
"Rput",
"Ireduce",
]
},
"4.1.x": {"cuda": []},
},
}


PLATFORM = platform.platform()
TORCH_VERSION = torch.__version__
TORCH_CUDA_IS_AVAILABLE = torch.cuda.is_available()
Expand Down
95 changes: 66 additions & 29 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .stride_tricks import sanitize_axis

from ._config import GPU_AWARE_MPI
from ._config import GPU_AWARE_MPI, mpi_library


class MPIRequest:
Expand Down Expand Up @@ -55,11 +55,25 @@ def Wait(self, status: MPI.Status = None):
Waits for an MPI request to complete
"""
self.handle.Wait(status)
if self.tensor is not None and isinstance(self.tensor, torch.Tensor):
if self.permutation is not None:
self.recvbuf = self.recvbuf.permute(self.permutation)
if self.tensor is not None and self.tensor.is_cuda and not GPU_AWARE_MPI:
self.tensor.copy_(self.recvbuf)

# Apply permutation if needed (for all buffer types)
if self.permutation is not None and self.recvbuf is not None:
self.recvbuf = self.recvbuf.permute(self.permutation)

# Copy result from CPU back to GPU if needed
if self.tensor is not None:
tensor_device = (
self.tensor.device
if isinstance(self.tensor, torch.Tensor)
else self.tensor.larray.device
)
recvbuf_device = self.recvbuf.device

if tensor_device != recvbuf_device:
if isinstance(self.tensor, torch.Tensor):
self.tensor.copy_(self.recvbuf.to(tensor_device))
else:
self.tensor.larray.copy_(self.recvbuf.to(tensor_device))

def __getattr__(self, name: str) -> Callable:
"""
Expand Down Expand Up @@ -382,8 +396,8 @@ def mpi_type_and_elements_of(
# chain the types based on the
for i in range(len(shape) - 1, -1, -1):
mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i])
mpi_type.Commit()

mpi_type.Commit()
if counts is not None:
return mpi_type, (counts, displs)

Expand Down Expand Up @@ -444,9 +458,25 @@ def as_buffer(
return [mpi_mem, elements, mpi_type]

def _moveToCompDevice(self, x: torch.Tensor, func: Callable | None) -> torch.Tensor:
"""Moves the torch tensor to the relevant device, in case the function is not compatible with the MPI+GPU library."""
"""
Moves the torch tensor to the relevant device, in case the function is not compatible with the MPI+GPU library.

Parameters
----------
x: torch.Tensor
The tensor to be moved to the relevant device
func: Callable
The MPI function that is intended to be called with the tensor, used to check for compatibility with the MPI+GPU library

Returns
-------
torch.Tensor
The tensor on the relevant device for the MPI function
"""
if x.is_cuda:
if GPU_AWARE_MPI:
if GPU_AWARE_MPI and func.__name__ not in mpi_library.incompatible_operations.get(
"cuda", []
):
torch.cuda.synchronize(x.device)
return x
else:
Expand Down Expand Up @@ -847,8 +877,11 @@ def Bcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> None:
Rank of the root process, that broadcasts the message
"""
ret, sbuf, rbuf, buf = self.__broadcast_like(self.handle.Bcast, buf, root)
if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not GPU_AWARE_MPI:
buf.copy_(rbuf)
if buf is not None and not GPU_AWARE_MPI:
if isinstance(buf, torch.Tensor) and buf.is_cuda:
buf.copy_(rbuf)
elif isinstance(buf, DNDarray) and buf.larray.is_cuda:
buf.larray.copy_(rbuf)
return ret

Bcast.__doc__ = MPI.Comm.Bcast.__doc__
Expand Down Expand Up @@ -1084,8 +1117,11 @@ def Allreduce(
The operation to perform upon reduction
"""
ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Allreduce, sendbuf, recvbuf, op)
if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not GPU_AWARE_MPI:
buf.copy_(rbuf)
if buf is not None and not GPU_AWARE_MPI:
if isinstance(buf, torch.Tensor) and buf.is_cuda:
buf.copy_(rbuf)
elif isinstance(buf, DNDarray) and buf.larray.is_cuda:
buf.larray.copy_(rbuf)
return ret

Allreduce.__doc__ = MPI.Comm.Allreduce.__doc__
Expand Down Expand Up @@ -1341,7 +1377,7 @@ def __allgather_like(
rbuf = recvbuf
mpi_recvbuf = recvbuf

# perform the scatter operation
# perform the allgather operation
exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs)

return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation
Expand Down Expand Up @@ -1369,8 +1405,12 @@ def Allgather(
)
if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None:
rbuf = rbuf.permute(permutation)
if isinstance(buf, torch.Tensor) and buf.is_cuda and not GPU_AWARE_MPI:
buf.copy_(rbuf)

if buf is not None and not GPU_AWARE_MPI:
if isinstance(buf, torch.Tensor) and buf.is_cuda:
buf.copy_(rbuf)
elif isinstance(buf, DNDarray) and buf.larray.is_cuda:
buf.larray.copy_(rbuf)
return ret

Allgather.__doc__ = MPI.Comm.Allgather.__doc__
Expand Down Expand Up @@ -1398,8 +1438,12 @@ def Allgatherv(
)
if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None:
rbuf = rbuf.permute(permutation)
if isinstance(buf, torch.Tensor) and buf.is_cuda and not GPU_AWARE_MPI:
buf.copy_(rbuf)
print("Unpermuted")
if buf is not None and not GPU_AWARE_MPI:
if isinstance(buf, torch.Tensor) and buf.is_cuda:
buf.copy_(rbuf)
elif isinstance(buf, DNDarray) and buf.larray.is_cuda:
buf.larray.copy_(rbuf)
return ret

Allgatherv.__doc__ = MPI.Comm.Allgatherv.__doc__
Expand Down Expand Up @@ -1796,7 +1840,6 @@ def _create_recursive_vectortype(
... datatype, tensor_stride, subarray_sizes
... )
"""
datatype_history = []
current_datatype = datatype

i = len(tensor_stride) - 1
Expand All @@ -1816,34 +1859,28 @@ def _create_recursive_vectortype(
next_size = subarray_sizes[i]
new_vector_datatype = current_datatype.Create_vector(
next_size, current_size, current_stride
).Commit()
)

else:
if i == len(tensor_stride) - 1:
new_vector_datatype = current_datatype.Create_vector(
current_size, 1, current_stride
).Commit()
)
else:
new_vector_datatype = current_datatype.Create_vector(
current_size, 1, 1
).Commit()
new_vector_datatype = current_datatype.Create_vector(current_size, 1, 1)

datatype_history.append(new_vector_datatype)
# Set extent of the new datatype to the extent of the basic datatype to allow interweaving of data
next_stride = tensor_stride[i - 1]
new_resized_vector_datatype = new_vector_datatype.Create_resized(
0, datatype.Get_extent()[1] * next_stride
).Commit()
datatype_history.append(new_resized_vector_datatype)
)
current_datatype = new_resized_vector_datatype

i -= 1

displacement = sum([x * y for x, y in zip(tensor_stride, start)]) * datatype.Get_extent()[1]
current_datatype = current_datatype.Create_hindexed_block(1, [displacement]).Commit()

for dt in datatype_history[:-1]:
dt.Free()
return current_datatype

def Ialltoall(
Expand Down
7 changes: 2 additions & 5 deletions tests/core/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,11 +2616,8 @@ def test_largecount_workaround_IsendRecv(self):
)
def test_largecount_workaround_Allreduce(self):
shape = (2**10, 2**11, 2**10)
data = (
torch.zeros(shape, dtype=torch.bool)
if ht.MPI_WORLD.rank % 2 == 0
else torch.ones(shape, dtype=torch.bool)
)
data = torch.zeros(shape, dtype=torch.bool) if ht.MPI_WORLD.rank % 2 == 0 else torch.ones(shape, dtype=torch.bool)

ht.MPI_WORLD.Allreduce(ht.MPI.IN_PLACE, data, op=ht.MPI.SUM)
self.assertTrue(data.all())

Expand Down
Loading