Skip to content

Commit 77b037c

Browse files
dssgabrielcedricchevalier19
authored andcommitted
refactor: comm space generic reduction operator conversion
Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent 95d974f commit 77b037c

File tree

4 files changed

+101
-87
lines changed

4 files changed

+101
-87
lines changed

src/KokkosComm/nccl/allreduce.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
1111
#include <KokkosComm/datatype.hpp>
12+
#include <KokkosComm/reduction_op.hpp>
1213
#include "nccl_space.hpp"
1314

1415
#include "impl/pack_traits.hpp"
@@ -49,7 +50,7 @@ namespace Impl {
4950
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
5051
struct AllReduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
5152
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv) -> Req<NcclSpace> {
52-
return nccl::allreduce(h.space(), sv, rv, nccl::Impl::reduction_op_v<RedOp>, h.comm());
53+
return nccl::allreduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), h.comm());
5354
}
5455
};
5556

src/KokkosComm/nccl/broadcast.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ namespace nccl {
1919
namespace KC = KokkosComm;
2020

2121
template <KokkosView View>
22-
auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm)
23-
-> Req<NcclSpace> {
22+
auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Req<NcclSpace> {
2423
using T = typename View::non_const_value_type;
2524
static_assert(KC::rank<View>() <= 1,
2625
"KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported");

src/KokkosComm/nccl/reduce.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
1111
#include <KokkosComm/datatype.hpp>
12+
#include <KokkosComm/reduction_op.hpp>
1213
#include "nccl_space.hpp"
1314

1415
#include <KokkosComm/impl/contiguous.hpp>
@@ -70,7 +71,7 @@ namespace Impl {
7071
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
7172
struct Reduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
7273
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv, int root) -> Req<NcclSpace> {
73-
return nccl::reduce(h.space(), sv, rv, nccl::Impl::reduction_op_v<RedOp>, root, h.rank(), h.comm());
74+
return nccl::reduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank(), h.comm());
7475
}
7576
};
7677

src/KokkosComm/reduction_op.hpp

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,120 @@
1-
//@HEADER
2-
// ************************************************************************
3-
//
4-
// Kokkos v. 4.0
5-
// Copyright (2022) National Technology & Engineering
6-
// Solutions of Sandia, LLC (NTESS).
7-
//
8-
// Under the terms of Contract DE-NA0003525 with NTESS,
9-
// the U.S. Government retains certain rights in this software.
10-
//
11-
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12-
// See https://kokkos.org/LICENSE for license information.
131
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14-
//
15-
//@HEADER
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
163

174
#pragma once
185

6+
#include <type_traits>
7+
198
#include <Kokkos_Core.hpp>
209
#ifdef KOKKOSCOMM_ENABLE_NCCL
2110
#include <nccl.h>
2211
#endif
2312

2413
#include <KokkosComm/concepts.hpp>
14+
#include "mpi/mpi_space.hpp"
15+
#ifdef KOKKOSCOMM_ENABLE_NCCL
16+
#include "nccl/nccl_space.hpp"
17+
#endif
2518

2619
namespace KokkosComm {
20+
namespace {
21+
22+
// clang-format off
23+
#define DECL_REDUCTION_OP_FOR(operator) \
24+
struct operator; \
25+
template <> struct Impl::is_reduction_operator<operator> : public std::true_type {}
26+
// clang-format on
27+
28+
} // namespace
29+
30+
DECL_REDUCTION_OP_FOR(BAnd);
31+
DECL_REDUCTION_OP_FOR(BOr);
32+
DECL_REDUCTION_OP_FOR(BXor);
33+
DECL_REDUCTION_OP_FOR(LAnd);
34+
DECL_REDUCTION_OP_FOR(LOr);
35+
DECL_REDUCTION_OP_FOR(LXor);
36+
DECL_REDUCTION_OP_FOR(Max);
37+
DECL_REDUCTION_OP_FOR(MaxLoc);
38+
DECL_REDUCTION_OP_FOR(Min);
39+
DECL_REDUCTION_OP_FOR(MinLoc);
40+
DECL_REDUCTION_OP_FOR(Sum);
41+
DECL_REDUCTION_OP_FOR(Prod);
42+
DECL_REDUCTION_OP_FOR(Average);
43+
44+
namespace Impl {
45+
46+
template <ReductionOperator RO>
47+
constexpr auto mpi_reduction_op() -> MPI_Op {
48+
if constexpr (std::is_same_v<RO, BAnd>) {
49+
return MPI_BAND;
50+
} else if constexpr (std::is_same_v<RO, BOr>) {
51+
return MPI_BOR;
52+
} else if constexpr (std::is_same_v<RO, BXor>) {
53+
return MPI_BXOR;
54+
} else if constexpr (std::is_same_v<RO, LAnd>) {
55+
return MPI_LAND;
56+
} else if constexpr (std::is_same_v<RO, LOr>) {
57+
return MPI_LOR;
58+
} else if constexpr (std::is_same_v<RO, LXor>) {
59+
return MPI_LXOR;
60+
} else if constexpr (std::is_same_v<RO, Max>) {
61+
return MPI_MAX;
62+
} else if constexpr (std::is_same_v<RO, MaxLoc>) {
63+
return MPI_MAXLOC;
64+
} else if constexpr (std::is_same_v<RO, Min>) {
65+
return MPI_MIN;
66+
} else if constexpr (std::is_same_v<RO, MinLoc>) {
67+
return MPI_MINLOC;
68+
} else if constexpr (std::is_same_v<RO, Sum>) {
69+
return MPI_SUM;
70+
} else if constexpr (std::is_same_v<RO, Prod>) {
71+
return MPI_PROD;
72+
} else {
73+
static_assert(std::is_void_v<RO>, "KokkosComm::Impl::mpi_reduction_op: operator not implemented");
74+
return MPI_SUM; // unreachable
75+
}
76+
}
2777

28-
struct BAnd {};
29-
template <>
30-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::BAnd> : public std::true_type {};
31-
32-
struct BOr {};
33-
template <>
34-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::BOr> : public std::true_type {};
35-
36-
struct LAnd {};
37-
template <>
38-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::LAnd> : public std::true_type {};
39-
40-
struct LOr {};
41-
template <>
42-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::LOr> : public std::true_type {};
43-
44-
struct Max {};
45-
template <>
46-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Max> : public std::true_type {};
47-
48-
struct MaxLoc {};
49-
template <>
50-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MaxLoc> : public std::true_type {};
51-
52-
struct Min {};
53-
template <>
54-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Min> : public std::true_type {};
55-
56-
struct MinLoc {};
57-
template <>
58-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinLoc> : public std::true_type {};
59-
60-
struct MinMax {};
61-
template <>
62-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinMax> : public std::true_type {};
63-
64-
struct MinMaxLoc {};
65-
template <>
66-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::MinMaxLoc> : public std::true_type {};
67-
68-
struct Sum {};
69-
template <>
70-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Sum> : public std::true_type {};
71-
72-
struct Prod {};
73-
template <>
74-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Prod> : public std::true_type {};
75-
76-
struct Average {};
77-
template <>
78-
struct KokkosComm::Impl::is_reduction_operator<KokkosComm::Average> : public std::true_type {};
79-
80-
#ifdef KOKKOSCOMM_ENABLE_NCCL
81-
namespace Experimental::nccl::Impl {
82-
83-
template <ReductionOperator RedOp>
84-
constexpr auto reduction_op() -> ncclRedOp_t {
85-
if constexpr (std::is_same_v<RedOp, KokkosComm::Max>) {
86-
return ncclMax;
87-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Min>) {
88-
return ncclMin;
89-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Sum>) {
78+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
79+
template <ReductionOperator RO>
80+
constexpr auto nccl_reduction_op() -> ncclRedOp_t {
81+
if constexpr (std::is_same_v<RO, Sum>) {
9082
return ncclSum;
91-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Prod>) {
83+
} else if constexpr (std::is_same_v<RO, Prod>) {
9284
return ncclProd;
93-
} else if constexpr (std::is_same_v<RedOp, KokkosComm::Average>) {
85+
} else if constexpr (std::is_same_v<RO, Min>) {
86+
return ncclMin;
87+
} else if constexpr (std::is_same_v<RO, Max>) {
88+
return ncclMax;
89+
} else if constexpr (std::is_same_v<RO, Average>) {
9490
return ncclAvg;
9591
} else {
96-
static_assert(std::is_void_v<RedOp>, "nccl::reduction_op: operator not implemented");
97-
return ncclMax; // unreachable
92+
static_assert(std::is_void_v<RO>, "KokkosComm::Impl::nccl_reduction_op: operator not implemented");
93+
return ncclSum; // unreachable
9894
}
9995
}
96+
#endif
10097

101-
template <ReductionOperator RedOp>
102-
inline constexpr ncclRedOp_t reduction_op_v = reduction_op<RedOp>();
98+
} // namespace Impl
10399

104-
} // namespace Experimental::nccl::Impl
100+
template <CommunicationSpace CS, ReductionOperator RO>
101+
[[nodiscard]] constexpr auto reduction_op() -> typename CS::reduction_op_type {
102+
if constexpr (std::is_same_v<CS, MpiSpace>) {
103+
return Impl::mpi_reduction_op<RO>();
104+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
105+
} else if constexpr (std::is_same_v<CS, Experimental::NcclSpace>) {
106+
return Impl::nccl_reduction_op<RO>();
105107
#endif
108+
} else {
109+
static_assert(std::is_void_v<CS>,
110+
"KokkosComm::reduction_op: conversion not implemented for this communication space");
111+
return Impl::mpi_reduction_op<RO>(); // unreachable
112+
}
113+
}
114+
115+
template <CommunicationSpace CS, ReductionOperator RO>
116+
[[nodiscard]] constexpr auto reduction_op_for(RO&&) -> typename CS::reduction_op_type {
117+
return reduction_op<CS, std::remove_cvref_t<RO>>();
118+
}
106119

107120
} // namespace KokkosComm

0 commit comments

Comments
 (0)