Skip to content

Commit 0b402cd

Browse files
authored
Use RegisterWaitForSingleObject() in MultiHandleIOWait (#40658)
* Save state * Add test coverage * Cleanup for PR * Remove stale command * Apply PR feedback * Create explicit move ctor
1 parent e0198f1 commit 0b402cd

4 files changed

Lines changed: 261 additions & 75 deletions

File tree

src/windows/common/HandleIO.cpp

Lines changed: 105 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
6161
}
6262
}
6363

64+
inline void UnregisterWait(HANDLE waitHandle) noexcept
65+
{
66+
// INVALID_HANDLE_VALUE makes UnregisterWaitEx block until any in-flight wait callback returns.
67+
LOG_LAST_ERROR_IF(!UnregisterWaitEx(waitHandle, INVALID_HANDLE_VALUE));
68+
}
69+
70+
using unique_registered_wait = wil::unique_any_handle_null<decltype(&UnregisterWait), &UnregisterWait>;
71+
6472
} // namespace
6573

6674
// HandleWrapper
@@ -914,133 +922,158 @@ void DockerIORelayHandle::OnRead(const gsl::span<char>& Buffer)
914922

915923
// MultiHandleWait
916924

925+
MultiHandleWait::MultiHandleWait(MultiHandleWait&& other) noexcept
926+
{
927+
*this = std::move(other);
928+
}
929+
930+
MultiHandleWait& MultiHandleWait::operator=(MultiHandleWait&& other) noexcept
931+
{
932+
if (this != &other)
933+
{
934+
m_handles = std::move(other.m_handles);
935+
m_handleSignaledEvent = std::move(other.m_handleSignaledEvent);
936+
m_cancel = other.m_cancel;
937+
938+
for (auto& entry : m_handles)
939+
{
940+
entry->self = this;
941+
}
942+
943+
// N.B. moving a MultiHandleWait() while running is not supported
944+
WI_ASSERT(m_signaledHandles.empty());
945+
}
946+
947+
return *this;
948+
}
949+
917950
void MultiHandleWait::AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle, Flags flags)
918951
{
919-
m_handles.emplace_back(flags, std::move(handle));
952+
auto entry = std::make_unique<Entry>();
953+
entry->HandleFlags = flags;
954+
entry->Handle = std::move(handle);
955+
entry->self = this;
956+
m_handles.emplace_back(std::move(entry));
920957
}
921958

922959
void MultiHandleWait::Cancel()
923960
{
924961
m_cancel = true;
925962
}
926963

964+
void NTAPI MultiHandleWait::WaitCallback(PVOID Context, BOOLEAN /*TimerOrWaitFired*/)
965+
{
966+
auto* entry = static_cast<Entry*>(Context);
967+
968+
entry->self->m_signaledHandles.push(entry);
969+
entry->self->m_handleSignaledEvent.SetEvent();
970+
}
971+
927972
bool MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
928973
{
929974
m_cancel = false; // Run may be called multiple times.
930975

931976
std::optional<std::chrono::steady_clock::time_point> deadline;
932-
933977
if (Timeout.has_value())
934978
{
935979
deadline = std::chrono::steady_clock::now() + Timeout.value();
936980
}
937981

938-
// Run until all handles are completed.
982+
std::vector<unique_registered_wait> callbacks;
939983

940-
while (!m_handles.empty() && !m_cancel)
984+
while (!m_cancel)
941985
{
942-
// Schedule IO on each handle until all are either pending, or completed.
943-
for (size_t i = 0; i < m_handles.size() && !m_cancel; i++)
986+
// Cancel any pending callback.
987+
callbacks.clear();
988+
989+
Entry* signaledEntry = nullptr;
990+
while (m_signaledHandles.try_pop(signaledEntry))
991+
{
992+
try
993+
{
994+
signaledEntry->Handle->Collect();
995+
}
996+
catch (...)
997+
{
998+
if (WI_IsFlagSet(signaledEntry->HandleFlags, Flags::IgnoreErrors))
999+
{
1000+
signaledEntry->Handle.reset();
1001+
continue;
1002+
}
1003+
1004+
throw;
1005+
}
1006+
}
1007+
1008+
m_handleSignaledEvent.ResetEvent();
1009+
1010+
bool hasHandleToWaitFor = false;
1011+
for (auto it = m_handles.begin(); it != m_handles.end();)
9441012
{
945-
while (m_handles[i].second->GetState() == IOHandleStatus::Standby && !m_cancel)
1013+
auto& entry = **it;
1014+
1015+
while (entry.Handle && entry.Handle->GetState() == IOHandleStatus::Standby && !m_cancel)
9461016
{
9471017
try
9481018
{
949-
m_handles[i].second->Schedule();
1019+
entry.Handle->Schedule();
9501020
}
9511021
catch (...)
9521022
{
953-
if (WI_IsFlagSet(m_handles[i].first, Flags::IgnoreErrors))
1023+
if (WI_IsFlagSet(entry.HandleFlags, Flags::IgnoreErrors))
9541024
{
955-
m_handles[i].second.reset(); // Reset the handle so it can be deleted.
1025+
entry.Handle.reset();
9561026
break;
9571027
}
958-
else
959-
{
960-
throw;
961-
}
1028+
1029+
throw;
9621030
}
9631031
}
964-
}
9651032

966-
// Remove completed handles from m_handles.
967-
bool hasHandleToWaitFor = false;
968-
for (auto it = m_handles.begin(); it != m_handles.end();)
969-
{
970-
if (!it->second)
971-
{
972-
it = m_handles.erase(it);
973-
}
974-
else if (it->second->GetState() == IOHandleStatus::Completed)
1033+
if (!entry.Handle || entry.Handle->GetState() == IOHandleStatus::Completed)
9751034
{
976-
if (WI_IsFlagSet(it->first, Flags::CancelOnCompleted))
1035+
if (entry.Handle && WI_IsFlagSet(entry.HandleFlags, Flags::CancelOnCompleted))
9771036
{
978-
m_cancel = true; // Cancel the IO if a handle with CancelOnCompleted is in the completed state.
1037+
m_cancel = true;
9791038
}
9801039

9811040
it = m_handles.erase(it);
1041+
continue;
9821042
}
983-
else
1043+
1044+
auto& callback = callbacks.emplace_back();
1045+
1046+
THROW_IF_WIN32_BOOL_FALSE(RegisterWaitForSingleObject(
1047+
&callback, entry.Handle->GetHandle(), &WaitCallback, &entry, INFINITE, WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE));
1048+
1049+
if (WI_IsFlagClear(entry.HandleFlags, Flags::NeedNotComplete))
9841050
{
985-
// If only NeedNotComplete handles are left, we want to exit Run.
986-
if (WI_IsFlagClear(it->first, Flags::NeedNotComplete))
987-
{
988-
hasHandleToWaitFor = true;
989-
}
990-
++it;
1051+
hasHandleToWaitFor = true;
9911052
}
992-
}
9931053

994-
if (!hasHandleToWaitFor || m_cancel)
995-
{
996-
break;
1054+
++it;
9971055
}
9981056

999-
// Wait for the next operation to complete.
1000-
std::vector<HANDLE> waitHandles;
1001-
for (const auto& e : m_handles)
1057+
if (m_handles.empty() || !hasHandleToWaitFor || m_cancel)
10021058
{
1003-
waitHandles.emplace_back(e.second->GetHandle());
1059+
break;
10041060
}
10051061

10061062
DWORD waitTimeout = INFINITE;
10071063
if (deadline.has_value())
10081064
{
1009-
auto miliseconds =
1065+
auto milliseconds =
10101066
std::chrono::duration_cast<std::chrono::milliseconds>(deadline.value() - std::chrono::steady_clock::now()).count();
10111067

1012-
waitTimeout = static_cast<DWORD>(std::max(0LL, miliseconds));
1068+
waitTimeout = static_cast<DWORD>(std::max<long long>(0, milliseconds));
10131069
}
10141070

1015-
auto result = WaitForMultipleObjects(static_cast<DWORD>(waitHandles.size()), waitHandles.data(), false, waitTimeout);
1016-
if (result == WAIT_TIMEOUT)
1017-
{
1018-
THROW_WIN32(ERROR_TIMEOUT);
1019-
}
1020-
else if (result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + m_handles.size())
1021-
{
1022-
auto index = result - WAIT_OBJECT_0;
1023-
1024-
try
1025-
{
1026-
m_handles[index].second->Collect();
1027-
}
1028-
catch (...)
1029-
{
1030-
if (WI_IsFlagSet(m_handles[index].first, Flags::IgnoreErrors))
1031-
{
1032-
m_handles.erase(m_handles.begin() + index);
1033-
}
1034-
else
1035-
{
1036-
throw;
1037-
}
1038-
}
1039-
}
1040-
else
1041-
{
1042-
THROW_LAST_ERROR_MSG("Timeout: %lu, Count: %llu", waitTimeout, waitHandles.size());
1043-
}
1071+
THROW_HR_IF_MSG(
1072+
HRESULT_FROM_WIN32(ERROR_TIMEOUT),
1073+
!m_handleSignaledEvent.wait(waitTimeout),
1074+
"Timed out waiting for %llu handles. Timeout: %lu",
1075+
m_handles.size(),
1076+
waitTimeout);
10441077
}
10451078

10461079
return !m_cancel;

src/windows/common/HandleIO.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#pragma once
44

5+
#include <concurrent_queue.h>
6+
57
#define LX_RELAY_BUFFER_SIZE 0x1000
68

79
namespace wsl::windows::common::io {
@@ -349,12 +351,10 @@ class DockerIORelayHandle : public OverlappedIOHandle
349351
WriteHandle* ActiveHandle = nullptr;
350352
size_t RemainingBytes = 0;
351353
};
352-
353354
class MultiHandleWait
354355
{
355356
public:
356357
NON_COPYABLE(MultiHandleWait);
357-
DEFAULT_MOVABLE(MultiHandleWait);
358358

359359
enum Flags
360360
{
@@ -365,13 +365,27 @@ class MultiHandleWait
365365
};
366366

367367
MultiHandleWait() = default;
368+
MultiHandleWait(MultiHandleWait&&) noexcept;
369+
MultiHandleWait& operator=(MultiHandleWait&&) noexcept;
368370

369371
void AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle, Flags flags = Flags::None);
370372
bool Run(std::optional<std::chrono::milliseconds> Timeout);
371373
void Cancel();
372374

373375
private:
374-
std::vector<std::pair<Flags, std::unique_ptr<OverlappedIOHandle>>> m_handles;
376+
struct Entry
377+
{
378+
Flags HandleFlags{};
379+
std::unique_ptr<OverlappedIOHandle> Handle;
380+
MultiHandleWait* self;
381+
};
382+
383+
static void NTAPI WaitCallback(PVOID Context, BOOLEAN TimerOrWaitFired);
384+
385+
concurrency::concurrent_queue<Entry*> m_signaledHandles;
386+
wil::unique_event m_handleSignaledEvent{wil::EventOptions::ManualReset};
387+
388+
std::vector<std::unique_ptr<Entry>> m_handles;
375389
bool m_cancel = false;
376390
};
377391

test/windows/UnitTests.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6946,6 +6946,94 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
69466946
}
69476947
}
69486948

6949+
TEST_METHOD(MultiHandleWaitAboveMaximumWaitObjects)
6950+
{
6951+
// Validate that MultiHandleWait can wait on more than MAXIMUM_WAIT_OBJECTS (64) handles.
6952+
constexpr size_t handleCount = 100;
6953+
static_assert(handleCount > MAXIMUM_WAIT_OBJECTS);
6954+
6955+
// Scenario 1: signal every event before Run(); all callbacks must fire and Run() must return.
6956+
{
6957+
std::vector<wil::unique_event> events;
6958+
events.reserve(handleCount);
6959+
for (size_t i = 0; i < handleCount; ++i)
6960+
{
6961+
events.emplace_back(wil::EventOptions::ManualReset);
6962+
}
6963+
6964+
std::vector<bool> fired(handleCount, false);
6965+
std::atomic<size_t> firedCount{0};
6966+
std::mutex firedLock;
6967+
6968+
wsl::windows::common::io::MultiHandleWait io;
6969+
for (size_t i = 0; i < handleCount; ++i)
6970+
{
6971+
io.AddHandle(std::make_unique<wsl::windows::common::io::EventHandle>(
6972+
wsl::windows::common::io::HandleWrapper{events[i].get()}, [&fired, &firedCount, &firedLock, i]() {
6973+
std::lock_guard lock{firedLock};
6974+
VERIFY_IS_FALSE(fired[i]);
6975+
fired[i] = true;
6976+
firedCount.fetch_add(1);
6977+
}));
6978+
}
6979+
6980+
for (auto& e : events)
6981+
{
6982+
e.SetEvent();
6983+
}
6984+
6985+
VERIFY_IS_TRUE(io.Run(std::chrono::seconds(60)));
6986+
VERIFY_ARE_EQUAL(firedCount.load(), handleCount);
6987+
for (size_t i = 0; i < handleCount; ++i)
6988+
{
6989+
VERIFY_IS_TRUE(fired[i]);
6990+
}
6991+
}
6992+
6993+
// Scenario 2: signal events one at a time from another thread while Run() processes them.
6994+
{
6995+
std::vector<wil::unique_event> events;
6996+
events.reserve(handleCount);
6997+
for (size_t i = 0; i < handleCount; ++i)
6998+
{
6999+
events.emplace_back(wil::EventOptions::ManualReset);
7000+
}
7001+
7002+
std::vector<bool> fired(handleCount, false);
7003+
std::atomic<size_t> firedCount{0};
7004+
std::mutex firedLock;
7005+
7006+
wsl::windows::common::io::MultiHandleWait io;
7007+
for (size_t i = 0; i < handleCount; ++i)
7008+
{
7009+
io.AddHandle(std::make_unique<wsl::windows::common::io::EventHandle>(
7010+
wsl::windows::common::io::HandleWrapper{events[i].get()}, [&fired, &firedCount, &firedLock, i]() {
7011+
std::lock_guard lock{firedLock};
7012+
VERIFY_IS_FALSE(fired[i]);
7013+
fired[i] = true;
7014+
firedCount.fetch_add(1);
7015+
}));
7016+
}
7017+
7018+
std::thread signaller([&events]() {
7019+
for (auto& e : events)
7020+
{
7021+
e.SetEvent();
7022+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
7023+
}
7024+
});
7025+
7026+
VERIFY_IS_TRUE(io.Run(std::chrono::seconds(60)));
7027+
signaller.join();
7028+
7029+
VERIFY_ARE_EQUAL(firedCount.load(), handleCount);
7030+
for (size_t i = 0; i < handleCount; ++i)
7031+
{
7032+
VERIFY_IS_TRUE(fired[i]);
7033+
}
7034+
}
7035+
}
7036+
69497037
TEST_METHOD(SocketChannel)
69507038
{
69517039
// Read exactly `size` bytes from a raw socket into the destination buffer.

0 commit comments

Comments
 (0)