Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/KokkosComm/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ template <typename T>
concept KokkosExecutionSpace = Kokkos::is_execution_space_v<T>;

template <typename T>
concept CommunicationSpace = KokkosComm::Impl::is_communication_space<T>::value;
concept CommunicationSpace = requires {
KokkosComm::Impl::is_communication_space<T>::value;
typename T::communication_space;
typename T::handle_type;
typename T::request_type;
typename T::datatype_type;
typename T::reduction_op_type;
typename T::rank_type;
};

template <typename T>
concept ReductionOperator = KokkosComm::Impl::is_reduction_operator<T>::value;
Expand Down
19 changes: 10 additions & 9 deletions src/KokkosComm/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@

namespace KokkosComm {

// NCCL backend also implicitly declares MPI
#if defined(KOKKOSCOMM_ENABLE_NCCL)
namespace Experimental {
class Nccl;
} // namespace Experimental
class Mpi;
using DefaultCommunicationSpace = Experimental::Nccl;
using FallbackCommunicationSpace = Mpi;
struct NcclSpace;
}
// NCCL backend also declares the MPI space as fallback
struct MpiSpace;

using DefaultCommunicationSpace = Experimental::NcclSpace;
using FallbackCommunicationSpace = MpiSpace;
#elif defined(KOKKOSCOMM_ENABLE_MPI)
class Mpi;
using DefaultCommunicationSpace = Mpi;
using FallbackCommunicationSpace = Mpi;
struct MpiSpace;
using DefaultCommunicationSpace = MpiSpace;
using FallbackCommunicationSpace = MpiSpace;
#else
#error at least one communication space must be enabled
#endif
Expand Down
14 changes: 7 additions & 7 deletions src/KokkosComm/mpi/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Channel {

void wait() {
Kokkos::Tools::pushRegion("KokkosComm::Channel::wait");
std::vector<Req<Mpi>> reqs;
std::vector<Req<MpiSpace>> reqs;
reqs.reserve(send_reqs_.size() + recv_reqs_.size());
reqs.insert(reqs.end(), send_reqs_.begin(), send_reqs_.end());
reqs.insert(reqs.end(), recv_reqs_.begin(), recv_reqs_.end());
Expand All @@ -66,12 +66,12 @@ class Channel {
}

private:
std::vector<Req<Mpi>> send_reqs_; // Queue for send requests
std::vector<Req<Mpi>> recv_reqs_; // Queue for receive requests
int dest_rank_; // Destination rank for send
int src_rank_; // Source rank for receive
int tag_; // MPI tag
MPI_Comm comm_; // MPI communicator
std::vector<Req<MpiSpace>> send_reqs_; // Queue for send requests
std::vector<Req<MpiSpace>> recv_reqs_; // Queue for receive requests
int dest_rank_; // Destination rank for send
int src_rank_; // Source rank for receive
int tag_; // MPI tag
MPI_Comm comm_; // MPI communicator
};

} // namespace KokkosComm
4 changes: 2 additions & 2 deletions src/KokkosComm/mpi/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ namespace KokkosComm {
- post-wait
*/
template <KokkosExecutionSpace ExecSpace>
class Handle<ExecSpace, Mpi> {
class Handle<ExecSpace, MpiSpace> {
public:
using execution_space = ExecSpace;
using transport_type = Mpi;
using transport_type = MpiSpace;
using size_type = int;

explicit Handle(const execution_space &space, MPI_Comm comm) : space_(space), comm_(comm) {}
Expand Down
6 changes: 3 additions & 3 deletions src/KokkosComm/mpi/irecv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ namespace Impl {

// Recv implementation for Mpi
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
struct Recv<RecvView, ExecSpace, Mpi> {
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const RecvView &rv, int src) {
struct Recv<RecvView, ExecSpace, MpiSpace> {
static Req<MpiSpace> execute(Handle<ExecSpace, MpiSpace> &h, const RecvView &rv, int src) {
using KCPT = KokkosComm::PackTraits<RecvView>;
using Packer = typename KCPT::packer_type;
using Args = typename Packer::args_type;

const ExecSpace &space = h.space();

Req<Mpi> req;
Req<MpiSpace> req;
if (KokkosComm::is_contiguous(rv)) {
space.fence("fence before irecv");
MPI_Irecv(KokkosComm::data_handle(rv), KokkosComm::span(rv), mpi_type_v<typename RecvView::value_type>, src,
Expand Down
12 changes: 6 additions & 6 deletions src/KokkosComm/mpi/isend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace KokkosComm {
namespace Impl {

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, mpi::CommunicationMode SendMode>
Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag, SendMode) {
Req<MpiSpace> isend_impl(Handle<ExecSpace, MpiSpace> &h, const SendView &sv, int dest, int tag, SendMode) {
auto mpi_isend_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
MPI_Comm mpi_comm, MPI_Request *mpi_req) {
if constexpr (std::is_same_v<SendMode, mpi::CommModeStandard>) {
Expand All @@ -34,7 +34,7 @@ Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int
}
};

Req<Mpi> req;
Req<MpiSpace> req;
if (KokkosComm::is_contiguous(sv)) {
h.space().fence("fence before isend");
mpi_isend_fn(KokkosComm::data_handle(sv), KokkosComm::span(sv), mpi_type_v<typename SendView::value_type>, dest,
Expand All @@ -55,8 +55,8 @@ Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int

// Implementation of KokkosComm::Send
template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
struct Send<SendView, ExecSpace, Mpi> {
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest) {
struct Send<SendView, ExecSpace, MpiSpace> {
static Req<MpiSpace> execute(Handle<ExecSpace, MpiSpace> &h, const SendView &sv, int dest) {
return isend_impl<ExecSpace, SendView>(h, sv, dest, POINTTOPOINT_TAG, mpi::DefaultCommMode{});
}
};
Expand All @@ -65,12 +65,12 @@ struct Send<SendView, ExecSpace, Mpi> {
namespace mpi {

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, CommunicationMode SendMode>
Req<Mpi> isend(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag, SendMode) {
Req<MpiSpace> isend(Handle<ExecSpace, MpiSpace> &h, const SendView &sv, int dest, int tag, SendMode) {
return KokkosComm::Impl::isend_impl<ExecSpace, SendView>(h, sv, dest, tag, SendMode{});
}

template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
Req<Mpi> isend(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) {
Req<MpiSpace> isend(Handle<ExecSpace, MpiSpace> &h, const SendView &sv, int dest, int tag) {
return isend<ExecSpace, SendView>(h, sv, dest, tag, DefaultCommMode{});
}

Expand Down
32 changes: 12 additions & 20 deletions src/KokkosComm/mpi/mpi_space.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,18 @@

namespace KokkosComm {

// TODO: not sure what members this thing needs
struct Mpi {
// TODO: just an example
static int world_size() {
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
}

// TODO: just an example
static int world_rank() {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}

}; // struct Mpi

// KokkosComm::Mpi is a KokkosComm::CommunicationSpace
/// The MPI communication space.
struct MpiSpace {
using communication_space = MpiSpace;
using handle_type = MPI_Comm;
using request_type = MPI_Request;
using datatype_type = MPI_Datatype;
using reduction_op_type = MPI_Op;
using rank_type = int;
};

// KokkosComm::MpiSpace is a KokkosComm::CommunicationSpace
template <>
struct Impl::is_communication_space<KokkosComm::Mpi> : public std::true_type {};
struct Impl::is_communication_space<MpiSpace> : public std::true_type {};

} // namespace KokkosComm
72 changes: 69 additions & 3 deletions src/KokkosComm/mpi/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,67 @@

#include <KokkosComm/concepts.hpp>
#include <KokkosComm/traits.hpp>
#include <KokkosComm/reduction_op.hpp>
#include "mpi_space.hpp"
#include "req.hpp"

#include "impl/error_handling.hpp"
#include "impl/pack_traits.hpp"
#include "impl/types.hpp"
#include "impl/error_handling.hpp"

namespace KokkosComm::mpi {
namespace KokkosComm {
namespace mpi {

template <KokkosView SView, KokkosView RView, KokkosExecutionSpace ExecSpace>
auto ireduce(const ExecSpace &space, const SView sv, RView rv, MPI_Op op, int root, MPI_Comm comm) -> Req<MpiSpace> {
using ST = typename SView::non_const_value_type;
using RT = typename RView::non_const_value_type;
using SPkr = typename PackTraits<SView>::packer_type;
using RPkr = typename PackTraits<RView>::packer_type;
static_assert(std::is_same_v<ST, RT>, "KokkosComm::mpi::ireduce: View value types must be identical");
static_assert(rank<SView>() <= 1 and rank<RView>() <= 1,
"KokkosComm::mpi::ireduce: Views with rank higher than 1 are not supported");
Kokkos::Tools::pushRegion("KokkosComm::mpi::ireduce");

const int rank = [=]() {
int _r;
MPI_Comm_rank(comm, &_r);
return _r;
}();

Req<MpiSpace> req;
if (is_contiguous(sv)) {
if (rank == root and not is_contiguous(rv)) {
auto pkd_rv = RPkr::allocate_packed_for(space, "KC::mpi::ireduce c_sv pkd_rv", rv);
space.fence("fence allocation before MPI call");
MPI_Ireduce(data_handle(sv), data_handle(pkd_rv.view), span(sv), Impl::mpi_type_v<ST>, op, root, comm,
&req.mpi_request());
RPkr::unpack_into(space, rv, pkd_rv.view);
req.extend_view_lifetime(pkd_rv.view);
} else {
MPI_Ireduce(data_handle(sv), data_handle(rv), span(sv), Impl::mpi_type_v<ST>, op, root, comm, &req.mpi_request());
}
} else {
auto send_args = SPkr::pack(space, sv);
if (rank == root and !is_contiguous(rv)) {
auto pkd_rv = RPkr::allocate_packed_for(space, "KC::mpi::ireduce nc_sv pkd_rv", rv);
space.fence("fence allocation before MPI call");
MPI_Ireduce(data_handle(send_args.view), data_handle(pkd_rv.view), send_args.count, Impl::mpi_type_v<ST>, op,
root, comm, &req.mpi_request());
RPkr::unpack_into(space, rv, pkd_rv.view);
req.extend_view_lifetime(pkd_rv.view);
} else {
MPI_Ireduce(data_handle(send_args.view), data_handle(rv), send_args.count, Impl::mpi_type_v<ST>, op, root, comm,
&req.mpi_request());
}
req.extend_view_lifetime(send_args.view);
}
req.extend_view_lifetime(sv);
req.extend_view_lifetime(rv);

Kokkos::Tools::popRegion();
return req;
}

template <KokkosView SendView, KokkosView RecvView>
void reduce(const SendView &sv, const RecvView &rv, MPI_Op op, int root, MPI_Comm comm) {
Expand Down Expand Up @@ -73,4 +128,15 @@ void reduce(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_
Kokkos::Tools::popRegion();
}

} // namespace KokkosComm::mpi
} // namespace mpi
namespace Experimental::Impl {

template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp, KokkosExecutionSpace ExecSpace>
struct Reduce<SendView, RecvView, RedOp, ExecSpace, MpiSpace> {
static auto execute(Handle<ExecSpace, MpiSpace> &h, const SendView sv, RecvView rv, int root) -> Req<MpiSpace> {
return mpi::ireduce(h.space(), sv, rv, reduction_op<MpiSpace, RedOp>(), root, h.comm());
}
};

} // namespace Experimental::Impl
} // namespace KokkosComm
22 changes: 11 additions & 11 deletions src/KokkosComm/mpi/req.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace KokkosComm {

template <>
class Req<Mpi> {
class Req<MpiSpace> {
// a type-erased view. Request uses these to keep temporary views alive for
// the lifetime of "Immediate" MPI operations
struct ViewHolderBase {
Expand Down Expand Up @@ -52,34 +52,34 @@ class Req<Mpi> {
private:
std::shared_ptr<Record> record_;

friend void wait(Req<Mpi> &req);
friend void wait(Req<Mpi> &&req);
friend void wait_all(std::span<Req<Mpi>> reqs);
friend void wait_any(std::span<Req<Mpi>> reqs);
friend void wait(Req<MpiSpace> &req);
friend void wait(Req<MpiSpace> &&req);
friend void wait_all(std::span<Req<MpiSpace>> reqs);
friend void wait_any(std::span<Req<MpiSpace>> reqs);
};

inline void wait(Req<Mpi> &req) {
inline void wait(Req<MpiSpace> &req) {
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}

inline void wait(Req<Mpi> &&req) { wait(req); }
inline void wait(Req<MpiSpace> &&req) { wait(req); }

inline void wait_all(std::span<Req<Mpi>> reqs) {
for (Req<Mpi> &req : reqs) {
inline void wait_all(std::span<Req<MpiSpace>> reqs) {
for (Req<MpiSpace> &req : reqs) {
wait(req);
}
}

/// FIXME: This function will loop indefinitely if all requests in the list are equivalent to `MPI_REQUEST_NULL`.
/// FIXME: This function should return the index of the completed request, if any.
inline void wait_any(std::span<Req<Mpi>> reqs) {
inline void wait_any(std::span<Req<MpiSpace>> reqs) {
// FIXME: Active wait-loop
while (true) {
for (Req<Mpi> &req : reqs) {
for (Req<MpiSpace> &req : reqs) {
int flag;
MPI_Test(&(req.mpi_request()), &flag, MPI_STATUS_IGNORE);
if (flag) {
Expand Down
8 changes: 4 additions & 4 deletions src/KokkosComm/nccl/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace nccl {
namespace KC = KokkosComm;

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
auto allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, ncclComm_t comm) -> Req<Nccl> {
auto allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, ncclComm_t comm) -> Req<NcclSpace> {
using ST = typename SendView::non_const_value_type;
using RT = typename RecvView::non_const_value_type;
static_assert(std::is_same_v<ST, RT>,
Expand All @@ -27,7 +27,7 @@ auto allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, n
"KokkosComm::nccl::allgather: Views with rank higher than 1 are not supported");
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::allgather");

Req<Nccl> req{space.cuda_stream()};
Req<NcclSpace> req{space.cuda_stream()};
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
ncclAllGather(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), Impl::datatype_v<ST>, comm,
space.cuda_stream());
Expand All @@ -45,8 +45,8 @@ auto allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, n
namespace Impl {

template <KokkosView SendView, KokkosView RecvView>
struct AllGather<SendView, RecvView, Kokkos::Cuda, Nccl> {
static auto execute(Handle<Kokkos::Cuda, Nccl> &h, const SendView sv, RecvView rv) -> Req<Nccl> {
struct AllGather<SendView, RecvView, Kokkos::Cuda, NcclSpace> {
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv) -> Req<NcclSpace> {
return nccl::allgather(h.space(), sv, rv, h.comm());
}
};
Expand Down
11 changes: 6 additions & 5 deletions src/KokkosComm/nccl/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <KokkosComm/concepts.hpp>
#include <KokkosComm/traits.hpp>
#include <KokkosComm/reduction_op.hpp>

#include "impl/pack_traits.hpp"
#include "impl/types.hpp"
Expand All @@ -19,7 +20,7 @@ namespace KC = KokkosComm;

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
auto allreduce(const ExecSpace &space, const SendView &sv, const RecvView &rv, ncclRedOp_t op, ncclComm_t comm)
-> Req<Nccl> {
-> Req<NcclSpace> {
using ST = typename SendView::non_const_value_type;
using RT = typename RecvView::non_const_value_type;
static_assert(std::is_same_v<ST, RT>,
Expand All @@ -28,7 +29,7 @@ auto allreduce(const ExecSpace &space, const SendView &sv, const RecvView &rv, n
"KokkosComm::Experimental::nccl::allreduce: Views with rank higher than 1 are not supported");
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::allreduce");

Req<Nccl> req{space.cuda_stream()};
Req<NcclSpace> req{space.cuda_stream()};
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
ncclAllReduce(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), Impl::datatype_v<ST>, op, comm,
space.cuda_stream());
Expand All @@ -46,9 +47,9 @@ auto allreduce(const ExecSpace &space, const SendView &sv, const RecvView &rv, n
namespace Impl {

template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
struct AllReduce<SendView, RecvView, RedOp, Kokkos::Cuda, Nccl> {
static auto execute(Handle<Kokkos::Cuda, Nccl> &h, const SendView sv, RecvView rv) -> Req<Nccl> {
return nccl::allreduce(h.space(), sv, rv, nccl::Impl::reduction_op_v<RedOp>, h.comm());
struct AllReduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv) -> Req<NcclSpace> {
return nccl::allreduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), h.comm());
}
};

Expand Down
Loading