diff --git a/examples/callback/callback_context.cc b/examples/callback/callback_context.cc index 2f97261b73..9478e08052 100644 --- a/examples/callback/callback_context.cc +++ b/examples/callback/callback_context.cc @@ -89,7 +89,7 @@ static void handler(CallbackMsg* msg) { auto data_msg = vt::makeMessage(); data_msg->vec_ = std::vector{18,45,28,-1,344}; fmt::print("handler: vec.size={}\n", data_msg->vec_.size()); - cb.sendMsg(data_msg); + cb.sendMsg(data_msg); } // Some instance of the context diff --git a/src/vt/collective/reduce/operators/default_op.impl.h b/src/vt/collective/reduce/operators/default_op.impl.h index a8a87e72c4..ee58b3d4b0 100644 --- a/src/vt/collective/reduce/operators/default_op.impl.h +++ b/src/vt/collective/reduce/operators/default_op.impl.h @@ -92,6 +92,13 @@ template ); while (cur_msg != nullptr) { ReduceCombine<>::combine(fst_msg, cur_msg); + if (!fst_msg->hasValidCallback() && cur_msg->hasValidCallback()) { + if (cur_msg->isParamCallback()) { + fst_msg->setCallback(cur_msg->getParamCallback()); + } else { + fst_msg->setCallback(cur_msg->getMsgCallback()); + } + } cur_msg = cur_msg->template getNext(); } } diff --git a/src/vt/pipe/callback/cb_union/cb_raw_base.h b/src/vt/pipe/callback/cb_union/cb_raw_base.h index 67d80f4711..1169ab9eb3 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw_base.h +++ b/src/vt/pipe/callback/cb_union/cb_raw_base.h @@ -229,75 +229,16 @@ struct CallbackTyped : CallbackRawBaseSingle { } template - void sendTuple(std::tuple tup) { - using Trait = CBTraits; - using MsgT = messaging::ParamMsg; - auto msg = vt::makeMessage(); - msg->setParams(std::move(tup)); - CallbackRawBaseSingle::sendMsg(msg); - } + void sendTuple(std::tuple tup); template - void send(Params&&... params) { - using Trait = CBTraits; - if constexpr (std::is_same_v) { - // We have to go through some tricky code to make the MsgProps case work - // If we use the type for Params to send, it's possible that we have a - // type mismatch in the actual handler type. A possible edge case is when - // a char const* is sent, but the handler is a std::string. In this case, - // the ParamMsg will be cast incorrectly during the virual dispatch to a - // collection because callbacks don't have the collection type. Thus, the - // wrong ParamMsg will be cast to which requires serialization, leading to - // a failure. - if constexpr (sizeof...(Params) == sizeof...(Args) + 1) { - using MsgT = messaging::ParamMsg< - std::tuple< - std::decay_t>>, - std::decay_t... - > - >; - auto msg = vt::makeMessage(); - msg->setParams(std::forward(params)...); - CallbackRawBaseSingle::sendMsg(msg); - } else { - using MsgT = messaging::ParamMsg; - auto msg = vt::makeMessage(); - msg->setParams(std::forward(params)...); - CallbackRawBaseSingle::sendMsg(msg); - } - } else { - using MsgT = typename Trait::MsgT; - auto msg = makeMessage(std::forward(params)...); - sendMsg(msg.get()); - } - } - - void send(typename CBTraits::MsgT* msg) { - using MsgT = typename CBTraits::MsgT; - if constexpr (not std::is_same_v) { - CallbackRawBaseSingle::sendMsg(msg); - } - } + void send(Params&&... params); template void send(messaging::MsgPtrThief msg) { CallbackRawBaseSingle::sendMsg(msg); } - void sendMsg(messaging::MsgPtrThief::MsgT> msg) { - using MsgT = typename CBTraits::MsgT; - if constexpr (not std::is_same_v) { - CallbackRawBaseSingle::sendMsg(msg); - } - } - - void sendMsg(typename CBTraits::MsgT* msg) { - using MsgT = typename CBTraits::MsgT; - if constexpr (not std::is_same_v) { - CallbackRawBaseSingle::sendMsg(msg); - } - } - template void serialize(SerializerT& s) { CallbackRawBaseSingle::serialize(s); diff --git a/src/vt/pipe/callback/cb_union/cb_raw_base.impl.h b/src/vt/pipe/callback/cb_union/cb_raw_base.impl.h index 6bad1cf4a4..4d131fc770 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw_base.impl.h +++ b/src/vt/pipe/callback/cb_union/cb_raw_base.impl.h @@ -109,6 +109,51 @@ void CallbackRawBaseSingle::serialize(SerializerT& s) { s | cb_ | pipe_; } +template +template +void CallbackTyped::sendTuple(std::tuple tup) { + using Trait = CBTraits; + using MsgT = messaging::ParamMsg; + auto msg = vt::makeMessage(); + msg->setParams(std::move(tup)); + CallbackRawBaseSingle::sendMsg(msg); +} + +template +template +void CallbackTyped::send(Params&&... params) { + using Trait = CBTraits; + if constexpr (std::is_same_v) { + // We have to go through some tricky code to make the MsgProps case work + // If we use the type for Params to send, it's possible that we have a + // type mismatch in the actual handler type. A possible edge case is when + // a char const* is sent, but the handler is a std::string. In this case, + // the ParamMsg will be cast incorrectly during the virual dispatch to a + // collection because callbacks don't have the collection type. Thus, the + // wrong ParamMsg will be cast to which requires serialization, leading to + // a failure. + if constexpr (sizeof...(Params) == sizeof...(Args) + 1) { + using MsgT = messaging::ParamMsg< + std::tuple< + std::decay_t>>, + std::decay_t... + > + >; + auto msg = vt::makeMessage(); + msg->setParams(std::forward(params)...); + CallbackRawBaseSingle::sendMsg(msg); + } else { + using MsgT = messaging::ParamMsg; + auto msg = vt::makeMessage(); + msg->setParams(std::forward(params)...); + CallbackRawBaseSingle::sendMsg(msg); + } + } else { + using MsgT = typename Trait::MsgT; + auto msg = makeMessage(std::forward(params)...); + sendMsg(msg.get()); + } +} }}}} /* end namespace vt::pipe::callback::cbunion */ diff --git a/src/vt/rdma/rdma.cc b/src/vt/rdma/rdma.cc index 2a0a71039e..b501f4b20d 100644 --- a/src/vt/rdma/rdma.cc +++ b/src/vt/rdma/rdma.cc @@ -275,7 +275,7 @@ RDMAManager::RDMAManager() if (not msg->has_bytes) { auto cbmsg = makeMessage(num_bytes); - msg->cb.sendMsg(cbmsg); + msg->cb.sendMsg(cbmsg); } theRDMA()->createDirectChannelInternal( diff --git a/tests/unit/active/test_active_send_large.cc b/tests/unit/active/test_active_send_large.cc index b694930ce9..03e1648798 100644 --- a/tests/unit/active/test_active_send_large.cc +++ b/tests/unit/active/test_active_send_large.cc @@ -99,7 +99,7 @@ template void myHandler(MsgT* m) { checkMsg(m); auto msg = makeMessage(); - m->cb_.send(msg.get()); + m->cb_.template sendMsg(msg.get()); } template diff --git a/tests/unit/memory/test_memory_lifetime.cc b/tests/unit/memory/test_memory_lifetime.cc index 380628af7b..12a4d12b0a 100644 --- a/tests/unit/memory/test_memory_lifetime.cc +++ b/tests/unit/memory/test_memory_lifetime.cc @@ -211,7 +211,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_normal_lifetime_msgptr) { //////////////////////////////////////////////////////////////////////////////// static void callbackHan(CallbackMsg* msg) { auto send_msg = makeMessage(); - msg->cb_.send(send_msg.get()); + msg->cb_.template sendMsg(send_msg.get()); theTerm()->addAction([send_msg]{ // Call event cleanup all pending MPI requests to clear @@ -244,7 +244,7 @@ TEST_F(TestMemoryLifetime, test_active_send_callback_lifetime_1) { //////////////////////////////////////////////////////////////////////////////// static void callbackHan(CallbackMsg* msg) { auto send_msg = makeMessage(); - msg->cb_.send(send_msg.get()); + msg->cb_.template sendMsg(send_msg.get()); theTerm()->addAction([send_msg]{ // Call event cleanup all pending MPI requests to clear