diff --git a/src/KokkosComm/collective.hpp b/src/KokkosComm/collective.hpp index 96c58d8b..fdd21a12 100644 --- a/src/KokkosComm/collective.hpp +++ b/src/KokkosComm/collective.hpp @@ -9,6 +9,12 @@ #include "fwd.hpp" #include "reduction_op.hpp" +#if defined(KOKKOSCOMM_ENABLE_MPI) +#include "mpi/mpi_space.hpp" +#include "mpi/handle.hpp" +#include "mpi/req.hpp" +#include "mpi/broadcast.hpp" +#endif #if defined(KOKKOSCOMM_ENABLE_NCCL) #include "nccl/nccl_space.hpp" #include "nccl/handle.hpp" diff --git a/src/KokkosComm/mpi/broadcast.hpp b/src/KokkosComm/mpi/broadcast.hpp index aca4edfd..6ce625aa 100644 --- a/src/KokkosComm/mpi/broadcast.hpp +++ b/src/KokkosComm/mpi/broadcast.hpp @@ -14,7 +14,8 @@ #include "impl/error_handling.hpp" -namespace KokkosComm::mpi { +namespace KokkosComm { +namespace mpi { template auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req { @@ -56,4 +57,15 @@ void broadcast(ExecSpace const& space, View const& v, int root, MPI_Comm comm) { Kokkos::Tools::popRegion(); } -} // namespace KokkosComm::mpi +} // namespace mpi +namespace Experimental::Impl { + +template +struct Broadcast { + static auto execute(Handle& h, View& v, int root) -> Req { + return KokkosComm::mpi::ibroadcast(h.space(), v, root, h.comm()); + } +}; + +} // namespace Experimental::Impl +} // namespace KokkosComm diff --git a/src/KokkosComm/nccl/broadcast.hpp b/src/KokkosComm/nccl/broadcast.hpp index ae00c5ad..682dafe9 100644 --- a/src/KokkosComm/nccl/broadcast.hpp +++ b/src/KokkosComm/nccl/broadcast.hpp @@ -27,7 +27,7 @@ auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Req req{space.cuda_stream()}; if (KC::is_contiguous(v)) { - ncclBcast(KC::data_handle(v), KC::span(v), datatype, root, comm, space.cuda_stream()); + ncclBcast(KC::data_handle(v), KC::span(v), datatype(), root, comm, space.cuda_stream()); } else { Kokkos::abort("KokkosComm::Experimental::nccl::broadcast: unimplemented for non-contiguous views"); } diff --git a/unit_tests/test_broadcast.cpp b/unit_tests/test_broadcast.cpp new file mode 100644 index 00000000..d76b45e3 --- /dev/null +++ b/unit_tests/test_broadcast.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#include +#include +#include + +#if defined(KOKKOSCOMM_ENABLE_NCCL) +#include "nccl/utils.hpp" +#endif + +namespace { + +template +class Broadcast : public testing::Test { + public: + using Scalar = T; +}; +#if defined(KOKKOSCOMM_ENABLE_NCCL) +using ScalarTypes = testing::Types; +#else +using ScalarTypes = + testing::Types, Kokkos::complex, int, unsigned, int64_t, size_t>; +#endif +TYPED_TEST_SUITE(Broadcast, ScalarTypes); + +template +auto broadcast_0d() -> void { +#if defined(KOKKOSCOMM_ENABLE_NCCL) + using ExecSpace = Kokkos::Cuda; + auto nccl_ctx = test_utils::nccl::Ctx::init(); + KokkosComm::Handle h(ExecSpace(), nccl_ctx.comm()); +#else + using ExecSpace = Kokkos::DefaultExecutionSpace; + KokkosComm::Handle h{}; +#endif + int rank = h.rank(); + int size = h.size(); + int root = 0; + + Kokkos::View v("v"); + if (rank == root) { + // Prepare broadcast view + Kokkos::parallel_for( + Kokkos::RangePolicy(ExecSpace(), 0, v.extent(0)), KOKKOS_LAMBDA(const int) { v() = size; }); + } + // Using the same execution space for both operations lets us not need an explicit `fence` + auto req = KokkosComm::Experimental::broadcast(h, v, root); + KokkosComm::wait(req); + + int errs; + Kokkos::parallel_reduce( + v.extent(0), KOKKOS_LAMBDA(const int, int& lsum) { lsum += v() != size; }, errs); + EXPECT_EQ(errs, 0); +} + +template +auto broadcast_contig_1d() -> void { +#if defined(KOKKOSCOMM_ENABLE_NCCL) + using ExecSpace = Kokkos::Cuda; + auto nccl_ctx = test_utils::nccl::Ctx::init(); + KokkosComm::Handle h(ExecSpace(), nccl_ctx.comm()); +#else + using ExecSpace = Kokkos::DefaultExecutionSpace; + KokkosComm::Handle h{}; +#endif + int rank = h.rank(); + int size = h.size(); + int root = 0; + + Kokkos::View v("v", 100); + if (rank == root) { + // Prepare broadcast view + Kokkos::parallel_for( + Kokkos::RangePolicy(ExecSpace(), 0, v.extent(0)), KOKKOS_LAMBDA(const int i) { v(i) = size + i; }); + } + // Using the same execution space for both operations lets us not need an explicit `fence` + auto req = KokkosComm::Experimental::broadcast(h, v, root); + KokkosComm::wait(req); + + int errs; + Kokkos::parallel_reduce( + v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += (v(i) != size + i); }, errs); + EXPECT_EQ(errs, 0); +} + +TYPED_TEST(Broadcast, 0D) { broadcast_0d(); } +TYPED_TEST(Broadcast, Contiguous1D) { broadcast_contig_1d(); } + +} // namespace