diff --git a/src/windows/common/Dmesg.cpp b/src/windows/common/Dmesg.cpp index ca4bd67c6..3039ed73d 100644 --- a/src/windows/common/Dmesg.cpp +++ b/src/windows/common/Dmesg.cpp @@ -15,6 +15,8 @@ Module Name: #include "precomp.h" #include "Dmesg.h" +namespace io = wsl::windows::common::io; + DmesgCollector::DmesgCollector( GUID VmId, const wil::unique_event& ExitEvent, bool EnableTelemetry, bool EnableDebugConsole, const std::wstring& Com1PipeName, wil::unique_handle&& OutputHandle) : m_com1PipeName(Com1PipeName), @@ -80,25 +82,27 @@ std::pair DmesgCollector::StartDmesgThread(InputSourc // When the pipe connects, start reading data. wsl::windows::common::helpers::ConnectPipe(Pipe.get(), INFINITE, m_exitEvents); - std::vector buffer(LX_RELAY_BUFFER_SIZE); - const auto allBuffer = gsl::make_span(buffer); - OVERLAPPED overlapped = {}; - const wil::unique_event overlappedEvent(wil::EventOptions::ManualReset); - overlapped.hEvent = overlappedEvent.get(); - for (;;) - { - overlappedEvent.ResetEvent(); - const auto bytesRead = wsl::windows::common::relay::InterruptableRead( - Pipe.get(), gslhelpers::convert_span(allBuffer), m_exitEvents, &overlapped); - - if (bytesRead == 0) - { - break; - } - - auto validBuffer = allBuffer.subspan(0, bytesRead); - ProcessInput(Source, validBuffer); - } + io::MultiHandleWait ioWait; + ioWait.AddHandle( + std::make_unique( + io::HandleWrapper{Pipe.get()}, + [this, Source](const gsl::span& buffer) { + if (!buffer.empty()) + { + ProcessInput(Source, buffer); + } + }), + io::MultiHandleWait::IgnoreErrors); + + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{m_exitEvents[0]}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); + + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{m_exitEvents[1]}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); + + ioWait.Run(std::nullopt); } catch (...) { @@ -161,8 +165,7 @@ void DmesgCollector::ProcessInput(InputSource Source, const gsl::span& Inp if (m_outputHandle != nullptr) { m_overlappedEvent.ResetEvent(); - if (wsl::windows::common::relay::InterruptableWrite( - m_outputHandle.get(), gslhelpers::convert_span(Input), m_exitEvents, &m_overlapped) == 0) + if (io::InterruptableWrite(m_outputHandle.get(), gslhelpers::convert_span(Input), m_exitEvents, &m_overlapped) == 0) { m_outputHandle = nullptr; } @@ -192,7 +195,7 @@ void DmesgCollector::WriteToCom1(const gsl::span& Input) m_overlappedEvent.ResetEvent(); const auto buffer = gslhelpers::convert_span(Input); - if (wsl::windows::common::relay::InterruptableWrite(m_com1Pipe.get(), buffer, m_exitEvents, &m_overlapped) == 0) + if (io::InterruptableWrite(m_com1Pipe.get(), buffer, m_exitEvents, &m_overlapped) == 0) { if (m_debugConsole || !m_pipeServer) { diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 664dd64e6..325278226 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -209,8 +209,8 @@ HANDLE EventHandle::GetHandle() const // ReadHandle -ReadHandle::ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead) : - Handle(std::move(MovedHandle)), OnRead(OnRead), Offset(InitializeFileOffset(Handle.Get())) +ReadHandle::ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead, size_t BufferSize) : + Handle(std::move(MovedHandle)), OnRead(std::move(OnRead)), Buffer(BufferSize), Offset(InitializeFileOffset(Handle.Get())) { Overlapped.hEvent = Event.get(); } @@ -1133,3 +1133,114 @@ bool MultiHandleWait::Run(std::optional Timeout) return !m_cancel; } + +DWORD wsl::windows::common::io::InterruptableRead( + _In_ HANDLE InputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_opt_ LPOVERLAPPED Overlapped) +{ + // Initialize an overlapped structure if one was not provided by the caller. + OVERLAPPED overlapped = {}; + wil::unique_event overlappedEvent = {}; + if (!ARGUMENT_PRESENT(Overlapped)) + { + overlappedEvent.create(wil::EventOptions::ManualReset); + overlapped.hEvent = overlappedEvent.get(); + Overlapped = &overlapped; + } + + DWORD bytesRead = 0; + if (!ReadFile(InputHandle, Buffer.data(), gsl::narrow_cast(Buffer.size()), &bytesRead, Overlapped)) + { + auto lastError = GetLastError(); + if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) + { + return 0; + } + + THROW_LAST_ERROR_IF_MSG(lastError != ERROR_IO_PENDING, "Handle: 0x%p", (void*)InputHandle); + + auto cancelRead = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + CancelIoEx(InputHandle, Overlapped); + GetOverlappedResult(InputHandle, Overlapped, &bytesRead, TRUE); + }); + + // Wait for the read to complete, or the client to exit. + if (!InterruptableWait(Overlapped->hEvent, ExitHandles)) + { + return 0; + } + + if (!GetOverlappedResult(InputHandle, Overlapped, &bytesRead, FALSE)) + { + lastError = GetLastError(); + if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) + { + return 0; + } + + THROW_LAST_ERROR(); + } + + cancelRead.release(); + } + + return bytesRead; +} + +bool wsl::windows::common::io::InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector& ExitHandles) +{ + // Wait for the object to become signaled or one of the exit handles to be signaled. + std::vector waitObjects{WaitObject}; + for (const auto& exitHandle : ExitHandles) + { + waitObjects.push_back(exitHandle); + } + + const DWORD waitResult = WaitForMultipleObjects(gsl::narrow_cast(waitObjects.size()), waitObjects.data(), FALSE, INFINITE); + if (waitResult != WAIT_OBJECT_0) + { + if (waitResult > WAIT_OBJECT_0 && waitResult < WAIT_OBJECT_0 + waitObjects.size()) + { + return false; + } + + THROW_HR_MSG(E_FAIL, "WaitForMultipleObjects %d", waitResult); + } + + return true; +} + +DWORD wsl::windows::common::io::InterruptableWrite( + _In_ HANDLE OutputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_ LPOVERLAPPED Overlapped) +{ + const DWORD bytesToWrite = gsl::narrow_cast(Buffer.size()); + DWORD bytesWritten = 0; + BOOL success = WriteFile(OutputHandle, Buffer.data(), bytesToWrite, &bytesWritten, Overlapped); + if (!success) + { + const auto lastError = GetLastError(); + if (lastError == ERROR_NO_DATA) + { + return 0; + } + + THROW_LAST_ERROR_IF(lastError != ERROR_IO_PENDING); + + auto cancelWrite = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + CancelIoEx(OutputHandle, Overlapped); + GetOverlappedResult(OutputHandle, Overlapped, &bytesWritten, TRUE); + }); + + if (InterruptableWait(Overlapped->hEvent, ExitHandles)) + { + success = GetOverlappedResult(OutputHandle, Overlapped, &bytesWritten, FALSE); + if (success) + { + cancelWrite.release(); + } + } + } + + WI_ASSERT(!success || (bytesWritten == bytesToWrite)); + + return bytesWritten; +} diff --git a/src/windows/common/HandleIO.h b/src/windows/common/HandleIO.h index c197d21ee..7fedd1f4b 100644 --- a/src/windows/common/HandleIO.h +++ b/src/windows/common/HandleIO.h @@ -99,7 +99,7 @@ class ReadHandle : public OverlappedIOHandle NON_COPYABLE(ReadHandle); NON_MOVABLE(ReadHandle); - ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead); + ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead, size_t BufferSize = LX_RELAY_BUFFER_SIZE); virtual ~ReadHandle(); void Schedule() override; @@ -111,7 +111,7 @@ class ReadHandle : public OverlappedIOHandle std::function& Buffer)> OnRead; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; - BufferWrapper Buffer{LX_RELAY_BUFFER_SIZE}; + BufferWrapper Buffer; LARGE_INTEGER Offset{}; }; @@ -234,8 +234,16 @@ class RelayHandle : public OverlappedIOHandle NON_COPYABLE(RelayHandle); NON_MOVABLE(RelayHandle); - RelayHandle(HandleWrapper&& Input, HandleWrapper&& Output) : - Read(std::move(Input), [this](const gsl::span& Buffer) { return OnRead(Buffer); }), Write(std::move(Output)) + RelayHandle(HandleWrapper&& Input, HandleWrapper&& Output, size_t BufferSize = LX_RELAY_BUFFER_SIZE) + requires std::is_same_v + : + Read(std::move(Input), [this](const gsl::span& Buffer) { return OnRead(Buffer); }, BufferSize), Write(std::move(Output)) + { + } + + RelayHandle(HandleWrapper&& Input, HandleWrapper&& Output) + requires(!std::is_same_v) + : Read(std::move(Input), [this](const gsl::span& Buffer) { return OnRead(Buffer); }), Write(std::move(Output)) { } @@ -397,4 +405,12 @@ class MultiHandleWait DEFINE_ENUM_FLAG_OPERATORS(MultiHandleWait::Flags); +// Standalone cancellable IO operations. These perform a single overlapped read/write +// and can be cancelled via exit event handles. +DWORD InterruptableRead(_In_ HANDLE InputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_opt_ LPOVERLAPPED Overlapped = nullptr); + +DWORD InterruptableWrite(_In_ HANDLE OutputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_ LPOVERLAPPED Overlapped); + +bool InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector& ExitHandles = {}); + } // namespace wsl::windows::common::io diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 9801edccf..a88e1b016 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -18,395 +18,73 @@ Module Name: using wsl::windows::common::relay::ScopedMultiRelay; using wsl::windows::common::relay::ScopedRelay; +namespace io = wsl::windows::common::io; -namespace { - -LARGE_INTEGER InitializeFileOffset(HANDLE File) +void wsl::windows::common::relay::InterruptableRelay(_In_ HANDLE InputHandle, _In_opt_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) { - LARGE_INTEGER Offset{}; - if (GetFileType(File) == FILE_TYPE_DISK) - { - LOG_IF_WIN32_BOOL_FALSE(SetFilePointerEx(File, {}, &Offset, FILE_CURRENT)); - } + io::MultiHandleWait ioWait; - return Offset; -} - -void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped) -{ - DWORD bytesTransferred{}; - if (CancelIoEx((HANDLE)Handle, &Overlapped)) + if (OutputHandle) { - if constexpr (std::is_same_v) - { - if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, nullptr)) - { - auto error = WSAGetLastError(); - LOG_LAST_ERROR_IF(error != WSAECONNABORTED && error != WSA_OPERATION_ABORTED && error != WSAECONNRESET); - } - } - else - { - static_assert(std::is_same_v); - if (!GetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true)) - { - auto error = GetLastError(); - LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED); - } - } + ioWait.AddHandle( + std::make_unique>(io::HandleWrapper{InputHandle}, io::HandleWrapper{OutputHandle}, BufferSize), + io::MultiHandleWait::IgnoreErrors | io::MultiHandleWait::CancelOnCompleted); } else { - // ERROR_NOT_FOUND is returned if there was no IO to cancel. - LOG_LAST_ERROR_IF(GetLastError() != ERROR_NOT_FOUND); - } -} - -} // namespace - -std::thread wsl::windows::common::relay::CreateThread(_In_ HANDLE InputHandle, _In_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) -{ - return std::thread([InputHandle, OutputHandle, ExitHandle, BufferSize]() { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"IO Relay"); - InterruptableRelay(InputHandle, OutputHandle, ExitHandle, BufferSize); - } - CATCH_LOG() - }); -} - -std::thread wsl::windows::common::relay::CreateThread( - _In_ wil::unique_handle&& InputHandle, _In_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) -{ - return std::thread([InputHandle = std::move(InputHandle), OutputHandle, ExitHandle, BufferSize]() { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"IO Relay"); - InterruptableRelay(InputHandle.get(), OutputHandle, ExitHandle, BufferSize); - } - CATCH_LOG() - }); -} - -std::thread wsl::windows::common::relay::CreateThread( - _In_ HANDLE InputHandle, _In_ wil::unique_handle&& OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) -{ - return std::thread([InputHandle, OutputHandle = std::move(OutputHandle), ExitHandle, BufferSize]() { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"IO Relay"); - InterruptableRelay(InputHandle, OutputHandle.get(), ExitHandle, BufferSize); - } - CATCH_LOG() - }); -} - -std::thread wsl::windows::common::relay::CreateThread( - _In_ wil::unique_handle&& InputHandle, _In_ wil::unique_handle&& OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) -{ - return std::thread([InputHandle = std::move(InputHandle), OutputHandle = std::move(OutputHandle), ExitHandle, BufferSize]() { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"IO Relay"); - InterruptableRelay(InputHandle.get(), OutputHandle.get(), ExitHandle, BufferSize); - } - CATCH_LOG() - }); -} - -DWORD -wsl::windows::common::relay::InterruptableRead( - _In_ HANDLE InputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_opt_ LPOVERLAPPED Overlapped) -{ - // Initialize an overlapped structure if one was not provided by the caller. - OVERLAPPED overlapped = {}; - wil::unique_event overlappedEvent = {}; - if (!ARGUMENT_PRESENT(Overlapped)) - { - overlappedEvent.create(wil::EventOptions::ManualReset); - overlapped.hEvent = overlappedEvent.get(); - Overlapped = &overlapped; - } - - DWORD bytesRead = 0; - if (!ReadFile(InputHandle, Buffer.data(), gsl::narrow_cast(Buffer.size()), &bytesRead, Overlapped)) - { - auto lastError = GetLastError(); - if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) - { - return 0; - } - - THROW_LAST_ERROR_IF_MSG(lastError != ERROR_IO_PENDING, "Handle: 0x%p", (void*)InputHandle); - - auto cancelRead = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { - CancelIoEx(InputHandle, Overlapped); - GetOverlappedResult(InputHandle, Overlapped, &bytesRead, TRUE); - }); - - // Wait for the read to complete, or the client to exit. - if (!InterruptableWait(Overlapped->hEvent, ExitHandles)) - { - return 0; - } - - if (!GetOverlappedResult(InputHandle, Overlapped, &bytesRead, FALSE)) - { - lastError = GetLastError(); - if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) - { - return 0; - } - - THROW_LAST_ERROR(); - } - - cancelRead.release(); + // No output handle — drain input until EOF without writing. + ioWait.AddHandle( + std::make_unique( + io::HandleWrapper{InputHandle}, [](const gsl::span&) {}, BufferSize), + io::MultiHandleWait::IgnoreErrors | io::MultiHandleWait::CancelOnCompleted); } - return bytesRead; -} - -void wsl::windows::common::relay::InterruptableRelay(_In_ HANDLE InputHandle, _In_opt_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) -{ - // If the handle file is seekable, make sure to respect the offset. - // This is useful in cases when WSL is invoked on an existing file, like: wsl.exe echo foo >> file - // See: https://github.com/microsoft/WSL/issues/11799 - - LARGE_INTEGER writeOffset = InitializeFileOffset(OutputHandle); - LARGE_INTEGER readOffset = InitializeFileOffset(InputHandle); - - std::vector buffer(BufferSize); - const auto readSpan = gsl::make_span(buffer); - - std::vector exitHandles; if (ExitHandle) { - exitHandles.push_back(ExitHandle); + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{ExitHandle}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); } - OVERLAPPED overlapped = {0}; - const wil::unique_event overlappedEvent(wil::EventOptions::ManualReset); - overlapped.hEvent = overlappedEvent.get(); - for (;;) - { - overlapped.Offset = readOffset.LowPart; - overlapped.OffsetHigh = readOffset.HighPart; - const auto bytesRead = InterruptableRead(InputHandle, readSpan, exitHandles, &overlapped); - if (bytesRead == 0) - { - break; - } - - readOffset.QuadPart += bytesRead; - - if (OutputHandle) - { - overlapped.Offset = writeOffset.LowPart; - overlapped.OffsetHigh = writeOffset.HighPart; - auto writeSpan = readSpan.first(bytesRead); - const auto bytesWritten = InterruptableWrite(OutputHandle, writeSpan, exitHandles, &overlapped); - if (bytesWritten == 0) - { - break; - } - - WI_ASSERT(bytesWritten == bytesRead); - } - - writeOffset.QuadPart += bytesRead; - } -} - -bool wsl::windows::common::relay::InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector& ExitHandles) -{ - // Wait for the object to become signaled or one of the exit handles to be signaled. - std::vector waitObjects{WaitObject}; - for (const auto& exitHandle : ExitHandles) - { - waitObjects.push_back(exitHandle); - } - - const DWORD waitResult = WaitForMultipleObjects(gsl::narrow_cast(waitObjects.size()), waitObjects.data(), FALSE, INFINITE); - if (waitResult != WAIT_OBJECT_0) - { - if (waitResult > WAIT_OBJECT_0 && waitResult < WAIT_OBJECT_0 + waitObjects.size()) - { - return false; - } - - THROW_HR_MSG(E_FAIL, "WaitForMultipleObjects %d", waitResult); - } - - return true; -} - -DWORD -wsl::windows::common::relay::InterruptableWrite( - _In_ HANDLE OutputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_ LPOVERLAPPED Overlapped) -{ - const DWORD bytesToWrite = gsl::narrow_cast(Buffer.size()); - DWORD bytesWritten = 0; - BOOL success = WriteFile(OutputHandle, Buffer.data(), bytesToWrite, &bytesWritten, Overlapped); - if (!success) - { - const auto lastError = GetLastError(); - if (lastError == ERROR_NO_DATA) - { - return 0; - } - - THROW_LAST_ERROR_IF(lastError != ERROR_IO_PENDING); - - auto cancelWrite = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { - CancelIoEx(OutputHandle, Overlapped); - GetOverlappedResult(OutputHandle, Overlapped, &bytesWritten, TRUE); - }); - - if (InterruptableWait(Overlapped->hEvent, ExitHandles)) - { - success = GetOverlappedResult(OutputHandle, Overlapped, &bytesWritten, FALSE); - if (success) - { - cancelWrite.release(); - } - } - } - - WI_ASSERT(!success || (bytesWritten == bytesToWrite)); - - return bytesWritten; + ioWait.Run(std::nullopt); } void wsl::windows::common::relay::BidirectionalRelay(_In_ HANDLE LeftHandle, _In_ HANDLE RightHandle, _In_ size_t BufferSize, _In_ RelayFlags Flags) { - std::vector leftBuffer(BufferSize); - const auto leftReadSpan = gsl::make_span(leftBuffer); - OVERLAPPED leftOverlapped = {0}; - const wil::unique_event leftOverlappedEvent(wil::EventOptions::None); - leftOverlapped.hEvent = leftOverlappedEvent.get(); - LARGE_INTEGER leftOffset{}; - - std::vector rightBuffer(BufferSize); - const auto rightReadSpan = gsl::make_span(rightBuffer); - OVERLAPPED rightOverlapped = {0}; - const wil::unique_event rightOverlappedEvent(wil::EventOptions::None); - rightOverlapped.hEvent = rightOverlappedEvent.get(); - LARGE_INTEGER rightOffset{}; - - bool leftReadPending = false; - bool rightReadPending = false; - auto cancelReads = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { - DWORD bytes; - if (leftReadPending) - { - CancelIoEx(LeftHandle, &leftOverlapped); - GetOverlappedResult(LeftHandle, &leftOverlapped, &bytes, TRUE); - } - - if (rightReadPending) - { - CancelIoEx(RightHandle, &rightOverlapped); - GetOverlappedResult(RightHandle, &rightOverlapped, &bytes, TRUE); - } - }); - - DWORD bytesWritten; - const HANDLE waitObjects[] = {leftOverlapped.hEvent, rightOverlapped.hEvent}; - for (;;) - { - if ((LeftHandle == nullptr) || (RightHandle == nullptr)) - { - break; - } - - DWORD leftBytesRead = 0; - if (!leftReadPending && LeftHandle) - { - if (!ReadFile(LeftHandle, leftReadSpan.data(), gsl::narrow_cast(leftReadSpan.size()), &leftBytesRead, &leftOverlapped)) - { - THROW_LAST_ERROR_IF(GetLastError() != ERROR_IO_PENDING); - } - - leftReadPending = true; - } - - DWORD rightBytesRead = 0; - if (!rightReadPending && RightHandle) - { - if (!ReadFile(RightHandle, rightReadSpan.data(), gsl::narrow_cast(rightReadSpan.size()), &rightBytesRead, &rightOverlapped)) - { - THROW_LAST_ERROR_IF(GetLastError() != ERROR_IO_PENDING); - } - - rightReadPending = true; - } - - const DWORD waitResult = WaitForMultipleObjects(RTL_NUMBER_OF(waitObjects), waitObjects, FALSE, INFINITE); - if (waitResult == WAIT_OBJECT_0) - { - LOG_LAST_ERROR_IF_MSG( - !GetOverlappedResult(LeftHandle, &leftOverlapped, &leftBytesRead, FALSE), "WSAGetLastError %d", WSAGetLastError()); - - leftReadPending = false; - if (leftBytesRead == 0) - { - LeftHandle = nullptr; - if (WI_IsFlagSet(Flags, RelayFlags::RightIsSocket)) - { - LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(RightHandle), SD_SEND) == SOCKET_ERROR); - } - } - else if (RightHandle != nullptr) - { - auto writeSpan = leftReadSpan.first(leftBytesRead); - bytesWritten = InterruptableWrite(RightHandle, writeSpan, {}, &leftOverlapped); - if (bytesWritten == 0) - { - break; - } - - leftOffset.QuadPart += leftBytesRead; - leftOverlapped.Offset = leftOffset.LowPart; - leftOverlapped.OffsetHigh = leftOffset.HighPart; - } - } - else if (waitResult == (WAIT_OBJECT_0 + 1)) - { - LOG_LAST_ERROR_IF_MSG( - !GetOverlappedResult(RightHandle, &rightOverlapped, &rightBytesRead, FALSE), "WSAGetLastError %d", WSAGetLastError()); - - rightReadPending = false; - if (rightBytesRead == 0) - { - RightHandle = nullptr; - if (WI_IsFlagSet(Flags, RelayFlags::LeftIsSocket)) - { - LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(LeftHandle), SD_SEND) == SOCKET_ERROR); - } - } - else if (LeftHandle != nullptr) - { - auto writeSpan = rightReadSpan.first(rightBytesRead); - bytesWritten = InterruptableWrite(LeftHandle, writeSpan, {}, &rightOverlapped); - if (bytesWritten == 0) - { - break; - } + io::MultiHandleWait ioWait; + + // Left-to-right relay. OnClose on the input handle fires shutdown(SD_SEND) on the peer when this direction completes. + ioWait.AddHandle( + std::make_unique>( + io::HandleWrapper{ + LeftHandle, + [&]() { + if (WI_IsFlagSet(Flags, RelayFlags::RightIsSocket)) + { + LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(RightHandle), SD_SEND) == SOCKET_ERROR); + } + }}, + io::HandleWrapper{RightHandle}, + BufferSize), + io::MultiHandleWait::IgnoreErrors); + + // Right-to-left relay. + ioWait.AddHandle( + std::make_unique>( + io::HandleWrapper{ + RightHandle, + [&]() { + if (WI_IsFlagSet(Flags, RelayFlags::LeftIsSocket)) + { + LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(LeftHandle), SD_SEND) == SOCKET_ERROR); + } + }}, + io::HandleWrapper{LeftHandle}, + BufferSize), + io::MultiHandleWait::IgnoreErrors); - rightOffset.QuadPart += rightBytesRead; - rightOverlapped.Offset = rightOffset.LowPart; - rightOverlapped.OffsetHigh = rightOffset.HighPart; - } - } - else - { - THROW_HR_MSG(E_FAIL, "WaitForMultipleObjects %d", waitResult); - } - } + ioWait.Run(std::nullopt); } - #define TTY_ALT_NUMPAD_VK_MENU (0x12) #define TTY_ESCAPE_CHARACTER (L'\x1b') #define TTY_INPUT_EVENT_BUFFER_SIZE (16) @@ -763,7 +441,7 @@ bool wsl::windows::common::relay::StandardInputRelay( WORD RepeatIndex; for (RepeatIndex = 0; RepeatIndex < InputRecordBuffer[0].Event.KeyEvent.wRepeatCount; RepeatIndex += 1) { - BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); + BytesWritten = io::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); if (BytesWritten == 0) { break; @@ -772,7 +450,7 @@ bool wsl::windows::common::relay::StandardInputRelay( } else if (Utf8StringSize > 0) { - BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); + BytesWritten = io::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); if (BytesWritten == 0) { break; @@ -815,11 +493,16 @@ void ScopedRelay::Run(_In_ HANDLE Input, _In_ HANDLE Output, size_t BufferSize) { wsl::windows::common::wslutil::SetThreadDescription(L"ScopedRelay"); - try - { - InterruptableRelay(Input, Output, m_exitEvent.get(), BufferSize); - } - CATCH_LOG(); + io::MultiHandleWait ioWait; + ioWait.AddHandle( + std::make_unique>(io::HandleWrapper{Input}, io::HandleWrapper{Output}, BufferSize), + io::MultiHandleWait::IgnoreErrors | io::MultiHandleWait::CancelOnCompleted); + + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{m_exitEvent.get()}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); + + ioWait.Run(std::nullopt); } ScopedMultiRelay::ScopedMultiRelay(const std::vector& Inputs, const TWriteMethod& Write, size_t BufferSize) @@ -846,135 +529,27 @@ ScopedMultiRelay::~ScopedMultiRelay() void ScopedMultiRelay::Run(const std::vector& Handles, const TWriteMethod& Write, size_t BufferSize) const try { - enum State - { - Standby, - Pending, - Eof - }; + io::MultiHandleWait ioWait; - struct Input + for (size_t i = 0; i < Handles.size(); i++) { - HANDLE Handle; - LARGE_INTEGER Offset; - std::vector Buffer; - wil::unique_event Event{wil::EventOptions::ManualReset}; - OVERLAPPED Overlapped; - State State = Standby; - - Input(Input&&) = default; - Input& operator=(Input&&) = default; - - Input(HANDLE Handle, LARGE_INTEGER Offset, size_t BufferSize) : Handle(Handle), Offset(Offset), Buffer(BufferSize) - { - Overlapped.hEvent = Event.get(); - } - - ~Input() - { - // Cancel outstanding IO, if any. - if (State == Pending) - { - CancelIoEx(Handle, &Overlapped); - DWORD bytesRead{}; - GetOverlappedResult(Handle, &Overlapped, &bytesRead, TRUE); - } - } - }; - - std::vector Inputs; - for (const auto& e : Handles) - { - Inputs.emplace_back(e, InitializeFileOffset(e), BufferSize); - } - - while (true) - { - // Exit if all inputs are completed, or if the exit event is set. - if (m_exitEvent.is_signaled() || std::all_of(Inputs.begin(), Inputs.end(), [](const auto& e) { return e.State == Eof; })) - { - return; - } - - for (size_t i = 0; i < Inputs.size(); i++) - { - auto& e = Inputs[i]; - - // If a read has been scheduled, check if IO is available. - if (e.State == Pending) - { - if (e.Event.is_signaled()) - { - DWORD Transferred{}; - if (!GetOverlappedResult(e.Handle, &e.Overlapped, &Transferred, TRUE)) - { - auto lastError = GetLastError(); - if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) - { - e.State = Eof; - continue; - } - - THROW_LAST_ERROR_IF(lastError != ERROR_IO_PENDING); - } - - // IO is available. - Write(i, gsl::make_span(e.Buffer.data(), Transferred)); - - // Update input state. - e.Offset.QuadPart += Transferred; - e.State = Standby; - } - } - - // If no read is pending, start one. - if (e.State == Standby) - { - e.Event.ResetEvent(); - - e.Overlapped.Offset = e.Offset.LowPart; - e.Overlapped.OffsetHigh = e.Offset.HighPart; - - DWORD BytesRead{}; - if (ReadFile(e.Handle, e.Buffer.data(), static_cast(e.Buffer.size()), &BytesRead, &e.Overlapped)) - { - // IO is available. - Write(i, gsl::make_span(e.Buffer.data(), BytesRead)); - - // Update input state. - e.Offset.QuadPart += BytesRead; - e.State = Standby; - } - else - { - auto lastError = GetLastError(); - if ((lastError == ERROR_HANDLE_EOF) || (lastError == ERROR_BROKEN_PIPE)) + ioWait.AddHandle( + std::make_unique( + io::HandleWrapper{Handles[i]}, + [i, &Write](const gsl::span& buffer) { + if (!buffer.empty()) { - e.State = Eof; - continue; + Write(i, gsl::make_span(reinterpret_cast(buffer.data()), buffer.size())); } + }, + BufferSize), + io::MultiHandleWait::IgnoreErrors); + } - THROW_LAST_ERROR_IF(lastError != ERROR_IO_PENDING); - e.State = Pending; - } - } - } + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{m_exitEvent.get()}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); - // Only wait if all non-completed inputs have a scheduled ReadFile to avoid a pipe hang. - if (std::all_of(Inputs.begin(), Inputs.end(), [](const auto& e) { return e.State == Eof || e.State == Pending; })) - { - // Wait until a handle is signaled. - std::vector waits{m_exitEvent.get()}; - for (const auto& e : Inputs) - { - if (e.State == Pending) - { - waits.emplace_back(e.Event.get()); - } - } - - THROW_LAST_ERROR_IF(WaitForMultipleObjects(static_cast(waits.size()), waits.data(), false, INFINITE) == WAIT_FAILED); - } - } + ioWait.Run(std::nullopt); } CATCH_LOG() diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index d723ce7d0..420bdc24c 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -22,27 +22,33 @@ namespace wsl::windows::common::relay { using namespace wsl::windows::common::io; -std::thread CreateThread(_In_ HANDLE InputHandle, _In_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle = nullptr, _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE); - -std::thread CreateThread(_In_ wil::unique_handle&& InputHandle, _In_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle = nullptr, _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE); - -std::thread CreateThread(_In_ HANDLE InputHandle, _In_ wil::unique_handle&& OutputHandle, _In_opt_ HANDLE ExitHandle = nullptr, _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE); - -std::thread CreateThread( - _In_ wil::unique_handle&& InputHandle, - _In_ wil::unique_handle&& OutputHandle, - _In_opt_ HANDLE ExitHandle = nullptr, - _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE); - -DWORD -InterruptableRead(_In_ HANDLE InputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_opt_ LPOVERLAPPED Overlapped = nullptr); - void InterruptableRelay(_In_ HANDLE InputHandle, _In_opt_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle = nullptr, _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE); -bool InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector& ExitHandles = {}); - -DWORD -InterruptableWrite(_In_ HANDLE OutputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_ LPOVERLAPPED Overlapped); +template +std::thread CreateThread(TInput&& InputHandle, TOutput&& OutputHandle, _In_opt_ HANDLE ExitHandle = nullptr, _In_ size_t BufferSize = LX_RELAY_BUFFER_SIZE) +{ + return std::thread( + [InputHandle = std::forward(InputHandle), OutputHandle = std::forward(OutputHandle), ExitHandle, BufferSize]() { + try + { + wsl::windows::common::wslutil::SetThreadDescription(L"IO Relay"); + + auto getHandle = [](const auto& h) -> HANDLE { + if constexpr (std::is_same_v, HANDLE>) + { + return h; + } + else + { + return h.get(); + } + }; + + InterruptableRelay(getHandle(InputHandle), getHandle(OutputHandle), ExitHandle, BufferSize); + } + CATCH_LOG() + }); +} bool StandardInputRelay( HANDLE ConsoleHandle, diff --git a/src/windows/service/exe/GuestTelemetryLogger.cpp b/src/windows/service/exe/GuestTelemetryLogger.cpp index 950eb54bd..256361680 100644 --- a/src/windows/service/exe/GuestTelemetryLogger.cpp +++ b/src/windows/service/exe/GuestTelemetryLogger.cpp @@ -61,23 +61,28 @@ void GuestTelemetryLogger::Start(const wil::unique_event& ExitEvent) const std::vector exitEvents = {m_threadExit.get(), ExitEvent.get()}; wsl::windows::common::helpers::ConnectPipe(Pipe.get(), INFINITE, exitEvents); - std::vector buffer(LX_RELAY_BUFFER_SIZE); - OVERLAPPED overlapped = {}; - const wil::unique_event overlappedEvent(wil::EventOptions::ManualReset); - overlapped.hEvent = overlappedEvent.get(); - for (;;) - { - overlappedEvent.ResetEvent(); - const auto bytesRead = - wsl::windows::common::relay::InterruptableRead(Pipe.get(), gsl::make_span(buffer), exitEvents, &overlapped); - - if (bytesRead == 0) - { - break; - } - - ProcessInput(std::string_view{reinterpret_cast(buffer.data()), bytesRead}); - } + namespace io = wsl::windows::common::io; + io::MultiHandleWait ioWait; + ioWait.AddHandle( + std::make_unique( + io::HandleWrapper{Pipe.get()}, + [this](const gsl::span& buffer) { + if (!buffer.empty()) + { + ProcessInput(std::string_view{buffer.data(), static_cast(buffer.size())}); + } + }), + io::MultiHandleWait::IgnoreErrors); + + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{m_threadExit.get()}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); + + ioWait.AddHandle( + std::make_unique(io::HandleWrapper{ExitEvent.get()}), + io::MultiHandleWait::CancelOnCompleted | io::MultiHandleWait::NeedNotComplete); + + ioWait.Run(std::nullopt); } CATCH_LOG() }); diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp index 831b9403d..46508af72 100644 --- a/src/windows/service/exe/LxssUserSession.cpp +++ b/src/windows/service/exe/LxssUserSession.cpp @@ -1653,7 +1653,7 @@ HRESULT LxssUserSessionImpl::RegisterDistribution( wil::unique_handle clientProcess = wsl::windows::common::wslutil::OpenCallingProcess(GENERIC_READ | SYNCHRONIZE); MESSAGE_HEADER header{}; const auto headerSpan = gslhelpers::struct_as_writeable_bytes(header); - auto bytesRead = wsl::windows::common::relay::InterruptableRead( + auto bytesRead = wsl::windows::common::io::InterruptableRead( output.first.get(), gslhelpers::struct_as_writeable_bytes(header), {clientProcess.get()}); THROW_HR_IF(WSL_E_IMPORT_FAILED, bytesRead != headerSpan.size() || header.MessageSize <= headerSpan.size() || header.MessageType != LxMiniInitMessageImportResult); @@ -1666,7 +1666,7 @@ HRESULT LxssUserSessionImpl::RegisterDistribution( while (offset < span.size()) { bytesRead = - wsl::windows::common::relay::InterruptableRead(output.first.get(), span.subspan(offset), {clientProcess.get()}); + wsl::windows::common::io::InterruptableRead(output.first.get(), span.subspan(offset), {clientProcess.get()}); if (bytesRead <= 0) { break; @@ -3172,7 +3172,7 @@ LONG LxssUserSessionImpl::_GetElfExitStatus(_In_ const LXSS_RUN_ELF_CONTEXT& Con { // Wait for the instance to terminate or the client process to exit. const wil::unique_handle clientProcess = wsl::windows::common::wslutil::OpenCallingProcess(GENERIC_READ | SYNCHRONIZE); - THROW_HR_IF(E_ABORT, !wsl::windows::common::relay::InterruptableWait(Context.instanceTerminatedEvent.get(), {clientProcess.get()})); + THROW_HR_IF(E_ABORT, !wsl::windows::common::io::InterruptableWait(Context.instanceTerminatedEvent.get(), {clientProcess.get()})); // Ensure that the process exited successfully. If the process encountered // an error, wait for the stderr worker thread and log the error message.