diff --git a/.pipelines/build-stage.yml b/.pipelines/build-stage.yml
index 0d0c30b175..61d4a6dbc7 100644
--- a/.pipelines/build-stage.yml
+++ b/.pipelines/build-stage.yml
@@ -32,8 +32,8 @@ parameters:
- name: targets
type: object
default:
- - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin;wslcsession;wslc;wsltests;wslcsdk"
- pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll,wslcsession.exe,wslc.exe,wslcsdk.dll"
+ - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslpluginhost;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin;wslcsession;wslc;wsltests;wslcsdk"
+ pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslpluginhost.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll,wslcsession.exe,wslc.exe,wslcsdk.dll"
- target: "msixgluepackage"
pattern: "gluepackage.msix"
- target: "msipackage;wslcsdkcs"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c88fbb8e90..d0b607f423 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -545,6 +545,7 @@ add_subdirectory(src/windows/wsl)
add_subdirectory(src/windows/wslg)
add_subdirectory(src/windows/wslhost)
add_subdirectory(src/windows/wslrelay)
+add_subdirectory(src/windows/wslpluginhost)
add_subdirectory(src/windows/wslinstall)
add_subdirectory(src/windows/wslc)
add_subdirectory(src/windows/WslcSDK)
diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt
index 99586c9727..ab025346e9 100644
--- a/msipackage/CMakeLists.txt
+++ b/msipackage/CMakeLists.txt
@@ -17,7 +17,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi)
set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in)
set(PACKAGE_WIX ${BIN}/package.wix)
set(CAB_CACHE ${BIN}/cab)
-set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll;wslc.exe;wslcsession.exe)
+set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslpluginhost.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll;wslc.exe;wslcsession.exe)
if (WSL_BUILD_WSL_SETTINGS)
list(APPEND WINDOWS_BINARIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll")
endif()
@@ -57,7 +57,7 @@ add_custom_command(
add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE})
set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN})
-add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslc wslcsession)
+add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslpluginhost wslserviceproxystub init initramfs wslinstall msixgluepackage wslc wslcsession)
if (WSL_BUILD_WSL_SETTINGS)
add_dependencies(msipackage wslsettings libwsl)
diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in
index 6b2cd0c5e9..38934c076e 100644
--- a/msipackage/package.wix.in
+++ b/msipackage/package.wix.in
@@ -31,6 +31,7 @@
+
@@ -177,6 +178,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/windows/common/precomp.h b/src/windows/common/precomp.h
index c7b53ee547..6d098a5c92 100644
--- a/src/windows/common/precomp.h
+++ b/src/windows/common/precomp.h
@@ -84,6 +84,7 @@ Module Name:
#include
#include
#include
+#include
#include
#include
#include
diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt
index c7b4ebb0dd..d5cdb06cb9 100644
--- a/src/windows/service/exe/CMakeLists.txt
+++ b/src/windows/service/exe/CMakeLists.txt
@@ -7,6 +7,7 @@ set(SOURCES
LxssHttpProxy.cpp
LxssInstance.cpp
PluginManager.cpp
+ PluginCallPump.cpp
ServiceMain.cpp
BridgedNetworking.cpp
GnsRpcServer.cpp
@@ -37,6 +38,7 @@ set(HEADERS
LxssIptables.h
LxssHttpProxy.h
PluginManager.h
+ PluginCallPump.h
LxssInstance.h
BridgedNetworking.h
GnsRpcServer.h
@@ -58,7 +60,7 @@ set(HEADERS
WSLCPluginNotifier.h)
add_executable(wslservice ${SOURCES} ${HEADERS})
-add_dependencies(wslservice wslserviceidl wslservicemc)
+add_dependencies(wslservice wslserviceidl wslservicemc wslpluginhostidl)
add_compile_definitions(__WRL_CLASSIC_COM__)
add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__)
add_compile_definitions(USE_COM_CONTEXT_DEF=1)
diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp
index ba22e10e03..578ebd1ac7 100644
--- a/src/windows/service/exe/LxssUserSession.cpp
+++ b/src/windows/service/exe/LxssUserSession.cpp
@@ -3653,6 +3653,18 @@ HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In
return S_OK;
}
+bool LxssUserSessionImpl::TryInvokeUnderInstanceLock(std::chrono::milliseconds Timeout, const std::function& Work, _Out_ HRESULT& Result)
+{
+ std::unique_lock lock(m_instanceLock, std::defer_lock);
+ if (!lock.try_lock_for(Timeout))
+ {
+ return false;
+ }
+
+ Result = Work();
+ return true;
+}
+
HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* Socket)
{
std::lock_guard lock(m_instanceLock);
diff --git a/src/windows/service/exe/LxssUserSession.h b/src/windows/service/exe/LxssUserSession.h
index 6e2d416873..fbeac06923 100644
--- a/src/windows/service/exe/LxssUserSession.h
+++ b/src/windows/service/exe/LxssUserSession.h
@@ -445,6 +445,18 @@ class LxssUserSessionImpl
HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name);
+ ///
+ /// Attempts to run Work while holding m_instanceLock, acquired with the
+ /// given timeout. Returns true and sets Result to Work()'s HRESULT if the
+ /// lock was acquired and Work ran; returns false (without running Work) if
+ /// the lock could not be acquired within Timeout. Used by the plugin
+ /// callback pump to run an out-of-hook callback directly without blocking
+ /// indefinitely on the instance lock — a notification thread may hold it in
+ /// its pre-notification phase and later wait for this very callback, which
+ /// would deadlock if we blocked on the lock unconditionally.
+ ///
+ bool TryInvokeUnderInstanceLock(std::chrono::milliseconds Timeout, const std::function& Work, _Out_ HRESULT& Result);
+
///
/// Registers a distribution.
///
diff --git a/src/windows/service/exe/PluginCallPump.cpp b/src/windows/service/exe/PluginCallPump.cpp
new file mode 100644
index 0000000000..2bac10e9e9
--- /dev/null
+++ b/src/windows/service/exe/PluginCallPump.cpp
@@ -0,0 +1,147 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#include "precomp.h"
+#include "PluginCallPump.h"
+#include
+
+using wsl::windows::service::PluginCallPump;
+
+PluginCallPump::PluginCallPump() = default;
+
+void PluginCallPump::DrainQueue()
+{
+ for (;;)
+ {
+ Call* call = nullptr;
+ {
+ auto lock = m_lock.lock_exclusive();
+ if (m_queue.empty())
+ {
+ break;
+ }
+ call = m_queue.front();
+ m_queue.pop_front();
+ }
+
+ // Run the plugin's service-side work on the pump (notification) thread,
+ // which still holds the notifying lock — recursive locks re-enter here.
+ // The closures already translate exceptions to HRESULTs via CATCH_RETURN,
+ // but guard defensively so a throw can never skip done.SetEvent() and
+ // strand the waiting RPC thread.
+ try
+ {
+ call->result = call->work ? call->work() : E_UNEXPECTED;
+ }
+ catch (...)
+ {
+ call->result = wil::ResultFromCaughtException();
+ }
+
+ call->done.SetEvent();
+ }
+}
+
+HRESULT PluginCallPump::Run(const std::function& Notification)
+try
+{
+ HRESULT notificationResult = E_FAIL;
+
+ // The worker makes the outbound cross-process notification call. It runs on
+ // its own thread so THIS thread is free to pump the plugin's callbacks while
+ // the notification is in flight. (Thread creation can throw std::system_error
+ // under resource exhaustion; the surrounding try/CATCH_RETURN converts that to
+ // an HRESULT so it never escapes into a non-throwing teardown notification.)
+ wil::unique_event workerDone(wil::EventOptions::ManualReset);
+ std::thread worker([&]() {
+ // Guard defensively: a throw escaping the thread's top-level function
+ // would call std::terminate() and crash the service, and would skip
+ // workerDone.SetEvent() — stranding the pumping thread. Translate to an
+ // HRESULT and always signal completion (mirrors DrainQueue).
+ try
+ {
+ notificationResult = Notification();
+ }
+ catch (...)
+ {
+ notificationResult = wil::ResultFromCaughtException();
+ }
+
+ workerDone.SetEvent();
+ });
+
+ auto join = wil::scope_exit([&]() {
+ if (worker.joinable())
+ {
+ worker.join();
+ }
+ });
+
+ const HANDLE waits[] = {m_callAvailable.get(), workerDone.get()};
+ for (;;)
+ {
+ const DWORD wait = ::WaitForMultipleObjects(ARRAYSIZE(waits), waits, FALSE, INFINITE);
+
+ // A kernel wait failure must never spin this thread. Stop the pump and
+ // fail any queued/future calls so their RPC threads aren't stranded.
+ FAIL_FAST_LAST_ERROR_IF(wait == WAIT_FAILED);
+
+ // Drain regardless of which handle woke us: the auto-reset event
+ // coalesces multiple enqueues into a single signal, and a final call may
+ // race in just as the worker completes.
+ DrainQueue();
+
+ // Check worker completion independently of the wait result.
+ // WaitForMultipleObjects reports the LOWEST signaled index, so a steady
+ // stream of callbacks on m_callAvailable (index 0) would otherwise starve
+ // the workerDone (index 1) branch and keep this thread pumping (and the
+ // notifying lock held) forever. workerDone is manual-reset, so this is a
+ // cheap non-consuming poll.
+ if (::WaitForSingleObject(workerDone.get(), 0) == WAIT_OBJECT_0)
+ {
+ // Worker (notification) finished. Stop accepting further work and
+ // fail any call that raced in after this point so its RPC thread is
+ // never stranded, then drain whatever was already queued.
+ {
+ auto lock = m_lock.lock_exclusive();
+ m_stopped = true;
+ }
+ DrainQueue();
+ break;
+ }
+ }
+
+ return notificationResult;
+}
+CATCH_RETURN()
+
+bool PluginCallPump::Invoke(std::function Work, _Out_ HRESULT& Result)
+{
+ Call call;
+ call.work = std::move(Work);
+
+ const HRESULT createHr = call.done.create();
+ if (FAILED(createHr))
+ {
+ // Could not create the completion event (resource exhaustion). Surface
+ // the failure as the executed result rather than reporting "not run":
+ // retrying on the direct path would not help and risks double-execution.
+ Result = createHr;
+ return true;
+ }
+
+ {
+ auto lock = m_lock.lock_exclusive();
+ if (m_stopped)
+ {
+ // The notification already returned; there is no pump thread left to
+ // run this. Report "not run" so the caller executes it directly.
+ return false;
+ }
+ m_queue.push_back(&call);
+ }
+
+ m_callAvailable.SetEvent();
+ call.done.wait();
+ Result = call.result;
+ return true;
+}
diff --git a/src/windows/service/exe/PluginCallPump.h b/src/windows/service/exe/PluginCallPump.h
new file mode 100644
index 0000000000..2128fdf80b
--- /dev/null
+++ b/src/windows/service/exe/PluginCallPump.h
@@ -0,0 +1,85 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace wsl::windows::service {
+
+//
+// PluginCallPump implements the threaded-callback model for out-of-process
+// plugin notifications.
+//
+// Problem: a plugin lifecycle notification (OnVMStarted, OnDistributionStarted,
+// ...) is an outbound cross-process COM call. While the service thread is
+// blocked inside that call, the plugin may call back into the service
+// (MountFolder, ExecuteBinary, ...). That callback arrives on a *different* COM
+// RPC thread, so it cannot re-enter the locks held by the notifying thread
+// (m_instanceLock etc.) without a second, parallel locking scheme.
+//
+// Solution (matches the in-process model): make the outbound notification on a
+// worker thread and "pump" the plugin's service-side API calls back onto the
+// ORIGINAL notifying thread. Because the work then runs on the lock-holding
+// thread, recursive locks (std::recursive_timed_mutex m_instanceLock) re-enter
+// exactly as they did when plugins were loaded in-process — no second lock
+// (m_callbackLock) and no out-of-band session registry are needed.
+//
+// Notifying thread: pump.Run([&]{ return host->OnVMStarted(...); });
+// RPC callback: return pump.Invoke([&]{ return session->Mount...(); });
+//
+// A single pump instance services one outbound notification at a time. The
+// pump thread is the single consumer; any number of RPC threads may Invoke().
+//
+class PluginCallPump
+{
+public:
+ PluginCallPump();
+ ~PluginCallPump() = default;
+
+ PluginCallPump(const PluginCallPump&) = delete;
+ PluginCallPump& operator=(const PluginCallPump&) = delete;
+ PluginCallPump(PluginCallPump&&) = delete;
+ PluginCallPump& operator=(PluginCallPump&&) = delete;
+
+ // Runs `Notification` (the outbound host->On... COM call) on a dedicated
+ // worker thread and pumps queued Invoke() calls on the CALLING thread until
+ // the worker completes. The worker performs its own COM initialization, so
+ // `Notification` should acquire the apartment-local host proxy itself.
+ //
+ // Returns the HRESULT returned by `Notification`. Any service-side work
+ // requested by the plugin runs on the calling (lock-holding) thread, so
+ // recursive locks behave exactly as in the in-process model.
+ HRESULT Run(const std::function& Notification);
+
+ // Called from a COM RPC callback thread. Marshals `Work` to the pump thread,
+ // blocks until it has executed there, and reports its HRESULT via `Result`.
+ // Returns true if `Work` was executed (`Result` is set); returns false if the
+ // pump is no longer running (the notification already returned), in which case
+ // `Work` was NOT run and the caller must run it itself. Reporting "not run"
+ // out-of-band (rather than via a sentinel HRESULT) lets `Work` legitimately
+ // return any HRESULT, including RPC_E_DISCONNECTED, without ambiguity.
+ bool Invoke(std::function Work, _Out_ HRESULT& Result);
+
+private:
+ struct Call
+ {
+ std::function work;
+ HRESULT result{E_FAIL};
+ wil::unique_event_nothrow done;
+ };
+
+ // Runs every currently-queued call on the calling (pump) thread.
+ void DrainQueue();
+
+ wil::srwlock m_lock;
+ _Guarded_by_(m_lock) std::deque m_queue;
+ _Guarded_by_(m_lock) bool m_stopped { false };
+
+ // Auto-reset event signaled whenever a call is enqueued.
+ wil::unique_event m_callAvailable{wil::EventOptions::None};
+};
+
+} // namespace wsl::windows::service
diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp
index 0d6b6ccbb7..1b31d19555 100644
--- a/src/windows/service/exe/PluginManager.cpp
+++ b/src/windows/service/exe/PluginManager.cpp
@@ -9,6 +9,8 @@ Module Name:
Abstract:
This file contains the PluginManager helper class implementation.
+ Plugins are loaded in isolated wslpluginhost.exe processes via COM,
+ so a crashing plugin cannot take down the WSL service.
--*/
@@ -16,678 +18,1326 @@ Module Name:
#include "install.h"
#include "PluginManager.h"
#include "WslPluginApi.h"
+#include "WslPluginHost.h"
#include "LxssUserSessionFactory.h"
#include "WSLCSessionManager.h"
using wsl::windows::common::Context;
using wsl::windows::common::ExecutionContext;
+using wsl::windows::service::PluginHostCallbackImpl;
using wsl::windows::service::PluginManager;
-constexpr auto c_pluginPath = L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Lxss\\Plugins";
-
-constexpr WSLVersion Version = {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision};
-
-thread_local std::optional g_pluginErrorMessage;
+// Acquire an apartment-local IWslPluginHost proxy for `plugin` (named `host`).
+// On a host-process crash, latch a fatal plugin error and `continue` the
+// surrounding loop: teardown/notification hooks must not block WSL, but the
+// latch makes the next start operation fail fast instead of repeatedly driving
+// a dead host. On any other failure to acquire (which would indicate a
+// fundamental COM problem, not a plugin-reported issue), log the HRESULT and
+// `continue` so a single busted plugin does not break the iteration for the
+// others. Use only inside the per-plugin loops in PluginManager hook methods.
+#define ACQUIRE_PLUGIN_HOST_OR_CONTINUE(plugin, host, stage) \
+ Microsoft::WRL::ComPtr host; \
+ { \
+ const HRESULT _acqHr = AcquireHostProxy((plugin), &(host)); \
+ if (FAILED(_acqHr)) \
+ { \
+ if (IsHostCrash(_acqHr)) \
+ { \
+ LatchHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \
+ } \
+ else \
+ { \
+ LOG_HR_MSG(_acqHr, "Failed to acquire plugin host proxy for: '%ls'", (plugin).name.c_str()); \
+ } \
+ continue; \
+ } \
+ }
-extern "C" {
-HRESULT MountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name)
-try
-{
- const auto session = FindSessionByCookie(Session);
- RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
+// Same as ACQUIRE_PLUGIN_HOST_OR_CONTINUE, but for hook methods that surface a
+// plugin error to abort the guarded operation (e.g. OnVmStarted). A host crash
+// is fatal: it is latched and thrown as a fatal plugin error, matching the
+// pre-refactor behavior where an in-process plugin crash brought down WSL. Any
+// other acquisition failure is likewise thrown so it is surfaced exactly like a
+// plugin-reported error would be, rather than silently allowing the operation to
+// proceed without consulting the plugin.
+#define ACQUIRE_PLUGIN_HOST_OR_THROW(plugin, host, stage) \
+ Microsoft::WRL::ComPtr host; \
+ { \
+ const HRESULT _acqHr = AcquireHostProxy((plugin), &(host)); \
+ if (FAILED(_acqHr)) \
+ { \
+ if (IsHostCrash(_acqHr)) \
+ { \
+ ThrowHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \
+ } \
+ THROW_HR_MSG(_acqHr, "Failed to acquire plugin host proxy for: '%ls'", (plugin).name.c_str()); \
+ } \
+ }
- auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name);
+constexpr auto c_pluginPath = L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Lxss\\Plugins";
- WSL_LOG(
- "PluginMountFolderCall",
- TraceLoggingValue(WindowsPath, "WindowsPath"),
- TraceLoggingValue(LinuxPath, "LinuxPath"),
- TraceLoggingValue(ReadOnly, "ReadOnly"),
- TraceLoggingValue(Name, "Name"),
- TraceLoggingValue(result, "Result"));
+// --- IWslPluginHostCallback implementation (service-side) ---
+// These methods handle API calls from the plugin host process.
- return result;
-}
-CATCH_RETURN();
+// Returned to the plugin when a session cookie no longer resolves (the session
+// was torn down). Deliberately not an RPC_* status: a plugin that propagates
+// this from its hook must not be mistaken for a dead host by IsHostCrash.
+constexpr HRESULT c_pluginSessionNotFound = HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
-HRESULT ExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket)
+STDMETHODIMP PluginHostCallbackImpl::MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name)
try
{
+ RETURN_HR_IF(E_INVALIDARG, WindowsPath == nullptr || LinuxPath == nullptr || Name == nullptr);
- const auto session = FindSessionByCookie(Session);
- RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
+ // Marshal the work onto the notifying thread so MountRootNamespaceFolder runs
+ // under the session's recursive m_instanceLock (see InvokeOnWslPump). The
+ // captured pointers reference the caller's stack frame, which stays alive
+ // because InvokeOnWslPump blocks until the work has run.
+ return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT {
+ const auto session = FindSessionByCookie(SessionId);
+ RETURN_HR_IF(c_pluginSessionNotFound, !session);
- auto result = session->CreateLinuxProcess(nullptr, Path, Arguments, Socket);
+ auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name);
- WSL_LOG("PluginExecuteBinaryCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result"));
- return result;
+ WSL_LOG(
+ "PluginCallbackMountFolderEnd", TraceLoggingValue(WindowsPath, "WindowsPath"), TraceLoggingValue(result, "Result"));
+ return result;
+ });
}
CATCH_RETURN();
-HRESULT PluginError(LPCWSTR UserMessage)
+STDMETHODIMP PluginHostCallbackImpl::ExecuteBinary(
+ _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket)
try
{
- const auto* context = ExecutionContext::Current();
- THROW_HR_IF(E_INVALIDARG, UserMessage == nullptr);
- THROW_HR_IF_MSG(
- E_ILLEGAL_METHOD_CALL, context == nullptr || WI_IsFlagClear(context->CurrentContext(), Context::Plugin), "Message: %ls", UserMessage);
+ RETURN_HR_IF(E_POINTER, Socket == nullptr);
+ *Socket = nullptr;
+ RETURN_HR_IF(E_INVALIDARG, Path == nullptr);
+ RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr);
+
+ return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT {
+ WSL_LOG("PluginCallbackExecuteBinaryBegin", TraceLoggingValue(Path, "Path"), TraceLoggingValue(SessionId, "SessionId"));
+ const auto session = FindSessionByCookie(SessionId);
+ RETURN_HR_IF(c_pluginSessionNotFound, !session);
+
+ // Build NULL-terminated argument array expected by CreateLinuxProcess.
+ std::vector args;
+ if (Arguments != nullptr)
+ {
+ args.assign(Arguments, Arguments + ArgumentCount);
+ }
+ args.push_back(nullptr);
- // Logs when a WSL plugin hits an error and what that error message is
- WSL_LOG_TELEMETRY("PluginError", PDT_ProductAndServicePerformance, TraceLoggingValue(UserMessage, "Message"));
+ WSL_LOG("PluginCallbackExecuteBinaryCallingCreateProcess", TraceLoggingValue(Path, "Path"));
+ wil::unique_socket sock;
+ auto result = session->CreateLinuxProcess(nullptr, Path, args.data(), &sock);
- THROW_HR_IF(E_ILLEGAL_STATE_CHANGE, g_pluginErrorMessage.has_value());
+ WSL_LOG("PluginCallbackExecuteBinaryEnd", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result"));
- g_pluginErrorMessage.emplace(UserMessage);
+ if (SUCCEEDED(result))
+ {
+ // Return socket as HANDLE — COM's system_handle marshaling will
+ // duplicate it into the host process automatically.
+ *Socket = reinterpret_cast(sock.release());
+ }
- return S_OK;
+ return result;
+ });
}
CATCH_RETURN();
-HRESULT ExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket)
+STDMETHODIMP PluginHostCallbackImpl::ExecuteBinaryInDistribution(
+ _In_ DWORD SessionId,
+ _In_ const GUID* DistributionId,
+ _In_ LPCSTR Path,
+ _In_ DWORD ArgumentCount,
+ _In_reads_opt_(ArgumentCount) LPCSTR* Arguments,
+ _Out_ HANDLE* Socket)
try
{
- THROW_HR_IF(E_INVALIDARG, Distro == nullptr);
+ RETURN_HR_IF(E_POINTER, Socket == nullptr);
+ *Socket = nullptr;
+ RETURN_HR_IF(E_INVALIDARG, DistributionId == nullptr);
+ RETURN_HR_IF(E_INVALIDARG, Path == nullptr);
+ RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr);
+
+ return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT {
+ const auto session = FindSessionByCookie(SessionId);
+ RETURN_HR_IF(c_pluginSessionNotFound, !session);
+
+ std::vector args;
+ if (Arguments != nullptr)
+ {
+ args.assign(Arguments, Arguments + ArgumentCount);
+ }
+ args.push_back(nullptr);
- const auto session = FindSessionByCookie(Session);
- RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
+ wil::unique_socket sock;
+ auto result = session->CreateLinuxProcess(DistributionId, Path, args.data(), &sock);
- auto result = session->CreateLinuxProcess(Distro, Path, Arguments, Socket);
+ WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result"));
- WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result"));
+ if (SUCCEEDED(result))
+ {
+ *Socket = reinterpret_cast(sock.release());
+ }
- return result;
+ return result;
+ });
}
CATCH_RETURN();
-}
-
-namespace {
-// Opaque wrapper around IWSLCProcess, handed out as WSLCProcessHandle to plugins.
-struct WslcProcessWrapper
-{
- wil::com_ptr Process;
-};
+// --- PluginManager implementation ---
-wil::com_ptr ResolveWslcSession(WSLCSessionId Session)
+PluginManager::~PluginManager()
{
- auto* mgr = wsl::windows::service::wslc::WSLCSessionManagerImpl::Instance();
- THROW_HR_IF(RPC_E_DISCONNECTED, mgr == nullptr);
+ // m_plugins should have been cleared by Shutdown() while COM was still
+ // initialized. Releasing GIT cookies and the GIT itself here (at global
+ // teardown, after CoUninitialize and after the proxy/stub DLL has been
+ // unloaded) crashes inside the marshaler. If anything is left here, leak
+ // it on purpose rather than dereferencing a torn-down proxy vtable.
+ if (!m_plugins.empty())
+ {
+ LOG_HR_MSG(E_UNEXPECTED, "PluginManager destroyed without Shutdown(); leaking %zu host registrations", m_plugins.size());
+ for (auto& e : m_plugins)
+ {
+ // Drop the cookie without revoking — calling GIT after CoUninitialize crashes.
+ e.hostCookie = 0;
+ (void)e.callback.Detach();
+ }
+ m_plugins.clear();
+ }
+ if (m_git)
+ {
+ // Same reasoning: leak the GIT reference rather than releasing it after teardown.
+ (void)m_git.Detach();
+ }
- return mgr->FindSession(static_cast(Session));
+ // WSLC session references are COM proxies too. Shutdown() should have
+ // drained them while COM was still initialized; if any survive, detach
+ // (leak) them rather than releasing a torn-down proxy.
+ if (!m_wslcSessionRefs.empty())
+ {
+ LOG_HR_MSG(E_UNEXPECTED, "PluginManager destroyed with %zu WSLC session references still registered", m_wslcSessionRefs.size());
+ for (auto& [id, ref] : m_wslcSessionRefs)
+ {
+ (void)ref.detach();
+ }
+ m_wslcSessionRefs.clear();
+ }
+ m_jobObject.reset();
}
-} // namespace
-
-extern "C" {
-
-HRESULT WSLCMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly)
-try
+void PluginManager::Shutdown()
{
- // TODO: Once plugins are out of proc, add logic to validate that the mountpoint isn't in use by another plugin.
- RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Mountpoint == nullptr);
-
- auto session = ResolveWslcSession(Session);
- auto result = session->MountWindowsFolder(WindowsPath, Mountpoint, ReadOnly);
+ // Must be called while COM is still initialized. Revoking each GIT cookie
+ // releases the underlying marshaled IWslPluginHost reference, which causes
+ // the wslpluginhost.exe processes to exit; the job object below makes that
+ // guaranteed even if a revoke fails.
+ if (m_git)
+ {
+ for (auto& e : m_plugins)
+ {
+ if (e.hostCookie != 0)
+ {
+ LOG_IF_FAILED(m_git->RevokeInterfaceFromGlobal(e.hostCookie));
+ e.hostCookie = 0;
+ }
+ }
+ m_git.Reset();
+ }
+ m_plugins.clear();
- WSL_LOG(
- "WslcPluginMountFolderCall",
- TraceLoggingValue(Session, "SessionId"),
- TraceLoggingValue(WindowsPath, "WindowsPath"),
- TraceLoggingValue(Mountpoint, "Mountpoint"),
- TraceLoggingValue(ReadOnly, "ReadOnly"),
- TraceLoggingValue(result, "Result"));
+ // Release any remaining WSLC session references while COM is still
+ // initialized. By this point all sessions should already have been torn
+ // down (and thus unregistered), but drain defensively so we never release
+ // these proxies from the destructor after CoUninitialize.
+ {
+ std::lock_guard lock(m_wslcSessionRefLock);
+ m_wslcSessionRefs.clear();
+ }
- return result;
+ m_jobObject.reset();
}
-CATCH_RETURN();
-HRESULT WSLCUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint)
-try
+void PluginManager::LoadPlugins()
{
- // TODO: Once plugins are out of proc, add logic to validate that the mountpoint is actually owned by the plugin.
- RETURN_HR_IF(E_POINTER, Mountpoint == nullptr);
-
- auto session = ResolveWslcSession(Session);
-
- auto result = session->UnmountWindowsFolder(Mountpoint);
-
- WSL_LOG(
- "WslcPluginUnmountFolderCall",
- TraceLoggingValue(Session, "SessionId"),
- TraceLoggingValue(Mountpoint, "Mountpoint"),
- TraceLoggingValue(result, "Result"));
-
- return result;
-}
-CATCH_RETURN();
+ ExecutionContext context(Context::Plugin);
-HRESULT WSLCCreateProcess(WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno)
-try
-{
- RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr);
+ const auto key = common::registry::CreateKey(HKEY_LOCAL_MACHINE, c_pluginPath, KEY_READ);
+ const auto values = common::registry::EnumValues(key.get());
- *Process = nullptr;
- if (Errno != nullptr)
+ std::set loaded;
+ for (const auto& e : values)
{
- *Errno = 0;
- }
-
- auto session = ResolveWslcSession(Session);
-
- // Count NULL-terminated arrays.
- auto countArray = [](LPCSTR* arr) -> ULONG {
- if (arr == nullptr)
- {
- return 0;
- }
- ULONG count = 0;
- while (arr[count] != nullptr)
+ if (e.second != REG_SZ)
{
- ++count;
+ LOG_HR_MSG(E_UNEXPECTED, "Plugin value: '%ls' has incorrect type: %lu, skipping", e.first.c_str(), e.second);
+ continue;
}
- return count;
- };
-
- WSLCProcessOptions options{};
- options.CommandLine.Values = Arguments;
- options.CommandLine.Count = countArray(Arguments);
- options.Environment.Values = Env;
- options.Environment.Count = countArray(Env);
- options.Flags = WSLCProcessFlagsStdin;
- wil::com_ptr process;
- int errnoValue = 0;
- auto result = session->CreateRootNamespaceProcess(Executable, &options, 0, 0, &process, &errnoValue);
+ auto path = common::registry::ReadString(key.get(), nullptr, e.first.c_str());
- if (Errno != nullptr)
- {
- *Errno = errnoValue;
- }
+ if (!loaded.insert(path).second)
+ {
+ LOG_HR_MSG(E_UNEXPECTED, "Module '%ls' has already been loaded, skipping plugin '%ls'", path.c_str(), e.first.c_str());
+ continue;
+ }
- if (FAILED(result))
- {
- WSL_LOG(
- "WslcPluginCreateProcessCall",
- TraceLoggingValue(Session, "SessionId"),
- TraceLoggingValue(Executable, "Executable"),
- TraceLoggingValue(result, "Result"),
- TraceLoggingValue(errnoValue, "Errno"));
- return result;
+ // Record the plugin for deferred activation. The actual COM host process
+ // is created in EnsureInitialized(), which runs after the service's COM
+ // initialization is complete (CoInitializeSecurity must happen first).
+ OutOfProcPlugin plugin{};
+ plugin.name = e.first;
+ plugin.path = path;
+ m_plugins.emplace_back(std::move(plugin));
+
+ // Discovery-only event. The plugin DLL is no longer loaded into
+ // wslservice.exe — actual load happens out-of-process via COM
+ // activation in EnsureInitialized(). See "PluginLoad" emitted from
+ // that path for the real load result.
+ WSL_LOG_TELEMETRY(
+ "PluginDiscovered",
+ PDT_ProductAndServiceUsage,
+ TraceLoggingValue(e.first.c_str(), "Name"),
+ TraceLoggingValue(path.c_str(), "Path"));
}
-
- auto wrapper = std::make_unique();
- wrapper->Process = std::move(process);
- *Process = wrapper.release();
-
- WSL_LOG(
- "WslcPluginCreateProcessCall",
- TraceLoggingValue(Session, "SessionId"),
- TraceLoggingValue(Executable, "Executable"),
- TraceLoggingValue(*Process, "Process"),
- TraceLoggingValue(S_OK, "Result"));
-
- return S_OK;
}
-CATCH_RETURN();
-HRESULT WSLCProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle)
-try
+PluginManager::ScopedComInit::ScopedComInit() : initHr(::CoInitializeEx(nullptr, COINIT_MULTITHREADED))
{
- RETURN_HR_IF(E_POINTER, Process == nullptr || Handle == nullptr);
-
- *Handle = nullptr;
-
- auto* wrapper = static_cast(Process);
+}
- WSLCFD wslcFd{};
- switch (Fd)
+PluginManager::ScopedComInit::~ScopedComInit()
+{
+ if (SUCCEEDED(initHr))
{
- case WSLCProcessFdStdin:
- wslcFd = WSLCFDStdin;
- break;
- case WSLCProcessFdStdout:
- wslcFd = WSLCFDStdout;
- break;
- case WSLCProcessFdStderr:
- wslcFd = WSLCFDStderr;
- break;
- default:
- WSL_LOG(
- "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result"));
- return E_INVALIDARG;
+ ::CoUninitialize();
}
-
- WSLCHandle handle{};
- auto result = wrapper->Process->GetStdHandle(wslcFd, &handle);
-
- WSL_LOG(
- "WslcPluginProcessGetFd",
- TraceLoggingValue(static_cast(Fd), "Fd"),
- TraceLoggingValue(handle.Handle.Socket, "Handle"),
- TraceLoggingValue(result, "Result"));
-
- RETURN_IF_FAILED(result);
- WI_ASSERT(handle.Type == WSLCHandleTypeSocket);
-
- *Handle = handle.Handle.Socket;
- return S_OK;
}
-CATCH_RETURN();
-HRESULT WSLCProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent)
-try
+PluginManager::ScopedComInit::ScopedComInit(ScopedComInit&& other) noexcept : initHr(other.initHr)
{
- RETURN_HR_IF(E_POINTER, Process == nullptr || ExitEvent == nullptr);
-
- *ExitEvent = nullptr;
+ // Suppress uninit in moved-from instance.
+ other.initHr = RPC_E_CHANGED_MODE;
+}
- auto* wrapper = static_cast(Process);
- auto result = wrapper->Process->GetExitEvent(ExitEvent);
+HRESULT PluginManager::ScopedComInit::Result() const noexcept
+{
+ return (initHr == RPC_E_CHANGED_MODE) ? S_OK : initHr;
+}
- WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result"));
+PluginManager::ScopedComInit PluginManager::EnsureInitialized()
+{
+ // Join the calling thread to the MTA for the duration of the dispatch. The
+ // returned guard must outlive any subsequent plugin host calls because
+ // those are cross-process COM calls that require an initialized apartment,
+ // and so does the GIT-based proxy acquisition below.
+ ScopedComInit coInit;
+ THROW_IF_FAILED(coInit.Result());
+
+ // Lazily create the process-wide Global Interface Table accessor. The GIT
+ // itself is a process-global COM-provided singleton; we just need an
+ // in-process accessor to register and look up cookies.
+ std::call_once(m_gitOnce, [this]() {
+ THROW_IF_FAILED(CoCreateInstance(CLSID_StdGlobalInterfaceTable, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&m_git)));
+ });
+
+ std::call_once(m_initOnce, [this]() {
+ for (auto& e : m_plugins)
+ {
+ auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e); });
+
+ // Canonical "plugin was actually loaded" telemetry — matches the
+ // semantics of the pre-refactor PluginLoad event (emitted after
+ // the entry point ran). PluginHostActivation below is the more
+ // granular event covering the COM activation path specifically.
+ WSL_LOG_TELEMETRY(
+ "PluginLoad",
+ PDT_ProductAndServiceUsage,
+ TraceLoggingValue(e.name.c_str(), "Name"),
+ TraceLoggingValue(e.path.c_str(), "Path"),
+ TraceLoggingValue(loadResult, "Result"));
+
+ WSL_LOG_TELEMETRY(
+ "PluginHostActivation",
+ PDT_ProductAndServiceUsage,
+ TraceLoggingValue(e.name.c_str(), "Name"),
+ TraceLoggingValue(e.path.c_str(), "Path"),
+ TraceLoggingValue(loadResult, "Result"));
+
+ if (FAILED(loadResult))
+ {
+ // Any load failure is fatal: the plugin is recorded so that
+ // subsequent operations block with a clear error rather than a
+ // silently-disabled plugin. This includes host-process crashes
+ // and benign-looking COM activation failures (the server is
+ // shutting down or its exec failed) — matching the pre-refactor
+ // behavior where a plugin that failed to load blocked WSL.
+ m_pluginError.emplace(PluginError{e.name, loadResult});
+ }
+ }
+ });
- return result;
+ return coInit;
}
-CATCH_RETURN();
-HRESULT WSLCProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode)
-try
+void PluginManager::LoadPlugin(OutOfProcPlugin& plugin)
{
- RETURN_HR_IF(E_POINTER, Process == nullptr || ExitCode == nullptr);
+ // One callback object per plugin host so the WSLC process-cookie map is
+ // isolated per plugin. When the plugin host process is released, the
+ // callback's refcount drops and any unreleased IWSLCProcess refs are freed.
+ plugin.callback = Microsoft::WRL::Make(*this);
+ THROW_IF_NULL_ALLOC(plugin.callback);
+
+ // Activate the plugin host via COM. The LocalServer32 registration causes COM
+ // to spawn wslpluginhost.exe automatically.
+ Microsoft::WRL::ComPtr host;
+ HRESULT activationHr = CoCreateInstance(CLSID_WslPluginHost, nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&host));
+ WSL_LOG(
+ "PluginHostActivation",
+ TraceLoggingValue(plugin.name.c_str(), "Plugin"),
+ TraceLoggingValue(plugin.path.c_str(), "Path"),
+ TraceLoggingValue(activationHr, "CoCreateInstanceResult"));
+ THROW_IF_FAILED_MSG(activationHr, "Failed to create plugin host for: '%ls'", plugin.path.c_str());
+
+ // Create the job object before initializing the host so we can hand it to
+ // Initialize. The host assigns itself to the job before running any plugin
+ // code, so any child processes the plugin spawns inherit the job and are
+ // killed when the service exits. Job assignment is fatal in the host: a host
+ // that isn't in the job would escape the kill-on-close guarantee, so a
+ // failure surfaces here (alongside activation and plugin entry-point errors)
+ // and the host process exits when its proxy is released on unwind.
+ // system_handle(sh_job) marshaling duplicates the job handle into the host
+ // for the duration of the Initialize call.
+ EnsureJobObjectCreated();
+
+ THROW_IF_FAILED_MSG(
+ host->Initialize(plugin.callback.Get(), m_jobObject.get(), plugin.path.c_str(), plugin.name.c_str()),
+ "Plugin host failed to initialize: '%ls'",
+ plugin.path.c_str());
+
+ // Stash the IWslPluginHost in the Global Interface Table. The proxy returned
+ // by CoCreateInstance is bound to the apartment of this thread (MTA, since
+ // EnsureInitialized() joined us to the MTA). Hook dispatch can arrive on
+ // threads in any apartment — in particular NTA-on-MTA, which is what wslservice's
+ // RPC dispatcher uses when the incoming call is via a WinRT-style interface
+ // (e.g. IWSLCSession). MTA-bound proxies are NOT callable from NTA, so storing
+ // the raw proxy here and reusing it cross-apartment fails with RPC_E_WRONG_THREAD.
+ // Storing in the GIT and re-unmarshaling per call yields an apartment-local
+ // proxy that works from any apartment.
+ DWORD cookie = 0;
+ THROW_IF_FAILED(m_git->RegisterInterfaceInGlobal(host.Get(), __uuidof(IWslPluginHost), &cookie));
+
+ auto revokeOnFailure = wil::scope_exit([&] {
+ if (cookie != 0)
+ {
+ LOG_IF_FAILED(m_git->RevokeInterfaceFromGlobal(cookie));
+ }
+ });
- *ExitCode = -1;
- auto* wrapper = static_cast(Process);
+ // Drop the raw proxy — future access goes through the GIT to get an
+ // apartment-local proxy. The GIT keeps the underlying marshaled stream
+ // alive, so the host process stays running.
+ host.Reset();
- WSLCProcessState state{};
- auto result = wrapper->Process->GetState(&state, ExitCode);
+ plugin.hostCookie = cookie;
+ cookie = 0;
+ revokeOnFailure.release();
+}
- if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled)
+HRESULT PluginManager::AcquireHostProxy(const OutOfProcPlugin& plugin, _COM_Outptr_ IWslPluginHost** host)
+{
+ *host = nullptr;
+ if (plugin.hostCookie == 0 || !m_git)
{
- result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE);
+ return E_NOT_VALID_STATE;
}
+ return m_git->GetInterfaceFromGlobal(plugin.hostCookie, __uuidof(IWslPluginHost), reinterpret_cast(host));
+}
- WSL_LOG(
- "WslcPluginProcessGetExitCode",
- TraceLoggingValue(*ExitCode, "ExitCode"),
- TraceLoggingValue(static_cast(state), "State"),
- TraceLoggingValue(result, "Result"));
-
- return result;
+void PluginManager::EnsureJobObjectCreated()
+{
+ std::call_once(m_jobObjectOnce, [this]() {
+ m_jobObject.reset(CreateJobObjectW(nullptr, nullptr));
+ THROW_LAST_ERROR_IF(!m_jobObject);
+
+ JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{};
+ jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
+ THROW_IF_WIN32_BOOL_FALSE(SetInformationJobObject(m_jobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo)));
+ });
}
-CATCH_RETURN();
-void WSLCReleaseProcess(WSLCProcessHandle Process)
+std::vector PluginManager::SerializeSid(PSID Sid)
{
- if (Process != nullptr)
- {
- WSL_LOG("WslcPluginReleaseProcess", TraceLoggingValue(Process, "Process"));
- delete static_cast(Process);
- }
+ const DWORD sidLength = GetLengthSid(Sid);
+ std::vector buffer(sidLength);
+ THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLength, buffer.data(), Sid));
+ return buffer;
}
-} // extern "C"
+// Registers the pump servicing the in-flight WSL notification for a session, so
+// that plugin callbacks arriving on RPC threads can be marshaled onto the
+// notifying thread (see InvokeOnWslPump). Invariant: a session has at most one
+// WSL notification in flight at a time — notifications for a session are issued
+// serially by that session, and the callbacks marshaled by the pump
+// (Mount/CreateLinuxProcess) never trigger further session notifications — so a
+// single SessionId-keyed entry never needs to nest. Register/Unregister are
+// bracketed by a scope_exit in RunHostNotification.
+void PluginManager::RegisterWslPump(ULONG SessionId, const std::shared_ptr& Pump)
+{
+ auto lock = m_wslPumpLock.lock_exclusive();
+ m_wslPumps[SessionId] = Pump;
+}
-static constexpr WSLPluginAPIV1 ApiV1 = {
- Version,
- &MountFolder,
- &ExecuteBinary,
- &PluginError,
- &ExecuteBinaryInDistribution,
- &WSLCMountFolder,
- &WSLCUnmountFolder,
- &WSLCCreateProcess,
- &WSLCProcessGetFd,
- &WSLCProcessGetExitEvent,
- &WSLCProcessGetExitCode,
- &WSLCReleaseProcess};
+void PluginManager::UnregisterWslPump(ULONG SessionId)
+{
+ auto lock = m_wslPumpLock.lock_exclusive();
+ m_wslPumps.erase(SessionId);
+}
-void PluginManager::LoadPlugins()
+HRESULT PluginManager::InvokeOnWslPump(ULONG SessionId, std::function Work)
{
- ExecutionContext context(Context::Plugin);
+ RETURN_HR_IF(E_UNEXPECTED, !Work);
- const auto key = common::registry::CreateKey(HKEY_LOCAL_MACHINE, c_pluginPath, KEY_READ);
- const auto values = common::registry::EnumValues(key.get());
+ // Granularity for the direct path's timed lock acquisition (see below).
+ constexpr auto c_directLockPollInterval = std::chrono::milliseconds(25);
- std::set loaded;
- for (const auto& e : values)
+ for (;;)
{
- if (e.second != REG_SZ)
+ // 1. If a notification hook for this session is in flight, marshal the
+ // work onto the notifying thread (which holds the session's recursive
+ // m_instanceLock) via that hook's pump. Copy the shared_ptr out under
+ // the lock and release it before the (blocking) Invoke, so the
+ // registry lock is never held across a callback.
+ std::shared_ptr pump;
{
- LOG_HR_MSG(E_UNEXPECTED, "Plugin value: '%ls' has incorrect type: %lu, skipping", e.first.c_str(), e.second);
- continue;
+ auto lock = m_wslPumpLock.lock_shared();
+ const auto it = m_wslPumps.find(SessionId);
+ if (it != m_wslPumps.end())
+ {
+ pump = it->second;
+ }
}
- auto path = common::registry::ReadString(key.get(), nullptr, e.first.c_str());
-
- if (!loaded.insert(path).second)
+ if (pump)
{
- LOG_HR_MSG(E_UNEXPECTED, "Module '%ls' has already been loaded, skipping plugin '%ls'", path.c_str(), e.first.c_str());
- continue;
+ // Pass Work by copy (not moved): if the pump has already stopped we
+ // fall through to the direct path below and still need it. Invoke
+ // reports out-of-band whether it actually ran the work, so a real
+ // HRESULT from Work (even RPC_E_DISCONNECTED) is never mistaken for
+ // "pump stopped" — which would otherwise double-run the work.
+ HRESULT hr = E_FAIL;
+ if (pump->Invoke(Work, hr))
+ {
+ return hr;
+ }
+
+ // The hook returned (the pump stopped) between our lookup and the
+ // Invoke, but the entry is not yet unregistered. The notifying thread
+ // is no longer waiting on this callback, so fall through and run the
+ // work directly. The timed lock acquisition below backs off until that
+ // thread releases m_instanceLock, so a live call is never spuriously
+ // failed.
}
- auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e.first.c_str(), path.c_str()); });
-
- // Logs when a WSL plugin is loaded, used for evaluating plugin populations
- WSL_LOG_TELEMETRY(
- "PluginLoad",
- PDT_ProductAndServiceUsage,
- TraceLoggingValue(e.first.c_str(), "Name"),
- TraceLoggingValue(path.c_str(), "Path"),
- TraceLoggingValue(loadResult, "Result"));
+ // 2. No hook is in flight for this session (or its pump just stopped).
+ // Run the work directly on this RPC thread, like an in-process plugin
+ // calling the API from one of its own worker threads. We must NOT just
+ // block acquiring m_instanceLock: a notification thread may already
+ // hold it (in its pre-notification phase, before registering its pump)
+ // and then wait for THIS callback - e.g. a hook that joins the worker
+ // thread that issued it. Blocking would deadlock (the holder waits for
+ // us; we wait for the holder). Instead acquire with a short timeout; on
+ // failure, loop back and re-check for a pump, which the notifying
+ // thread registers before it can wait on us, so step 1 then routes the
+ // work onto it.
+ const auto session = FindSessionByCookie(SessionId);
+ if (!session)
+ {
+ // Unknown session and no hook in flight: run Work() so it surfaces
+ // the appropriate "session not found" failure.
+ return Work();
+ }
- if (FAILED(loadResult))
+ HRESULT result = E_FAIL;
+ if (session->TryInvokeUnderInstanceLock(c_directLockPollInterval, Work, result))
{
- // If this plugin reported an error, record it to display it to the user
- m_pluginError.emplace(PluginError{e.first, loadResult});
+ return result;
}
+
+ // The instance lock is held by another thread that has not registered a
+ // pump (most likely a session operation in its pre-notification phase).
+ // Loop: re-check for a pump or retry the direct acquisition.
}
}
-void PluginManager::LoadPlugin(LPCWSTR Name, LPCWSTR ModulePath)
+HRESULT PluginManager::RunHostNotification(OutOfProcPlugin& Plugin, ULONG SessionId, const std::function& Notify)
{
- // Validate the plugin signature before loading it.
- // The handle to the module is kept open after validating the signature so the file can't be written to
- // after the signature check.
- wil::unique_hfile pluginHandle;
- if constexpr (wsl::shared::OfficialBuild)
- {
- pluginHandle = wsl::windows::common::install::ValidateFileSignature(ModulePath);
- WI_ASSERT(pluginHandle.is_valid());
- }
-
- LoadedPlugin plugin{};
- plugin.name = Name;
-
- plugin.module.reset(LoadLibrary(ModulePath));
- THROW_LAST_ERROR_IF_NULL(plugin.module);
-
- const WSLPluginAPI_EntryPointV1 entryPoint =
- reinterpret_cast(GetProcAddress(plugin.module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1)));
-
- THROW_LAST_ERROR_IF_NULL(entryPoint);
- THROW_IF_FAILED_MSG(entryPoint(&ApiV1, &plugin.hooks), "Error returned by plugin: '%ls'", ModulePath);
+ auto pump = std::make_shared();
+ RegisterWslPump(SessionId, pump);
+ auto unregister = wil::scope_exit([&]() { UnregisterWslPump(SessionId); });
+
+ return pump->Run([&]() -> HRESULT {
+ // Runs on the pump's worker thread. The host proxy is apartment-local,
+ // so acquire it here (and COM-init this thread) rather than on the
+ // notifying thread.
+ ScopedComInit init;
+ RETURN_IF_FAILED(init.Result());
+ Microsoft::WRL::ComPtr host;
+ const HRESULT acquire = AcquireHostProxy(Plugin, &host);
+ if (FAILED(acquire))
+ {
+ // Surface the acquire failure as the notification result; the caller
+ // routes it through IsHostCrash exactly like a call-time failure.
+ if (!IsHostCrash(acquire))
+ {
+ LOG_HR_MSG(acquire, "Failed to acquire plugin host proxy for: '%ls'", Plugin.name.c_str());
+ }
+ return acquire;
+ }
- m_plugins.emplace_back(std::move(plugin));
+ return Notify(host.Get());
+ });
}
void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnVMStarted != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid"));
-
- SlowOperationWatcher slowOperation{"PluginOnVmStarted"};
- ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), e.name.c_str());
+ continue;
+ }
+ WSL_LOG("PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid"));
+
+ wil::unique_cotaskmem_string errorMessage;
+ SlowOperationWatcher slowOperation{"PluginOnVmStarted"};
+ WSL_LOG("PluginOnVmStartedBeginRpc", TraceLoggingValue(e.name.c_str(), "Plugin"));
+ const HRESULT hr = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnVMStarted(
+ Session->SessionId,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ static_cast(Settings->CustomConfigurationFlags),
+ &errorMessage);
+ });
+ WSL_LOG("PluginOnVmStartedEndRpc", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(hr, "Result"));
+
+ if (IsHostCrash(hr))
+ {
+ ThrowHostCrash(e, hr, "OnVmStarted");
}
+
+ ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str());
}
}
-void PluginManager::OnVmStopping(const WSLSessionInformation* Session) const
+void PluginManager::OnVmStopping(const WSLSessionInformation* Session)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
- for (const auto& e : m_plugins)
+ auto sidData = SerializeSid(Session->UserSid);
+
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnVMStopping != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid"));
+ continue;
+ }
+ WSL_LOG("PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid"));
- const auto result = e.hooks.OnVMStopping(Session);
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnVMStopping(Session->SessionId, Session->UserToken, static_cast(sidData.size()), sidData.data());
+ });
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnVmStopping");
+ continue;
}
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnDistributionStarted != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnDistroStartedCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->UserSid, "Sid"),
- TraceLoggingValue(Distribution->Id, "DistributionId"));
-
- SlowOperationWatcher slowOperation{"PluginOnDistributionStarted"};
- ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), e.name.c_str());
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnDistroStartedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->UserSid, "Sid"),
+ TraceLoggingValue(Distribution->Id, "DistributionId"));
+
+ wil::unique_cotaskmem_string errorMessage;
+ SlowOperationWatcher slowOperation{"PluginOnDistributionStarted"};
+ const HRESULT hr = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnDistributionStarted(
+ Session->SessionId,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ &Distribution->Id,
+ Distribution->Name,
+ Distribution->PidNamespace,
+ Distribution->PackageFamilyName,
+ Distribution->InitPid,
+ Distribution->Flavor,
+ Distribution->Version,
+ &errorMessage);
+ });
+
+ if (IsHostCrash(hr))
+ {
+ ThrowHostCrash(e, hr, "OnDistributionStarted");
}
+
+ ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str());
}
}
-void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) const
+void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnDistributionStopping != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnDistroStoppingCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->UserSid, "Sid"),
- TraceLoggingValue(Distribution->Id, "DistributionId"));
-
- const auto result = e.hooks.OnDistributionStopping(Session, Distribution);
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnDistroStoppingCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->UserSid, "Sid"),
+ TraceLoggingValue(Distribution->Id, "DistributionId"));
+
+ const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnDistributionStopping(
+ Session->SessionId,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ &Distribution->Id,
+ Distribution->Name,
+ Distribution->PidNamespace,
+ Distribution->PackageFamilyName,
+ Distribution->InitPid,
+ Distribution->Flavor,
+ Distribution->Version);
+ });
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnDistributionStopping");
+ continue;
}
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const
+void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
- for (const auto& e : m_plugins)
+ auto sidData = SerializeSid(Session->UserSid);
+
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnDistributionRegistered != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnDistributionRegisteredCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->UserSid, "Sid"),
- TraceLoggingValue(Distribution->Id, "DistributionId"));
-
- const auto result = e.hooks.OnDistributionRegistered(Session, Distribution);
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnDistributionRegisteredCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->UserSid, "Sid"),
+ TraceLoggingValue(Distribution->Id, "DistributionId"));
+
+ const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnDistributionRegistered(
+ Session->SessionId,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ &Distribution->Id,
+ Distribution->Name,
+ Distribution->PackageFamilyName,
+ Distribution->Flavor,
+ Distribution->Version);
+ });
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnDistributionRegistered");
+ continue;
}
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const
+void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnDistributionUnregistered != nullptr)
+ if (e.hostCookie == 0)
{
- WSL_LOG(
- "PluginOnDistributionUnregisteredCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->UserSid, "Sid"),
- TraceLoggingValue(Distribution->Id, "DistributionId"));
-
- const auto result = e.hooks.OnDistributionUnregistered(Session, Distribution);
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ continue;
}
+ WSL_LOG(
+ "PluginOnDistributionUnregisteredCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->UserSid, "Sid"),
+ TraceLoggingValue(Distribution->Id, "DistributionId"));
+
+ const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) {
+ return host->OnDistributionUnregistered(
+ Session->SessionId,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ &Distribution->Id,
+ Distribution->Name,
+ Distribution->PackageFamilyName,
+ Distribution->Flavor,
+ Distribution->Version);
+ });
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnDistributionUnregistered");
+ continue;
+ }
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-void PluginManager::ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin)
+void PluginManager::ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId Session, LPCWSTR Plugin)
{
- const auto message = std::move(g_pluginErrorMessage);
- g_pluginErrorMessage.reset(); // std::move() doesn't clear the previous std::optional
-
if (FAILED(Result))
{
- if (message.has_value())
+ if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0')
{
- THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, message->c_str()));
+ THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, ErrorMessage));
}
else
{
THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginError(Plugin));
}
}
- else if (message.has_value())
+ else if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0')
{
THROW_HR_MSG(E_ILLEGAL_STATE_CHANGE, "Plugin '%ls' emitted an error message but returned success", Plugin);
}
}
-void PluginManager::ThrowIfFatalPluginError() const
+bool PluginManager::IsHostCrash(HRESULT hr)
{
- ExecutionContext context(Context::Plugin);
+ // Each of these unambiguously indicates the COM server process has gone
+ // away. RPC_E_CALL_REJECTED is deliberately excluded: it means a busy
+ // server rejected the call, not that the server died — treating it as a
+ // crash would silently disable the plugin for the rest of the session.
+ switch (hr)
+ {
+ case RPC_E_DISCONNECTED:
+ case RPC_E_SERVER_DIED:
+ case RPC_E_SERVER_DIED_DNE:
+ case CO_E_OBJNOTCONNECTED:
+ case HRESULT_FROM_WIN32(RPC_S_SERVER_UNAVAILABLE):
+ case HRESULT_FROM_WIN32(RPC_S_CALL_FAILED):
+ case HRESULT_FROM_WIN32(RPC_S_CALL_FAILED_DNE):
+ return true;
+ default:
+ return false;
+ }
+}
+
+void PluginManager::LatchHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage)
+{
+ LOG_HR_MSG(result, "Plugin host crashed at %hs for: '%ls'", stage, plugin.name.c_str());
+
+ // Fire telemetry only on first observation per plugin: a dead plugin will
+ // hit this path on every subsequent VM/distro lifecycle event, and we
+ // don't want to flood the telemetry channel with duplicates. Any WSLC
+ // processes the dead host created are released automatically by COM when
+ // the host process exits, so there is nothing to drain here.
+ if (!plugin.crashTelemetryFired.exchange(true))
+ {
+ WSL_LOG_TELEMETRY(
+ "PluginHostCrash",
+ PDT_ProductAndServiceUsage,
+ TraceLoggingValue(plugin.name.c_str(), "Name"),
+ TraceLoggingValue(plugin.path.c_str(), "Path"),
+ TraceLoggingValue(result, "Result"),
+ TraceLoggingValue(stage, "Stage"));
+ }
+ // Latch a fatal plugin error so subsequent operations fail fast with a
+ // single consistent message instead of repeatedly driving a dead host that
+ // is never re-activated for this service lifetime. The first crash wins; a
+ // load-time failure already recorded in m_pluginError is left untouched.
+ auto lock = m_pluginErrorLock.lock_exclusive();
if (!m_pluginError.has_value())
+ {
+ m_pluginError.emplace(PluginError{plugin.name, result});
+ }
+}
+
+void PluginManager::ThrowHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage)
+{
+ // Record the crash (telemetry + fatal latch), then throw a fatal plugin
+ // error so the guarded start/veto operation (VM/distro/session/container
+ // creation) is aborted. The HRESULT is whichever RPC/CO_E_* code COM
+ // surfaced for the dead host; it is reported the same way a plugin-returned
+ // fatal error would be.
+ LatchHostCrash(plugin, result, stage);
+ THROW_HR_WITH_USER_ERROR(result, wsl::shared::Localization::MessageFatalPluginError(plugin.name.c_str()));
+}
+
+void PluginManager::ThrowIfFatalPluginError()
+{
+ ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ // m_pluginError can be set at load time (single-threaded, under m_initOnce)
+ // or latched at runtime from a hook on any RPC thread when a host crash is
+ // observed, so read it under the lock and act on a local copy.
+ std::optional error;
+ {
+ auto lock = m_pluginErrorLock.lock_shared();
+ error = m_pluginError;
+ }
+
+ if (!error.has_value())
{
return;
}
- else if (m_pluginError->error == WSL_E_PLUGIN_REQUIRES_UPDATE)
+ else if (error->error == WSL_E_PLUGIN_REQUIRES_UPDATE)
{
- THROW_HR_WITH_USER_ERROR(
- WSL_E_PLUGIN_REQUIRES_UPDATE, wsl::shared::Localization::MessagePluginRequiresUpdate(m_pluginError->plugin));
+ THROW_HR_WITH_USER_ERROR(WSL_E_PLUGIN_REQUIRES_UPDATE, wsl::shared::Localization::MessagePluginRequiresUpdate(error->plugin));
}
else
{
- THROW_HR_WITH_USER_ERROR(m_pluginError->error, wsl::shared::Localization::MessageFatalPluginError(m_pluginError->plugin));
+ THROW_HR_WITH_USER_ERROR(error->error, wsl::shared::Localization::MessageFatalPluginError(error->plugin));
}
}
+void PluginManager::RegisterWslcSession(ULONG SessionId, IWSLCSessionReference* Reference)
+{
+ THROW_HR_IF(E_POINTER, Reference == nullptr);
+
+ std::lock_guard lock(m_wslcSessionRefLock);
+ m_wslcSessionRefs[SessionId] = Reference;
+}
+
+void PluginManager::UnregisterWslcSession(ULONG SessionId)
+{
+ std::lock_guard lock(m_wslcSessionRefLock);
+ m_wslcSessionRefs.erase(SessionId);
+}
+
+wil::com_ptr PluginManager::ResolveWslcSession(ULONG SessionId)
+{
+ wil::com_ptr reference;
+ {
+ std::lock_guard lock(m_wslcSessionRefLock);
+ auto it = m_wslcSessionRefs.find(SessionId);
+ THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_wslcSessionRefs.end());
+ reference = it->second;
+ }
+
+ // OpenSession() is called OUTSIDE m_wslcSessionRefLock: the reference is a
+ // weak handle, so opening it can fail or block and must not hold the lock.
+ wil::com_ptr session;
+ const auto hr = reference->OpenSession(&session);
+ THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), FAILED(hr) || session == nullptr);
+
+ return session;
+}
+
void PluginManager::OnWslcSessionCreated(const WSLCSessionInformation* Session)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnSessionCreated != nullptr)
+ if (e.hostCookie == 0)
{
- auto result = e.hooks.OnSessionCreated(Session);
- WSL_LOG(
- "PluginOnWslcSessionCreatedCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(Session->DisplayName, "DisplayName"),
- TraceLoggingValue(result, "Result"));
-
- ThrowIfPluginError(result, e.name.c_str());
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnWslcSessionCreatedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(Session->DisplayName, "DisplayName"));
+
+ ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnWslcSessionCreated");
+
+ wil::unique_cotaskmem_string errorMessage;
+ HRESULT hr = host->OnWslcSessionCreated(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ &errorMessage);
+
+ if (IsHostCrash(hr))
+ {
+ ThrowHostCrash(e, hr, "OnWslcSessionCreated");
}
+
+ ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str());
}
}
-void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) const
+void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.OnSessionStopping != nullptr)
+ if (e.hostCookie == 0)
{
- const auto result = e.hooks.OnSessionStopping(Session);
- WSL_LOG(
- "PluginOnWslcSessionStoppingCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(result, "Result"));
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnWslcSessionStoppingCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"));
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcSessionStopping");
+
+ const auto result = host->OnWslcSessionStopping(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data());
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnWslcSessionStopping");
+ continue;
}
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const
+HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson)
try
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.ContainerStarted != nullptr)
+ if (e.hostCookie == 0)
{
- // Failure here aborts the container creation. Surface the first error.
- const auto result = e.hooks.ContainerStarted(Session, InspectJson);
- WSL_LOG(
- "PluginOnWslcContainerStartedCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(result, "Result"));
-
- ThrowIfPluginError(result, e.name.c_str());
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnWslcContainerStartedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"));
+
+ ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnWslcContainerStarted");
+
+ // Failure here aborts the container creation. Surface the first error.
+ wil::unique_cotaskmem_string errorMessage;
+ const HRESULT hr = host->OnWslcContainerStarted(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ InspectJson,
+ &errorMessage);
+
+ if (IsHostCrash(hr))
+ {
+ ThrowHostCrash(e, hr, "OnWslcContainerStarted");
}
+
+ ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str());
}
return S_OK;
}
CATCH_RETURN()
-void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const
+void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.ContainerStopping != nullptr)
+ if (e.hostCookie == 0)
{
-
- const auto result = e.hooks.ContainerStopping(Session, ContainerId);
- WSL_LOG(
- "PluginOnWslcContainerStoppingCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(ContainerId, "ContainerId"),
- TraceLoggingValue(result, "Result"));
-
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ continue;
}
+ WSL_LOG(
+ "PluginOnWslcContainerStoppingCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(ContainerId, "ContainerId"));
+
+ ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcContainerStopping");
+
+ const auto result = host->OnWslcContainerStopping(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ ContainerId);
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnWslcContainerStopping");
+ continue;
+ }
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const
+void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
+
+ auto sidData = SerializeSid(Session->UserSid);
- for (const auto& e : m_plugins)
+ for (auto& e : m_plugins)
{
- if (e.hooks.ImageCreated != nullptr)
+ if (e.hostCookie == 0)
{
- const auto result = e.hooks.ImageCreated(Session, InspectJson);
- WSL_LOG(
- "PluginOnWslcImageCreatedCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(result, "Result"));
-
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ continue;
}
+ WSL_LOG(
+ "PluginOnWslcImageCreatedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"));
+
+ ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcImageCreated");
+
+ const auto result = host->OnWslcImageCreated(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ InspectJson);
+
+ if (IsHostCrash(result))
+ {
+ LatchHostCrash(e, result, "OnWslcImageCreated");
+ continue;
+ }
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
}
}
-void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const
+void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId)
{
ExecutionContext context(Context::Plugin);
+ auto coInit = EnsureInitialized();
- for (const auto& e : m_plugins)
+ auto sidData = SerializeSid(Session->UserSid);
+
+ for (auto& e : m_plugins)
{
- if (e.hooks.ImageDeleted != nullptr)
+ if (e.hostCookie == 0)
+ {
+ continue;
+ }
+ WSL_LOG(
+ "PluginOnWslcImageDeletedCall",
+ TraceLoggingValue(e.name.c_str(), "Plugin"),
+ TraceLoggingValue(Session->SessionId, "SessionId"),
+ TraceLoggingValue(ImageId, "ImageId"));
+
+ ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcImageDeleted");
+
+ const auto result = host->OnWslcImageDeleted(
+ Session->SessionId,
+ Session->DisplayName,
+ Session->ApplicationPid,
+ Session->UserToken,
+ static_cast(sidData.size()),
+ sidData.data(),
+ ImageId);
+
+ if (IsHostCrash(result))
{
- const auto result = e.hooks.ImageDeleted(Session, ImageId);
- WSL_LOG(
- "PluginOnWslcImageDeletedCall",
- TraceLoggingValue(e.name.c_str(), "Plugin"),
- TraceLoggingValue(Session->SessionId, "SessionId"),
- TraceLoggingValue(ImageId, "ImageId"),
- TraceLoggingValue(result, "Result"));
- LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ LatchHostCrash(e, result, "OnWslcImageDeleted");
+ continue;
}
+
+ LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str());
+ }
+}
+
+// --- IWslPluginHostCallback WSLC implementations (service-side) ---
+
+STDMETHODIMP PluginHostCallbackImpl::WslcMountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCSTR Mountpoint, _In_ BOOL ReadOnly)
+try
+{
+ // TODO: Once plugins are out of proc, add logic to validate that the mountpoint isn't in use by another plugin.
+ RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Mountpoint == nullptr);
+
+ auto session = m_owner.ResolveWslcSession(SessionId);
+
+ auto result = session->MountWindowsFolder(WindowsPath, Mountpoint, ReadOnly);
+
+ WSL_LOG(
+ "WslcPluginMountFolderCall",
+ TraceLoggingValue(SessionId, "SessionId"),
+ TraceLoggingValue(WindowsPath, "WindowsPath"),
+ TraceLoggingValue(Mountpoint, "Mountpoint"),
+ TraceLoggingValue(ReadOnly, "ReadOnly"),
+ TraceLoggingValue(result, "Result"));
+
+ return result;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHostCallbackImpl::WslcUnmountFolder(_In_ DWORD SessionId, _In_ LPCSTR Mountpoint)
+try
+{
+ RETURN_HR_IF(E_POINTER, Mountpoint == nullptr);
+
+ auto session = m_owner.ResolveWslcSession(SessionId);
+ auto result = session->UnmountWindowsFolder(Mountpoint);
+
+ WSL_LOG(
+ "WslcPluginUnmountFolderCall",
+ TraceLoggingValue(SessionId, "SessionId"),
+ TraceLoggingValue(Mountpoint, "Mountpoint"),
+ TraceLoggingValue(result, "Result"));
+
+ return result;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHostCallbackImpl::WslcCreateProcess(
+ _In_ DWORD SessionId,
+ _In_ LPCSTR Executable,
+ _In_ DWORD ArgumentCount,
+ _In_reads_opt_(ArgumentCount) LPCSTR* Arguments,
+ _In_ DWORD EnvCount,
+ _In_reads_opt_(EnvCount) LPCSTR* Environment,
+ _COM_Outptr_ IWSLCProcess** Process,
+ _Out_ int* Errno)
+try
+{
+ RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr || Errno == nullptr);
+ *Process = nullptr;
+ *Errno = 0;
+ RETURN_HR_IF(E_INVALIDARG, (ArgumentCount > 0 && Arguments == nullptr) || (EnvCount > 0 && Environment == nullptr));
+
+ auto session = m_owner.ResolveWslcSession(SessionId);
+
+ // Build NULL-terminated argument/env arrays expected by CreateRootNamespaceProcess.
+ std::vector argsTerminated;
+ if (Arguments != nullptr)
+ {
+ argsTerminated.assign(Arguments, Arguments + ArgumentCount);
}
+ argsTerminated.push_back(nullptr);
+
+ std::vector envTerminated;
+ if (Environment != nullptr)
+ {
+ envTerminated.assign(Environment, Environment + EnvCount);
+ envTerminated.push_back(nullptr);
+ }
+
+ WSLCProcessOptions options{};
+ options.CommandLine.Values = argsTerminated.data();
+ options.CommandLine.Count = ArgumentCount;
+ if (!envTerminated.empty())
+ {
+ options.Environment.Values = envTerminated.data();
+ options.Environment.Count = EnvCount;
+ }
+ options.Flags = WSLCProcessFlagsStdin;
+
+ wil::com_ptr process;
+ int errnoValue = 0;
+ auto result = session->CreateRootNamespaceProcess(Executable, &options, 0, 0, &process, &errnoValue);
+ *Errno = errnoValue;
+
+ if (FAILED(result))
+ {
+ WSL_LOG(
+ "WslcPluginCreateProcessCall",
+ TraceLoggingValue(SessionId, "SessionId"),
+ TraceLoggingValue(Executable, "Executable"),
+ TraceLoggingValue(result, "Result"),
+ TraceLoggingValue(errnoValue, "Errno"));
+ return result;
+ }
+
+ // Hand the IWSLCProcess back to the host. COM marshals it across the host
+ // boundary; the host then calls it directly (GetStdHandle/GetExitEvent/
+ // GetState) and owns its lifetime, so the service keeps no per-process state.
+ *Process = process.detach();
+
+ WSL_LOG(
+ "WslcPluginCreateProcessCall",
+ TraceLoggingValue(SessionId, "SessionId"),
+ TraceLoggingValue(Executable, "Executable"),
+ TraceLoggingValue(S_OK, "Result"));
+
+ return S_OK;
}
+CATCH_RETURN();
diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h
index a99a3e332d..ecf1744a08 100644
--- a/src/windows/service/exe/PluginManager.h
+++ b/src/windows/service/exe/PluginManager.h
@@ -9,17 +9,93 @@ Module Name:
Abstract:
This file contains the PluginManager class definition.
+ Plugins are loaded out-of-process in wslpluginhost.exe via COM
+ to isolate the service from plugin crashes.
--*/
#pragma once
#include
+#include
+#include
+#include
+#include
#include
+#include
#include
#include "WslPluginApi.h"
+#include "WslPluginHost.h"
+#include "wslc.h"
+#include "PluginCallPump.h"
namespace wsl::windows::service {
+
+class PluginManager;
+
+//
+// IWslPluginHostCallback implementation — lives in the service process and
+// handles API calls coming from the plugin host (MountFolder, ExecuteBinary,
+// WSLC* APIs etc.). WslcCreateProcess returns a marshaled IWSLCProcess directly
+// to the host, which calls it (GetStdHandle/GetExitEvent/GetState) without any
+// service-side process bookkeeping; the remote process is released by COM when
+// the host releases it or the host process exits.
+//
+class PluginHostCallbackImpl
+ : public Microsoft::WRL::RuntimeClass, IWslPluginHostCallback>
+{
+public:
+ explicit PluginHostCallbackImpl(PluginManager& owner) : m_owner(owner)
+ {
+ }
+ ~PluginHostCallbackImpl() override = default;
+
+ PluginHostCallbackImpl(const PluginHostCallbackImpl&) = delete;
+ PluginHostCallbackImpl& operator=(const PluginHostCallbackImpl&) = delete;
+
+ // WSL plugin API (host -> service)
+ STDMETHODIMP MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name) override;
+
+ STDMETHODIMP ExecuteBinary(
+ _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket) override;
+
+ STDMETHODIMP ExecuteBinaryInDistribution(
+ _In_ DWORD SessionId,
+ _In_ const GUID* DistributionId,
+ _In_ LPCSTR Path,
+ _In_ DWORD ArgumentCount,
+ _In_reads_opt_(ArgumentCount) LPCSTR* Arguments,
+ _Out_ HANDLE* Socket) override;
+
+ // WSLC plugin API (host -> service)
+ STDMETHODIMP WslcMountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCSTR Mountpoint, _In_ BOOL ReadOnly) override;
+
+ STDMETHODIMP WslcUnmountFolder(_In_ DWORD SessionId, _In_ LPCSTR Mountpoint) override;
+
+ STDMETHODIMP WslcCreateProcess(
+ _In_ DWORD SessionId,
+ _In_ LPCSTR Executable,
+ _In_ DWORD ArgumentCount,
+ _In_reads_opt_(ArgumentCount) LPCSTR* Arguments,
+ _In_ DWORD EnvCount,
+ _In_reads_opt_(EnvCount) LPCSTR* Environment,
+ _COM_Outptr_ IWSLCProcess** Process,
+ _Out_ int* Errno) override;
+
+private:
+ // The PluginManager that owns this callback. Used to resolve a WSLC
+ // SessionId to a live IWSLCSession via the manager's registered session
+ // reference map (see PluginManager::ResolveWslcSession).
+ PluginManager& m_owner;
+};
+
+///
+/// Manages out-of-process plugin hosts (wslpluginhost.exe) via COM activation.
+/// Each plugin DLL is loaded in a separate process to isolate the service from
+/// plugin crashes. Communication uses IWslPluginHost (service → host) for lifecycle
+/// notifications and IWslPluginHostCallback (host → service) for plugin API calls.
+/// A job object ensures all hosts are terminated if the service exits unexpectedly.
+///
class PluginManager
{
public:
@@ -30,6 +106,7 @@ class PluginManager
};
PluginManager() = default;
+ ~PluginManager();
PluginManager(const PluginManager&) = delete;
PluginManager& operator=(const PluginManager&) = delete;
@@ -37,37 +114,206 @@ class PluginManager
PluginManager& operator=(PluginManager&&) = delete;
void LoadPlugins();
+
+ // Releases all out-of-process plugin host COM proxies and tears down the
+ // job object that contains the host processes. MUST be called from the
+ // service shutdown path WHILE COM IS STILL INITIALIZED on this thread
+ // (i.e. before CoUninitialize). Releasing IWslPluginHost proxies after
+ // COM has been uninitialized — for example, when this PluginManager is
+ // destroyed as a global at process exit — crashes inside the marshaler
+ // because the proxy/stub DLL has been unloaded.
+ void Shutdown();
+
void OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings);
- void OnVmStopping(const WSLSessionInformation* Session) const;
+ void OnVmStopping(const WSLSessionInformation* Session);
void OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* distro);
- void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro) const;
- void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const;
- void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const;
+ void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro);
+ void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro);
+ void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro);
- // WSLC notifications. Returning failure from OnSessionCreated/OnContainerStarted causes the
- // corresponding operation to be aborted. Other notifications log errors and continue.
+ // WSLC notifications. Returning failure from OnWslcSessionCreated / OnWslcContainerStarted causes
+ // the corresponding operation to be aborted. Other notifications log errors and continue.
void OnWslcSessionCreated(const WSLCSessionInformation* Session);
- void OnWslcSessionStopping(const WSLCSessionInformation* Session) const;
- HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const;
- void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const;
- void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const;
- void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const;
+ void OnWslcSessionStopping(const WSLCSessionInformation* Session);
+ HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson);
+ void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId);
+ void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson);
+ void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId);
- void ThrowIfFatalPluginError() const;
+ void ThrowIfFatalPluginError();
-private:
- void LoadPlugin(LPCWSTR Name, LPCWSTR Path);
- static void ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin);
+ // WSLC session reference registry. The WSLCSessionManager registers a weak
+ // IWSLCSessionReference for each live session here (keyed by SessionId)
+ // before notifying plugins of the session, and unregisters it when the
+ // session stops or creation is rolled back. Plugin API callbacks that take
+ // a SessionId (WslcMountFolder/WslcCreateProcess/etc.) resolve the live
+ // session through this registry instead of reaching back into the
+ // WSLCSessionManager — this breaks the lock-reentrancy cycle that would
+ // otherwise occur when an out-of-process plugin calls back on a different
+ // RPC thread while the manager holds its session lock.
+ void RegisterWslcSession(ULONG SessionId, IWSLCSessionReference* Reference);
+ void UnregisterWslcSession(ULONG SessionId);
+
+ // Resolves a SessionId to a live IWSLCSession. The reference is copied out
+ // under m_wslcSessionRefLock and OpenSession() is then called OUTSIDE the
+ // lock (the reference is a weak handle, so OpenSession can fail/block).
+ // Throws HRESULT_FROM_WIN32(ERROR_NOT_FOUND) if the session is unknown or
+ // can no longer be opened.
+ wil::com_ptr ResolveWslcSession(ULONG SessionId);
+
+ // Routes a WSL-session plugin API callback (MountFolder / ExecuteBinary /
+ // ExecuteBinaryInDistribution) so it runs with in-process semantics. If a
+ // notification hook for SessionId is in flight, the work is marshaled onto
+ // the notifying thread (which holds the session's recursive m_instanceLock)
+ // via that hook's PluginCallPump — so out-of-process callbacks re-enter the
+ // lock exactly as in-process plugins did, and no second (m_callbackLock)
+ // lock is needed. Otherwise the work runs directly on the calling RPC thread
+ // (acquiring m_instanceLock itself via a timed try-acquire that re-checks for
+ // a pump on contention, so it can never deadlock against a notification
+ // thread that holds m_instanceLock and later waits for this callback). This
+ // supports plugin API calls made from a plugin's own worker threads outside
+ // any hook.
+ HRESULT InvokeOnWslPump(ULONG SessionId, std::function Work);
- struct LoadedPlugin
+private:
+ struct OutOfProcPlugin
{
- wil::unique_hmodule module;
+ // GIT cookie for the IWslPluginHost proxy. Zero means "not loaded".
+ // Resolved per-call via AcquireHostProxy() because the raw proxy
+ // returned by CoCreateInstance is apartment-bound to the LoadPlugin
+ // thread, but hook dispatch can arrive on threads in any apartment
+ // (MTA, NTA-on-MTA via WinRT-style RPC dispatch, etc.). Storing in
+ // the Global Interface Table and re-unmarshaling per call yields an
+ // apartment-local proxy that COM can dispatch from any apartment.
+ DWORD hostCookie{0};
+ Microsoft::WRL::ComPtr callback;
std::wstring name;
- WSLPluginHooksV1 hooks{};
+ std::wstring path;
+
+ // Set the first time a host crash is observed for this plugin during
+ // runtime (i.e. after successful load). Subsequent crash sites only
+ // log to ETL and skip the telemetry event, avoiding flood from a
+ // dead plugin that we'd notify on every distro/VM lifecycle event.
+ std::atomic crashTelemetryFired{false};
+
+ OutOfProcPlugin() = default;
+ OutOfProcPlugin(const OutOfProcPlugin&) = delete;
+ OutOfProcPlugin& operator=(const OutOfProcPlugin&) = delete;
+ OutOfProcPlugin(OutOfProcPlugin&& other) noexcept :
+ hostCookie(std::exchange(other.hostCookie, 0)),
+ callback(std::move(other.callback)),
+ name(std::move(other.name)),
+ path(std::move(other.path)),
+ crashTelemetryFired(other.crashTelemetryFired.load())
+ {
+ }
+ OutOfProcPlugin& operator=(OutOfProcPlugin&&) = delete;
+ };
+
+ // RAII helper that joins the calling thread to the MTA for the duration of the
+ // scope. Plugin host activation and per-call dispatch through plugin host proxies
+ // are both cross-process COM calls that require the calling thread to be COM-init'd,
+ // and hook dispatch can run on threadpool threads (e.g. _VmIdleTerminate path) that
+ // haven't called CoInitializeEx. RPC_E_CHANGED_MODE means the thread is already
+ // initialized to a different apartment (typically STA), which is still fine: the
+ // existing apartment is preserved and COM dispatches proxy calls accordingly.
+ struct ScopedComInit
+ {
+ HRESULT initHr{RPC_E_CHANGED_MODE};
+
+ ScopedComInit();
+ ~ScopedComInit();
+ ScopedComInit(const ScopedComInit&) = delete;
+ ScopedComInit& operator=(const ScopedComInit&) = delete;
+ ScopedComInit(ScopedComInit&& other) noexcept;
+ ScopedComInit& operator=(ScopedComInit&&) = delete;
+
+ HRESULT Result() const noexcept;
};
- std::vector m_plugins;
+ void LoadPlugin(OutOfProcPlugin& plugin);
+ [[nodiscard]] ScopedComInit EnsureInitialized();
+ void EnsureJobObjectCreated();
+
+ // Returns an apartment-local IWslPluginHost proxy for the given plugin by
+ // re-unmarshaling from the Global Interface Table. The returned proxy is
+ // only valid on the current thread / apartment and must not be cached.
+ // On failure (host process crashed, host released, etc.) returns the
+ // HRESULT from GIT; caller should treat host-crash HRESULTs via
+ // IsHostCrash() exactly as if the call itself had failed.
+ HRESULT AcquireHostProxy(const OutOfProcPlugin& plugin, _COM_Outptr_ IWslPluginHost** host);
+
+ static void ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId session, LPCWSTR Plugin);
+ static std::vector SerializeSid(PSID Sid);
+ static bool IsHostCrash(HRESULT hr);
+
+ // Records an observed host-process crash: logs it to ETL, fires the
+ // PluginHostCrash telemetry event at most once per plugin per service
+ // lifetime (so a single bad plugin does not flood telemetry across every
+ // subsequent VM/distro lifecycle), and latches a fatal plugin error so
+ // later operations fail fast. The host is not re-activated for the rest of
+ // the service lifetime.
+ void LatchHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage);
+
+ // LatchHostCrash, then throws the latched error as a fatal plugin error to
+ // abort a start/veto operation. Use from start hooks (OnVmStarted, etc.);
+ // teardown hooks latch but cannot block, so they call LatchHostCrash + skip.
+ [[noreturn]] void ThrowHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage);
+
+ // Registers/unregisters the active PluginCallPump for a WSL session while a
+ // notification (OnVmStarted, etc.) is in flight. Keyed by the session cookie
+ // that is handed to the plugin host and echoed back on callbacks. Plugin
+ // notifications for a given session are serialized by m_instanceLock, so at
+ // most one pump is registered per SessionId at a time.
+ void RegisterWslPump(ULONG SessionId, const std::shared_ptr& Pump);
+ void UnregisterWslPump(ULONG SessionId);
+
+ // Drives one outbound plugin-host notification (OnVMStarted, etc.) through a
+ // PluginCallPump registered under SessionId. The pump runs `Notify` on a
+ // worker thread (which acquires the apartment-local host proxy and performs
+ // its own COM init) while THIS thread pumps the plugin's API callbacks — so
+ // they execute back here, under the session's recursive m_instanceLock.
+ // Returns the HRESULT from `Notify`, or the proxy-acquire failure (the
+ // caller routes it through IsHostCrash exactly as before).
+ HRESULT RunHostNotification(OutOfProcPlugin& Plugin, ULONG SessionId, const std::function& Notify);
+
+ std::once_flag m_initOnce;
+ std::vector m_plugins;
+
+ // Guards m_pluginError. The error is written once at load time under
+ // m_initOnce and may later be latched from a hook on any RPC thread when a
+ // host crash is observed, so all accesses outside the initial load take it.
+ wil::srwlock m_pluginErrorLock;
std::optional m_pluginError;
+
+ // Global Interface Table used to make IWslPluginHost proxies callable from
+ // any apartment. Created lazily inside EnsureInitialized() while COM is
+ // guaranteed to be initialized on the calling thread.
+ std::once_flag m_gitOnce;
+ Microsoft::WRL::ComPtr m_git;
+
+ // Job object that automatically terminates all plugin host processes
+ // when wslservice exits or crashes (JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE).
+ std::once_flag m_jobObjectOnce;
+ wil::unique_handle m_jobObject;
+
+ // Maps a WSLC SessionId to a weak reference to the live session. Populated
+ // by RegisterWslcSession / drained by UnregisterWslcSession and Shutdown.
+ // The references are COM proxies and MUST be released while COM is still
+ // initialized (see Shutdown); the destructor detaches (leaks) any survivors
+ // rather than releasing them after CoUninitialize.
+ std::mutex m_wslcSessionRefLock;
+ std::unordered_map> m_wslcSessionRefs;
+
+ // Active WSL-session notification pumps, keyed by session cookie. Populated
+ // for the duration of an out-of-process notification call so that plugin API
+ // callbacks (which arrive on a different RPC thread) can be marshaled back
+ // onto the notifying thread. Held by shared_ptr so InvokeOnWslPump can copy
+ // out a stable reference under the lock and then release the lock before the
+ // (blocking) Invoke — the lock is never held across a callback. See
+ // InvokeOnWslPump.
+ wil::srwlock m_wslPumpLock;
+ _Guarded_by_(m_wslPumpLock) std::unordered_map> m_wslPumps;
};
} // namespace wsl::windows::service
diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp
index 40852b5511..f18c3321dc 100644
--- a/src/windows/service/exe/ServiceMain.cpp
+++ b/src/windows/service/exe/ServiceMain.cpp
@@ -258,6 +258,12 @@ void WslService::ServiceStopped()
LxssClientUninitialize();
}
+ // Release plugin host COM proxies BEFORE CoUninitialize. The IWslPluginHost
+ // proxies must be torn down while COM is still initialized; otherwise the
+ // proxy/stub DLL is unloaded and releasing the proxies later (during global
+ // destruction) crashes inside the marshaler.
+ g_pluginManager.Shutdown();
+
// There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread
// isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue.
winrt::clear_factory_cache();
diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp
index 3966245df5..e1a8531036 100644
--- a/src/windows/service/exe/WSLCSessionManager.cpp
+++ b/src/windows/service/exe/WSLCSessionManager.cpp
@@ -134,26 +134,41 @@ WSLCSessionManagerImpl::~WSLCSessionManagerImpl()
g_managerInstance.store(nullptr);
// Terminate all sessions on shutdown.
- // Call Terminate() directly rather than going through ForEachSession(),
- // which would needlessly resolve weak references and call GetState().
- // Terminate() already handles the "session is gone" case gracefully.
- std::lock_guard lock(m_wslcSessionsLock);
- for (auto& entry : m_sessions)
+ // Plugin notifications are dispatched out-of-process to wslpluginhost.exe.
+ // Move the sessions out under the lock; notify and terminate after release
+ // so we never invoke a plugin callback while m_wslcSessionsLock is held.
+ std::vector snapshot;
{
- NotifySessionStoppingLockHeld(entry);
+ std::lock_guard lock(m_wslcSessionsLock);
+ snapshot = std::move(m_sessions);
+ m_sessions.clear();
+ m_persistentSessions.clear();
+ }
+
+ for (auto& entry : snapshot)
+ {
+ DispatchSessionStopping(entry);
LOG_IF_FAILED(entry.Ref->Terminate());
}
}
-void WSLCSessionManagerImpl::NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept
+void WSLCSessionManagerImpl::DispatchSessionStopping(SessionEntry& entry) noexcept
try
{
- if (entry.StoppingNotified)
+ // Fire OnWslcSessionStopping exactly once. std::exchange guarantees only
+ // the first caller dispatches even if the entry is observed from multiple
+ // teardown paths (ForEachSession cleanup, destructor, etc.).
+ if (std::exchange(entry.StoppingNotified, true))
{
return;
}
- entry.StoppingNotified = true;
+ // Drop the session from the plugin session reference map once stopping
+ // completes, even if the notification throws. Unregistered AFTER the
+ // notification so a plugin's OnWslcSessionStopping handler can still
+ // resolve the session (e.g. to unmount folders).
+ auto unregister = wil::scope_exit([&] { g_pluginManager.UnregisterWslcSession(entry.SessionId); });
+
WSLCSessionInformation info{};
info.SessionId = static_cast(entry.SessionId);
info.DisplayName = entry.DisplayName.c_str();
@@ -168,10 +183,15 @@ void WSLCSessionManagerImpl::CreateSession(
_In_ const WSLCSessionSettings* Settings, _In_ WSLCSessionFlags Flags, _In_opt_ IWarningCallback* WarningCallback, _Out_ IWSLCSession** WslcSession)
{
THROW_HR_IF_NULL(E_POINTER, WslcSession);
+ *WslcSession = nullptr;
auto tokenInfo = GetCallingProcessTokenInfo();
const auto callerToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
+ // Snapshot before tokenInfo is moved into the SessionEntry below; read by
+ // the telemetry event at the end of this function.
+ const bool elevated = tokenInfo.Elevated;
+
// Resolve display name upfront (for both default and custom sessions).
std::wstring resolvedDisplayName;
if (Settings == nullptr)
@@ -200,126 +220,253 @@ void WSLCSessionManagerImpl::CreateSession(
resolvedDisplayName = Settings->DisplayName;
}
- std::lock_guard lock(m_wslcSessionsLock);
+ std::vector deadSessions;
+ std::optional existingResult;
- // Check for an existing session first.
- auto result = ForEachSession([&](auto& entry, const wil::com_ptr& session) noexcept -> std::optional {
+ wslutil::StopWatch stopWatch;
+ std::wstring callerFileName;
+ HRESULT creationResult = S_OK;
+ const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent);
+
+ // State for a created-but-unconfirmed session, carried across the unlocked
+ // OnWslcSessionCreated notification. The session is intentionally NOT added
+ // to m_sessions until the notification succeeds, so another caller can never
+ // observe (or be handed) a session that a plugin might still veto.
+ wil::com_ptr createdRef;
+ wil::com_ptr createdSession;
+ Microsoft::WRL::ComPtr createdNotifier;
+ wil::shared_handle createdToken;
+ std::vector createdSid;
+ ULONG createdSessionId = 0;
+ DWORD createdPid = 0;
+ wil::unique_handle createdJob;
+ bool created = false;
+
+ // Set once the created session's outcome (publish / veto / race-loss) has
+ // been handled. The guard only runs if an unexpected throw escapes before
+ // then, dropping the ref-map registration and terminating the orphan.
+ bool outcomeHandled = false;
+ auto createdCleanup = wil::scope_exit([&] {
+ if (created && !outcomeHandled)
+ {
+ g_pluginManager.UnregisterWslcSession(createdSessionId);
+ if (createdRef)
+ {
+ LOG_IF_FAILED(createdRef->Terminate());
+ }
+ }
+ });
+
+ // Matches an existing session by name and, if the caller has access, copies
+ // it out. Used both for the initial existence check and, after the unlocked
+ // notification, to detect a concurrent create of the same name.
+ auto openExisting = [&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional {
if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), resolvedDisplayName.c_str()))
{
return {};
}
RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), WI_IsFlagClear(Flags, WSLCSessionFlagsOpenExisting));
-
RETURN_IF_FAILED(CheckTokenAccess(entry, tokenInfo));
-
RETURN_IF_FAILED(wil::com_copy_to_nothrow(session, WslcSession));
-
return S_OK;
- });
+ };
- if (result.has_value())
{
- THROW_IF_FAILED(result.value());
- return; // Existing session was opened.
- }
-
- wslutil::StopWatch stopWatch;
+ std::lock_guard lock(m_wslcSessionsLock);
- // Initialize settings for the default session.
- std::unique_ptr defaultSettings;
- if (Settings == nullptr)
- {
- defaultSettings = SessionSettings::Default(callerToken.get(), resolvedDisplayName);
- Settings = &defaultSettings->Settings;
- }
+ existingResult = ForEachSessionLockHeld(openExisting, deadSessions);
- std::wstring callerFileName;
-
- HRESULT creationResult = wil::ResultFromException([&]() {
- // Get caller info.
- const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION);
- const ULONG sessionId = m_nextSessionId++;
- const DWORD creatorPid = GetProcessId(callerProcess.get());
-
- // Query the full image path of the calling process and extract just the file name.
- std::wstring callerFilePath;
- if (SUCCEEDED_LOG(wil::QueryFullProcessImageNameW(callerProcess.get(), 0, callerFilePath)))
+ if (!existingResult.has_value())
{
- callerFileName = std::filesystem::path(callerFilePath).filename().wstring();
+ // Initialize settings for the default session.
+ std::unique_ptr defaultSettings;
+ if (Settings == nullptr)
+ {
+ defaultSettings = SessionSettings::Default(callerToken.get(), resolvedDisplayName);
+ Settings = &defaultSettings->Settings;
+ }
+
+ creationResult = wil::ResultFromException([&]() {
+ // Get caller info.
+ const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION);
+ const ULONG sessionId = m_nextSessionId++;
+ const DWORD creatorPid = GetProcessId(callerProcess.get());
+
+ // Query the full image path of the calling process and extract just the file name.
+ std::wstring callerFilePath;
+ if (SUCCEEDED_LOG(wil::QueryFullProcessImageNameW(callerProcess.get(), 0, callerFilePath)))
+ {
+ callerFileName = std::filesystem::path(callerFilePath).filename().wstring();
+ }
+
+ const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
+
+ // Capture a duplicated user token + raw SID so PluginManager can build
+ // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating.
+ // The token is shared between the SessionEntry and the WSLCPluginNotifier.
+ wil::unique_handle dupToken;
+ THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx(
+ userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken));
+ wil::shared_handle sharedToken{dupToken.release()};
+
+ const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid);
+ std::vector storedSid(sidLen);
+ THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid));
+
+ // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry.
+ Microsoft::WRL::ComPtr notifier;
+ notifier = wil::MakeOrThrow(
+ g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid));
+
+ // Create the VM in the SYSTEM service (privileged).
+ auto vm = Microsoft::WRL::Make(Settings);
+
+ // Launch per-user COM server factory and add it to a fresh per-session job object for crash cleanup.
+ auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get());
+ wil::unique_handle sessionJob = CreateSessionProcessJob(factory.get());
+
+ auto sessionSettings = CreateSessionSettings(sessionId, callerFileName.c_str(), Settings, resolvedDisplayName.c_str());
+ wil::com_ptr session;
+ wil::com_ptr serviceRef;
+ THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), WarningCallback, &session, &serviceRef));
+
+ // Register the session reference so plugin API callbacks (WslcMountFolder etc.)
+ // can resolve it WITHOUT taking m_wslcSessionsLock during the OnWslcSessionCreated
+ // notification, even though the plugin may call back on a different RPC thread. The
+ // session is not yet in m_sessions, so it stays invisible to other callers until
+ // creation is confirmed below.
+ g_pluginManager.RegisterWslcSession(sessionId, serviceRef.get());
+
+ // Carry the created (but unconfirmed) session out to the notification and commit.
+ createdRef = serviceRef;
+ createdSession = session;
+ createdNotifier = notifier;
+ createdToken = sharedToken;
+ createdSid = std::move(storedSid);
+ createdSessionId = sessionId;
+ createdPid = creatorPid;
+ createdJob = std::move(sessionJob);
+ created = true;
+ });
}
+ }
- const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
-
- // Capture a duplicated user token + raw SID so PluginManager can build
- // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating.
- // The token is shared between the SessionEntry and the WSLCPluginNotifier.
- wil::unique_handle dupToken;
- THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx(
- userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken));
- wil::shared_handle sharedToken{dupToken.release()};
-
- const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid);
- std::vector storedSid(sidLen);
- THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid));
-
- // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry.
- Microsoft::WRL::ComPtr notifier;
- notifier = wil::MakeOrThrow(
- g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid));
-
- // Create the VM in the SYSTEM service (privileged).
- auto vm = Microsoft::WRL::Make(Settings);
-
- // Launch per-user COM server factory and add it to a fresh per-session job object for crash cleanup.
- auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get());
- wil::unique_handle sessionJob = CreateSessionProcessJob(factory.get());
-
- const auto sessionSettings = CreateSessionSettings(sessionId, callerFileName.c_str(), Settings, resolvedDisplayName.c_str());
- wil::com_ptr session;
- wil::com_ptr serviceRef;
- THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), WarningCallback, &session, &serviceRef));
-
- // Track the session via its service ref, along with metadata and security info.
- m_sessions.push_back(SessionEntry{
- std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo), notifier, false, sharedToken, std::move(storedSid), std::move(sessionJob)});
+ // Dispatch OnWslcSessionStopping for any dead sessions found during the
+ // existence check, now that the lock is released.
+ for (auto& entry : deadSessions)
+ {
+ DispatchSessionStopping(entry);
+ }
+ deadSessions.clear();
- // For persistent sessions, also hold a strong reference to keep them alive.
- const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent);
- if (persistent)
- {
- m_persistentSessions.emplace_back(sessionId, session);
- }
+ if (existingResult.has_value())
+ {
+ THROW_IF_FAILED(existingResult.value());
+ return; // Existing session was opened.
+ }
- // Notify plugins that the session was created. A failure here aborts session creation.
+ if (created)
+ {
+ // Notify plugins with the lock RELEASED: a plugin runs arbitrary code and
+ // may call back into the service (including the externally-activatable
+ // session manager) on another thread, so holding m_wslcSessionsLock across
+ // the notification could deadlock. Plugin API callbacks resolve this
+ // session through the ref-map registered above instead.
+ WSLCSessionInformation info{};
+ info.SessionId = static_cast(createdSessionId);
+ info.DisplayName = resolvedDisplayName.c_str();
+ info.ApplicationPid = createdPid;
+ info.UserToken = createdToken.get();
+ info.UserSid = createdSid.data();
+
+ HRESULT pluginHr = S_OK;
try
{
- auto& entry = m_sessions.back();
- WSLCSessionInformation info{};
- info.SessionId = static_cast(entry.SessionId);
- info.DisplayName = entry.DisplayName.c_str();
- info.ApplicationPid = entry.CreatorPid;
- info.UserToken = entry.UserToken.get();
- info.UserSid = entry.UserSid.data();
g_pluginManager.OnWslcSessionCreated(&info);
}
catch (...)
{
- const auto error = wil::ResultFromCaughtException();
+ pluginHr = wil::ResultFromCaughtException();
+ }
- // Plugin rejected the session: tear it down before propagating.
- m_sessions.back().StoppingNotified = true; // Don't fire stopping for a session that never started successfully.
- LOG_IF_FAILED(m_sessions.back().Ref->Terminate());
- m_sessions.pop_back();
+ enum class Outcome
+ {
+ Commit,
+ Veto,
+ RaceLost
+ } outcome = Outcome::Commit;
- auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == sessionId; });
- m_persistentSessions.erase(remove.begin(), remove.end());
+ {
+ std::lock_guard lock(m_wslcSessionsLock);
+
+ if (FAILED(pluginHr))
+ {
+ outcome = Outcome::Veto;
+ creationResult = pluginHr;
+ }
+ else if (auto raceWinner = ForEachSessionLockHeld(openExisting, deadSessions); raceWinner.has_value())
+ {
+ outcome = Outcome::RaceLost;
+ creationResult = raceWinner.value();
+ }
+ else
+ {
+ // Creation confirmed: publish the session and hand it to the
+ // caller. Reserve first so the inserts can't throw and leave a
+ // half-published, externally-visible session behind.
+ m_sessions.reserve(m_sessions.size() + 1);
+ if (persistent)
+ {
+ m_persistentSessions.reserve(m_persistentSessions.size() + 1);
+ }
+
+ m_sessions.push_back(SessionEntry{
+ std::move(createdRef),
+ createdSessionId,
+ createdPid,
+ resolvedDisplayName,
+ std::move(tokenInfo),
+ createdNotifier,
+ false,
+ createdToken,
+ std::move(createdSid),
+ std::move(createdJob)});
+
+ if (persistent)
+ {
+ m_persistentSessions.emplace_back(createdSessionId, createdSession);
+ }
+
+ *WslcSession = createdSession.detach();
+ outcomeHandled = true;
+ }
+ }
- THROW_HR(error);
+ // Tear down vetoed / race-lost sessions with the lock released: both
+ // Terminate() and the stopping notification are out-of-process calls
+ // that must not run under m_wslcSessionsLock.
+ if (outcome == Outcome::Veto || outcome == Outcome::RaceLost)
+ {
+ // A plugin vetoed the session in OnWslcSessionCreated, or a concurrent
+ // CreateSession won the name. Either way, fire OnWslcSessionStopping so
+ // every plugin that already observed OnWslcSessionCreated sees a matching
+ // stopping (the established veto convention, mirroring the VM path's
+ // OnVmStarted -> _VmTerminate -> OnVmStopping), then drop and terminate.
+ SessionEntry teardown{
+ createdRef, createdSessionId, createdPid, resolvedDisplayName, std::move(tokenInfo), createdNotifier, false, createdToken, std::move(createdSid)};
+ DispatchSessionStopping(teardown);
+ LOG_IF_FAILED(createdRef->Terminate());
+ outcomeHandled = true;
}
+ }
- *WslcSession = session.detach();
- });
+ // Dispatch OnWslcSessionStopping for any dead sessions found during the
+ // post-notification re-check, now that the lock is released.
+ for (auto& entry : deadSessions)
+ {
+ DispatchSessionStopping(entry);
+ }
// This telemetry event is used to keep track of session creation performance (via CreationTimeMs) and failure reasons (via Result).
WSL_LOG(
@@ -330,7 +477,7 @@ void WSLCSessionManagerImpl::CreateSession(
TraceLoggingValue(WSL_PACKAGE_VERSION, "wslVersion"),
TraceLoggingValue(stopWatch.ElapsedMilliseconds(), "CreationTimeMs"),
TraceLoggingValue(creationResult, "Result"),
- TraceLoggingValue(tokenInfo.Elevated, "Elevated"),
+ TraceLoggingValue(elevated, "Elevated"),
TraceLoggingValue(static_cast(Flags), "Flags"),
TraceLoggingValue(callerFileName.c_str(), "CallerFileName"),
TraceLoggingLevel(WINEVENT_LEVEL_INFO));
@@ -605,22 +752,4 @@ WSLCSessionManagerImpl* WSLCSessionManagerImpl::Instance() noexcept
return g_managerInstance.load();
}
-wil::com_ptr WSLCSessionManagerImpl::FindSession(ULONG Id)
-{
- wil::com_ptr result;
-
- ForEachSession([&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional {
- if (entry.SessionId != Id)
- {
- return std::nullopt;
- }
-
- result = session;
- return S_OK;
- });
-
- THROW_HR_IF_MSG(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), !result, "WSLC session %lu not found", Id);
- return result;
-}
-
} // namespace wsl::windows::service::wslc
diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h
index d48f52e430..ab8ae387cf 100644
--- a/src/windows/service/exe/WSLCSessionManager.h
+++ b/src/windows/service/exe/WSLCSessionManager.h
@@ -91,9 +91,6 @@ class WSLCSessionManagerImpl
void OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session);
void OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session);
- // Resolves a session by ID for plugin->API calls. Throws ERROR_NOT_FOUND if no session matches.
- wil::com_ptr FindSession(ULONG Id);
-
static WSLCSessionManagerImpl* Instance() noexcept;
private:
@@ -104,13 +101,16 @@ class WSLCSessionManagerImpl
// Returns true if the name matches a reserved default session prefix.
static bool IsReservedSessionName(LPCWSTR Name);
- // Iterates over all sessions, cleaning up released sessions.
- // The routine receives a SessionEntry& and can return an optional to stop iteration.
+ // Iterates over all sessions, cleaning up released sessions. The CALLER
+ // MUST hold m_wslcSessionsLock. Sessions whose backing process is gone are
+ // moved out of tracking into `DeadSessions` so the caller can dispatch
+ // their OnWslcSessionStopping plugin notification (via DispatchSessionStopping)
+ // AFTER releasing the lock — the notification is an out-of-process call and
+ // must not run under the lock. The routine receives a SessionEntry& and can
+ // return an optional to stop iteration.
template
- inline auto ForEachSession(const auto& Routine)
+ inline auto ForEachSessionLockHeld(const auto& Routine, std::vector& DeadSessions)
{
- std::lock_guard lock(m_wslcSessionsLock);
-
// Enforce noexcept: remove_if leaves the container in an unspecified
// (partially-moved) state if the predicate throws. Callers must handle
// errors via return values, not exceptions.
@@ -118,7 +118,7 @@ class WSLCSessionManagerImpl
std::is_nothrow_invocable_v&>,
"ForEachSession routine must be noexcept to preserve container invariants during remove_if");
- using TResult = std::conditional_t, nullptr_t, std::optional>;
+ using TResult = std::conditional_t, std::nullptr_t, std::optional>;
TResult result{};
auto each = [&](SessionEntry& entry) {
@@ -128,8 +128,28 @@ class WSLCSessionManagerImpl
wil::com_ptr lockedSession;
if (FAILED_LOG(entry.Ref->OpenSession(&lockedSession)))
{
- // Session is gone: notify plugins (if not already), then drop persistent reference if any.
- NotifySessionStoppingLockHeld(entry);
+ // Session is gone: move it out for deferred OnWslcSessionStopping
+ // dispatch (must happen outside the lock) and drop any persistent
+ // reference. The StoppingNotified flag is intentionally left
+ // untouched here; DispatchSessionStopping flips it so the
+ // notification fires exactly once.
+ // Reserve the slot up front so push_back cannot reallocate (and
+ // therefore cannot throw) after `entry` is moved-from. Combined
+ // with SessionEntry's noexcept move, this guarantees `entry` is
+ // only consumed once the destination slot is known to exist.
+ DeadSessions.reserve(DeadSessions.size() + 1);
+ try
+ {
+ DeadSessions.push_back(std::move(entry));
+ }
+ catch (...)
+ {
+ // Defensive: if queuing for the deferred stopping dispatch
+ // still fails, keep the session tracked and reap it on a
+ // later pass rather than dropping it without unregistering.
+ LOG_CAUGHT_EXCEPTION();
+ return false;
+ }
auto remove =
std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == entry.SessionId; });
@@ -166,15 +186,59 @@ class WSLCSessionManagerImpl
}
[[nodiscard]] wil::unique_handle CreateSessionProcessJob(_In_ IWSLCSessionFactory* Factory);
+
+ // Convenience wrapper that takes m_wslcSessionsLock internally and dispatches
+ // OnWslcSessionStopping for any dead sessions after releasing the lock. Use
+ // this from call sites that do NOT already hold the lock.
+ template
+ inline auto ForEachSession(const auto& Routine)
+ {
+ std::vector deadSessions;
+
+ using TResult = std::conditional_t, std::nullptr_t, std::optional>;
+ TResult result{};
+
+ {
+ std::lock_guard lock(m_wslcSessionsLock);
+ if constexpr (std::is_same_v)
+ {
+ ForEachSessionLockHeld(Routine, deadSessions);
+ }
+ else
+ {
+ result = ForEachSessionLockHeld(Routine, deadSessions);
+ }
+ }
+
+ // Safe to invoke plugin callbacks now that the lock is released.
+ for (auto& entry : deadSessions)
+ {
+ DispatchSessionStopping(entry);
+ }
+
+ if constexpr (std::is_same_v)
+ {
+ return;
+ }
+ else
+ {
+ return result;
+ }
+ }
+
WSLCSessionInitSettings CreateSessionSettings(
_In_ ULONG SessionId, _In_ LPCWSTR CreatorProcessName, _In_ const WSLCSessionSettings* Settings, _In_ LPCWSTR ResolvedDisplayName);
static CallingProcessTokenInfo GetCallingProcessTokenInfo();
static HRESULT CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo);
- void NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept;
+ // Fires OnWslcSessionStopping for `entry` exactly once (guarded by
+ // entry.StoppingNotified) and then unregisters the session from the plugin
+ // session reference map. MUST be called WITHOUT m_wslcSessionsLock held —
+ // the plugin notification is dispatched out-of-process to wslpluginhost.exe.
+ void DispatchSessionStopping(SessionEntry& entry) noexcept;
std::atomic m_nextSessionId{1};
- std::recursive_mutex m_wslcSessionsLock;
+ std::mutex m_wslcSessionsLock;
// All sessions tracked via SessionEntry (which holds weak refs and service-side security info).
// Sessions are automatically cleaned up when the underlying session is released.
diff --git a/src/windows/service/inc/CMakeLists.txt b/src/windows/service/inc/CMakeLists.txt
index 59888c441f..e8cbc1dc24 100644
--- a/src/windows/service/inc/CMakeLists.txt
+++ b/src/windows/service/inc/CMakeLists.txt
@@ -1,3 +1,5 @@
add_idl(wslserviceidl "wslservice.idl;wslc.idl" "windowsdefs.idl")
+add_idl(wslpluginhostidl "WslPluginHost.idl" "")
set_target_properties(wslserviceidl PROPERTIES FOLDER windows)
+set_target_properties(wslpluginhostidl PROPERTIES FOLDER windows)
diff --git a/src/windows/service/inc/WslPluginHost.idl b/src/windows/service/inc/WslPluginHost.idl
new file mode 100644
index 0000000000..81abaf5df1
--- /dev/null
+++ b/src/windows/service/inc/WslPluginHost.idl
@@ -0,0 +1,245 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ WslPluginHost.idl
+
+Abstract:
+
+ This file contains the COM interface definitions for out-of-process
+ plugin hosting. IWslPluginHost is implemented by the plugin host process
+ and called by the service. IWslPluginHostCallback is implemented by the
+ service and called by the plugin host when a plugin invokes API functions.
+
+--*/
+
+import "unknwn.idl";
+import "wtypes.idl";
+import "wslc.idl";
+
+cpp_quote("const GUID CLSID_WslPluginHost = {0x7a1d2c3e, 0x4b5f, 0x6a7d, {0x8e, 0x9f, 0x0a, 0x1b, 0x2c, 0x3d, 0x4e, 0x5f}};")
+cpp_quote("#ifdef __cplusplus")
+cpp_quote("class DECLSPEC_UUID(\"7a1d2c3e-4b5f-6a7d-8e9f-0a1b2c3d4e5f\") WslPluginHost;")
+cpp_quote("#endif")
+
+//
+// IWslPluginHostCallback - implemented by the service, called by the plugin host
+// when a plugin invokes WSLPluginAPIV1 functions (MountFolder, ExecuteBinary, etc.)
+//
+
+[
+ uuid(A2B3C4D5-E6F7-4890-AB12-CD34EF56A789),
+ pointer_default(unique),
+ object
+]
+interface IWslPluginHostCallback : IUnknown
+{
+ HRESULT MountFolder(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR WindowsPath,
+ [in, string] LPCWSTR LinuxPath,
+ [in] BOOL ReadOnly,
+ [in, string] LPCWSTR Name);
+
+ HRESULT ExecuteBinary(
+ [in] DWORD SessionId,
+ [in, string] LPCSTR Path,
+ [in] DWORD ArgumentCount,
+ [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments,
+ [out, system_handle(sh_socket)] HANDLE* Socket);
+
+ HRESULT ExecuteBinaryInDistribution(
+ [in] DWORD SessionId,
+ [in] const GUID* DistributionId,
+ [in, string] LPCSTR Path,
+ [in] DWORD ArgumentCount,
+ [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments,
+ [out, system_handle(sh_socket)] HANDLE* Socket);
+
+ //
+ // WSLC plugin API. WslcCreateProcess returns a marshaled IWSLCProcess that
+ // the host wraps in an opaque WSLCProcessHandle for the plugin and calls
+ // directly (GetStdHandle/GetExitEvent/GetState), so there is no service-side
+ // process bookkeeping: lifetime is owned by the host via COM reference
+ // counting and the remote process is released when the host releases it (or
+ // when the host process exits).
+ //
+
+ HRESULT WslcMountFolder(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR WindowsPath,
+ [in, string] LPCSTR Mountpoint,
+ [in] BOOL ReadOnly);
+
+ HRESULT WslcUnmountFolder(
+ [in] DWORD SessionId,
+ [in, string] LPCSTR Mountpoint);
+
+ HRESULT WslcCreateProcess(
+ [in] DWORD SessionId,
+ [in, string] LPCSTR Executable,
+ [in] DWORD ArgumentCount,
+ [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments,
+ [in] DWORD EnvCount,
+ [in, unique, size_is(EnvCount), string] LPCSTR* Environment,
+ [out] IWSLCProcess** Process,
+ [out] int* Errno);
+};
+
+//
+// IWslPluginHost - implemented by the plugin host process, called by the service
+// to deliver lifecycle notifications to the plugin.
+//
+
+[
+ uuid(B3C4D5E6-F7A8-4901-BC23-DE45FA67B890),
+ pointer_default(unique),
+ object
+]
+interface IWslPluginHost : IUnknown
+{
+ //
+ // Initialize the plugin host: load the plugin DLL and call its entry point.
+ // The Callback interface is used by the plugin to call back into the service.
+ // JobObject is the service's job object; the host assigns itself to it before
+ // running any plugin code so that processes spawned by the plugin inherit the
+ // job and are terminated if the service exits unexpectedly.
+ //
+
+ HRESULT Initialize(
+ [in] IWslPluginHostCallback* Callback,
+ [in, system_handle(sh_job)] HANDLE JobObject,
+ [in, string] LPCWSTR PluginPath,
+ [in, string] LPCWSTR PluginName);
+
+ //
+ // Lifecycle hook dispatchers - mirror WSLPluginHooksV1.
+ // UserToken is duplicated into the host process by the service before calling.
+ // UserSid is serialized as a byte array.
+ //
+
+ HRESULT OnVMStarted(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData,
+ [in] DWORD CustomConfigurationFlags,
+ [out, string] LPWSTR* ErrorMessage);
+
+ HRESULT OnVMStopping(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData);
+
+ HRESULT OnDistributionStarted(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData,
+ [in] const GUID* DistributionId,
+ [in, string] LPCWSTR DistributionName,
+ [in] ULONGLONG PidNamespace,
+ [in, unique, string] LPCWSTR PackageFamilyName,
+ [in] DWORD InitPid,
+ [in, unique, string] LPCWSTR Flavor,
+ [in, unique, string] LPCWSTR Version,
+ [out, string] LPWSTR* ErrorMessage);
+
+ HRESULT OnDistributionStopping(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData,
+ [in] const GUID* DistributionId,
+ [in, string] LPCWSTR DistributionName,
+ [in] ULONGLONG PidNamespace,
+ [in, unique, string] LPCWSTR PackageFamilyName,
+ [in] DWORD InitPid,
+ [in, unique, string] LPCWSTR Flavor,
+ [in, unique, string] LPCWSTR Version);
+
+ HRESULT OnDistributionRegistered(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData,
+ [in] const GUID* DistributionId,
+ [in, string] LPCWSTR DistributionName,
+ [in, unique, string] LPCWSTR PackageFamilyName,
+ [in, unique, string] LPCWSTR Flavor,
+ [in, unique, string] LPCWSTR Version);
+
+ HRESULT OnDistributionUnregistered(
+ [in] DWORD SessionId,
+ [in, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, size_is(SidSize)] BYTE* SidData,
+ [in] const GUID* DistributionId,
+ [in, string] LPCWSTR DistributionName,
+ [in, unique, string] LPCWSTR PackageFamilyName,
+ [in, unique, string] LPCWSTR Flavor,
+ [in, unique, string] LPCWSTR Version);
+
+ //
+ // WSLC plugin hooks. Mirror WSLPluginHooksV1 WSLC entries.
+ // UserToken is duplicated into the host by COM's system_handle marshaling.
+ // UserSid is serialized as a byte array (same convention as WSL hooks).
+ //
+
+ HRESULT OnWslcSessionCreated(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData,
+ [out, string] LPWSTR* ErrorMessage);
+
+ HRESULT OnWslcSessionStopping(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData);
+
+ HRESULT OnWslcContainerStarted(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData,
+ [in, string] LPCSTR InspectJson,
+ [out, string] LPWSTR* ErrorMessage);
+
+ HRESULT OnWslcContainerStopping(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData,
+ [in, string] LPCSTR ContainerId);
+
+ HRESULT OnWslcImageCreated(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData,
+ [in, string] LPCSTR InspectJson);
+
+ HRESULT OnWslcImageDeleted(
+ [in] DWORD SessionId,
+ [in, string] LPCWSTR DisplayName,
+ [in] DWORD ApplicationPid,
+ [in, unique, system_handle(sh_token)] HANDLE UserToken,
+ [in] DWORD SidSize,
+ [in, unique, size_is(SidSize)] BYTE* SidData,
+ [in, string] LPCSTR ImageId);
+};
diff --git a/src/windows/service/stub/CMakeLists.txt b/src/windows/service/stub/CMakeLists.txt
index c0308d6dcd..c8ef86f3f9 100644
--- a/src/windows/service/stub/CMakeLists.txt
+++ b/src/windows/service/stub/CMakeLists.txt
@@ -3,6 +3,8 @@ set(SOURCES
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslservice_p_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslc_i_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslc_p_${TARGET_PLATFORM}.c
+ ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_i_${TARGET_PLATFORM}.c
+ ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_p_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.def
${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.rc)
@@ -10,6 +12,6 @@ set(SOURCES
set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE)
add_library(wslserviceproxystub SHARED ${SOURCES})
-add_dependencies(wslserviceproxystub wslserviceidl)
+add_dependencies(wslserviceproxystub wslserviceidl wslpluginhostidl)
target_link_libraries(wslserviceproxystub ${COMMON_LINK_LIBRARIES})
set_target_properties(wslserviceproxystub PROPERTIES FOLDER windows)
\ No newline at end of file
diff --git a/src/windows/wslinstall/DllMain.cpp b/src/windows/wslinstall/DllMain.cpp
index cc42baf018..01d39d4d39 100644
--- a/src/windows/wslinstall/DllMain.cpp
+++ b/src/windows/wslinstall/DllMain.cpp
@@ -827,7 +827,7 @@ void RegisterLspCategoriesImpl(DWORD flags)
const auto installRoot = wsl::windows::common::wslutil::GetMsiPackagePath();
THROW_HR_IF(E_INVALIDARG, !installRoot.has_value());
- for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe", L"wslservice.exe"})
+ for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe", L"wslservice.exe"})
{
auto executable = installRoot.value() + e;
INT error{};
diff --git a/src/windows/wslpluginhost/CMakeLists.txt b/src/windows/wslpluginhost/CMakeLists.txt
new file mode 100644
index 0000000000..ce8dc8f770
--- /dev/null
+++ b/src/windows/wslpluginhost/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(exe)
diff --git a/src/windows/wslpluginhost/exe/CMakeLists.txt b/src/windows/wslpluginhost/exe/CMakeLists.txt
new file mode 100644
index 0000000000..fc16ec3bc2
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/CMakeLists.txt
@@ -0,0 +1,25 @@
+set(SOURCES
+ main.cpp
+ main.rc
+ PluginHost.cpp)
+
+set(HEADERS
+ PluginHost.h
+ resource.h)
+
+add_executable(wslpluginhost WIN32 ${SOURCES} ${HEADERS})
+add_dependencies(wslpluginhost
+ wslpluginhostidl
+ common)
+
+target_include_directories(wslpluginhost PRIVATE
+ ${CMAKE_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
+
+target_link_libraries(wslpluginhost
+ ${COMMON_LINK_LIBRARIES}
+ ${MSI_LINK_LIBRARIES}
+ common
+ ole32.lib)
+
+target_precompile_headers(wslpluginhost REUSE_FROM common)
+set_target_properties(wslpluginhost PROPERTIES FOLDER windows)
diff --git a/src/windows/wslpluginhost/exe/PluginHost.cpp b/src/windows/wslpluginhost/exe/PluginHost.cpp
new file mode 100644
index 0000000000..b451ea090c
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/PluginHost.cpp
@@ -0,0 +1,976 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ PluginHost.cpp
+
+Abstract:
+
+ This file contains the IWslPluginHost COM class implementation.
+ It loads a plugin DLL in this (host) process and forwards lifecycle
+ notifications from the service to the plugin, while routing plugin API
+ callbacks back to the service via IWslPluginHostCallback.
+
+--*/
+
+#include "precomp.h"
+#include "PluginHost.h"
+#include "install.h"
+
+using namespace wsl::windows::pluginhost;
+
+// Defined in main.cpp — part of the COM local server lifecycle.
+extern void AddComRef();
+extern void ReleaseComRef();
+
+std::atomic wsl::windows::pluginhost::g_pluginHost{nullptr};
+
+// Thread ID of the thread currently dispatching a plugin hook.
+// Only that thread may call PluginError. Using thread ID instead of
+// thread_local to avoid TLS initialization issues across DLL/EXE boundaries.
+static std::atomic g_hookThreadId{0};
+
+namespace {
+
+// Plugins may invoke API stubs from worker threads they create themselves,
+// which often haven't called CoInitializeEx. Without COM initialization, the
+// cross-process m_callback proxy can't be used and calls would fail with
+// CO_E_NOTINITIALIZED. This RAII helper joins the calling thread to the MTA
+// for the duration of the API call. RPC_E_CHANGED_MODE means the thread is
+// already STA-initialized; we leave its apartment unchanged and still proceed,
+// relying on COM cross-apartment marshaling to dispatch the call to the MTA
+// where m_callback was marshaled. The primary (and tested) path is an
+// uninitialized worker thread joining the MTA directly.
+struct ScopedComInitForCallback
+{
+ HRESULT initHr;
+ ScopedComInitForCallback() : initHr(::CoInitializeEx(nullptr, COINIT_MULTITHREADED))
+ {
+ }
+ ~ScopedComInitForCallback()
+ {
+ if (SUCCEEDED(initHr))
+ {
+ ::CoUninitialize();
+ }
+ }
+ ScopedComInitForCallback(const ScopedComInitForCallback&) = delete;
+ ScopedComInitForCallback& operator=(const ScopedComInitForCallback&) = delete;
+
+ HRESULT Result() const
+ {
+ return (initHr == RPC_E_CHANGED_MODE) ? S_OK : initHr;
+ }
+};
+
+} // namespace
+
+// Tracks how many PluginHost instances have ever been constructed in this
+// process. Used by main() to distinguish "we were activated and the client
+// disconnected normally" from "we sat here for the entire startup window
+// without ever being activated by anyone" (orphaned host).
+std::atomic wsl::windows::pluginhost::g_activationCount{0};
+
+PluginHost::PluginHost()
+{
+ // Increment the COM server reference count so the process stays alive while
+ // this instance exists. Pairs with ReleaseComRef() in ~PluginHost(); tying
+ // both to the object's lifetime guarantees they always balance regardless of
+ // whether the factory's CopyTo() succeeds or fails.
+ AddComRef();
+ g_activationCount.fetch_add(1, std::memory_order_release);
+}
+
+PluginHost::~PluginHost()
+{
+ // Clear globally reachable state so late plugin API calls fail with
+ // E_UNEXPECTED instead of dereferencing freed memory. Release-store
+ // pairs with the acquire-loads in the Local* API stubs.
+ PluginHost* expected = this;
+ g_pluginHost.compare_exchange_strong(expected, nullptr, std::memory_order_acq_rel);
+
+ // Module unloads automatically via wil::unique_hmodule destructor.
+
+ // Decrement the COM server reference count. When it reaches zero,
+ // the process will exit. Matches AddComRef() in PluginHost::PluginHost().
+ ReleaseComRef();
+}
+
+// --- IWslPluginHost implementation ---
+
+STDMETHODIMP PluginHost::Initialize(_In_ IWslPluginHostCallback* Callback, _In_ HANDLE JobObject, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName)
+try
+{
+ RETURN_HR_IF(E_INVALIDARG, Callback == nullptr || JobObject == nullptr || PluginPath == nullptr || PluginName == nullptr);
+ RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, m_module.is_valid()); // Already initialized
+
+ // Join the service's job object before loading or running any plugin code, so that
+ // any child processes the plugin spawns inherit the job and are terminated when the
+ // service exits (the service sets JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE). The host runs
+ // as SYSTEM, so it can assign itself. JobObject is a duplicate owned by the marshaler
+ // for the duration of this call and is freed on return; the assignment persists.
+ // A failure here is fatal: a host that isn't in the job escapes the kill-on-close
+ // guarantee and would be orphaned (along with any children it spawns) when the
+ // service closes the job. Failing Initialize makes the service release this host's
+ // proxy, which exits the host process before any plugin code runs.
+ RETURN_IF_WIN32_BOOL_FALSE_MSG(
+ AssignProcessToJobObject(JobObject, GetCurrentProcess()), "Failed to assign plugin host to job object: '%ls'", PluginName);
+
+ m_callback = Callback;
+
+ // Validate the plugin signature before loading it.
+ // Keep the file handle open to prevent TOCTOU (swap between validation and load).
+ wil::unique_hfile signatureHandle;
+ if constexpr (wsl::shared::OfficialBuild)
+ {
+ signatureHandle = wsl::windows::common::install::ValidateFileSignature(PluginPath);
+ }
+
+ m_module.reset(LoadLibrary(PluginPath));
+ THROW_LAST_ERROR_IF_NULL(m_module);
+ signatureHandle.reset(); // Safe to release after LoadLibrary has mapped the DLL
+
+ const auto entryPoint =
+ reinterpret_cast(GetProcAddress(m_module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1)));
+ THROW_LAST_ERROR_IF_NULL(entryPoint);
+
+ // Build the API vtable that the plugin will use to call back into the service.
+ // The function pointers are static methods on this class that route through g_pluginHost.
+ static const WSLPluginAPIV1 api = {
+ {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision},
+ &LocalMountFolder,
+ &LocalExecuteBinary,
+ &LocalPluginError,
+ &LocalExecuteBinaryInDistribution,
+ &LocalWslcMountFolder,
+ &LocalWslcUnmountFolder,
+ &LocalWslcCreateProcess,
+ &LocalWslcProcessGetFd,
+ &LocalWslcProcessGetExitEvent,
+ &LocalWslcProcessGetExitCode,
+ &LocalWslcReleaseProcess};
+
+ // Publish g_pluginHost with release semantics so an acquire-load in a stub
+ // observes a fully-constructed m_callback / m_hooks. Only publish on success
+ // so a failed entry point never leaves a dangling pointer for late stub calls.
+ g_pluginHost.store(this, std::memory_order_release);
+ HRESULT hr = entryPoint(&api, &m_hooks);
+
+ if (FAILED(hr))
+ {
+ g_pluginHost.store(nullptr, std::memory_order_release);
+ RETURN_HR_MSG(hr, "Plugin entry point failed: '%ls'", PluginPath);
+ }
+
+ return S_OK;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnVMStarted(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ DWORD CustomConfigurationFlags,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage)
+try
+{
+ RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr);
+ *ErrorMessage = nullptr;
+
+ if (m_hooks.OnVMStarted == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+ WSLVmCreationSettings settings{};
+ settings.CustomConfigurationFlags = static_cast(CustomConfigurationFlags);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnVMStarted(&ctx.info, &settings);
+
+ // If the plugin called PluginError during the hook, return the message.
+ if (m_pluginErrorMessage.has_value())
+ {
+ *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release();
+ m_pluginErrorMessage.reset();
+ }
+
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData)
+try
+{
+ if (m_hooks.OnVMStopping == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnVMStopping(&ctx.info);
+
+ // Plugin must not call PluginError outside hooks that surface ErrorMessage;
+ // silently drop any value so it can't poison a later fatal hook.
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnDistributionStarted(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_ ULONGLONG PidNamespace,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_ DWORD InitPid,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage)
+try
+{
+ RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr);
+ *ErrorMessage = nullptr;
+
+ if (m_hooks.OnDistributionStarted == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+
+ // Preserve nullability of PackageFamilyName/Flavor/Version per the public
+ // ABI documented in WslPluginApi.h. Coalescing to L"" would silently change
+ // behavior of plugins that check `if (Distribution->PackageFamilyName)`.
+ WSLDistributionInformation distro{};
+ distro.Id = *DistributionId;
+ distro.Name = DistributionName;
+ distro.PidNamespace = PidNamespace;
+ distro.PackageFamilyName = PackageFamilyName;
+ distro.InitPid = InitPid;
+ distro.Flavor = Flavor;
+ distro.Version = Version;
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnDistributionStarted(&ctx.info, &distro);
+
+ if (m_pluginErrorMessage.has_value())
+ {
+ *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release();
+ m_pluginErrorMessage.reset();
+ }
+
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnDistributionStopping(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_ ULONGLONG PidNamespace,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_ DWORD InitPid,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version)
+try
+{
+ if (m_hooks.OnDistributionStopping == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+
+ WSLDistributionInformation distro{};
+ distro.Id = *DistributionId;
+ distro.Name = DistributionName;
+ distro.PidNamespace = PidNamespace;
+ distro.PackageFamilyName = PackageFamilyName;
+ distro.InitPid = InitPid;
+ distro.Flavor = Flavor;
+ distro.Version = Version;
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnDistributionStopping(&ctx.info, &distro);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnDistributionRegistered(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version)
+try
+{
+ if (m_hooks.OnDistributionRegistered == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+
+ WslOfflineDistributionInformation distro{};
+ distro.Id = *DistributionId;
+ distro.Name = DistributionName;
+ distro.PackageFamilyName = PackageFamilyName;
+ distro.Flavor = Flavor;
+ distro.Version = Version;
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnDistributionRegistered(&ctx.info, &distro);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnDistributionUnregistered(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version)
+try
+{
+ if (m_hooks.OnDistributionUnregistered == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData);
+
+ WslOfflineDistributionInformation distro{};
+ distro.Id = *DistributionId;
+ distro.Name = DistributionName;
+ distro.PackageFamilyName = PackageFamilyName;
+ distro.Flavor = Flavor;
+ distro.Version = Version;
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnDistributionUnregistered(&ctx.info, &distro);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+// --- Helpers ---
+
+PluginHost::SessionContext PluginHost::BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData)
+{
+ SessionContext ctx{};
+ ctx.info.SessionId = SessionId;
+
+ // The marshaled UserToken is owned by the RPC stub, which closes it when the
+ // call returns. Borrow it for the duration of the hook; do not take ownership.
+ ctx.info.UserToken = UserToken;
+
+ // Reconstruct the SID. Reject malformed inputs at the COM boundary —
+ // pointer arithmetic on a null pointer (even with size 0) is undefined
+ // behavior, and a malformed SID would be dereferenced by plugin code.
+ THROW_HR_IF(E_INVALIDARG, SidData == nullptr || SidSize == 0);
+ ctx.sidBuffer.assign(SidData, SidData + SidSize);
+ THROW_HR_IF(E_INVALIDARG, !::IsValidSid(reinterpret_cast(ctx.sidBuffer.data())));
+ ctx.info.UserSid = reinterpret_cast(ctx.sidBuffer.data());
+
+ return ctx;
+}
+
+// --- Static API stubs ---
+
+HRESULT CALLBACK PluginHost::LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ auto hr = host->m_callback->MountFolder(Session, WindowsPath, LinuxPath, ReadOnly, Name);
+ return hr;
+}
+
+HRESULT CALLBACK PluginHost::LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_POINTER, Socket == nullptr);
+ // Initialize the out-param so callers don't observe an uninitialized
+ // socket value on any error return below.
+ *Socket = INVALID_SOCKET;
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ // Count arguments (NULL-terminated array)
+ DWORD count = 0;
+ if (Arguments != nullptr)
+ {
+ for (const LPCSTR* p = Arguments; *p != nullptr; ++p)
+ {
+ ++count;
+ }
+ }
+
+ HANDLE socketResult = nullptr;
+ HRESULT hr = host->m_callback->ExecuteBinary(Session, Path, count, Arguments, &socketResult);
+
+ if (SUCCEEDED(hr))
+ {
+ // COM's system_handle marshaling duplicated the socket into our process.
+ *Socket = reinterpret_cast(socketResult);
+ }
+ else if (socketResult != nullptr)
+ {
+ if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR)
+ {
+ LOG_WIN32(WSAGetLastError());
+ }
+ }
+
+ return hr;
+}
+
+HRESULT CALLBACK PluginHost::LocalPluginError(LPCWSTR UserMessage)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr)
+ {
+ // Not on a hook thread — PluginError must only be called
+ // synchronously from within OnVMStarted/OnDistributionStarted.
+ return E_ILLEGAL_METHOD_CALL;
+ }
+
+ RETURN_HR_IF(E_INVALIDARG, UserMessage == nullptr);
+ RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, GetCurrentThreadId() != g_hookThreadId.load());
+ RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, host->m_pluginErrorMessage.has_value());
+
+ // Store locally — returned to service alongside the hook HRESULT.
+ host->m_pluginErrorMessage.emplace(UserMessage);
+ return S_OK;
+}
+
+HRESULT CALLBACK PluginHost::LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_INVALIDARG, Distro == nullptr);
+ RETURN_HR_IF(E_POINTER, Socket == nullptr);
+ // Initialize the out-param so callers don't observe an uninitialized
+ // socket value on any error return below.
+ *Socket = INVALID_SOCKET;
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ DWORD count = 0;
+ if (Arguments != nullptr)
+ {
+ for (const LPCSTR* p = Arguments; *p != nullptr; ++p)
+ {
+ ++count;
+ }
+ }
+
+ HANDLE socketResult = nullptr;
+ HRESULT hr = host->m_callback->ExecuteBinaryInDistribution(Session, Distro, Path, count, Arguments, &socketResult);
+
+ if (SUCCEEDED(hr))
+ {
+ *Socket = reinterpret_cast(socketResult);
+ }
+ else if (socketResult != nullptr)
+ {
+ if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR)
+ {
+ LOG_WIN32(WSAGetLastError());
+ }
+ }
+
+ return hr;
+}
+
+// --- WSLC hook implementations ---
+
+PluginHost::WslcSessionContext PluginHost::BuildWslcSessionContext(
+ DWORD SessionId, LPCWSTR DisplayName, DWORD ApplicationPid, HANDLE UserToken, DWORD SidSize, BYTE* SidData)
+{
+ WslcSessionContext ctx{};
+ ctx.info.SessionId = SessionId;
+ ctx.info.DisplayName = DisplayName;
+ ctx.info.ApplicationPid = ApplicationPid;
+
+ // UserToken / SidData are unique in the IDL — both may be null in stopping
+ // paths invoked during session teardown. The marshaled token is owned by the
+ // RPC stub (closed when the call returns), so borrow it without taking ownership.
+ ctx.info.UserToken = UserToken;
+
+ if (SidData != nullptr && SidSize > 0)
+ {
+ ctx.sidBuffer.assign(SidData, SidData + SidSize);
+ THROW_HR_IF(E_INVALIDARG, !::IsValidSid(reinterpret_cast(ctx.sidBuffer.data())));
+ ctx.info.UserSid = reinterpret_cast(ctx.sidBuffer.data());
+ }
+
+ return ctx;
+}
+
+STDMETHODIMP PluginHost::OnWslcSessionCreated(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage)
+try
+{
+ RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr);
+ *ErrorMessage = nullptr;
+
+ if (m_hooks.OnSessionCreated == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnSessionCreated(&ctx.info);
+
+ if (m_pluginErrorMessage.has_value())
+ {
+ *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release();
+ m_pluginErrorMessage.reset();
+ }
+
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnWslcSessionStopping(
+ _In_ DWORD SessionId, _In_ LPCWSTR DisplayName, _In_ DWORD ApplicationPid, _In_opt_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_opt_(SidSize) BYTE* SidData)
+try
+{
+ if (m_hooks.OnSessionStopping == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.OnSessionStopping(&ctx.info);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnWslcContainerStarted(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR InspectJson,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage)
+try
+{
+ RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr);
+ *ErrorMessage = nullptr;
+
+ if (m_hooks.ContainerStarted == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.ContainerStarted(&ctx.info, InspectJson);
+
+ if (m_pluginErrorMessage.has_value())
+ {
+ *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release();
+ m_pluginErrorMessage.reset();
+ }
+
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnWslcContainerStopping(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR ContainerId)
+try
+{
+ if (m_hooks.ContainerStopping == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.ContainerStopping(&ctx.info, ContainerId);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnWslcImageCreated(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR InspectJson)
+try
+{
+ if (m_hooks.ImageCreated == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.ImageCreated(&ctx.info, InspectJson);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+STDMETHODIMP PluginHost::OnWslcImageDeleted(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR ImageId)
+try
+{
+ if (m_hooks.ImageDeleted == nullptr)
+ {
+ return S_OK;
+ }
+
+ auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData);
+
+ std::lock_guard hookLock(m_hookLock);
+ g_hookThreadId.store(GetCurrentThreadId());
+ m_pluginErrorMessage.reset();
+ auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); });
+
+ HRESULT hr = m_hooks.ImageDeleted(&ctx.info, ImageId);
+
+ m_pluginErrorMessage.reset();
+ return hr;
+}
+CATCH_RETURN();
+
+// --- WSLC API stubs ---
+
+namespace {
+
+// Opaque wrapper handed to the plugin as WSLCProcessHandle. It owns the
+// IWSLCProcess COM proxy marshaled from wslcsession (via the service), so the
+// host calls the process interface directly and the remote process is released
+// when the plugin releases this wrapper (or when the host process exits).
+struct WslcProcessWrapper
+{
+ wil::com_ptr process;
+};
+
+} // namespace
+
+HRESULT CALLBACK PluginHost::LocalWslcMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_POINTER, Mountpoint == nullptr);
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ return host->m_callback->WslcMountFolder(Session, WindowsPath, Mountpoint, ReadOnly);
+}
+
+HRESULT CALLBACK PluginHost::LocalWslcUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ return host->m_callback->WslcUnmountFolder(Session, Mountpoint);
+}
+
+HRESULT CALLBACK PluginHost::LocalWslcCreateProcess(
+ WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_POINTER, Process == nullptr);
+ *Process = nullptr;
+
+ int localErrno = 0;
+ if (Errno != nullptr)
+ {
+ *Errno = 0;
+ }
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ DWORD argCount = 0;
+ if (Arguments != nullptr)
+ {
+ for (const LPCSTR* p = Arguments; *p != nullptr; ++p)
+ {
+ ++argCount;
+ }
+ }
+
+ DWORD envCount = 0;
+ if (Env != nullptr)
+ {
+ for (const LPCSTR* p = Env; *p != nullptr; ++p)
+ {
+ ++envCount;
+ }
+ }
+
+ // Allocate the wrapper before creating the remote process so a throwing
+ // allocation can't strand a remote process that only WslcReleaseProcess
+ // frees. Nothing between the remote create and release() below can throw.
+ auto wrapper = std::make_unique();
+
+ HRESULT hr =
+ host->m_callback->WslcCreateProcess(Session, Executable, argCount, Arguments, envCount, Env, wrapper->process.put(), &localErrno);
+ if (Errno != nullptr)
+ {
+ *Errno = localErrno;
+ }
+
+ if (FAILED(hr))
+ {
+ return hr;
+ }
+
+ *Process = wrapper.release();
+ return S_OK;
+}
+
+HRESULT CALLBACK PluginHost::LocalWslcProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_INVALIDARG, Process == nullptr);
+ RETURN_HR_IF(E_POINTER, Handle == nullptr);
+ *Handle = nullptr;
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ auto* wrapper = static_cast(Process);
+ RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr);
+
+ WSLCFD wslcFd{};
+ switch (Fd)
+ {
+ case WSLCProcessFdStdin:
+ wslcFd = WSLCFDStdin;
+ break;
+ case WSLCProcessFdStdout:
+ wslcFd = WSLCFDStdout;
+ break;
+ case WSLCProcessFdStderr:
+ wslcFd = WSLCFDStderr;
+ break;
+ default:
+ return E_INVALIDARG;
+ }
+
+ WSLCHandle handle{};
+ RETURN_IF_FAILED(wrapper->process->GetStdHandle(wslcFd, &handle));
+
+ WI_ASSERT(handle.Type == WSLCHandleTypeSocket);
+
+ // Pass through as HANDLE; COM's system_handle(sh_socket) marshaling already
+ // duplicated it into this process.
+ *Handle = handle.Handle.Socket;
+ return S_OK;
+}
+
+HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_INVALIDARG, Process == nullptr);
+ RETURN_HR_IF(E_POINTER, ExitEvent == nullptr);
+ *ExitEvent = nullptr;
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ auto* wrapper = static_cast(Process);
+ RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr);
+ return wrapper->process->GetExitEvent(ExitEvent);
+}
+
+HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode)
+{
+ auto* host = g_pluginHost.load(std::memory_order_acquire);
+ if (host == nullptr || host->m_callback == nullptr)
+ {
+ return E_UNEXPECTED;
+ }
+
+ RETURN_HR_IF(E_INVALIDARG, Process == nullptr);
+ RETURN_HR_IF(E_POINTER, ExitCode == nullptr);
+ *ExitCode = -1;
+
+ ScopedComInitForCallback coInit;
+ RETURN_IF_FAILED(coInit.Result());
+
+ auto* wrapper = static_cast(Process);
+ RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr);
+
+ WSLCProcessState state{};
+ auto result = wrapper->process->GetState(&state, ExitCode);
+ if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled)
+ {
+ result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE);
+ }
+
+ return result;
+}
+
+void CALLBACK PluginHost::LocalWslcReleaseProcess(WSLCProcessHandle Process)
+{
+ if (Process == nullptr)
+ {
+ return;
+ }
+
+ // Initialize COM before taking ownership: destroying the wrapper releases
+ // the IWSLCProcess proxy, which marshals a Release back to wslcsession and
+ // needs COM initialized on this thread. coInit is declared first so it
+ // outlives the wrapper (reverse destruction order).
+ ScopedComInitForCallback coInit;
+ LOG_IF_FAILED(coInit.Result());
+
+ std::unique_ptr wrapper{static_cast(Process)};
+}
\ No newline at end of file
diff --git a/src/windows/wslpluginhost/exe/PluginHost.h b/src/windows/wslpluginhost/exe/PluginHost.h
new file mode 100644
index 0000000000..d092917680
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/PluginHost.h
@@ -0,0 +1,221 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ PluginHost.h
+
+Abstract:
+
+ This file contains the COM class that implements IWslPluginHost.
+ It loads a plugin DLL and dispatches lifecycle notifications to it,
+ forwarding plugin API callbacks to the service via IWslPluginHostCallback.
+
+--*/
+
+#pragma once
+
+#include "WslPluginApi.h"
+#include "WslPluginHost.h"
+
+namespace wsl::windows::pluginhost {
+
+class PluginHost : public Microsoft::WRL::RuntimeClass, IWslPluginHost>
+{
+public:
+ PluginHost();
+ ~PluginHost();
+
+ PluginHost(const PluginHost&) = delete;
+ PluginHost& operator=(const PluginHost&) = delete;
+
+ // IWslPluginHost
+ STDMETHODIMP Initialize(_In_ IWslPluginHostCallback* Callback, _In_ HANDLE JobObject, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) override;
+
+ STDMETHODIMP OnVMStarted(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ DWORD CustomConfigurationFlags,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override;
+
+ STDMETHODIMP OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData) override;
+
+ STDMETHODIMP OnDistributionStarted(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_ ULONGLONG PidNamespace,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_ DWORD InitPid,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override;
+
+ STDMETHODIMP OnDistributionStopping(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_ ULONGLONG PidNamespace,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_ DWORD InitPid,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version) override;
+
+ STDMETHODIMP OnDistributionRegistered(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version) override;
+
+ STDMETHODIMP OnDistributionUnregistered(
+ _In_ DWORD SessionId,
+ _In_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_(SidSize) BYTE* SidData,
+ _In_ const GUID* DistributionId,
+ _In_ LPCWSTR DistributionName,
+ _In_opt_ LPCWSTR PackageFamilyName,
+ _In_opt_ LPCWSTR Flavor,
+ _In_opt_ LPCWSTR Version) override;
+
+ //
+ // WSLC plugin hooks. UserToken and SidData are marshaled the same way as
+ // the existing WSL hooks. The string DisplayName plus ApplicationPid make
+ // up the rest of the WSLCSessionInformation struct.
+ //
+
+ STDMETHODIMP OnWslcSessionCreated(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override;
+
+ STDMETHODIMP OnWslcSessionStopping(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData) override;
+
+ STDMETHODIMP OnWslcContainerStarted(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR InspectJson,
+ _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override;
+
+ STDMETHODIMP OnWslcContainerStopping(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR ContainerId) override;
+
+ STDMETHODIMP OnWslcImageCreated(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR InspectJson) override;
+
+ STDMETHODIMP OnWslcImageDeleted(
+ _In_ DWORD SessionId,
+ _In_ LPCWSTR DisplayName,
+ _In_ DWORD ApplicationPid,
+ _In_opt_ HANDLE UserToken,
+ _In_ DWORD SidSize,
+ _In_reads_opt_(SidSize) BYTE* SidData,
+ _In_ LPCSTR ImageId) override;
+
+private:
+ // Build a WSLSessionInformation struct from the marshaled parameters.
+ // The returned struct and its SID allocation are valid for the lifetime of the wil::unique_handle.
+ struct SessionContext
+ {
+ WSLSessionInformation info{};
+ std::vector sidBuffer;
+ };
+
+ SessionContext BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData);
+
+ // WSLC counterpart — builds a WSLCSessionInformation. DisplayName is held
+ // by the caller (the IDL marshaled buffer lives for the call), so the info
+ // struct's LPCWSTR DisplayName points at it directly. UserToken/SidData are
+ // optional in the WSLC hooks (NotifySessionStopping is invoked without an
+ // application token in dtor paths).
+ struct WslcSessionContext
+ {
+ WSLCSessionInformation info{};
+ std::vector sidBuffer;
+ };
+
+ WslcSessionContext BuildWslcSessionContext(DWORD SessionId, LPCWSTR DisplayName, DWORD ApplicationPid, HANDLE UserToken, DWORD SidSize, BYTE* SidData);
+
+ // Local stubs for the WSLPluginAPIV1 function pointers.
+ // These forward calls to the service via m_callback.
+ static HRESULT CALLBACK LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name);
+ static HRESULT CALLBACK LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket);
+ static HRESULT CALLBACK LocalPluginError(LPCWSTR UserMessage);
+ static HRESULT CALLBACK LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket);
+
+ // WSLC API stubs. WSLCProcessHandle is a heap-allocated wrapper that owns
+ // the IWSLCProcess COM proxy marshaled from wslcsession via the service.
+ static HRESULT CALLBACK LocalWslcMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly);
+ static HRESULT CALLBACK LocalWslcUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint);
+ static HRESULT CALLBACK LocalWslcCreateProcess(
+ WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno);
+ static HRESULT CALLBACK LocalWslcProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle);
+ static HRESULT CALLBACK LocalWslcProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent);
+ static HRESULT CALLBACK LocalWslcProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode);
+ static void CALLBACK LocalWslcReleaseProcess(WSLCProcessHandle Process);
+
+ wil::unique_hmodule m_module;
+ WSLPluginHooksV1 m_hooks{};
+ Microsoft::WRL::ComPtr m_callback;
+
+ // Serializes hook dispatch so m_pluginErrorMessage and g_hookThreadId
+ // are not raced when multiple sessions call hooks concurrently (MTA).
+ std::mutex m_hookLock;
+
+ // Error message captured by LocalPluginError during hook execution
+ std::optional m_pluginErrorMessage;
+};
+
+// Process-wide pointer to the single PluginHost instance. Safe because
+// REGCLS_SINGLEUSE guarantees one PluginHost per wslpluginhost.exe process.
+// This allows plugin DLLs to call API functions from any thread, not just
+// the thread dispatching the current hook. Atomic so concurrent stub calls
+// from plugin worker threads observe a coherent value during ctor/dtor.
+extern std::atomic g_pluginHost;
+
+// Number of PluginHost instances ever constructed in this process. Used by
+// main() to detect orphan hosts that COM-activated wslpluginhost.exe but
+// never followed through with a successful CreateInstance.
+extern std::atomic g_activationCount;
+
+} // namespace wsl::windows::pluginhost
diff --git a/src/windows/wslpluginhost/exe/main.cpp b/src/windows/wslpluginhost/exe/main.cpp
new file mode 100644
index 0000000000..bb2a708452
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/main.cpp
@@ -0,0 +1,139 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ main.cpp
+
+Abstract:
+
+ This file contains the entry point for wslpluginhost.exe.
+ This process acts as a COM local server that loads a single WSL plugin DLL
+ in an isolated process, preventing a buggy or malicious plugin from crashing
+ the main WSL service.
+
+ The host is activated through COM local-server activation. It registers its
+ COM class factory, serves activation requests, and remains alive until all
+ COM server-process references are released, at which point it exits.
+
+--*/
+
+#include "precomp.h"
+#include "PluginHost.h"
+#include "WslPluginHost.h"
+
+using namespace Microsoft::WRL;
+
+static wil::unique_event g_exitEvent(wil::EventOptions::ManualReset);
+
+void AddComRef()
+{
+ CoAddRefServerProcess();
+}
+
+void ReleaseComRef()
+{
+ if (CoReleaseServerProcess() == 0)
+ {
+ g_exitEvent.SetEvent();
+ }
+}
+
+class PluginHostFactory : public RuntimeClass, IClassFactory>
+{
+public:
+ STDMETHODIMP CreateInstance(_In_opt_ IUnknown* pUnkOuter, _In_ REFIID riid, _Outptr_ void** ppCreated) override
+ try
+ {
+ RETURN_HR_IF_NULL(E_POINTER, ppCreated);
+ *ppCreated = nullptr;
+ RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr);
+
+ auto host = Make();
+ RETURN_IF_NULL_ALLOC(host);
+
+ // The PluginHost ctor/dtor pair manages the process keep-alive ref;
+ // no manual AddComRef/ReleaseComRef needed here. If CopyTo fails, the
+ // local ComPtr destructor releases the only reference, which destroys
+ // the PluginHost and decrements the keep-alive count.
+ RETURN_IF_FAILED(host.CopyTo(riid, ppCreated));
+ return S_OK;
+ }
+ CATCH_RETURN();
+
+ STDMETHODIMP LockServer(BOOL lock) noexcept override
+ {
+ if (lock)
+ {
+ AddComRef();
+ }
+ else
+ {
+ ReleaseComRef();
+ }
+ return S_OK;
+ }
+};
+
+int WINAPI wWinMain(_In_ HINSTANCE, _In_opt_ HINSTANCE, _In_ LPWSTR, _In_ int)
+try
+{
+ wsl::windows::common::wslutil::ConfigureCrt();
+ wsl::windows::common::wslutil::InitializeWil();
+
+ // Initialize logging.
+ WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild);
+ auto cleanupTracing = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [] { WslTraceLoggingUninitialize(); });
+
+ // Harden the process before loading any third-party plugin code. Match the
+ // mitigation set applied by the other WSL COM server processes
+ // (wslservice.exe / wslcsession.exe).
+ wsl::windows::common::security::ApplyProcessMitigationPolicies();
+
+ auto coInit = wil::CoInitializeEx(COINIT_MULTITHREADED);
+ wsl::windows::common::wslutil::CoInitializeSecurity();
+
+ // Initialize Winsock — plugins receive sockets from ExecuteBinary and need
+ // Winsock to be initialized for recv/send/closesocket to work.
+ WSADATA wsaData{};
+ THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &wsaData));
+ auto cleanupWinsock = wil::scope_exit([] { WSACleanup(); });
+
+ // Register the class factory so the service can CoCreateInstance on us.
+ DWORD cookie = 0;
+ auto factory = Make();
+ THROW_IF_NULL_ALLOC(factory);
+
+ THROW_IF_FAILED(::CoRegisterClassObject(CLSID_WslPluginHost, factory.Get(), CLSCTX_LOCAL_SERVER, REGCLS_SINGLEUSE, &cookie));
+
+ auto revokeOnExit = wil::scope_exit([&]() { ::CoRevokeClassObject(cookie); });
+
+ // Bounded shutdown for orphaned hosts: if COM activates wslpluginhost.exe
+ // but no client ever successfully creates an instance (e.g., the service
+ // crashes between launch and CreateInstance, or activation is abandoned
+ // by COM), exit instead of blocking on g_exitEvent forever. Once at least
+ // one PluginHost has been constructed, AddComRef/ReleaseComRef govern
+ // shutdown and we wait indefinitely for that ref count to drop to 0.
+ constexpr DWORD c_startupTimeoutMs = 60'000;
+ const DWORD waitResult = ::WaitForSingleObject(g_exitEvent.get(), c_startupTimeoutMs);
+ if (waitResult == WAIT_TIMEOUT && wsl::windows::pluginhost::g_activationCount.load(std::memory_order_acquire) == 0)
+ {
+ LOG_HR_MSG(
+ HRESULT_FROM_WIN32(ERROR_TIMEOUT),
+ "wslpluginhost.exe startup timeout: no client activated the host within %u ms; exiting",
+ c_startupTimeoutMs);
+ return 1;
+ }
+
+ // Either the event was signaled, or a host was activated. Wait for the
+ // exit signal (which fires when CoReleaseServerProcess returns 0).
+ g_exitEvent.wait();
+
+ return 0;
+}
+catch (...)
+{
+ LOG_CAUGHT_EXCEPTION();
+ return 1;
+}
diff --git a/src/windows/wslpluginhost/exe/main.rc b/src/windows/wslpluginhost/exe/main.rc
new file mode 100644
index 0000000000..84d6b8b8c6
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/main.rc
@@ -0,0 +1,25 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ main.rc
+
+Abstract:
+
+ This file contains resources for wslpluginhost.
+
+--*/
+
+#include
+#include "resource.h"
+#include "wslversioninfo.h"
+
+#define VER_INTERNALNAME_STR "wslpluginhost.exe"
+#define VER_ORIGINALFILENAME_STR "wslpluginhost.exe"
+
+#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux"
+ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\..\Images\wsl.ico"
+
+#include
diff --git a/src/windows/wslpluginhost/exe/resource.h b/src/windows/wslpluginhost/exe/resource.h
new file mode 100644
index 0000000000..355437a443
--- /dev/null
+++ b/src/windows/wslpluginhost/exe/resource.h
@@ -0,0 +1,15 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ resource.h
+
+Abstract:
+
+ This file contains resource declarations for wslpluginhost.exe
+
+--*/
+
+#define ID_ICON 1
diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp
index f5aa3b1a0e..ebc941782e 100644
--- a/test/windows/Common.cpp
+++ b/test/windows/Common.cpp
@@ -831,6 +831,7 @@ void CreateWerReports()
L"wsl.exe",
L"wslhost.exe",
L"wslrelay.exe",
+ L"wslpluginhost.exe",
L"wslservice.exe",
L"wslg.exe",
L"vmcompute.exe",
@@ -1329,6 +1330,48 @@ void StopWslService()
StopService(service.get());
}
+DWORD GetWslServicePid()
+{
+ const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)};
+ VERIFY_IS_NOT_NULL(manager);
+
+ const wil::unique_schandle service{OpenService(manager.get(), L"wslservice", SERVICE_QUERY_STATUS | SERVICE_START)};
+ VERIFY_IS_NOT_NULL(service);
+
+ auto [state, pid] = GetServiceState(service.get());
+ if (state == SERVICE_STOPPED)
+ {
+ // The service is on-demand; start it so we can capture a stable PID.
+ if (!StartService(service.get(), 0, nullptr))
+ {
+ const auto error = GetLastError();
+ VERIFY_IS_TRUE(error == ERROR_SERVICE_ALREADY_RUNNING);
+ }
+ }
+
+ if (state != SERVICE_RUNNING)
+ {
+ WaitForServiceState(service.get(), SERVICE_RUNNING, 0);
+ std::tie(state, pid) = GetServiceState(service.get());
+ }
+
+ VERIFY_ARE_EQUAL(static_cast(SERVICE_RUNNING), state);
+ return pid;
+}
+
+DWORD GetWslServiceRunningPid()
+{
+ const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)};
+ VERIFY_IS_NOT_NULL(manager);
+
+ const wil::unique_schandle service{OpenService(manager.get(), L"wslservice", SERVICE_QUERY_STATUS)};
+ VERIFY_IS_NOT_NULL(service);
+
+ auto [state, pid] = GetServiceState(service.get());
+ VERIFY_ARE_EQUAL(static_cast(SERVICE_RUNNING), state);
+ return pid;
+}
+
wil::unique_handle GetNonElevatedToken(TOKEN_TYPE Type)
{
auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS);
diff --git a/test/windows/Common.h b/test/windows/Common.h
index ad2706ebf2..611900d95c 100644
--- a/test/windows/Common.h
+++ b/test/windows/Common.h
@@ -120,6 +120,7 @@ using namespace std::chrono_literals;
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"WslServiceProxyStub.dll") \
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslhost.exe") \
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslrelay.exe") \
+ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslpluginhost.exe") \
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslconfig.exe") \
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wsl.exe") \
TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslg.exe") \
@@ -501,6 +502,13 @@ std::vector LxssSplitString(_In_ const std::wstring& string, _In_
void RestartWslService();
+DWORD GetWslServicePid();
+
+// Returns the PID of wslservice, asserting it is already RUNNING. Unlike
+// GetWslServicePid it never starts the service, so a crashed/stopped service
+// fails at capture instead of being silently restarted with a new PID.
+DWORD GetWslServiceRunningPid();
+
wil::unique_handle GetNonElevatedToken(TOKEN_TYPE Type = TokenPrimary);
std::wstring LxssWriteWslConfig(const std::wstring& Content);
diff --git a/test/windows/InstallerTests.cpp b/test/windows/InstallerTests.cpp
index 91d9c699c1..13f77be5c4 100644
--- a/test/windows/InstallerTests.cpp
+++ b/test/windows/InstallerTests.cpp
@@ -836,7 +836,7 @@ class InstallerTests
return flags;
};
- const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe"};
+ const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe"};
for (const auto& e : executables)
{
auto fullPath = installPath.value() + e;
diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp
index 61b28a74e2..3fb16b9166 100644
--- a/test/windows/PluginTests.cpp
+++ b/test/windows/PluginTests.cpp
@@ -34,6 +34,17 @@ class PluginTests
std::wstring pluginDll;
std::optional config;
+ // Returns true if the file does not exist, cannot be stat'd, or is zero
+ // bytes. Uses the non-throwing std::filesystem overloads so racy file
+ // deletion / access denials between the existence check and size query
+ // can't surface as exceptions out of test code.
+ static bool LogFileAbsentOrEmpty(const std::filesystem::path& path)
+ {
+ std::error_code ec;
+ const auto size = std::filesystem::file_size(path, ec);
+ return static_cast(ec) || size == 0;
+ }
+
WSL_TEST_CLASS(PluginTests)
TEST_CLASS_SETUP(TestClassSetup)
@@ -340,11 +351,13 @@ class PluginTests
WSL1_TEST_METHOD(SuccessWSL1)
{
- constexpr auto ExpectedOutput = LR"(Plugin loaded. TestMode=1)";
-
+ // Plugins are not loaded for WSL1-only sessions (no VM, no plugin hooks).
+ // Verify the plugin log file is absent/empty to assert no plugin code ran.
ConfigurePlugin(PluginTestType::Success);
StartWsl(0);
- ValidateLogFile(ExpectedOutput);
+
+ VERIFY_IS_TRUE(
+ LogFileAbsentOrEmpty(logFile), std::format(L"Expected plugin log file '{}' to be absent or empty for WSL1", logFile).c_str());
}
WSL2_TEST_METHOD(LoadFailureFatalWSL2)
@@ -363,13 +376,14 @@ class PluginTests
WSL1_TEST_METHOD(LoadFailureNonFatalWSL1)
{
- constexpr auto ExpectedOutput =
- LR"(Plugin loaded. TestMode=2
- OnLoad: E_UNEXPECTED)";
-
+ // Plugins are not loaded for WSL1-only sessions, so a plugin that
+ // would fail to load on WSL2 has no effect on WSL1. Assert the plugin
+ // log file is absent/empty to confirm no plugin code ran.
ConfigurePlugin(PluginTestType::FailToLoad);
StartWsl(0);
- ValidateLogFile(ExpectedOutput);
+
+ VERIFY_IS_TRUE(
+ LogFileAbsentOrEmpty(logFile), std::format(L"Expected plugin log file '{}' to be absent or empty for WSL1", logFile).c_str());
}
WSL2_TEST_METHOD(VmStartFailure)
@@ -598,6 +612,7 @@ class PluginTests
StartWsl(0);
ValidateLogFile(ExpectedOutput);
}
+
static wil::com_ptr OpenWslcSessionManager()
{
wil::com_ptr sessionManager;
@@ -744,7 +759,8 @@ class PluginTests
constexpr auto ExpectedOutput =
LR"(Plugin loaded. TestMode=19
WSLC Session created, name=plugin-wslc-rejected, id=*, pid=*, token=set, sid=set
- OnWslcSessionCreated: ERROR_ACCESS_DENIED)";
+ OnWslcSessionCreated: ERROR_ACCESS_DENIED
+ WSLC Session stopping, name=plugin-wslc-rejected, id=*)";
ValidateLogFile(ExpectedOutput);
}
@@ -776,6 +792,150 @@ class PluginTests
ValidateLogFile(ExpectedOutput);
}
+ // --- PR #40120 (out-of-process plugin host) coverage ---
+ //
+ // These tests validate the isolation and callback model:
+ // * HostCrashIsFatal — host process crash aborts the guarded operation (fatal).
+ // * ConcurrentCallbacks — many plugin threads issue API callbacks during a hook.
+ // * AsyncApiCallFromWorker — plugin API call from a worker thread OUTSIDE any hook.
+ // * CallbacksDuringTerminationDoNotCrash — callbacks racing VM teardown fail gracefully, never crash.
+ //
+ // Callback model (PluginCallPump): a plugin API callback that arrives WHILE a
+ // notification hook is in flight is marshaled back onto the notifying thread
+ // (which holds the session's recursive m_instanceLock), reproducing in-process
+ // re-entrancy; a callback that arrives with no hook in flight runs directly on
+ // the RPC thread (taking m_instanceLock itself). There is no separate callback
+ // lock — m_instanceLock alone serializes callbacks against VM teardown.
+
+ WSL2_TEST_METHOD(HostCrashIsFatal)
+ {
+ // A plugin host process crash during a veto hook (OnVmStarted) is fatal:
+ // the guarded operation is aborted with a fatal plugin error rather than
+ // silently continuing (matching the pre-refactor behavior where an
+ // in-process plugin crash took down WSL). The exact HRESULT is whichever
+ // RPC/CO_E_* code COM surfaces for the dead host, so assert on the
+ // user-facing prefix rather than an exact error code.
+ ConfigurePlugin(PluginTestType::HostCrash);
+
+ constexpr auto fatalPrefix = L"A fatal error was returned by plugin 'TestPlugin'";
+
+ auto [output, error] = LxsstuLaunchWslAndCaptureOutput(L"echo -n OK", -1);
+ VERIFY_IS_TRUE(
+ output.find(fatalPrefix) != std::wstring::npos, std::format(L"Expected a fatal plugin error, got: '{}'", output).c_str());
+
+ // The crash latches a fatal plugin error, so a subsequent operation also
+ // fails — the host is not re-activated for this service lifetime. Whether
+ // it fails fast via the latch (service still up) or re-crashes after the
+ // on-demand service idle-restarts and reloads the plugin, the user-facing
+ // result is the same fatal plugin error.
+ auto [output2, error2] = LxsstuLaunchWslAndCaptureOutput(L"echo -n OK", -1);
+ VERIFY_IS_TRUE(
+ output2.find(fatalPrefix) != std::wstring::npos,
+ std::format(L"Expected a fatal plugin error on the second attempt, got: '{}'", output2).c_str());
+
+ // Confirm the plugin actually ran up to the crash point.
+ StopWslService();
+
+ std::wifstream file(logFile);
+ const auto fileContent = std::wstring{std::istreambuf_iterator(file), {}};
+ LogInfo("Logfile: %ls", fileContent.c_str());
+
+ VERIFY_IS_TRUE(
+ fileContent.find(L"Crashing host") != std::wstring::npos,
+ std::format(L"Expected the plugin to reach the crash point, log: '{}'", fileContent).c_str());
+ }
+
+ WSL2_TEST_METHOD(ConcurrentCallbacks)
+ {
+ // The hook spawns 4 threads behind two gates: the first ensures all
+ // workers are spawned, the second rendezvouses them at the callback
+ // boundary so maxConcurrent deterministically reaches 4 (proving the
+ // plugin issues 4 callbacks concurrently). Per-thread logging is
+ // intentionally suppressed; only the deterministic summary line is
+ // asserted. Lifecycle pre/post lines validate the hook itself ran
+ // to completion.
+ constexpr auto ExpectedOutput =
+ LR"(Plugin loaded. TestMode=23
+ VM created (settings->CustomConfigurationFlags=0)
+ Concurrent callbacks complete: success=4 failures=0 maxConcurrent=4
+ Distribution started, name=test_distro, package=, PidNs=*, InitPid=*, Flavor=debian, Version=13
+ Distribution Stopping, name=test_distro, package=, PidNs=*, Flavor=debian, Version=13
+ VM Stopping)";
+
+ ConfigurePlugin(PluginTestType::ConcurrentApiCalls);
+ StartWsl(0);
+ ValidateLogFile(ExpectedOutput);
+ }
+
+ WSL2_TEST_METHOD(AsyncApiCallFromWorker)
+ {
+ // The worker thread is created in OnDistroStarted, sleeps briefly to
+ // ensure it runs after the hook has returned, then calls
+ // ExecuteBinaryInDistribution. It's joined unconditionally in
+ // OnDistroStopping (which defers its own "Distribution Stopping" log
+ // until after the join), so the worker-output line is guaranteed to
+ // precede "Distribution Stopping" in the log.
+ //
+ // The wsl command sleeps for 1s so the distro is alive long enough
+ // for the post-hook worker call to land before shutdown.
+ constexpr auto ExpectedOutput =
+ LR"(Plugin loaded. TestMode=24
+ VM created (settings->CustomConfigurationFlags=0)
+ Distribution started, name=test_distro, package=, PidNs=*, InitPid=*, Flavor=debian, Version=13
+ Async worker output: hello-from-worker
+ Distribution Stopping, name=test_distro, package=, PidNs=*, Flavor=debian, Version=13
+ VM Stopping)";
+
+ ConfigurePlugin(PluginTestType::AsyncApiCall);
+
+ auto [output, error] = LxsstuLaunchWslAndCaptureOutput(L"sh -c \"sleep 1; echo -n OK\"", 0);
+ VERIFY_ARE_EQUAL(output, L"OK");
+
+ ValidateLogFile(ExpectedOutput);
+ }
+
+ WSL2_TEST_METHOD(CallbacksDuringTerminationDoNotCrash)
+ {
+ // Drain test: 4 workers loop ExecuteBinaryInDistribution (with /bin/true,
+ // sub-ms callback) while the distro is alive. They keep calling across
+ // OnDistroStopping and _VmTerminate. Because callbacks run under (or block
+ // on) the session's recursive m_instanceLock, _VmTerminate's m_utilityVm
+ // reset is naturally serialized against them: a racing callback either
+ // runs before the reset (valid VM) or after it (returns E_NOT_VALID_STATE).
+ // Either way the service must not crash. After OnVmStopping signals
+ // wind-down, workers run a bounded number of further iterations and exit
+ // (see Plugin.cpp), so termination is deterministic and no worker can
+ // revive against a later VM.
+ //
+ // The post-shutdown StartWsl below triggers a second OnDistroStarted that
+ // joins the finished workers, so this test needs no fixed sleep.
+ //
+ // Scope:
+ // - Validates: callbacks racing teardown never crash; service survives.
+ // - Does NOT validate: drain semantics when a callback is genuinely
+ // stuck (e.g. service-side CreateLinuxProcess waiting on a hung
+ // Linux init). That requires cancellation plumbing through
+ // WslCoreInstance::CreateLinuxProcess and is tracked separately.
+ ConfigurePlugin(PluginTestType::CallbackDuringTermination);
+
+ // Use a 1s sleep so workers ramp up while the distro is alive.
+ auto [output, error] = LxsstuLaunchWslAndCaptureOutput(L"sh -c \"sleep 1; echo -n OK\"", 0);
+ VERIFY_ARE_EQUAL(output, L"OK");
+
+ const DWORD pidBefore = GetWslServiceRunningPid();
+ VERIFY_IS_TRUE(pidBefore != 0);
+
+ // Trigger VM termination with workers still running.
+ WslShutdown();
+
+ // Subsequent WSL command must still succeed — service survived. Its
+ // OnDistroStarted also joins the wound-down workers.
+ StartWsl(0);
+
+ const DWORD pidAfter = GetWslServiceRunningPid();
+ VERIFY_ARE_EQUAL(pidBefore, pidAfter);
+ }
+
// This test must run last so it doesn't break test cases that depends on plugin signature.
WSL2_TEST_METHOD(InvalidPluginSignature)
{
diff --git a/test/windows/PluginTests.h b/test/windows/PluginTests.h
index c9a779031c..9fb68a00ab 100644
--- a/test/windows/PluginTests.h
+++ b/test/windows/PluginTests.h
@@ -41,7 +41,11 @@ enum class PluginTestType
WslcSuccess,
WslcSessionRejected,
WslcContainerRejected,
- WslcImagePull
+ WslcImagePull,
+ HostCrash,
+ ConcurrentApiCalls,
+ AsyncApiCall,
+ CallbackDuringTermination
};
constexpr auto c_testType = L"TestType";
diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp
index 17087e6edc..30ea8e88a0 100644
--- a/test/windows/testplugin/Plugin.cpp
+++ b/test/windows/testplugin/Plugin.cpp
@@ -16,6 +16,9 @@ Module Name:
#include "WslPluginApi.h"
#include "wslc_schema.h"
+#include
+#include
+
#include "PluginTests.h"
using namespace wsl::windows::common::registry;
@@ -31,6 +34,47 @@ PluginTestType g_testType = PluginTestType::Invalid;
std::optional g_previousInitPid;
+// Serializes writes to g_logfile from multiple threads in modes that spawn
+// worker threads (ConcurrentApiCalls, AsyncApiCall, CallbackDuringTermination).
+// Hook-thread writes that don't overlap with worker writes don't need to take
+// this — but it's harmless to do so.
+std::mutex g_logMutex;
+
+void LogLine(const std::string& line)
+{
+ std::lock_guard guard{g_logMutex};
+ g_logfile << line << std::endl;
+}
+
+// State for AsyncApiCall: worker thread launched in OnDistroStarted, joined
+// in OnDistroStopping. The promise carries the result so the hook can log it
+// after the join. The future is retrieved exactly once (in OnDistroStarted)
+// and consumed in OnDistroStopping — std::promise::get_future() can only be
+// called once per promise instance.
+std::optional g_asyncWorker;
+std::optional> g_asyncWorkerResult;
+std::future g_asyncWorkerFuture;
+std::string g_asyncWorkerOutput;
+
+// State for CallbackDuringTermination. Workers loop ExecuteBinaryInDistribution
+// while the distro is alive. OnVmStopping (which fires before _VmTerminate's
+// exclusive-lock drain) sets g_drainWindDown; workers then keep racing the drain
+// for a bounded number of iterations and exit. Exiting on a fixed count (rather
+// than on a post-reset failure) keeps shutdown deterministic and prevents a
+// worker from "reviving" against a VM that a later StartWsl creates with the
+// same session/distro IDs.
+//
+// g_drainWorkersStarted prevents the post-shutdown StartWsl (which the test uses
+// to verify the service survived) from spawning a fresh batch; that second
+// OnDistroStarted instead joins the finished workers. Threads are kept joinable
+// in g_drainWorkers so the test needs no fixed sleep.
+constexpr int c_drainWindDownIterations = 100;
+std::atomic g_drainSuccess{0};
+std::atomic g_drainFailures{0};
+std::atomic g_drainWindDown{false};
+std::atomic g_drainWorkersStarted{false};
+std::vector g_drainWorkers;
+
std::vector ReadFromSocket(SOCKET socket)
{
// Simplified error handling for the sake of the demo.
@@ -123,7 +167,7 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet
}
result = g_api->ExecuteBinary(0xcafe, arguments[0], arguments.data(), &socket);
- if (result != RPC_E_DISCONNECTED)
+ if (result != HRESULT_FROM_WIN32(ERROR_NOT_FOUND))
{
g_logfile << "Unexpected error for ExecuteBinary(): " << result << std::endl;
return E_ABORT;
@@ -182,12 +226,134 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet
return S_OK;
}
+ else if (g_testType == PluginTestType::HostCrash)
+ {
+ // Validate plugin host crash handling. Forcefully exit the host process
+ // so the COM RPC returns one of the HRESULTs in IsHostCrash
+ // (RPC_E_DISCONNECTED / RPC_E_SERVER_DIED / ...). The service treats a
+ // crash during a veto hook as a fatal plugin error and aborts the
+ // operation.
+ LogLine("Crashing host");
+ g_logfile.flush();
+ TerminateProcess(GetCurrentProcess(), 1);
+ // Unreachable.
+ return E_UNEXPECTED;
+ }
+ else if (g_testType == PluginTestType::ConcurrentApiCalls)
+ {
+ // Validate service-side callbacks issued by multiple plugin threads
+ // during a hook. N threads call MountFolder + ExecuteBinary via a
+ // start-gate so the RPCs are all in flight at once.
+ //
+ // maxConcurrent records how many workers are simultaneously at the
+ // plugin-side callback boundary (via a second rendezvous). Reaching N
+ // proves the plugin issues N callbacks concurrently. It does NOT prove
+ // the service executes them in parallel: with the PluginCallPump the
+ // service marshals these onto the single notifying thread and runs them
+ // serially. The test asserts only that all N succeed, which is the
+ // strongest honest black-box assertion.
+ constexpr int N = 4;
+
+ std::filesystem::path modulePath = wil::GetModuleFileNameW(wil::GetModuleInstanceHandle()).get();
+ const auto mountSource = modulePath.parent_path().wstring();
+
+ std::mutex gateMutex;
+ std::condition_variable gateCv;
+ int arrived = 0;
+ bool released = false;
+ int inFlight = 0;
+ int maxConcurrent = 0;
+
+ std::atomic successes{0};
+ std::atomic failures{0};
+
+ const auto worker = [&](int index) {
+ // Gate 1: wait for all workers to be spawned so they overlap.
+ {
+ std::unique_lock lock{gateMutex};
+ ++arrived;
+ if (arrived == N)
+ {
+ released = true;
+ gateCv.notify_all();
+ }
+ else
+ {
+ gateCv.wait(lock, [&]() { return released; });
+ }
+ }
+
+ // Gate 2: rendezvous right before the callbacks so all N are at the
+ // boundary at once, making maxConcurrent == N deterministic.
+ {
+ std::unique_lock lock{gateMutex};
+ ++inFlight;
+ maxConcurrent = std::max(maxConcurrent, inFlight);
+ if (inFlight == N)
+ {
+ gateCv.notify_all();
+ }
+ else
+ {
+ gateCv.wait(lock, [&]() { return inFlight == N; });
+ }
+ }
+
+ const auto linuxPath = L"/test-plugin/concurrent-" + std::to_wstring(index);
+ const auto mountName = L"test-plugin-concurrent-" + std::to_wstring(index);
+ HRESULT hr = g_api->MountFolder(Session->SessionId, mountSource.c_str(), linuxPath.c_str(), true, mountName.c_str());
+ if (FAILED(hr))
+ {
+ ++failures;
+ return;
+ }
+
+ wil::unique_socket socket;
+ std::vector args = {"/bin/true", nullptr};
+ hr = g_api->ExecuteBinary(Session->SessionId, args[0], args.data(), &socket);
+ if (FAILED(hr))
+ {
+ ++failures;
+ return;
+ }
+
+ ++successes;
+ };
+
+ std::vector threads;
+ threads.reserve(N);
+ for (int i = 0; i < N; ++i)
+ {
+ threads.emplace_back(worker, i);
+ }
+ for (auto& t : threads)
+ {
+ t.join();
+ }
+
+ LogLine(
+ "Concurrent callbacks complete: success=" + std::to_string(successes.load()) +
+ " failures=" + std::to_string(failures.load()) + " maxConcurrent=" + std::to_string(maxConcurrent));
+
+ if (failures.load() != 0)
+ {
+ return E_FAIL;
+ }
+ }
return S_OK;
}
HRESULT OnVmStopping(const WSLSessionInformation* Session)
{
+ if (g_testType == PluginTestType::CallbackDuringTermination)
+ {
+ // Signal drain workers to begin a bounded wind-down. Fires before
+ // _VmTerminate resets m_utilityVm, so workers keep racing teardown for
+ // a fixed number of iterations before exiting.
+ g_drainWindDown = true;
+ }
+
g_logfile << "VM Stopping" << std::endl;
if (g_testType == PluginTestType::FailToStopVm)
@@ -283,16 +449,169 @@ HRESULT OnDistroStarted(const WSLSessionInformation* Session, const WSLDistribut
g_logfile << "Invalid distro launch returned: "
<< g_api->ExecuteBinaryInDistribution(Session->SessionId, &guid, arguments[0], arguments.data(), &socket) << std::endl;
}
+ else if (g_testType == PluginTestType::AsyncApiCall)
+ {
+ // Validate plugin API calls from a worker thread that outlives the
+ // hook. The worker thread is joined in OnDistroStopping — joining is
+ // unconditional (no timeout) because letting the worker outlive
+ // g_pluginHost (cleared in ~PluginHost) would dereference freed memory.
+ g_asyncWorkerOutput.clear();
+ g_asyncWorkerResult.emplace();
+ g_asyncWorkerFuture = g_asyncWorkerResult->get_future();
+
+ const DWORD sessionId = Session->SessionId;
+ const GUID distroId = Distribution->Id;
+
+ g_asyncWorker.emplace([sessionId, distroId]() {
+ // Sleep briefly so the call is guaranteed to happen after the
+ // hook has returned — exercises the cross-apartment callback
+ // path from a non-hook thread that hasn't called CoInitializeEx.
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+
+ wil::unique_socket socket;
+ std::vector args = {"/bin/echo", "hello-from-worker", nullptr};
+ const HRESULT hr = g_api->ExecuteBinaryInDistribution(sessionId, &distroId, args[0], args.data(), &socket);
+
+ if (SUCCEEDED(hr))
+ {
+ const auto output = ReadFromSocket(socket.get());
+ std::string captured(output.begin(), output.end());
+ // Strip trailing newline added by /bin/echo so the log line
+ // doesn't get split when ValidateLogFile splits on '\n'.
+ while (!captured.empty() && (captured.back() == '\n' || captured.back() == '\r'))
+ {
+ captured.pop_back();
+ }
+ std::lock_guard guard{g_logMutex};
+ g_asyncWorkerOutput = std::move(captured);
+ }
+
+ g_asyncWorkerResult->set_value(hr);
+ });
+ }
+ else if (g_testType == PluginTestType::CallbackDuringTermination)
+ {
+ // Validate that callbacks racing VM teardown never crash the service.
+ // Workers keep calling into the service across OnDistroStopping /
+ // _VmTerminate; each callback runs under (or blocks on) the session's
+ // recursive m_instanceLock, so it is naturally serialized against
+ // m_utilityVm.reset() and fails gracefully if it lands after teardown.
+ // Workers then wind down deterministically (see globals above).
+ //
+ // Scope: this test exercises only the *happy-path* race — the
+ // callback (/bin/true) returns in sub-millisecond, so workers are
+ // almost always between iterations when teardown runs. It is *not* a
+ // regression test for the hung-callback case, where a service-side
+ // callback is stuck inside CreateLinuxProcess waiting on a non-responsive
+ // Linux init; that scenario requires termination-event plumbing through
+ // WslCoreInstance::CreateLinuxProcess and is tracked separately.
+ constexpr int N = 4;
+
+ // Spawn at most once. The post-shutdown StartWsl in
+ // CallbacksDuringTerminationDoNotCrash triggers another
+ // OnDistroStarted; join the (already wound-down) workers there instead
+ // of starting a fresh batch.
+ if (g_drainWorkersStarted.exchange(true))
+ {
+ for (auto& t : g_drainWorkers)
+ {
+ if (t.joinable())
+ {
+ t.join();
+ }
+ }
+ g_drainWorkers.clear();
+ return S_OK;
+ }
+
+ g_drainSuccess = 0;
+ g_drainFailures = 0;
+ g_drainWindDown = false;
+
+ const DWORD sessionId = Session->SessionId;
+ const GUID distroId = Distribution->Id;
+
+ for (int i = 0; i < N; ++i)
+ {
+ g_drainWorkers.emplace_back([sessionId, distroId]() {
+ int windDownRemaining = -1;
+ while (true)
+ {
+ wil::unique_socket socket;
+ std::vector args = {"/bin/true", nullptr};
+ const HRESULT hr = g_api->ExecuteBinaryInDistribution(sessionId, &distroId, args[0], args.data(), &socket);
+ if (SUCCEEDED(hr))
+ {
+ ++g_drainSuccess;
+ }
+ else
+ {
+ ++g_drainFailures;
+ }
+
+ if (g_drainWindDown.load())
+ {
+ if (windDownRemaining < 0)
+ {
+ windDownRemaining = c_drainWindDownIterations;
+ }
+ if (windDownRemaining-- == 0)
+ {
+ return;
+ }
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ });
+ }
+ }
return S_OK;
}
HRESULT OnDistroStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution)
{
- g_logfile << "Distribution Stopping, name=" << wsl::shared::string::WideToMultiByte(Distribution->Name)
- << ", package=" << wsl::shared::string::WideToMultiByte(Distribution->PackageFamilyName)
- << ", PidNs=" << Distribution->PidNamespace << ", Flavor=" << wsl::shared::string::WideToMultiByte(Distribution->Flavor)
- << ", Version=" << wsl::shared::string::WideToMultiByte(Distribution->Version) << std::endl;
+ // For AsyncApiCall we defer the "Distribution Stopping" line until after
+ // the worker thread has been joined, so the worker's "Async worker output"
+ // line is guaranteed to appear before it in the log.
+ auto logDistroStopping = [&]() {
+ g_logfile << "Distribution Stopping, name=" << wsl::shared::string::WideToMultiByte(Distribution->Name)
+ << ", package=" << wsl::shared::string::WideToMultiByte(Distribution->PackageFamilyName)
+ << ", PidNs=" << Distribution->PidNamespace << ", Flavor=" << wsl::shared::string::WideToMultiByte(Distribution->Flavor)
+ << ", Version=" << wsl::shared::string::WideToMultiByte(Distribution->Version) << std::endl;
+ };
+
+ if (g_testType == PluginTestType::AsyncApiCall)
+ {
+ if (g_asyncWorker.has_value())
+ {
+ // Unconditional join — letting the worker outlive g_pluginHost
+ // (cleared in ~PluginHost) would dereference freed memory.
+ g_asyncWorker->join();
+ g_asyncWorker.reset();
+
+ HRESULT workerHr = S_OK;
+ if (g_asyncWorkerFuture.valid())
+ {
+ workerHr = g_asyncWorkerFuture.get();
+ g_asyncWorkerResult.reset();
+ }
+
+ if (SUCCEEDED(workerHr))
+ {
+ LogLine("Async worker output: " + g_asyncWorkerOutput);
+ }
+ else
+ {
+ LogLine("Async worker failed: " + std::to_string(workerHr));
+ }
+ }
+
+ logDistroStopping();
+ return S_OK;
+ }
+
+ logDistroStopping();
if (g_testType == PluginTestType::FailToStopDistro)
{
@@ -550,7 +869,7 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin
THROW_HR_IF(E_UNEXPECTED, !g_logfile);
g_testType = static_cast(ReadDword(key.get(), nullptr, c_testType, static_cast(PluginTestType::Invalid)));
- THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::WslcImagePull));
+ THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::CallbackDuringTermination));
g_logfile << "Plugin loaded. TestMode=" << static_cast(g_testType) << std::endl;
g_api = Api;