Skip to content

Commit aca0c9f

Browse files
authored
Merge pull request #2494 from DARMA-tasking/2493-fix-callback-bugs-for-reducing-in-lb-subcomponent
2493 fix callback bugs for reducing in lb subcomponent
2 parents 5fe1b46 + bd89f87 commit aca0c9f

File tree

7 files changed

+59
-66
lines changed

7 files changed

+59
-66
lines changed

examples/callback/callback_context.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static void handler(CallbackMsg* msg) {
8989
auto data_msg = vt::makeMessage<DataMsg>();
9090
data_msg->vec_ = std::vector<int>{18,45,28,-1,344};
9191
fmt::print("handler: vec.size={}\n", data_msg->vec_.size());
92-
cb.sendMsg(data_msg);
92+
cb.sendMsg<DataMsg>(data_msg);
9393
}
9494

9595
// Some instance of the context

src/vt/collective/reduce/operators/default_op.impl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ template <typename MsgT, typename Op, typename ActOp>
9292
);
9393
while (cur_msg != nullptr) {
9494
ReduceCombine<>::combine<MsgT,Op,ActOp>(fst_msg, cur_msg);
95+
if (!fst_msg->hasValidCallback() && cur_msg->hasValidCallback()) {
96+
if (cur_msg->isParamCallback()) {
97+
fst_msg->setCallback(cur_msg->getParamCallback());
98+
} else {
99+
fst_msg->setCallback(cur_msg->getMsgCallback());
100+
}
101+
}
95102
cur_msg = cur_msg->template getNext<MsgT>();
96103
}
97104
}

src/vt/pipe/callback/cb_union/cb_raw_base.h

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -229,75 +229,16 @@ struct CallbackTyped : CallbackRawBaseSingle {
229229
}
230230

231231
template <typename... Params>
232-
void sendTuple(std::tuple<Params...> tup) {
233-
using Trait = CBTraits<Args...>;
234-
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
235-
auto msg = vt::makeMessage<MsgT>();
236-
msg->setParams(std::move(tup));
237-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
238-
}
232+
void sendTuple(std::tuple<Params...> tup);
239233

240234
template <typename... Params>
241-
void send(Params&&... params) {
242-
using Trait = CBTraits<Args...>;
243-
if constexpr (std::is_same_v<typename Trait::MsgT, NoMsg>) {
244-
// We have to go through some tricky code to make the MsgProps case work
245-
// If we use the type for Params to send, it's possible that we have a
246-
// type mismatch in the actual handler type. A possible edge case is when
247-
// a char const* is sent, but the handler is a std::string. In this case,
248-
// the ParamMsg will be cast incorrectly during the virual dispatch to a
249-
// collection because callbacks don't have the collection type. Thus, the
250-
// wrong ParamMsg will be cast to which requires serialization, leading to
251-
// a failure.
252-
if constexpr (sizeof...(Params) == sizeof...(Args) + 1) {
253-
using MsgT = messaging::ParamMsg<
254-
std::tuple<
255-
std::decay_t<std::tuple_element_t<0, std::tuple<Params...>>>,
256-
std::decay_t<Args>...
257-
>
258-
>;
259-
auto msg = vt::makeMessage<MsgT>();
260-
msg->setParams(std::forward<Params>(params)...);
261-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
262-
} else {
263-
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
264-
auto msg = vt::makeMessage<MsgT>();
265-
msg->setParams(std::forward<Params>(params)...);
266-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
267-
}
268-
} else {
269-
using MsgT = typename Trait::MsgT;
270-
auto msg = makeMessage<MsgT>(std::forward<Params>(params)...);
271-
sendMsg(msg.get());
272-
}
273-
}
274-
275-
void send(typename CBTraits<Args...>::MsgT* msg) {
276-
using MsgT = typename CBTraits<Args...>::MsgT;
277-
if constexpr (not std::is_same_v<MsgT, NoMsg>) {
278-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
279-
}
280-
}
235+
void send(Params&&... params);
281236

282237
template <typename MsgT>
283238
void send(messaging::MsgPtrThief<MsgT> msg) {
284239
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
285240
}
286241

287-
void sendMsg(messaging::MsgPtrThief<typename CBTraits<Args...>::MsgT> msg) {
288-
using MsgT = typename CBTraits<Args...>::MsgT;
289-
if constexpr (not std::is_same_v<MsgT, NoMsg>) {
290-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
291-
}
292-
}
293-
294-
void sendMsg(typename CBTraits<Args...>::MsgT* msg) {
295-
using MsgT = typename CBTraits<Args...>::MsgT;
296-
if constexpr (not std::is_same_v<MsgT, NoMsg>) {
297-
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
298-
}
299-
}
300-
301242
template <typename SerializerT>
302243
void serialize(SerializerT& s) {
303244
CallbackRawBaseSingle::serialize(s);

src/vt/pipe/callback/cb_union/cb_raw_base.impl.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,51 @@ void CallbackRawBaseSingle::serialize(SerializerT& s) {
109109
s | cb_ | pipe_;
110110
}
111111

112+
template <typename... Args>
113+
template <typename... Params>
114+
void CallbackTyped<Args...>::sendTuple(std::tuple<Params...> tup) {
115+
using Trait = CBTraits<Args...>;
116+
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
117+
auto msg = vt::makeMessage<MsgT>();
118+
msg->setParams(std::move(tup));
119+
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
120+
}
121+
122+
template <typename... Args>
123+
template <typename... Params>
124+
void CallbackTyped<Args...>::send(Params&&... params) {
125+
using Trait = CBTraits<Args...>;
126+
if constexpr (std::is_same_v<typename Trait::MsgT, NoMsg>) {
127+
// We have to go through some tricky code to make the MsgProps case work
128+
// If we use the type for Params to send, it's possible that we have a
129+
// type mismatch in the actual handler type. A possible edge case is when
130+
// a char const* is sent, but the handler is a std::string. In this case,
131+
// the ParamMsg will be cast incorrectly during the virual dispatch to a
132+
// collection because callbacks don't have the collection type. Thus, the
133+
// wrong ParamMsg will be cast to which requires serialization, leading to
134+
// a failure.
135+
if constexpr (sizeof...(Params) == sizeof...(Args) + 1) {
136+
using MsgT = messaging::ParamMsg<
137+
std::tuple<
138+
std::decay_t<std::tuple_element_t<0, std::tuple<Params...>>>,
139+
std::decay_t<Args>...
140+
>
141+
>;
142+
auto msg = vt::makeMessage<MsgT>();
143+
msg->setParams(std::forward<Params>(params)...);
144+
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
145+
} else {
146+
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
147+
auto msg = vt::makeMessage<MsgT>();
148+
msg->setParams(std::forward<Params>(params)...);
149+
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
150+
}
151+
} else {
152+
using MsgT = typename Trait::MsgT;
153+
auto msg = makeMessage<MsgT>(std::forward<Params>(params)...);
154+
sendMsg(msg.get());
155+
}
156+
}
112157

113158
}}}} /* end namespace vt::pipe::callback::cbunion */
114159

src/vt/rdma/rdma.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ RDMAManager::RDMAManager()
275275

276276
if (not msg->has_bytes) {
277277
auto cbmsg = makeMessage<GetInfoChannel>(num_bytes);
278-
msg->cb.sendMsg(cbmsg);
278+
msg->cb.sendMsg<GetInfoChannel>(cbmsg);
279279
}
280280

281281
theRDMA()->createDirectChannelInternal(

tests/unit/active/test_active_send_large.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ template <typename MsgT>
9999
void myHandler(MsgT* m) {
100100
checkMsg(m);
101101
auto msg = makeMessage<RecvMsg>();
102-
m->cb_.send(msg.get());
102+
m->cb_.template sendMsg<RecvMsg>(msg.get());
103103
}
104104

105105
template <typename T>

tests/unit/memory/test_memory_lifetime.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_normal_lifetime_msgptr) {
211211
////////////////////////////////////////////////////////////////////////////////
212212
static void callbackHan(CallbackMsg<NormalTestMsg>* msg) {
213213
auto send_msg = makeMessage<NormalTestMsg>();
214-
msg->cb_.send(send_msg.get());
214+
msg->cb_.template sendMsg<NormalTestMsg>(send_msg.get());
215215

216216
theTerm()->addAction([send_msg]{
217217
// Call event cleanup all pending MPI requests to clear
@@ -244,7 +244,7 @@ TEST_F(TestMemoryLifetime, test_active_send_callback_lifetime_1) {
244244
////////////////////////////////////////////////////////////////////////////////
245245
static void callbackHan(CallbackMsg<SerialTestMsg>* msg) {
246246
auto send_msg = makeMessage<SerialTestMsg>();
247-
msg->cb_.send(send_msg.get());
247+
msg->cb_.template sendMsg<SerialTestMsg>(send_msg.get());
248248

249249
theTerm()->addAction([send_msg]{
250250
// Call event cleanup all pending MPI requests to clear

0 commit comments

Comments
 (0)