Skip to content

Commit eb44b60

Browse files
committed
Apply PR feedback
1 parent 0a6776f commit eb44b60

4 files changed

Lines changed: 37 additions & 36 deletions

File tree

src/shared/inc/SocketChannel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class SocketChannel
110110

111111
#ifdef WIN32
112112
m_exitEvents = std::move(other.m_exitEvents);
113+
m_pendingBytes = std::move(other.m_pendingBytes);
113114
#endif
114115
m_ignore_sequence = other.m_ignore_sequence;
115116
m_sent_non_transaction_messages = other.m_sent_non_transaction_messages;
@@ -722,7 +723,7 @@ class SocketChannel
722723
#ifdef WIN32
723724

724725
std::vector<HANDLE> m_exitEvents;
725-
std::optional<std::vector<gsl::byte>> m_pendingBytes;
726+
std::vector<gsl::byte> m_pendingBytes;
726727

727728
#endif
728729
uint32_t m_sent_non_transaction_messages = 0;

src/windows/common/HandleIO.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ DWORD CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
3838
{
3939
if constexpr (std::is_same_v<decltype(Handle), SOCKET>)
4040
{
41-
if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, nullptr))
41+
DWORD flagsReturned{};
42+
if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, &flagsReturned))
4243
{
4344
auto error = WSAGetLastError();
4445
LOG_LAST_ERROR_IF(error != WSAECONNABORTED && error != WSA_OPERATION_ABORTED && error != WSAECONNRESET);
@@ -531,7 +532,7 @@ void HTTPChunkBasedReadHandle::OnRead(const gsl::span<char>& Input)
531532
ReadSocketMessageHandle::ReadSocketMessageHandle(
532533
HandleWrapper&& MovedSocket,
533534
std::vector<gsl::byte>& Buffer,
534-
std::optional<std::vector<gsl::byte>>& PendingBytes,
535+
std::vector<gsl::byte>& PendingBytes,
535536
std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage) :
536537
Socket(std::move(MovedSocket)), Buffer(Buffer), PendingBytes(PendingBytes), OnMessage(std::move(OnMessage))
537538
{
@@ -542,20 +543,20 @@ ReadSocketMessageHandle::ReadSocketMessageHandle(
542543
Buffer.resize(sizeof(MESSAGE_HEADER));
543544
}
544545

545-
if (!PendingBytes.has_value())
546+
if (PendingBytes.empty())
546547
{
547548
return;
548549
}
549550

550551
// If bytes from a previously cancelled transaction are passed, process them now.
551-
if (Buffer.size() < PendingBytes->size())
552+
if (Buffer.size() < PendingBytes.size())
552553
{
553-
Buffer.resize(PendingBytes->size());
554+
Buffer.resize(PendingBytes.size());
554555
}
555556

556-
std::copy(PendingBytes->begin(), PendingBytes->end(), Buffer.begin());
557-
558-
CurrentOffset = PendingBytes->size();
557+
std::copy(PendingBytes.begin(), PendingBytes.end(), Buffer.begin());
558+
CurrentOffset = PendingBytes.size();
559+
PendingBytes.clear();
559560

560561
if (CurrentOffset < sizeof(MESSAGE_HEADER))
561562
{
@@ -569,24 +570,24 @@ ReadSocketMessageHandle::ReadSocketMessageHandle(
569570

570571
ReadSocketMessageHandle::~ReadSocketMessageHandle()
571572
{
572-
if (State == IOHandleStatus::Completed)
573-
{
574-
PendingBytes.reset();
575-
}
576-
else if (State == IOHandleStatus::Pending)
573+
if (State != IOHandleStatus::Completed)
577574
{
578-
// Cancel the pending receive and move any bytes already buffered for the in-flight message into PendingBytes
579-
const auto socket = reinterpret_cast<SOCKET>(Socket.Get());
580-
auto receivedBytes = CancelPendingIo(socket, Overlapped);
575+
auto pendingSize = CurrentOffset;
576+
577+
if (State == IOHandleStatus::Pending)
578+
{
579+
// Cancel the pending receive and move any bytes already buffered for the in-flight message into PendingBytes
580+
const auto socket = reinterpret_cast<SOCKET>(Socket.Get());
581+
pendingSize += CancelPendingIo(socket, Overlapped);
582+
}
581583

582-
const auto totalBytes = CurrentOffset + receivedBytes;
583-
if (totalBytes > 0)
584+
if (pendingSize > 0)
584585
{
585-
WI_ASSERT(totalBytes <= Buffer.size());
586-
Buffer.resize(totalBytes);
587-
PendingBytes.emplace(std::move(Buffer));
586+
WI_ASSERT(pendingSize <= Buffer.size());
587+
PendingBytes = {Buffer.begin(), Buffer.begin() + pendingSize};
588588

589-
WSL_LOG("CanceledMessageRead", TraceLoggingValue(totalBytes, "TotalBytes"), TraceLoggingValue(socket, "Socket"));
589+
WSL_LOG(
590+
"CanceledMessageRead", TraceLoggingValue(pendingSize, "TotalBytes"), TraceLoggingValue(Socket.Get(), "Socket"));
590591
}
591592
}
592593
}
@@ -663,7 +664,10 @@ bool ReadSocketMessageHandle::ProcessChunk()
663664
}
664665

665666
ReadingHeader = false;
666-
BytesRemaining = messageSize - CurrentOffset;
667+
if (CurrentOffset < messageSize)
668+
{
669+
BytesRemaining = messageSize - CurrentOffset;
670+
}
667671

668672
if (BytesRemaining > 0)
669673
{
@@ -680,14 +684,10 @@ void ReadSocketMessageHandle::Schedule()
680684
{
681685
WI_ASSERT(State == IOHandleStatus::Standby);
682686

683-
// If a previous receive on this socket was aborted while bytes were already in our
684-
// application buffer (see destructor), drain those chunks before issuing a new WSARecv.
685-
while (State == IOHandleStatus::Standby && BytesRemaining == 0)
687+
// Process previously received bytes, if any.
688+
if (BytesRemaining == 0 && !ProcessChunk())
686689
{
687-
if (!ProcessChunk())
688-
{
689-
return;
690-
}
690+
return; // Message has been fully received, not need to schedule a receive.
691691
}
692692

693693
ScheduleRecv();

src/windows/common/HandleIO.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class ReadSocketMessageHandle : public OverlappedIOHandle
181181
ReadSocketMessageHandle(
182182
HandleWrapper&& Socket,
183183
std::vector<gsl::byte>& Buffer,
184-
std::optional<std::vector<gsl::byte>>& PendingBytes,
184+
std::vector<gsl::byte>& PendingBytes,
185185
std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage);
186186
~ReadSocketMessageHandle();
187187

@@ -196,7 +196,7 @@ class ReadSocketMessageHandle : public OverlappedIOHandle
196196

197197
HandleWrapper Socket;
198198
std::vector<gsl::byte>& Buffer;
199-
std::optional<std::vector<gsl::byte>>& PendingBytes;
199+
std::vector<gsl::byte>& PendingBytes;
200200
std::function<void(const gsl::span<gsl::byte>& Message)> OnMessage;
201201
wil::unique_event Event{wil::EventOptions::ManualReset};
202202
OVERLAPPED Overlapped{};

test/windows/UnitTests.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6842,7 +6842,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
68426842
// Drive a ReadSocketMessageHandle until completion and return the bytes delivered to its
68436843
// OnMessage callback. If a non-success HRESULT is supplied, the call is expected to throw
68446844
// that HRESULT instead, and the OnMessage callback must not be invoked.
6845-
auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK, std::optional<std::vector<gsl::byte>> pendingBytes = {}) {
6845+
auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK, std::vector<gsl::byte> pendingBytes = {}) {
68466846
std::vector<gsl::byte> buffer;
68476847
bool callbackInvoked = false;
68486848
std::vector<gsl::byte> message;
@@ -7049,7 +7049,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
70497049
}
70507050

70517051
// Scenario 10: PendingBytes contains an invalid (too-small) message size. The
7052-
// constructor should detect this and throw E_UNEXPECTED without invoking OnMessage.
7052+
// IO should detect this and throw E_UNEXPECTED without invoking OnMessage.
70537053
{
70547054
auto [client, server] = MakeSocketPair();
70557055
client.reset();
@@ -7061,7 +7061,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
70617061
header.TransactionStep = 1;
70627062

70637063
const auto* headerBytes = reinterpret_cast<const gsl::byte*>(&header);
7064-
std::optional<std::vector<gsl::byte>> pendingBytes{std::vector<gsl::byte>{headerBytes, headerBytes + sizeof(header)}};
7064+
std::vector<gsl::byte> pendingBytes{headerBytes, headerBytes + sizeof(header)};
70657065

70667066
std::vector<gsl::byte> buffer;
70677067
bool callbackInvoked = false;

0 commit comments

Comments
 (0)