|
| 1 | +#include <memory> |
| 2 | +#include <mutex> |
| 3 | + |
| 4 | +#include <mpi.h> |
| 5 | +#include <nccl.h> |
| 6 | +#include <cuda_runtime.h> |
| 7 | + |
| 8 | +#include "utils.hpp" |
| 9 | +#include "../logging.hpp" |
| 10 | + |
| 11 | +namespace { |
| 12 | + |
| 13 | +[[nodiscard]] auto get_local_rank(MPI_Comm comm, int my_rank) -> int { |
| 14 | + MPI_Comm node_comm; |
| 15 | + MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, my_rank, MPI_INFO_NULL, &node_comm); |
| 16 | + |
| 17 | + int node_rank; |
| 18 | + MPI_Comm_rank(node_comm, &node_rank); |
| 19 | + |
| 20 | + MPI_Comm_free(&node_comm); |
| 21 | + return node_rank; |
| 22 | +} |
| 23 | + |
| 24 | +} // namespace |
| 25 | + |
| 26 | +namespace test_utils { |
| 27 | + |
| 28 | +std::unique_ptr<NcclCtx> NcclCtx::instance_{}; |
| 29 | +std::once_flag NcclCtx::init_flag_{}; |
| 30 | + |
| 31 | +NcclCtx::NcclCtx(ncclComm_t comm, cudaStream_t stream, int dev, int size, int rank) |
| 32 | + : comm_(comm), stream_(stream), dev_(dev), size_(size), rank_(rank) {} |
| 33 | + |
| 34 | +NcclCtx::~NcclCtx() { |
| 35 | + if (stream_ != nullptr) { |
| 36 | + cudaStreamDestroy(stream_); |
| 37 | + } |
| 38 | + if (comm_ != nullptr) { |
| 39 | + ncclCommDestroy(comm_); |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +auto NcclCtx::init(bool verbose) -> void { |
| 44 | + std::call_once(init_flag_, [verbose]() { |
| 45 | + int flag = 0; |
| 46 | + KC_MPI_CHECK(MPI_Initialized(&flag)); |
| 47 | + KC_CHECK(flag != 0, "MPI is not initialized"); |
| 48 | + |
| 49 | + MPI_Comm mpi_comm = MPI_COMM_WORLD; |
| 50 | + |
| 51 | + int size = 0; |
| 52 | + KC_MPI_CHECK(MPI_Comm_size(mpi_comm, &size)); |
| 53 | + int rank = 0; |
| 54 | + KC_MPI_CHECK(MPI_Comm_rank(mpi_comm, &rank)); |
| 55 | + |
| 56 | + int local_rank = get_local_rank(mpi_comm, rank); |
| 57 | + |
| 58 | + int devs = 0; |
| 59 | + KC_CUDA_CHECK(cudaGetDeviceCount(&devs)); |
| 60 | + |
| 61 | + if (verbose) { |
| 62 | + KC_INFO("P{} found {} CUDA devices", rank, devs); |
| 63 | + } |
| 64 | + |
| 65 | + KC_CHECK(local_rank < devs, "P{} needs device #{} but only {} devices available", rank, local_rank, devs); |
| 66 | + |
| 67 | + KC_CUDA_CHECK(cudaSetDevice(local_rank)); |
| 68 | + |
| 69 | + if (verbose) { |
| 70 | + KC_INFO("P{} assigned to CUDA device #{}", rank, local_rank); |
| 71 | + } |
| 72 | + |
| 73 | + ncclUniqueId nccl_id{}; |
| 74 | + if (rank == 0) { |
| 75 | + KC_NCCL_CHECK(ncclGetUniqueId(&nccl_id)); |
| 76 | + } |
| 77 | + |
| 78 | + KC_MPI_CHECK(MPI_Bcast(&nccl_id, NCCL_UNIQUE_ID_BYTES, MPI_CHAR, 0, mpi_comm)); |
| 79 | + |
| 80 | + ncclComm_t nccl_comm = nullptr; |
| 81 | + KC_NCCL_CHECK(ncclCommInitRank(&nccl_comm, size, nccl_id, rank)); |
| 82 | + |
| 83 | + cudaStream_t stream = nullptr; |
| 84 | + KC_CUDA_CHECK(cudaStreamCreate(&stream)); |
| 85 | + |
| 86 | + instance_ = std::unique_ptr<NcclCtx>(new NcclCtx(nccl_comm, stream, local_rank, size, rank)); |
| 87 | + }); |
| 88 | +} |
| 89 | + |
| 90 | +auto NcclCtx::fini() -> void { instance_.reset(); } |
| 91 | + |
| 92 | +auto NcclCtx::get() -> NcclCtx& { |
| 93 | + KC_CHECK(instance_ != nullptr, "NCCL context not initialized"); |
| 94 | + return *instance_; |
| 95 | +} |
| 96 | + |
| 97 | +auto NcclCtx::comm() const -> ncclComm_t { return comm_; } |
| 98 | + |
| 99 | +auto NcclCtx::stream() const -> cudaStream_t { return stream_; } |
| 100 | + |
| 101 | +auto NcclCtx::size() const -> int { return size_; } |
| 102 | + |
| 103 | +auto NcclCtx::rank() const -> int { return rank_; } |
| 104 | + |
| 105 | +auto NcclCtx::device() const -> int { return dev_; } |
| 106 | + |
| 107 | +} // namespace test_utils |
0 commit comments