Skip to content

Commit c2d6d81

Browse files
authored
Merge branch 'develop' into feature/core-alltoall
2 parents 2afea42 + 498e171 commit c2d6d81

File tree

4 files changed

+20
-34
lines changed

4 files changed

+20
-34
lines changed

src/KokkosComm/collective.hpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,11 @@
2828

2929
namespace KokkosComm::Experimental {
3030

31-
/// Copy the `sv` view on the `root` rank to all ranks' `rv` view.
32-
///
33-
/// The `sv` view is only used on the `root` rank and ignored for all other ranks.
34-
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
35-
CommunicationSpace CommSpace = DefaultCommunicationSpace>
36-
auto broadcast(Handle<ExecSpace, CommSpace>& h, const SendView sv, RecvView rv, int root) -> Req<CommSpace> {
37-
return Impl::Broadcast<SendView, RecvView, ExecSpace, CommSpace>::execute(h, sv, rv, root);
38-
}
39-
40-
/// In-place variant of `broadcast`. Copy the `v` view from the `root` rank to all ranks' `v` view.
31+
/// Copy the `v` view from the `root` rank to all ranks' `v` view.
4132
template <KokkosView View, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
4233
CommunicationSpace CommSpace = DefaultCommunicationSpace>
4334
auto broadcast(Handle<ExecSpace, CommSpace>& h, View v, int root) -> Req<CommSpace> {
44-
return Impl::Broadcast<View, View, ExecSpace, CommSpace>::execute(h, v, v, root);
35+
return Impl::Broadcast<View, ExecSpace, CommSpace>::execute(h, v, root);
4536
}
4637

4738
/// Copy the `sv` view from each rank to the `rv` view, receiving data from rank `i` at offset

src/KokkosComm/fwd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct Send;
4747
// Collectives are currently experimental functions
4848
namespace Experimental::Impl {
4949

50-
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
50+
template <KokkosView View, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
5151
CommunicationSpace CommSpace = DefaultCommunicationSpace>
5252
struct Broadcast;
5353

src/KokkosComm/nccl/broadcast.hpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,20 @@ namespace nccl {
1717

1818
namespace KC = KokkosComm;
1919

20-
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
21-
auto broadcast(const ExecSpace &space, const SendView &sv, const RecvView &rv, int root, ncclComm_t comm) -> Req<Nccl> {
22-
using ST = typename SendView::non_const_value_type;
23-
using RT = typename RecvView::non_const_value_type;
24-
static_assert(std::is_same_v<ST, RT>,
25-
"KokkosComm::Experimental::nccl::broadcast: View value types must be identical");
26-
static_assert(KC::rank<SendView>() <= 1 and KC::rank<RecvView>() <= 1,
20+
template <KokkosView View>
21+
auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Req<Nccl> {
22+
using T = typename View::non_const_value_type;
23+
static_assert(KC::rank<View>() <= 1,
2724
"KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported");
2825
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::broadcast");
2926

3027
Req<Nccl> req{space.cuda_stream()};
31-
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
32-
ncclBroadcast(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), Impl::datatype_v<ST>, root, comm,
33-
space.cuda_stream());
28+
if (KC::is_contiguous(v)) {
29+
ncclBcast(KC::data_handle(v), KC::span(v), Impl::datatype_v<T>, root, comm, space.cuda_stream());
3430
} else {
3531
Kokkos::abort("KokkosComm::Experimental::nccl::broadcast: unimplemented for non-contiguous views");
3632
}
37-
req.extend_view_lifetime(sv);
38-
req.extend_view_lifetime(rv);
33+
req.extend_view_lifetime(v);
3934

4035
Kokkos::Tools::popRegion();
4136
return req;
@@ -44,10 +39,10 @@ auto broadcast(const ExecSpace &space, const SendView &sv, const RecvView &rv, i
4439
} // namespace nccl
4540
namespace Impl {
4641

47-
template <KokkosView SendView, KokkosView RecvView>
48-
struct Broadcast<SendView, RecvView, Kokkos::Cuda, Nccl> {
49-
static auto execute(Handle<Kokkos::Cuda, Nccl> &h, const SendView sv, RecvView rv, int root) -> Req<Nccl> {
50-
return nccl::broadcast(h.space(), sv, rv, root, h.comm());
42+
template <KokkosView View>
43+
struct Broadcast<View, Kokkos::Cuda, Nccl> {
44+
static auto execute(Handle<Kokkos::Cuda, Nccl>& h, View v, int root) -> Req<Nccl> {
45+
return nccl::broadcast(h.space(), v, root, h.comm());
5146
}
5247
};
5348

unit_tests/nccl/test_broadcast.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ auto broadcast_0d() -> void {
4242

4343
int errs;
4444
Kokkos::parallel_reduce(
45-
v.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += v() != size; }, errs);
45+
v.extent(0), KOKKOS_LAMBDA(const int, int& lsum) { lsum += v() != size; }, errs);
4646
EXPECT_EQ(errs, 0);
4747
}
4848

4949
template <typename Scalar>
50-
auto broadcast_inplace_contig_1d() -> void {
50+
auto broadcast_contig_1d() -> void {
5151
auto nccl_ctx = test_utils::nccl::Ctx::init();
5252
ExecSpace space(nccl_ctx.stream());
5353
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
5454
int rank = h.rank();
5555
int size = h.size();
5656
int root = 0;
5757

58-
Kokkos::View<Scalar *> v("v", 100);
58+
Kokkos::View<Scalar*> v("v", 100);
5959
if (rank == root) {
6060
// Prepare broadcast view
6161
Kokkos::parallel_for(
@@ -67,11 +67,11 @@ auto broadcast_inplace_contig_1d() -> void {
6767

6868
int errs;
6969
Kokkos::parallel_reduce(
70-
v.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (v(i) != size + i); }, errs);
70+
v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += (v(i) != size + i); }, errs);
7171
EXPECT_EQ(errs, 0);
7272
}
7373

74-
TYPED_TEST(Broadcast, InPlace0D) { broadcast_0d<typename TestFixture::Scalar>(); }
75-
TYPED_TEST(Broadcast, InPlaceContiguous1D) { broadcast_inplace_contig_1d<typename TestFixture::Scalar>(); }
74+
TYPED_TEST(Broadcast, 0D) { broadcast_0d<typename TestFixture::Scalar>(); }
75+
TYPED_TEST(Broadcast, Contiguous1D) { broadcast_contig_1d<typename TestFixture::Scalar>(); }
7676

7777
} // namespace

0 commit comments

Comments
 (0)