diff --git a/heat/core/communication.py b/heat/core/communication.py index 8e41cff59..b2e1209cf 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -8,23 +8,33 @@ import os import subprocess import torch +import warnings from mpi4py import MPI from typing import Any, Callable, Optional, List, Tuple, Union from .stride_tricks import sanitize_axis CUDA_AWARE_MPI = False -# check whether OpenMPI support CUDA-aware MPI -if "openmpi" in os.environ.get("MPI_SUFFIX", "").lower(): +# check whether there is CUDA-aware OpenMPI +try: buffer = subprocess.check_output(["ompi_info", "--parsable", "--all"]) CUDA_AWARE_MPI = b"mpi_built_with_cuda_support:value:true" in buffer -# MVAPICH +except: # noqa E722 + pass +# do the same for MVAPICH CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MV2_USE_CUDA") == "1" -# MPICH +# do the same for MPICH CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MPIR_CVAR_ENABLE_HCOLL") == "1" -# ParaStationMPI +# do the same for ParaStationMPI CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("PSP_CUDA") == "1" +# warn the user if CUDA-aware MPI is not available, but PyTorch can use GPUs +if torch.cuda.is_available() and not CUDA_AWARE_MPI: + warnings.warn( + f"Heat has GPU-support (PyTorch version {torch.__version__}), but CUDA-awareness of MPI could not be detected. \n This may lead to performance degradation as direct MPI-communication between GPUs is not possible.", + UserWarning, + ) + class MPIRequest: """