|
48 | 48 | #include <memory> |
49 | 49 | #include <tuple> |
50 | 50 | #include <unordered_map> |
| 51 | +#include <atomic> |
| 52 | +#include <cstring> |
| 53 | +#include <vector> |
| 54 | +#include <type_traits> |
| 55 | +#include <algorithm> |
51 | 56 |
|
| 57 | +#include <vt/collective/reduce/operators/default_msg.h> |
52 | 58 | #include <vt/transport.h> |
53 | 59 |
|
54 | 60 | namespace vt_lb::comm { |
55 | 61 |
|
56 | | -template <typename U> |
57 | | -void reduceCb(U value) { |
58 | | - |
59 | | -} |
60 | | - |
61 | 62 | template <typename ProxyT> |
62 | 63 | struct ProxyWrapper : ProxyT { |
63 | 64 | ProxyWrapper(ProxyT proxy) : ProxyT(proxy) { } |
64 | 65 |
|
| 66 | + struct ReduceCtx { |
| 67 | + void* out_ptr = nullptr; |
| 68 | + std::atomic<bool> done{false}; |
| 69 | + std::size_t count = 0; |
| 70 | + }; |
| 71 | + |
| 72 | + template <typename T> |
| 73 | + static void reduceAnonCb(vt::collective::ReduceTMsg<T>* msg, ReduceCtx* ctx) { |
| 74 | + auto const& val = msg->getVal(); |
| 75 | + |
| 76 | + printf("%d: callback invoked\n", vt::theContext()->getNode()); |
| 77 | + |
| 78 | + if constexpr ( |
| 79 | + std::is_same_v<std::decay_t<T>, int> || std::is_same_v<std::decay_t<T>, double> || |
| 80 | + std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, long> || |
| 81 | + std::is_same_v<std::decay_t<T>, long long> |
| 82 | + ) { |
| 83 | + // scalar case |
| 84 | + *static_cast<std::decay_t<T>*>(ctx->out_ptr) = val; |
| 85 | + } else { |
| 86 | + // vector case |
| 87 | + using ValT = typename T::value_type; |
| 88 | + static_assert( |
| 89 | + std::is_trivially_copyable_v<ValT> || std::is_arithmetic_v<ValT>, |
| 90 | + "Reduce value must be trivially copyable for this helper" |
| 91 | + ); |
| 92 | + std::memcpy(ctx->out_ptr, std::addressof(val), sizeof(ValT) * std::max<std::size_t>(1, ctx->count)); |
| 93 | + } |
| 94 | + |
| 95 | + // Publish completion after writing result. |
| 96 | + ctx->done.store(true, std::memory_order_release); |
| 97 | + } |
| 98 | + |
| 99 | + // Runtime-dispatch wrapper: map a few common MPI datatypes to C++ types. |
65 | 100 | template <typename U, typename V> |
66 | 101 | void reduce(int root, MPI_Datatype datatype, MPI_Op op, U sendbuf, V recvbuf, int count) { |
67 | | - // @todo finish this?? |
68 | | - // if (op == MPI_MAX) { |
69 | | - // auto cb = vt::theCB()->makeSend<reduceCb>(vt::pipe::LifetimeEnum::Once, reduceCb); |
70 | | - // this->template reduce<vt::collective::MaxOp>(); |
71 | | - // } else if (op == MPI_MIN) { |
| 102 | + switch (datatype) { |
| 103 | + case MPI_INT: |
| 104 | + if constexpr (std::is_same_v<U, int*>) { |
| 105 | + reduce_impl<int>(root, op, sendbuf, recvbuf, count); |
| 106 | + } |
| 107 | + break; |
| 108 | + case MPI_DOUBLE: |
| 109 | + if constexpr (std::is_same_v<U, double*>) { |
| 110 | + reduce_impl<double>(root, op, sendbuf, recvbuf, count); |
| 111 | + } |
| 112 | + break; |
| 113 | + case MPI_FLOAT: |
| 114 | + if constexpr (std::is_same_v<U, float*>) { |
| 115 | + reduce_impl<float>(root, op, sendbuf, recvbuf, count); |
| 116 | + } |
| 117 | + break; |
| 118 | + case MPI_LONG: |
| 119 | + if constexpr (std::is_same_v<U, long*>) { |
| 120 | + reduce_impl<long>(root, op, sendbuf, recvbuf, count); |
| 121 | + } |
| 122 | + break; |
| 123 | + case MPI_LONG_LONG: |
| 124 | + if constexpr (std::is_same_v<U, long long*>) { |
| 125 | + reduce_impl<long long>(root, op, sendbuf, recvbuf, count); |
| 126 | + } |
| 127 | + break; |
| 128 | + default: |
| 129 | + vtAbort("ProxyWrapper::reduce: unsupported MPI_Datatype"); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | +private: |
| 134 | + // Map a few MPI_Op values to VT operator choices. |
| 135 | + // Extend as needed. |
| 136 | + enum class VTOp { Plus, Max, Min }; |
| 137 | + |
| 138 | + inline static VTOp mapOp(MPI_Op mpio) { |
| 139 | + if (mpio == MPI_SUM) return VTOp::Plus; |
| 140 | + if (mpio == MPI_MAX) return VTOp::Max; |
| 141 | + if (mpio == MPI_MIN) return VTOp::Min; |
| 142 | + return VTOp::Plus; |
| 143 | + } |
| 144 | + |
| 145 | + // Concrete implementation for element type T. |
| 146 | + template <typename T, typename SendBufT, typename RecvBufT> |
| 147 | + void reduce_impl(int root, MPI_Op op, SendBufT sendbuf, RecvBufT recvbuf, int count) { |
| 148 | + VTOp vk = mapOp(op); |
| 149 | + |
| 150 | + auto ctx = std::make_unique<ReduceCtx>(); |
| 151 | + ctx->out_ptr = static_cast<void*>(recvbuf); |
| 152 | + ctx->count = static_cast<std::size_t>(std::max(1, count)); |
| 153 | + ctx->done.store(false); |
| 154 | + |
| 155 | + printf("%d: initiating reduce\n", vt::theContext()->getNode()); |
| 156 | + |
| 157 | + if (count == 1) { |
| 158 | + // scalar path |
| 159 | + T value = *static_cast<T const*>(sendbuf); |
| 160 | + |
| 161 | + using MsgT = vt::collective::ReduceTMsg<T>; |
| 162 | + auto cb = vt::theCB()->makeCallbackSingleAnon<MsgT, ReduceCtx>( |
| 163 | + vt::pipe::LifetimeEnum::Once, ctx.get(), &ProxyWrapper::reduceAnonCb<T> |
| 164 | + ); |
| 165 | + auto msg = vt::makeMessage<MsgT>(value); |
| 166 | + if (vt::theContext()->getNode() == root) { |
| 167 | + msg->setCallback(cb); |
| 168 | + } |
| 169 | + ProxyT proxy = ProxyT(*this); |
| 170 | + |
| 171 | + if (vk == VTOp::Plus) { |
| 172 | + vt::theObjGroup()->template reduce< |
| 173 | + typename ProxyT::ObjGroupType, |
| 174 | + MsgT, |
| 175 | + &MsgT::template msgHandler< |
| 176 | + MsgT, vt::collective::PlusOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 177 | + > |
| 178 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 179 | + } else if (vk == VTOp::Max) { |
| 180 | + vt::theObjGroup()->template reduce< |
| 181 | + typename ProxyT::ObjGroupType, |
| 182 | + MsgT, |
| 183 | + &MsgT::template msgHandler< |
| 184 | + MsgT, vt::collective::MaxOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 185 | + > |
| 186 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 187 | + } else if (vk == VTOp::Min) { |
| 188 | + vt::theObjGroup()->template reduce< |
| 189 | + typename ProxyT::ObjGroupType, |
| 190 | + MsgT, |
| 191 | + &MsgT::template msgHandler< |
| 192 | + MsgT, vt::collective::MinOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 193 | + > |
| 194 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 195 | + } else { |
| 196 | + throw new std::runtime_error("Unsupported VTOp in reduce_impl"); |
| 197 | + } |
| 198 | + |
| 199 | + } else { |
| 200 | + // array path -> reduce a vector<T> |
| 201 | + std::vector<T> v(static_cast<std::size_t>(count)); |
| 202 | + std::memcpy(v.data(), static_cast<void const*>(sendbuf), sizeof(T) * static_cast<std::size_t>(count)); |
72 | 203 |
|
73 | | - // } else if (op == MPI_SUM) { |
| 204 | + using MsgT = vt::collective::ReduceTMsg<std::vector<T>>; |
| 205 | + auto cb = vt::theCB()->makeCallbackSingleAnon<MsgT, ReduceCtx>( |
| 206 | + vt::pipe::LifetimeEnum::Once, ctx.get(), &ProxyWrapper::reduceAnonCb<std::vector<T>> |
| 207 | + ); |
| 208 | + auto msg = vt::makeMessage<MsgT>(std::move(v)); |
| 209 | + if (vt::theContext()->getNode() == root) { |
| 210 | + msg->setCallback(cb); |
| 211 | + } |
74 | 212 |
|
75 | | - // } |
| 213 | + ProxyT proxy = ProxyT(*this); |
| 214 | + if (vk == VTOp::Plus) { |
| 215 | + vt::theObjGroup()->template reduce< |
| 216 | + typename ProxyT::ObjGroupType, |
| 217 | + MsgT, |
| 218 | + &MsgT::template msgHandler< |
| 219 | + MsgT, vt::collective::PlusOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 220 | + > |
| 221 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 222 | + } else if (vk == VTOp::Max) { |
| 223 | + vt::theObjGroup()->template reduce< |
| 224 | + typename ProxyT::ObjGroupType, |
| 225 | + MsgT, |
| 226 | + &MsgT::template msgHandler< |
| 227 | + MsgT, vt::collective::MaxOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 228 | + > |
| 229 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 230 | + } else if (vk == VTOp::Min) { |
| 231 | + vt::theObjGroup()->template reduce< |
| 232 | + typename ProxyT::ObjGroupType, |
| 233 | + MsgT, |
| 234 | + &MsgT::template msgHandler< |
| 235 | + MsgT, vt::collective::MinOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT> |
| 236 | + > |
| 237 | + >(proxy, msg, vt::collective::reduce::ReduceStamp{}); |
| 238 | + } else { |
| 239 | + throw new std::runtime_error("Unsupported VTOp in reduce_impl"); |
| 240 | + } |
| 241 | + } |
76 | 242 |
|
| 243 | + // Blocking wait: make VT progress until callback marks done on root rank. |
| 244 | + while (vt::theContext()->getNode() == root && !ctx->done.load(std::memory_order_acquire)) { |
| 245 | + vt::theSched()->runSchedulerOnceImpl(); |
| 246 | + } |
77 | 247 | } |
78 | 248 |
|
79 | 249 | }; |
|
0 commit comments