Skip to content

Commit 845b50b

Browse files
committed
refactor: comm space generic reduction operator conversion
Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent a1f4b36 commit 845b50b

File tree

3 files changed

+101
-85
lines changed

3 files changed

+101
-85
lines changed

src/KokkosComm/nccl/allreduce.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include <KokkosComm/reduction_op.hpp>
1112

1213
#include "impl/pack_traits.hpp"
1314
#include "impl/types.hpp"
@@ -48,7 +49,7 @@ namespace Impl {
4849
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
4950
struct AllReduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
5051
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv) -> Req<NcclSpace> {
51-
return nccl::allreduce(h.space(), sv, rv, nccl::Impl::reduction_op_v<RedOp>, h.comm());
52+
return nccl::allreduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), h.comm());
5253
}
5354
};
5455

src/KokkosComm/nccl/reduce.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include <KokkosComm/reduction_op.hpp>
1112

1213
#include <KokkosComm/impl/contiguous.hpp>
1314
#include "impl/pack_traits.hpp"
@@ -69,7 +70,7 @@ namespace Impl {
6970
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
7071
struct Reduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
7172
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv, int root) -> Req<NcclSpace> {
72-
return nccl::reduce(h.space(), sv, rv, nccl::Impl::reduction_op_v<RedOp>, root, h.rank(), h.comm());
73+
return nccl::reduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank(), h.comm());
7374
}
7475
};
7576

src/KokkosComm/reduction_op.hpp

Lines changed: 97 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,121 @@
1-
//@HEADER
2-
// ************************************************************************
3-
//
4-
// Kokkos v. 4.0
5-
// Copyright (2022) National Technology & Engineering
6-
// Solutions of Sandia, LLC (NTESS).
7-
//
8-
// Under the terms of Contract DE-NA0003525 with NTESS,
9-
// the U.S. Government retains certain rights in this software.
10-
//
11-
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12-
// See https://kokkos.org/LICENSE for license information.
131
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14-
//
15-
//@HEADER
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
163

174
#pragma once
185

6+
#include <type_traits>
7+
198
#include <Kokkos_Core.hpp>
209
#ifdef KOKKOSCOMM_ENABLE_NCCL
2110
#include <nccl.h>
2211
#endif
2312

2413
#include <KokkosComm/concepts.hpp>
14+
#include "mpi/mpi_space.hpp"
15+
#ifdef KOKKOSCOMM_ENABLE_NCCL
16+
#include "nccl/nccl_space.hpp"
17+
#endif
2518

19+
// clang-format off
2620
namespace KokkosComm {
21+
namespace {
22+
23+
#define DECL_REDUCTION_OP_FOR(operator) \
24+
struct operator; \
25+
template <> struct Impl::is_reduction_operator<operator> : public std::true_type {}
26+
27+
} // namespace
28+
29+
DECL_REDUCTION_OP_FOR(BAnd);
30+
DECL_REDUCTION_OP_FOR(BOr);
31+
DECL_REDUCTION_OP_FOR(BXor);
32+
DECL_REDUCTION_OP_FOR(LAnd);
33+
DECL_REDUCTION_OP_FOR(LOr);
34+
DECL_REDUCTION_OP_FOR(LXor);
35+
DECL_REDUCTION_OP_FOR(Max);
36+
DECL_REDUCTION_OP_FOR(MaxLoc);
37+
DECL_REDUCTION_OP_FOR(Min);
38+
DECL_REDUCTION_OP_FOR(MinLoc);
39+
DECL_REDUCTION_OP_FOR(Sum);
40+
DECL_REDUCTION_OP_FOR(Prod);
41+
DECL_REDUCTION_OP_FOR(Average);
42+
43+
namespace Impl {
44+
45+
template <ReductionOperator RO>
46+
constexpr auto mpi_reduction_op() -> MPI_Op {
47+
if constexpr (std::is_same_v<RO, BAnd>) {
48+
return MPI_BAND;
49+
} else if constexpr (std::is_same_v<RO, BOr>) {
50+
return MPI_BOR;
51+
} else if constexpr (std::is_same_v<RO, BXor>) {
52+
return MPI_BXOR;
53+
} else if constexpr (std::is_same_v<RO, LAnd>) {
54+
return MPI_LAND;
55+
} else if constexpr (std::is_same_v<RO, LOr>) {
56+
return MPI_LOR;
57+
} else if constexpr (std::is_same_v<RO, LXor>) {
58+
return MPI_LXOR;
59+
} else if constexpr (std::is_same_v<RO, Max>) {
60+
return MPI_MAX;
61+
} else if constexpr (std::is_same_v<RO, MaxLoc>) {
62+
return MPI_MAXLOC;
63+
} else if constexpr (std::is_same_v<RO, Min>) {
64+
return MPI_MIN;
65+
} else if constexpr (std::is_same_v<RO, MinLoc>) {
66+
return MPI_MINLOC;
67+
} else if constexpr (std::is_same_v<RO, Sum>) {
68+
return MPI_SUM;
69+
} else if constexpr (std::is_same_v<RO, Prod>) {
70+
return MPI_PROD;
71+
} else {
72+
static_assert(std::is_void_v<RO>, "KokkosComm::Impl::mpi_reduction_op: operator not implemented");
73+
return MPI_SUM; // unreachable
74+
}
75+
}
2776

28-
struct BAnd {};
29-
template <>
30-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::BAnd> : public std::true_type {};
31-
32-
struct BOr {};
33-
template <>
34-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::BOr> : public std::true_type {};
35-
36-
struct LAnd {};
37-
template <>
38-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::LAnd> : public std::true_type {};
39-
40-
struct LOr {};
41-
template <>
42-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::LOr> : public std::true_type {};
43-
44-
struct Max {};
45-
template <>
46-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Max> : public std::true_type {};
47-
48-
struct MaxLoc {};
49-
template <>
50-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MaxLoc> : public std::true_type {};
51-
52-
struct Min {};
53-
template <>
54-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Min> : public std::true_type {};
55-
56-
struct MinLoc {};
57-
template <>
58-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinLoc> : public std::true_type {};
59-
60-
struct MinMax {};
61-
template <>
62-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinMax> : public std::true_type {};
63-
64-
struct MinMaxLoc {};
65-
template <>
66-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinMaxLoc> : public std::true_type {};
67-
68-
struct Sum {};
69-
template <>
70-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Sum> : public std::true_type {};
71-
72-
struct Prod {};
73-
template <>
74-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Prod> : public std::true_type {};
75-
76-
struct Average {};
77-
template <>
78-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Average> : public std::true_type {};
79-
80-
#ifdef KOKKOSCOMM_ENABLE_NCCL
81-
namespace Experimental::nccl::Impl {
82-
83-
template <ReductionOperator RedOp>
84-
constexpr auto reduction_op() -> ncclRedOp_t {
85-
if constexpr (std::is_same_v<RedOp, KokkosComm::Max>) {
86-
return ncclMax;
87-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Min>) {
88-
return ncclMin;
89-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Sum>) {
77+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
78+
template <ReductionOperator RO>
79+
constexpr auto nccl_reduction_op() -> ncclRedOp_t {
80+
if constexpr (std::is_same_v<RO, Sum>) {
9081
return ncclSum;
91-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Prod>) {
82+
} else if constexpr (std::is_same_v<RO, Prod>) {
9283
return ncclProd;
93-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Average>) {
84+
} else if constexpr (std::is_same_v<RO, Min>) {
85+
return ncclMin;
86+
} else if constexpr (std::is_same_v<RO, Max>) {
87+
return ncclMax;
88+
} else if constexpr (std::is_same_v<RO, Average>) {
9489
return ncclAvg;
9590
} else {
96-
static_assert(std::is_void_v<RedOp>, "nccl::reduction_op: operator not implemented");
97-
return ncclMax; // unreachable
91+
static_assert(std::is_void_v<RO>, "KokkosComm::Impl::nccl_reduction_op: operator not implemented");
92+
return ncclSum; // unreachable
9893
}
9994
}
95+
#endif
10096

101-
template <ReductionOperator RedOp>
102-
inline constexpr ncclRedOp_t reduction_op_v = reduction_op<RedOp>();
97+
// clang-format on
10398

104-
} // namespace Experimental::nccl::Impl
99+
} // namespace Impl
100+
101+
template <CommunicationSpace CS, ReductionOperator RO>
102+
[[nodiscard]] constexpr auto reduction_op() -> typename CS::reduction_op_type {
103+
if constexpr (std::is_same_v<CS, MpiSpace>) {
104+
return Impl::mpi_reduction_op<RO>();
105+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
106+
} else if constexpr (std::is_same_v<CS, NcclSpace>) {
107+
return Impl::nccl_reduction_op<RO>();
105108
#endif
109+
} else {
110+
static_assert(std::is_void_v<CS>,
111+
"KokkosComm::reduction_op: conversion not implemented for this communication space");
112+
return Impl::mpi_reduction_op<RO>(); // unreachable
113+
}
114+
}
115+
116+
template <CommunicationSpace CS, ReductionOperator RO>
117+
[[nodiscard]] constexpr auto reduction_op_for(RO&&) -> typename CS::reduction_op_type {
118+
return reduction_op<CS, std::remove_cvref_t<RO>>();
119+
}
106120

107121
} // namespace KokkosComm

0 commit comments

Comments
 (0)