Skip to content

Commit 7c3d596

Browse files
authored
refactor: comm space generic datatype conversion (#189)
* chore: format + fix CommSpace typo in unit test Signed-off-by: Gabriel Dos Santos <[email protected]> * refactor!: declarative-style generic datatype conversion Overhaul datatype conversion: - fully generic: over communication space & type - fully compile-time known - declarative-style (similar implementation strategy to e.g., `std::is_integral`) - exposed as top-level function in the KokkosComm:: namespace - not an implementation detail anymore (some tests were relying on it) Signed-off-by: Gabriel Dos Santos <[email protected]> * refactor(mpi): propagate new datatype conversion strategy Signed-off-by: Gabriel Dos Santos <[email protected]> * refactor(nccl): propagate new datatype conversion Signed-off-by: Gabriel Dos Santos <[email protected]> * fix(nccl): correctly support `unsigned` dtype conversion Signed-off-by: Gabriel Dos Santos <[email protected]> --------- Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent 901ccc7 commit 7c3d596

24 files changed

+271
-183
lines changed

perf_tests/mpi/test_osu_latency.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ template <typename View>
6666
void osu_latency_MPI_isendirecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
6767
MPI_Request sendreq, recvreq;
6868
if (rank == 0) {
69-
MPI_Irecv(v.data(), v.size(), KokkosComm::Impl::mpi_type<typename View::value_type>(), 1, 0, comm, &recvreq);
69+
MPI_Irecv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
70+
&recvreq);
7071
MPI_Wait(&recvreq, MPI_STATUS_IGNORE);
7172
} else if (rank == 1) {
72-
MPI_Isend(v.data(), v.size(), KokkosComm::Impl::mpi_type<typename View::value_type>(), 0, 0, comm, &sendreq);
73+
MPI_Isend(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm,
74+
&sendreq);
7375
MPI_Wait(&sendreq, MPI_STATUS_IGNORE);
7476
}
7577
}
@@ -94,10 +96,10 @@ void benchmark_osu_latency_MPI_isendirecv(benchmark::State &state) {
9496
template <typename View>
9597
void osu_latency_MPI_sendrecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
9698
if (rank == 0) {
97-
MPI_Recv(v.data(), v.size(), KokkosComm::Impl::mpi_type<typename View::value_type>(), 1, 0, comm,
99+
MPI_Recv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
98100
MPI_STATUS_IGNORE);
99101
} else if (rank == 1) {
100-
MPI_Send(v.data(), v.size(), KokkosComm::Impl::mpi_type<typename View::value_type>(), 0, 0, comm);
102+
MPI_Send(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm);
101103
}
102104
}
103105

src/KokkosComm/CMakeLists.txt

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@ target_sources(
88
FILE_SET kokkoscomm_public_headers
99
TYPE HEADERS
1010
BASE_DIRS ${PROJECT_SOURCE_DIR}/src
11-
FILES KokkosComm.hpp collective.hpp concepts.hpp fwd.hpp point_to_point.hpp traits.hpp reduction_op.hpp
11+
FILES
12+
KokkosComm.hpp
13+
collective.hpp
14+
concepts.hpp
15+
fwd.hpp
16+
point_to_point.hpp
17+
traits.hpp
18+
datatype.hpp
19+
reduction_op.hpp
1220
)
1321

1422
# Implementation detail headers
@@ -70,12 +78,7 @@ if(KOKKOSCOMM_ENABLE_MPI)
7078
FILE_SET kokkoscomm_mpi_impl_headers
7179
TYPE HEADERS
7280
BASE_DIRS ${PROJECT_SOURCE_DIR}/src
73-
FILES
74-
mpi/impl/pack_traits.hpp
75-
mpi/impl/packer.hpp
76-
mpi/impl/tags.hpp
77-
mpi/impl/types.hpp
78-
mpi/impl/error_handling.hpp
81+
FILES mpi/impl/pack_traits.hpp mpi/impl/packer.hpp mpi/impl/tags.hpp mpi/impl/error_handling.hpp
7982
)
8083
endif()
8184

src/KokkosComm/datatype.hpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#pragma once
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <type_traits>
9+
10+
#include <Kokkos_Core.hpp>
11+
#include <mpi.h>
12+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
13+
#include <nccl.h>
14+
#endif
15+
16+
#include "concepts.hpp"
17+
#include "mpi/mpi_space.hpp"
18+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
19+
#include "nccl/nccl_space.hpp"
20+
#endif
21+
22+
namespace KokkosComm {
23+
namespace Impl {
24+
25+
template <typename T>
26+
constexpr auto mpi_datatype() -> MPI_Datatype {
27+
if constexpr (std::is_same_v<T, std::byte>) {
28+
return MPI_BYTE;
29+
} else if constexpr (std::is_same_v<T, char>) {
30+
return MPI_CHAR;
31+
} else if constexpr (std::is_same_v<T, unsigned char>) {
32+
return MPI_UNSIGNED_CHAR;
33+
} else if constexpr (std::is_same_v<T, short>) {
34+
return MPI_SHORT;
35+
} else if constexpr (std::is_same_v<T, unsigned short>) {
36+
return MPI_UNSIGNED_SHORT;
37+
} else if constexpr (std::is_same_v<T, int>) {
38+
return MPI_INT;
39+
} else if constexpr (std::is_same_v<T, unsigned>) {
40+
return MPI_UNSIGNED;
41+
} else if constexpr (std::is_same_v<T, long>) {
42+
return MPI_LONG;
43+
} else if constexpr (std::is_same_v<T, unsigned long>) {
44+
return MPI_UNSIGNED_LONG;
45+
} else if constexpr (std::is_same_v<T, long long>) {
46+
return MPI_LONG_LONG;
47+
} else if constexpr (std::is_same_v<T, unsigned long long>) {
48+
return MPI_UNSIGNED_LONG_LONG;
49+
} else if constexpr (std::is_same_v<T, std::int8_t>) {
50+
return MPI_INT8_T;
51+
} else if constexpr (std::is_same_v<T, std::uint8_t>) {
52+
return MPI_UINT8_T;
53+
} else if constexpr (std::is_same_v<T, std::int16_t>) {
54+
return MPI_INT16_T;
55+
} else if constexpr (std::is_same_v<T, std::uint16_t>) {
56+
return MPI_UINT16_T;
57+
} else if constexpr (std::is_same_v<T, std::int32_t>) {
58+
return MPI_INT32_T;
59+
} else if constexpr (std::is_same_v<T, std::uint32_t>) {
60+
return MPI_UINT32_T;
61+
} else if constexpr (std::is_same_v<T, std::int64_t>) {
62+
return MPI_INT64_T;
63+
} else if constexpr (std::is_same_v<T, std::uint64_t>) {
64+
return MPI_UINT64_T;
65+
} else if constexpr (std::is_same_v<T, std::size_t>) {
66+
if constexpr (sizeof(std::size_t) == 1) return MPI_UINT8_T;
67+
if constexpr (sizeof(std::size_t) == 2) return MPI_UINT16_T;
68+
if constexpr (sizeof(std::size_t) == 4) return MPI_UINT32_T;
69+
if constexpr (sizeof(std::size_t) == 8) return MPI_UINT64_T;
70+
} else if constexpr (std::is_same_v<T, std::ptrdiff_t>) {
71+
if constexpr (sizeof(std::ptrdiff_t) == 1) return MPI_INT8_T;
72+
if constexpr (sizeof(std::ptrdiff_t) == 2) return MPI_INT16_T;
73+
if constexpr (sizeof(std::ptrdiff_t) == 4) return MPI_INT32_T;
74+
if constexpr (sizeof(std::ptrdiff_t) == 8) return MPI_INT64_T;
75+
} else if constexpr (std::is_same_v<T, float>) {
76+
return MPI_FLOAT;
77+
} else if constexpr (std::is_same_v<T, double>) {
78+
return MPI_DOUBLE;
79+
} else if constexpr (std::is_same_v<T, long double>) {
80+
return MPI_LONG_DOUBLE;
81+
} else if constexpr (std::is_same_v<T, Kokkos::complex<float>>) {
82+
#if defined(KOKKOSCOMM_IMPL_MPI_IS_OPENMPI)
83+
return MPI_CXX_COMPLEX;
84+
#else
85+
return MPI_COMPLEX;
86+
#endif
87+
} else if constexpr (std::is_same_v<T, Kokkos::complex<double>>) {
88+
#if defined(KOKKOSCOMM_IMPL_MPI_IS_OPENMPI)
89+
return MPI_CXX_DOUBLE_COMPLEX;
90+
#else
91+
return MPI_DOUBLE_COMPLEX;
92+
#endif
93+
} else {
94+
static_assert(std::is_void_v<T>, "KokkosComm::Impl::mpi_datatype: datatype not implemented");
95+
return MPI_CHAR; // unreachable
96+
}
97+
}
98+
99+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
100+
template <typename T>
101+
constexpr auto nccl_datatype() -> ncclDataType_t {
102+
if constexpr (std::is_same_v<T, char>) {
103+
return ncclChar;
104+
} else if constexpr (std::is_same_v<T, int>) {
105+
return ncclInt;
106+
} else if constexpr (std::is_same_v<T, unsigned> and sizeof(unsigned) == 4) {
107+
return ncclUint32;
108+
} else if constexpr (std::is_same_v<T, std::int8_t>) {
109+
return ncclInt8;
110+
} else if constexpr (std::is_same_v<T, std::uint8_t>) {
111+
return ncclUint8;
112+
} else if constexpr (std::is_same_v<T, std::int32_t>) {
113+
return ncclInt32;
114+
} else if constexpr (std::is_same_v<T, std::uint32_t>) {
115+
return ncclUint32;
116+
} else if constexpr (std::is_same_v<T, std::int64_t>) {
117+
return ncclInt64;
118+
} else if constexpr (std::is_same_v<T, std::uint64_t>) {
119+
return ncclUint64;
120+
} else if constexpr (std::is_same_v<T, std::size_t>) {
121+
if constexpr (sizeof(std::size_t) == 1) return ncclUint8;
122+
if constexpr (sizeof(std::size_t) == 4) return ncclUint32;
123+
if constexpr (sizeof(std::size_t) == 8) return ncclUint64;
124+
} else if constexpr (std::is_same_v<T, std::ptrdiff_t>) {
125+
if constexpr (sizeof(std::ptrdiff_t) == 1) return ncclInt8;
126+
if constexpr (sizeof(std::ptrdiff_t) == 4) return ncclInt32;
127+
if constexpr (sizeof(std::ptrdiff_t) == 8) return ncclInt64;
128+
} else if constexpr (std::is_same_v<T, float>) {
129+
return ncclFloat;
130+
} else if constexpr (std::is_same_v<T, double>) {
131+
return ncclDouble;
132+
} else {
133+
static_assert(std::is_void_v<T>, "KokkosComm::Impl::nccl_datatype: datatype not implemented");
134+
return ncclChar; // unreachable
135+
}
136+
}
137+
#endif
138+
139+
} // namespace Impl
140+
141+
template <CommunicationSpace CS, typename T>
142+
[[nodiscard]] constexpr auto datatype() -> typename CS::datatype_type {
143+
if constexpr (std::is_same_v<CS, MpiSpace>) {
144+
return Impl::mpi_datatype<std::remove_cv_t<T>>();
145+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
146+
} else if constexpr (std::is_same_v<CS, Experimental::NcclSpace>) {
147+
return Impl::nccl_datatype<std::remove_cv_t<T>>();
148+
#endif
149+
} else {
150+
static_assert(std::is_void_v<CS>, "KokkosComm::datatype: conversion not implemented for this communication space");
151+
return Impl::mpi_datatype<std::remove_cv_t<T>>(); // unreachable
152+
}
153+
}
154+
155+
template <CommunicationSpace CS, typename T>
156+
[[nodiscard]] constexpr auto datatype_for(T&&) -> typename CS::datatype_type {
157+
return datatype<CS, std::remove_cvref_t<T>>();
158+
}
159+
160+
} // namespace KokkosComm

src/KokkosComm/fwd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <KokkosComm/config.hpp>
77
#include "concepts.hpp"
8+
#include "datatype.hpp"
89
#include "reduction_op.hpp"
910

1011
namespace KokkosComm {

src/KokkosComm/mpi/allgather.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include <KokkosComm/datatype.hpp>
1112
#include "mpi_space.hpp"
1213
#include "req.hpp"
1314

14-
#include "impl/types.hpp"
1515
#include "impl/error_handling.hpp"
1616

1717
namespace KokkosComm::mpi {
@@ -31,8 +31,8 @@ auto iallgather(const ExecSpace &space, const SView sv, RView rv, MPI_Comm comm)
3131

3232
Req<MpiSpace> req;
3333
// All ranks send/recv same count
34-
MPI_Iallgather(data_handle(sv), span(sv), Impl::mpi_type_v<ST>, data_handle(rv), span(sv), Impl::mpi_type_v<RT>, comm,
35-
&req.mpi_request());
34+
MPI_Iallgather(data_handle(sv), span(sv), datatype<MpiSpace, ST>, data_handle(rv), span(sv), datatype<MpiSpace, RT>,
35+
comm, &req.mpi_request());
3636
req.extend_view_lifetime(sv);
3737
req.extend_view_lifetime(rv);
3838

@@ -54,8 +54,8 @@ void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
5454
KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(rv), "low-level allgather requires contiguous recv view");
5555

5656
const int count = KokkosComm::span(sv); // all ranks send/recv same count
57-
MPI_Allgather(KokkosComm::data_handle(sv), count, KokkosComm::Impl::mpi_type_v<SendScalar>,
58-
KokkosComm::data_handle(rv), count, KokkosComm::Impl::mpi_type_v<RecvScalar>, comm);
57+
MPI_Allgather(KokkosComm::data_handle(sv), count, datatype<MpiSpace, SendScalar>(), KokkosComm::data_handle(rv),
58+
count, datatype<MpiSpace, RecvScalar>(), comm);
5959

6060
Kokkos::Tools::popRegion();
6161
}
@@ -73,7 +73,7 @@ void allgather(const ExecSpace &space, const RecvView &rv, const size_t recvCoun
7373

7474
space.fence("fence before allgather"); // work in space may have been used to produce send view data
7575
MPI_Allgather(MPI_IN_PLACE, 0 /*ignored*/, MPI_DATATYPE_NULL /*ignored*/, KokkosComm::data_handle(rv), recvCount,
76-
KokkosComm::Impl::mpi_type_v<RecvScalar>, comm);
76+
datatype<MpiSpace, RecvScalar>(), comm);
7777

7878
Kokkos::Tools::popRegion();
7979
}

src/KokkosComm/mpi/allreduce.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11-
12-
#include "impl/types.hpp"
11+
#include <KokkosComm/datatype.hpp>
1312

1413
namespace KokkosComm::mpi {
1514

@@ -30,8 +29,8 @@ void allreduce(SendView const &sv, RecvView const &rv, MPI_Op op, MPI_Comm comm)
3029
KokkosComm::mpi::fail_if(sv.size() != rv.size(), "allreduce requires send and receive views to have the same size");
3130

3231
int const count = sv.size();
33-
MPI_Allreduce(KokkosComm::data_handle(sv), KokkosComm::data_handle(rv), count,
34-
KokkosComm::Impl::mpi_type_v<SendScalar>, op, comm);
32+
MPI_Allreduce(KokkosComm::data_handle(sv), KokkosComm::data_handle(rv), count, datatype<MpiSpace, SendScalar>(), op,
33+
comm);
3534

3635
Kokkos::Tools::popRegion();
3736
}
@@ -47,7 +46,7 @@ void allreduce(View const &v, MPI_Op op, MPI_Comm comm) {
4746
KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(v), "low-level allgather requires contiguous recv view");
4847

4948
int const count = v.size();
50-
MPI_Allreduce(MPI_IN_PLACE, KokkosComm::data_handle(v), count, KokkosComm::Impl::mpi_type_v<Scalar>, op, comm);
49+
MPI_Allreduce(MPI_IN_PLACE, KokkosComm::data_handle(v), count, datatype<MpiSpace, Scalar>(), op, comm);
5150

5251
Kokkos::Tools::popRegion();
5352
}

src/KokkosComm/mpi/alltoall.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include <KokkosComm/datatype.hpp>
1112
#include "mpi_space.hpp"
1213
#include "req.hpp"
1314

1415
#include "impl/pack_traits.hpp"
15-
#include "impl/types.hpp"
1616
#include "impl/error_handling.hpp"
1717

1818
namespace KokkosComm::mpi {
@@ -32,7 +32,7 @@ auto ialltoall(const ExecSpace &space, const SView sv, RView rv, int count, MPI_
3232

3333
Req<MpiSpace> req;
3434
// All ranks send/recv same count
35-
MPI_Ialltoall(data_handle(sv), count, Impl::mpi_type_v<ST>, data_handle(rv), count, Impl::mpi_type_v<RT>, comm,
35+
MPI_Ialltoall(data_handle(sv), count, datatype<MpiSpace, ST>, data_handle(rv), count, datatype<MpiSpace, RT>, comm,
3636
&req.mpi_request());
3737
req.extend_view_lifetime(sv);
3838
req.extend_view_lifetime(rv);
@@ -74,8 +74,8 @@ void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount
7474
KokkosComm::mpi::fail_if(true, ss.str().data());
7575
}
7676

77-
MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, Impl::mpi_type_v<SendScalar>, KokkosComm::data_handle(rv),
78-
recvCount, Impl::mpi_type_v<RecvScalar>, comm);
77+
MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, datatype<MpiSpace, SendScalar>(), KokkosComm::data_handle(rv),
78+
recvCount, datatype<MpiSpace, RecvScalar>(), comm);
7979

8080
Kokkos::Tools::popRegion();
8181
}
@@ -105,7 +105,7 @@ void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount
105105
}
106106

107107
MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, KokkosComm::data_handle(rv), recvCount,
108-
Impl::mpi_type_v<RecvScalar>, comm);
108+
datatype<MpiSpace, RecvScalar>(), comm);
109109

110110
Kokkos::Tools::popRegion();
111111
}

src/KokkosComm/mpi/broadcast.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include <KokkosComm/datatype.hpp>
1112
#include "mpi_space.hpp"
1213
#include "req.hpp"
1314

14-
#include "impl/types.hpp"
1515
#include "impl/error_handling.hpp"
1616

1717
namespace KokkosComm::mpi {
@@ -26,7 +26,7 @@ auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req
2626
space.fence("fence before non-blocking broadcast");
2727

2828
Req<MpiSpace> req;
29-
MPI_Ibcast(data_handle(v), span(v), Impl::mpi_type_v<T>, root, comm, &req.mpi_request());
29+
MPI_Ibcast(data_handle(v), span(v), datatype<MpiSpace, T>, root, comm, &req.mpi_request());
3030
req.extend_view_lifetime(v);
3131

3232
Kokkos::Tools::popRegion();
@@ -41,7 +41,7 @@ void broadcast(View const& v, int root, MPI_Comm comm) {
4141

4242
KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(v), "low-level broadcast requires contiguous view");
4343

44-
MPI_Bcast(KokkosComm::data_handle(v), KokkosComm::span(v), KokkosComm::Impl::mpi_type_v<Scalar>, root, comm);
44+
MPI_Bcast(KokkosComm::data_handle(v), KokkosComm::span(v), datatype<MpiSpace, Scalar>(), root, comm);
4545

4646
Kokkos::Tools::popRegion();
4747
}

0 commit comments

Comments
 (0)