Skip to content

Commit bfec011

Browse files
committed
#1: reduce: implement MPI-like reduce in VT
1 parent 2004117 commit bfec011

File tree

2 files changed

+190
-13
lines changed

2 files changed

+190
-13
lines changed

examples/test_example.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@ struct MyClass {
1313
};
1414

1515
int main(int argc, char** argv) {
16-
auto comm = vt_lb::comm::CommMPI();
16+
auto comm = vt_lb::comm::CommVT();
1717
comm.init(argc, argv);
1818

1919
auto cls = std::make_unique<MyClass>();
2020
auto handle = comm.registerInstanceCollective(cls.get());
2121
auto rank = comm.getRank();
22+
23+
int value = 10;
24+
int recv_value = 0;
25+
handle.reduce(1, MPI_INT, MPI_SUM, &value, &recv_value, 1);
26+
27+
fmt::print("Rank {}: reduced value is {}\n", rank, recv_value);
28+
2229
if (rank == 0) {
2330
handle[1].send<&MyClass::myHandler2>(std::string{"hello from rank 0"});
2431
}

src/vt-lb/comm/comm_vt.h

Lines changed: 182 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,202 @@
4848
#include <memory>
4949
#include <tuple>
5050
#include <unordered_map>
51+
#include <atomic>
52+
#include <cstring>
53+
#include <vector>
54+
#include <type_traits>
55+
#include <algorithm>
5156

57+
#include <vt/collective/reduce/operators/default_msg.h>
5258
#include <vt/transport.h>
5359

5460
namespace vt_lb::comm {
5561

56-
template <typename U>
57-
void reduceCb(U value) {
58-
59-
}
60-
6162
template <typename ProxyT>
6263
struct ProxyWrapper : ProxyT {
6364
ProxyWrapper(ProxyT proxy) : ProxyT(proxy) { }
6465

66+
struct ReduceCtx {
67+
void* out_ptr = nullptr;
68+
std::atomic<bool> done{false};
69+
std::size_t count = 0;
70+
};
71+
72+
template <typename T>
73+
static void reduceAnonCb(vt::collective::ReduceTMsg<T>* msg, ReduceCtx* ctx) {
74+
auto const& val = msg->getVal();
75+
76+
printf("%d: callback invoked\n", vt::theContext()->getNode());
77+
78+
if constexpr (
79+
std::is_same_v<std::decay_t<T>, int> || std::is_same_v<std::decay_t<T>, double> ||
80+
std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, long> ||
81+
std::is_same_v<std::decay_t<T>, long long>
82+
) {
83+
// scalar case
84+
*static_cast<std::decay_t<T>*>(ctx->out_ptr) = val;
85+
} else {
86+
// vector case
87+
using ValT = typename T::value_type;
88+
static_assert(
89+
std::is_trivially_copyable_v<ValT> || std::is_arithmetic_v<ValT>,
90+
"Reduce value must be trivially copyable for this helper"
91+
);
92+
std::memcpy(ctx->out_ptr, std::addressof(val), sizeof(ValT) * std::max<std::size_t>(1, ctx->count));
93+
}
94+
95+
// Publish completion after writing result.
96+
ctx->done.store(true, std::memory_order_release);
97+
}
98+
99+
// Runtime-dispatch wrapper: map a few common MPI datatypes to C++ types.
65100
template <typename U, typename V>
66101
void reduce(int root, MPI_Datatype datatype, MPI_Op op, U sendbuf, V recvbuf, int count) {
67-
// @todo finish this??
68-
// if (op == MPI_MAX) {
69-
// auto cb = vt::theCB()->makeSend<reduceCb>(vt::pipe::LifetimeEnum::Once, reduceCb);
70-
// this->template reduce<vt::collective::MaxOp>();
71-
// } else if (op == MPI_MIN) {
102+
switch (datatype) {
103+
case MPI_INT:
104+
if constexpr (std::is_same_v<U, int*>) {
105+
reduce_impl<int>(root, op, sendbuf, recvbuf, count);
106+
}
107+
break;
108+
case MPI_DOUBLE:
109+
if constexpr (std::is_same_v<U, double*>) {
110+
reduce_impl<double>(root, op, sendbuf, recvbuf, count);
111+
}
112+
break;
113+
case MPI_FLOAT:
114+
if constexpr (std::is_same_v<U, float*>) {
115+
reduce_impl<float>(root, op, sendbuf, recvbuf, count);
116+
}
117+
break;
118+
case MPI_LONG:
119+
if constexpr (std::is_same_v<U, long*>) {
120+
reduce_impl<long>(root, op, sendbuf, recvbuf, count);
121+
}
122+
break;
123+
case MPI_LONG_LONG:
124+
if constexpr (std::is_same_v<U, long long*>) {
125+
reduce_impl<long long>(root, op, sendbuf, recvbuf, count);
126+
}
127+
break;
128+
default:
129+
vtAbort("ProxyWrapper::reduce: unsupported MPI_Datatype");
130+
}
131+
}
132+
133+
private:
134+
// Map a few MPI_Op values to VT operator choices.
135+
// Extend as needed.
136+
enum class VTOp { Plus, Max, Min };
137+
138+
inline static VTOp mapOp(MPI_Op mpio) {
139+
if (mpio == MPI_SUM) return VTOp::Plus;
140+
if (mpio == MPI_MAX) return VTOp::Max;
141+
if (mpio == MPI_MIN) return VTOp::Min;
142+
return VTOp::Plus;
143+
}
144+
145+
// Concrete implementation for element type T.
146+
template <typename T, typename SendBufT, typename RecvBufT>
147+
void reduce_impl(int root, MPI_Op op, SendBufT sendbuf, RecvBufT recvbuf, int count) {
148+
VTOp vk = mapOp(op);
149+
150+
auto ctx = std::make_unique<ReduceCtx>();
151+
ctx->out_ptr = static_cast<void*>(recvbuf);
152+
ctx->count = static_cast<std::size_t>(std::max(1, count));
153+
ctx->done.store(false);
154+
155+
printf("%d: initiating reduce\n", vt::theContext()->getNode());
156+
157+
if (count == 1) {
158+
// scalar path
159+
T value = *static_cast<T const*>(sendbuf);
160+
161+
using MsgT = vt::collective::ReduceTMsg<T>;
162+
auto cb = vt::theCB()->makeCallbackSingleAnon<MsgT, ReduceCtx>(
163+
vt::pipe::LifetimeEnum::Once, ctx.get(), &ProxyWrapper::reduceAnonCb<T>
164+
);
165+
auto msg = vt::makeMessage<MsgT>(value);
166+
if (vt::theContext()->getNode() == root) {
167+
msg->setCallback(cb);
168+
}
169+
ProxyT proxy = ProxyT(*this);
170+
171+
if (vk == VTOp::Plus) {
172+
vt::theObjGroup()->template reduce<
173+
typename ProxyT::ObjGroupType,
174+
MsgT,
175+
&MsgT::template msgHandler<
176+
MsgT, vt::collective::PlusOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT>
177+
>
178+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
179+
} else if (vk == VTOp::Max) {
180+
vt::theObjGroup()->template reduce<
181+
typename ProxyT::ObjGroupType,
182+
MsgT,
183+
&MsgT::template msgHandler<
184+
MsgT, vt::collective::MaxOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT>
185+
>
186+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
187+
} else if (vk == VTOp::Min) {
188+
vt::theObjGroup()->template reduce<
189+
typename ProxyT::ObjGroupType,
190+
MsgT,
191+
&MsgT::template msgHandler<
192+
MsgT, vt::collective::MinOp<T>, vt::collective::reduce::operators::ReduceCallback<MsgT>
193+
>
194+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
195+
} else {
196+
throw new std::runtime_error("Unsupported VTOp in reduce_impl");
197+
}
198+
199+
} else {
200+
// array path -> reduce a vector<T>
201+
std::vector<T> v(static_cast<std::size_t>(count));
202+
std::memcpy(v.data(), static_cast<void const*>(sendbuf), sizeof(T) * static_cast<std::size_t>(count));
72203

73-
// } else if (op == MPI_SUM) {
204+
using MsgT = vt::collective::ReduceTMsg<std::vector<T>>;
205+
auto cb = vt::theCB()->makeCallbackSingleAnon<MsgT, ReduceCtx>(
206+
vt::pipe::LifetimeEnum::Once, ctx.get(), &ProxyWrapper::reduceAnonCb<std::vector<T>>
207+
);
208+
auto msg = vt::makeMessage<MsgT>(std::move(v));
209+
if (vt::theContext()->getNode() == root) {
210+
msg->setCallback(cb);
211+
}
74212

75-
// }
213+
ProxyT proxy = ProxyT(*this);
214+
if (vk == VTOp::Plus) {
215+
vt::theObjGroup()->template reduce<
216+
typename ProxyT::ObjGroupType,
217+
MsgT,
218+
&MsgT::template msgHandler<
219+
MsgT, vt::collective::PlusOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT>
220+
>
221+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
222+
} else if (vk == VTOp::Max) {
223+
vt::theObjGroup()->template reduce<
224+
typename ProxyT::ObjGroupType,
225+
MsgT,
226+
&MsgT::template msgHandler<
227+
MsgT, vt::collective::MaxOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT>
228+
>
229+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
230+
} else if (vk == VTOp::Min) {
231+
vt::theObjGroup()->template reduce<
232+
typename ProxyT::ObjGroupType,
233+
MsgT,
234+
&MsgT::template msgHandler<
235+
MsgT, vt::collective::MinOp<std::vector<T>>, vt::collective::reduce::operators::ReduceCallback<MsgT>
236+
>
237+
>(proxy, msg, vt::collective::reduce::ReduceStamp{});
238+
} else {
239+
throw new std::runtime_error("Unsupported VTOp in reduce_impl");
240+
}
241+
}
76242

243+
// Blocking wait: make VT progress until callback marks done on root rank.
244+
while (vt::theContext()->getNode() == root && !ctx->done.load(std::memory_order_acquire)) {
245+
vt::theSched()->runSchedulerOnceImpl();
246+
}
77247
}
78248

79249
};

0 commit comments

Comments
 (0)