@@ -61,6 +61,14 @@ void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
6161 }
6262}
6363
64+ inline void UnregisterWait (HANDLE waitHandle) noexcept
65+ {
66+ // INVALID_HANDLE_VALUE makes UnregisterWaitEx block until any in-flight wait callback returns.
67+ LOG_LAST_ERROR_IF (!UnregisterWaitEx (waitHandle, INVALID_HANDLE_VALUE));
68+ }
69+
70+ using unique_registered_wait = wil::unique_any_handle_null<decltype (&UnregisterWait), &UnregisterWait>;
71+
6472} // namespace
6573
6674// HandleWrapper
@@ -914,133 +922,158 @@ void DockerIORelayHandle::OnRead(const gsl::span<char>& Buffer)
914922
915923// MultiHandleWait
916924
925+ MultiHandleWait::MultiHandleWait (MultiHandleWait&& other) noexcept
926+ {
927+ *this = std::move (other);
928+ }
929+
930+ MultiHandleWait& MultiHandleWait::operator =(MultiHandleWait&& other) noexcept
931+ {
932+ if (this != &other)
933+ {
934+ m_handles = std::move (other.m_handles );
935+ m_handleSignaledEvent = std::move (other.m_handleSignaledEvent );
936+ m_cancel = other.m_cancel ;
937+
938+ for (auto & entry : m_handles)
939+ {
940+ entry->self = this ;
941+ }
942+
943+ // N.B. moving a MultiHandleWait() while running is not supported
944+ WI_ASSERT (m_signaledHandles.empty ());
945+ }
946+
947+ return *this ;
948+ }
949+
917950void MultiHandleWait::AddHandle (std::unique_ptr<OverlappedIOHandle>&& handle, Flags flags)
918951{
919- m_handles.emplace_back (flags, std::move (handle));
952+ auto entry = std::make_unique<Entry>();
953+ entry->HandleFlags = flags;
954+ entry->Handle = std::move (handle);
955+ entry->self = this ;
956+ m_handles.emplace_back (std::move (entry));
920957}
921958
922959void MultiHandleWait::Cancel ()
923960{
924961 m_cancel = true ;
925962}
926963
964+ void NTAPI MultiHandleWait::WaitCallback (PVOID Context, BOOLEAN /* TimerOrWaitFired*/ )
965+ {
966+ auto * entry = static_cast <Entry*>(Context);
967+
968+ entry->self ->m_signaledHandles .push (entry);
969+ entry->self ->m_handleSignaledEvent .SetEvent ();
970+ }
971+
927972bool MultiHandleWait::Run (std::optional<std::chrono::milliseconds> Timeout)
928973{
929974 m_cancel = false ; // Run may be called multiple times.
930975
931976 std::optional<std::chrono::steady_clock::time_point> deadline;
932-
933977 if (Timeout.has_value ())
934978 {
935979 deadline = std::chrono::steady_clock::now () + Timeout.value ();
936980 }
937981
938- // Run until all handles are completed.
982+ std::vector<unique_registered_wait> callbacks;
939983
940- while (!m_handles. empty () && ! m_cancel)
984+ while (!m_cancel)
941985 {
942- // Schedule IO on each handle until all are either pending, or completed.
943- for (size_t i = 0 ; i < m_handles.size () && !m_cancel; i++)
986+ // Cancel any pending callback.
987+ callbacks.clear ();
988+
989+ Entry* signaledEntry = nullptr ;
990+ while (m_signaledHandles.try_pop (signaledEntry))
991+ {
992+ try
993+ {
994+ signaledEntry->Handle ->Collect ();
995+ }
996+ catch (...)
997+ {
998+ if (WI_IsFlagSet (signaledEntry->HandleFlags , Flags::IgnoreErrors))
999+ {
1000+ signaledEntry->Handle .reset ();
1001+ continue ;
1002+ }
1003+
1004+ throw ;
1005+ }
1006+ }
1007+
1008+ m_handleSignaledEvent.ResetEvent ();
1009+
1010+ bool hasHandleToWaitFor = false ;
1011+ for (auto it = m_handles.begin (); it != m_handles.end ();)
9441012 {
945- while (m_handles[i].second ->GetState () == IOHandleStatus::Standby && !m_cancel)
1013+ auto & entry = **it;
1014+
1015+ while (entry.Handle && entry.Handle ->GetState () == IOHandleStatus::Standby && !m_cancel)
9461016 {
9471017 try
9481018 {
949- m_handles[i]. second ->Schedule ();
1019+ entry. Handle ->Schedule ();
9501020 }
9511021 catch (...)
9521022 {
953- if (WI_IsFlagSet (m_handles[i]. first , Flags::IgnoreErrors))
1023+ if (WI_IsFlagSet (entry. HandleFlags , Flags::IgnoreErrors))
9541024 {
955- m_handles[i]. second .reset (); // Reset the handle so it can be deleted.
1025+ entry. Handle .reset ();
9561026 break ;
9571027 }
958- else
959- {
960- throw ;
961- }
1028+
1029+ throw ;
9621030 }
9631031 }
964- }
9651032
966- // Remove completed handles from m_handles.
967- bool hasHandleToWaitFor = false ;
968- for (auto it = m_handles.begin (); it != m_handles.end ();)
969- {
970- if (!it->second )
971- {
972- it = m_handles.erase (it);
973- }
974- else if (it->second ->GetState () == IOHandleStatus::Completed)
1033+ if (!entry.Handle || entry.Handle ->GetState () == IOHandleStatus::Completed)
9751034 {
976- if (WI_IsFlagSet (it-> first , Flags::CancelOnCompleted))
1035+ if (entry. Handle && WI_IsFlagSet (entry. HandleFlags , Flags::CancelOnCompleted))
9771036 {
978- m_cancel = true ; // Cancel the IO if a handle with CancelOnCompleted is in the completed state.
1037+ m_cancel = true ;
9791038 }
9801039
9811040 it = m_handles.erase (it);
1041+ continue ;
9821042 }
983- else
1043+
1044+ auto & callback = callbacks.emplace_back ();
1045+
1046+ THROW_IF_WIN32_BOOL_FALSE (RegisterWaitForSingleObject (
1047+ &callback, entry.Handle ->GetHandle (), &WaitCallback, &entry, INFINITE, WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE));
1048+
1049+ if (WI_IsFlagClear (entry.HandleFlags , Flags::NeedNotComplete))
9841050 {
985- // If only NeedNotComplete handles are left, we want to exit Run.
986- if (WI_IsFlagClear (it->first , Flags::NeedNotComplete))
987- {
988- hasHandleToWaitFor = true ;
989- }
990- ++it;
1051+ hasHandleToWaitFor = true ;
9911052 }
992- }
9931053
994- if (!hasHandleToWaitFor || m_cancel)
995- {
996- break ;
1054+ ++it;
9971055 }
9981056
999- // Wait for the next operation to complete.
1000- std::vector<HANDLE> waitHandles;
1001- for (const auto & e : m_handles)
1057+ if (m_handles.empty () || !hasHandleToWaitFor || m_cancel)
10021058 {
1003- waitHandles. emplace_back (e. second -> GetHandle ()) ;
1059+ break ;
10041060 }
10051061
10061062 DWORD waitTimeout = INFINITE;
10071063 if (deadline.has_value ())
10081064 {
1009- auto miliseconds =
1065+ auto milliseconds =
10101066 std::chrono::duration_cast<std::chrono::milliseconds>(deadline.value () - std::chrono::steady_clock::now ()).count ();
10111067
1012- waitTimeout = static_cast <DWORD>(std::max ( 0LL , miliseconds ));
1068+ waitTimeout = static_cast <DWORD>(std::max< long long >( 0 , milliseconds ));
10131069 }
10141070
1015- auto result = WaitForMultipleObjects (static_cast <DWORD>(waitHandles.size ()), waitHandles.data (), false , waitTimeout);
1016- if (result == WAIT_TIMEOUT)
1017- {
1018- THROW_WIN32 (ERROR_TIMEOUT);
1019- }
1020- else if (result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + m_handles.size ())
1021- {
1022- auto index = result - WAIT_OBJECT_0;
1023-
1024- try
1025- {
1026- m_handles[index].second ->Collect ();
1027- }
1028- catch (...)
1029- {
1030- if (WI_IsFlagSet (m_handles[index].first , Flags::IgnoreErrors))
1031- {
1032- m_handles.erase (m_handles.begin () + index);
1033- }
1034- else
1035- {
1036- throw ;
1037- }
1038- }
1039- }
1040- else
1041- {
1042- THROW_LAST_ERROR_MSG (" Timeout: %lu, Count: %llu" , waitTimeout, waitHandles.size ());
1043- }
1071+ THROW_HR_IF_MSG (
1072+ HRESULT_FROM_WIN32 (ERROR_TIMEOUT),
1073+ !m_handleSignaledEvent.wait (waitTimeout),
1074+ " Timed out waiting for %llu handles. Timeout: %lu" ,
1075+ m_handles.size (),
1076+ waitTimeout);
10441077 }
10451078
10461079 return !m_cancel;
0 commit comments