From cc1f84620277d4a98f7dd8b1ada29db34d340593 Mon Sep 17 00:00:00 2001 From: Ben Hillis Date: Tue, 12 May 2026 11:10:49 -0700 Subject: [PATCH 1/3] Isolate plugins in an out-of-process COM host WSL plugin DLLs are moved out of wslservice.exe into a separate wslpluginhost.exe COM server so plugin code can no longer crash or destabilize the service. Each plugin is activated in its own host process (CLSCTX_LOCAL_SERVER, SYSTEM-only via AppID) and reached through a versioned COM interface defined in WslPluginHost.idl. All hosts are tied to a service-owned job object and terminate when wslservice exits. The plugin API is unchanged; existing plugins run unmodified. A crashing or disconnected host is classified by IsHostCrash (RPC_E_DISCONNECTED, RPC_E_SERVER_DIED[_DNE], CO_E_OBJNOTCONNECTED, RPC_S_SERVER_UNAVAILABLE, RPC_S_CALL_FAILED[_DNE]); the service logs it and continues instead of treating it as a fatal plugin error. RPC_E_CALL_REJECTED is intentionally excluded as a transient busy state rather than a dead host. Plugin->service callbacks (MountFolder, ExecuteBinary, and the WSLC session APIs) arrive on a different COM thread than the outbound hook, so they cannot re-enter the lock held during the hook: - VM path: LxssUserSessionImpl guards callbacks with a shared_mutex (shared for callbacks, exclusive in _VmTerminate after OnVmStopping drains in-flight callbacks before the utility VM is destroyed). - WSLC path: PluginManager resolves sessions through its own reference map under a dedicated lock, and WSLCSessionManager releases its session lock before any plugin notification fires, so callbacks never re-enter the session lock. A session is registered in the reference map but not published until OnWslcSessionCreated succeeds, so a vetoed or race-lost session is never handed out. Proxy/stub is consolidated into wslserviceproxystub.dll. One new exe, no new DLLs. Tests - HostCrashIsolation: kills wslpluginhost.exe mid-OnVmStarted and verifies the service survives and m_initOnce stays sticky. - ConcurrentCallbacks: four plugin threads hammer MountFolder and ExecuteBinary, exercising the shared callback lock. - AsyncApiCallFromWorker: a plugin worker thread calls into the service post-hook (cross-apartment, non-COM-initialized). - CallbacksDuringTerminationDoNotCrash: worker threads race _VmTerminate's exclusive lock and VM teardown, then wind down deterministically after OnVmStopping signals them and are joined on the next session start. - Existing WSL1 plugin tests broadened alongside the refactor. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .pipelines/build-stage.yml | 4 +- CMakeLists.txt | 1 + msipackage/CMakeLists.txt | 4 +- msipackage/package.wix.in | 30 + src/windows/common/precomp.h | 1 + src/windows/service/exe/CMakeLists.txt | 2 +- src/windows/service/exe/LxssUserSession.cpp | 73 +- src/windows/service/exe/LxssUserSession.h | 44 +- src/windows/service/exe/PluginManager.cpp | 1511 ++++++++++++----- src/windows/service/exe/PluginManager.h | 255 ++- src/windows/service/exe/ServiceMain.cpp | 6 + .../service/exe/WSLCSessionManager.cpp | 356 ++-- src/windows/service/exe/WSLCSessionManager.h | 84 +- src/windows/service/inc/CMakeLists.txt | 2 + src/windows/service/inc/WslPluginHost.idl | 260 +++ src/windows/service/stub/CMakeLists.txt | 4 +- src/windows/wslinstall/DllMain.cpp | 2 +- src/windows/wslpluginhost/CMakeLists.txt | 1 + src/windows/wslpluginhost/exe/CMakeLists.txt | 25 + src/windows/wslpluginhost/exe/PluginHost.cpp | 947 +++++++++++ src/windows/wslpluginhost/exe/PluginHost.h | 222 +++ src/windows/wslpluginhost/exe/main.cpp | 139 ++ src/windows/wslpluginhost/exe/main.rc | 25 + src/windows/wslpluginhost/exe/resource.h | 15 + test/windows/Common.cpp | 43 + test/windows/Common.h | 8 + test/windows/InstallerTests.cpp | 2 +- test/windows/PluginTests.cpp | 178 +- test/windows/PluginTests.h | 6 +- test/windows/testplugin/Plugin.cpp | 329 +++- 30 files changed, 3960 insertions(+), 619 deletions(-) create mode 100644 src/windows/service/inc/WslPluginHost.idl create mode 100644 src/windows/wslpluginhost/CMakeLists.txt create mode 100644 src/windows/wslpluginhost/exe/CMakeLists.txt create mode 100644 src/windows/wslpluginhost/exe/PluginHost.cpp create mode 100644 src/windows/wslpluginhost/exe/PluginHost.h create mode 100644 src/windows/wslpluginhost/exe/main.cpp create mode 100644 src/windows/wslpluginhost/exe/main.rc create mode 100644 src/windows/wslpluginhost/exe/resource.h diff --git a/.pipelines/build-stage.yml b/.pipelines/build-stage.yml index 0d0c30b175..61d4a6dbc7 100644 --- a/.pipelines/build-stage.yml +++ b/.pipelines/build-stage.yml @@ -32,8 +32,8 @@ parameters: - name: targets type: object default: - - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin;wslcsession;wslc;wsltests;wslcsdk" - pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll,wslcsession.exe,wslc.exe,wslcsdk.dll" + - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslpluginhost;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin;wslcsession;wslc;wsltests;wslcsdk" + pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslpluginhost.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll,wslcsession.exe,wslc.exe,wslcsdk.dll" - target: "msixgluepackage" pattern: "gluepackage.msix" - target: "msipackage;wslcsdkcs" diff --git a/CMakeLists.txt b/CMakeLists.txt index c88fbb8e90..d0b607f423 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -545,6 +545,7 @@ add_subdirectory(src/windows/wsl) add_subdirectory(src/windows/wslg) add_subdirectory(src/windows/wslhost) add_subdirectory(src/windows/wslrelay) +add_subdirectory(src/windows/wslpluginhost) add_subdirectory(src/windows/wslinstall) add_subdirectory(src/windows/wslc) add_subdirectory(src/windows/WslcSDK) diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt index 99586c9727..ab025346e9 100644 --- a/msipackage/CMakeLists.txt +++ b/msipackage/CMakeLists.txt @@ -17,7 +17,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX ${BIN}/package.wix) set(CAB_CACHE ${BIN}/cab) -set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll;wslc.exe;wslcsession.exe) +set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslpluginhost.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll;wslc.exe;wslcsession.exe) if (WSL_BUILD_WSL_SETTINGS) list(APPEND WINDOWS_BINARIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") endif() @@ -57,7 +57,7 @@ add_custom_command( add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) -add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslc wslcsession) +add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslpluginhost wslserviceproxystub init initramfs wslinstall msixgluepackage wslc wslcsession) if (WSL_BUILD_WSL_SETTINGS) add_dependencies(msipackage wslsettings libwsl) diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 6b2cd0c5e9..38934c076e 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -31,6 +31,7 @@ + @@ -177,6 +178,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/windows/common/precomp.h b/src/windows/common/precomp.h index c7b53ee547..6d098a5c92 100644 --- a/src/windows/common/precomp.h +++ b/src/windows/common/precomp.h @@ -84,6 +84,7 @@ Module Name: #include #include #include +#include #include #include #include diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index c7b4ebb0dd..841883579e 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -58,7 +58,7 @@ set(HEADERS WSLCPluginNotifier.h) add_executable(wslservice ${SOURCES} ${HEADERS}) -add_dependencies(wslservice wslserviceidl wslservicemc) +add_dependencies(wslservice wslserviceidl wslservicemc wslpluginhostidl) add_compile_definitions(__WRL_CLASSIC_COM__) add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__) add_compile_definitions(USE_COM_CONTEXT_DEF=1) diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp index ba22e10e03..ac6f5ed3d3 100644 --- a/src/windows/service/exe/LxssUserSession.cpp +++ b/src/windows/service/exe/LxssUserSession.cpp @@ -2635,13 +2635,18 @@ std::shared_ptr LxssUserSessionImpl::_CreateInstance(_In_op registration.Write(Property::OsVersion, distributionInfo->Version); } - // This needs to be done before plugins are notifed because they might try to run a command inside the distribution. - m_runningInstances[registration.Id()] = instance; + // This needs to be done before plugins are notified because they might try to run a command inside the distribution. + { + std::unique_lock callbackLock(m_callbackLock); + m_runningInstances[registration.Id()] = instance; + } if (version == LXSS_WSL_VERSION_2) { - auto cleanupOnFailure = - wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { m_runningInstances.erase(registration.Id()); }); + auto cleanupOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + std::unique_lock callbackLock(m_callbackLock); + m_runningInstances.erase(registration.Id()); + }); m_pluginManager.OnDistributionStarted(&m_session, instance->DistributionInformation()); cleanupOnFailure.release(); } @@ -2877,7 +2882,13 @@ void LxssUserSessionImpl::_CreateVm() m_vmId.store(vmId); // Create the utility VM and register for callbacks. - m_utilityVm = WslCoreVm::Create(m_userToken, std::move(config), vmId); + // Publish m_utilityVm under m_callbackLock exclusive to honor the dual-lock + // invariant for mutations of m_utilityVm; this is uncontended here because + // no plugin callbacks can race against initial creation. + { + std::unique_lock callbackLock(m_callbackLock); + m_utilityVm = WslCoreVm::Create(m_userToken, std::move(config), vmId); + } if (m_httpProxyStateTracker) { @@ -3608,17 +3619,26 @@ bool LxssUserSessionImpl::_TerminateInstanceInternal(_In_ LPCGUID DistroGuid, _I m_pluginManager.OnDistributionStopping(&m_session, wslcoreInstance->DistributionInformation()); } - instance->second->Stop(); + m_lifetimeManager.RemoveCallback(clientKey); - const auto clientId = instance->second->GetClientId(); + // Stop the instance and remove it from m_runningInstances atomically + // under m_callbackLock. This prevents plugin callbacks (which hold + // m_callbackLock shared) from finding a stopped-but-still-listed + // instance between Stop() and erase. + ULONG clientId; { - auto lock = m_terminatedInstanceLock.lock_exclusive(); - m_terminatedInstances.push_back(std::move(instance->second)); - } + std::unique_lock callbackLock(m_callbackLock); - m_lifetimeManager.RemoveCallback(clientKey); + instance->second->Stop(); + clientId = instance->second->GetClientId(); + + { + auto lock = m_terminatedInstanceLock.lock_exclusive(); + m_terminatedInstances.push_back(std::move(instance->second)); + } - m_runningInstances.erase(instance); + m_runningInstances.erase(instance); + } // If the instance that was terminated was a WSL2 instance, // check if the VM is now idle. @@ -3646,7 +3666,10 @@ void LxssUserSessionImpl::_UpdateInit(_In_ const LXSS_DISTRO_CONFIGURATION& Conf HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name) { - std::lock_guard lock(m_instanceLock); + // Shared lock prevents _VmTerminate from destroying the VM while we use it. + // Do NOT acquire m_instanceLock — callbacks arrive on a different COM thread + // from the notification thread that holds m_instanceLock. + std::shared_lock lock(m_callbackLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); m_utilityVm->MountRootNamespaceFolder(HostPath, GuestPath, ReadOnly, Name); @@ -3655,7 +3678,9 @@ HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* Socket) { - std::lock_guard lock(m_instanceLock); + // Shared lock prevents _VmTerminate from destroying the VM or instances + // while we use them. See MountRootNamespaceFolder for rationale. + std::shared_lock lock(m_callbackLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); if (Distro == nullptr) @@ -3664,9 +3689,16 @@ HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In } else { - const auto distro = _RunningInstance(Distro); - THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, !distro); - + // Look up the running instance directly instead of calling _RunningInstance, + // which accesses m_lockedDistributions (guarded only by m_instanceLock). + // m_runningInstances is safe to read under m_callbackLock (shared). + // The _EnsureNotLocked check is unnecessary here: _ConversionBegin removes + // a distribution from m_runningInstances before adding it to m_lockedDistributions, + // so a locked distribution will never be found in this lookup. + const auto instance = m_runningInstances.find(*Distro); + THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, instance == m_runningInstances.end()); + + const auto distro = instance->second; const auto wsl2Distro = dynamic_cast(distro.get()); THROW_HR_IF(WSL_E_WSL2_NEEDED, !wsl2Distro); @@ -3904,7 +3936,12 @@ void LxssUserSessionImpl::_VmTerminate() m_telemetryThread.join(); } - m_utilityVm.reset(); + // Acquire exclusive callback lock to wait for any in-flight plugin callbacks + // (MountRootNamespaceFolder, CreateLinuxProcess) to complete before destroying the VM. + { + std::unique_lock callbackLock(m_callbackLock); + m_utilityVm.reset(); + } m_vmId.store(GUID_NULL); // Reset the user's token since its lifetime is tied to the VM. diff --git a/src/windows/service/exe/LxssUserSession.h b/src/windows/service/exe/LxssUserSession.h index 6e2d416873..c938f223fb 100644 --- a/src/windows/service/exe/LxssUserSession.h +++ b/src/windows/service/exe/LxssUserSession.h @@ -310,6 +310,10 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce7e") LxssUserSession /// class LxssUserSessionImpl { + // Plugin callbacks arrive on a different COM RPC thread and use m_callbackLock + // (shared) instead of m_instanceLock to access m_utilityVm and m_runningInstances. + friend class wsl::windows::service::PluginHostCallbackImpl; + public: LxssUserSessionImpl(_In_ PSID userSid, _In_ DWORD sessionId, _Inout_ wsl::windows::service::PluginManager& pluginManager); virtual ~LxssUserSessionImpl(); @@ -363,11 +367,6 @@ class LxssUserSessionImpl /// void ClearDiskStateInRegistry(_In_opt_ LPCWSTR Disk); - /// - /// Start a process in the root namespace or in a user distribution. - /// - HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); - /// /// Enumerates registered distributions, optionally including ones that are /// currently being registered, unregistered, or converted. @@ -443,8 +442,6 @@ class LxssUserSessionImpl HRESULT MoveDistribution(_In_ LPCGUID DistroGuid, _In_ LPCWSTR Location); - HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); - /// /// Registers a distribution. /// @@ -533,6 +530,18 @@ class LxssUserSessionImpl static CreateLxProcessContext s_GetCreateProcessContext(_In_ const GUID& DistroGuid, _In_ bool SystemDistro); private: + /// + /// Plugin callback methods — called from PluginHostCallbackImpl on a COM RPC + /// thread during plugin notifications. These acquire m_callbackLock (shared) + /// instead of m_instanceLock, preventing _VmTerminate from destroying the VM + /// while a callback is in-flight. Access is restricted via friend declaration. + /// + _Requires_lock_not_held_(m_instanceLock) + HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); + + _Requires_lock_not_held_(m_instanceLock) + HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); + /// /// Adds a distro to the list of converting distros. /// @@ -794,7 +803,9 @@ class LxssUserSessionImpl std::recursive_timed_mutex m_instanceLock; /// - /// Contains the currently running utility VM's. + /// Contains the currently running instances. + /// Reads guarded by m_instanceLock OR m_callbackLock (shared). + /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). /// _Guarded_by_(m_instanceLock) std::map, wsl::windows::common::helpers::GuidLess> m_runningInstances; @@ -811,9 +822,24 @@ class LxssUserSessionImpl /// /// The running utility vm for WSL2 distributions. - /// + /// Reads guarded by m_instanceLock OR m_callbackLock (shared). + /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). + /// _Guarded_by_(m_instanceLock) std::unique_ptr m_utilityVm; + /// + /// Reader-writer lock protecting m_utilityVm and m_runningInstances for + /// plugin callbacks. Callbacks take a shared (read) lock; _VmTerminate and + /// instance mutations take an exclusive (write) lock. + /// + /// Mutations of m_runningInstances/m_utilityVm require BOTH m_instanceLock + /// AND m_callbackLock (exclusive). Reads are safe under either lock alone. + /// + /// Lock ordering: m_instanceLock → m_callbackLock (never reverse). + /// Callbacks must NEVER acquire m_instanceLock (deadlock with notification thread). + /// + std::shared_mutex m_callbackLock; + std::atomic m_vmId{GUID_NULL}; /// diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp index 0d6b6ccbb7..0948ae443b 100644 --- a/src/windows/service/exe/PluginManager.cpp +++ b/src/windows/service/exe/PluginManager.cpp @@ -9,6 +9,8 @@ Module Name: Abstract: This file contains the PluginManager helper class implementation. + Plugins are loaded in isolated wslpluginhost.exe processes via COM, + so a crashing plugin cannot take down the WSL service. --*/ @@ -16,330 +18,242 @@ Module Name: #include "install.h" #include "PluginManager.h" #include "WslPluginApi.h" +#include "WslPluginHost.h" #include "LxssUserSessionFactory.h" #include "WSLCSessionManager.h" using wsl::windows::common::Context; using wsl::windows::common::ExecutionContext; +using wsl::windows::service::PluginHostCallbackImpl; using wsl::windows::service::PluginManager; +// Acquire an apartment-local IWslPluginHost proxy for `plugin` (named `host`). +// On a host-process crash, log it and `continue` the surrounding loop. On any +// other failure to acquire (which would indicate a fundamental COM problem, +// not a plugin-reported issue), log the HRESULT and `continue` so a single +// busted plugin does not break the iteration for the others. Use only inside +// the per-plugin loops in PluginManager hook methods. +#define ACQUIRE_PLUGIN_HOST_OR_CONTINUE(plugin, host, stage) \ + Microsoft::WRL::ComPtr host; \ + { \ + const HRESULT _acqHr = AcquireHostProxy((plugin), &(host)); \ + if (FAILED(_acqHr)) \ + { \ + if (IsHostCrash(_acqHr)) \ + { \ + LogPluginHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ + } \ + else \ + { \ + LOG_HR_MSG(_acqHr, "Failed to acquire plugin host proxy for: '%ls'", (plugin).name.c_str()); \ + } \ + continue; \ + } \ + } + +// Same as ACQUIRE_PLUGIN_HOST_OR_CONTINUE, but for hook methods that surface a +// plugin error to abort the guarded operation (e.g. OnVmStarted). A host crash +// is still isolated (logged + `continue`), but any other acquisition failure is +// thrown so it is surfaced exactly like a plugin-reported error would be, rather +// than silently allowing the operation to proceed without consulting the plugin. +#define ACQUIRE_PLUGIN_HOST_OR_THROW(plugin, host, stage) \ + Microsoft::WRL::ComPtr host; \ + { \ + const HRESULT _acqHr = AcquireHostProxy((plugin), &(host)); \ + if (FAILED(_acqHr)) \ + { \ + if (IsHostCrash(_acqHr)) \ + { \ + LogPluginHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ + continue; \ + } \ + THROW_HR_MSG(_acqHr, "Failed to acquire plugin host proxy for: '%ls'", (plugin).name.c_str()); \ + } \ + } + constexpr auto c_pluginPath = L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Lxss\\Plugins"; -constexpr WSLVersion Version = {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision}; +// --- IWslPluginHostCallback implementation (service-side) --- +// These methods handle API calls from the plugin host process. -thread_local std::optional g_pluginErrorMessage; +// Returned to the plugin when a session cookie no longer resolves (the session +// was torn down). Deliberately not an RPC_* status: a plugin that propagates +// this from its hook must not be mistaken for a dead host by IsHostCrash. +constexpr HRESULT c_pluginSessionNotFound = HRESULT_FROM_WIN32(ERROR_NOT_FOUND); -extern "C" { -HRESULT MountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name) +STDMETHODIMP PluginHostCallbackImpl::MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name) try { - const auto session = FindSessionByCookie(Session); - RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - - auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name); + RETURN_HR_IF(E_INVALIDARG, WindowsPath == nullptr || LinuxPath == nullptr || Name == nullptr); WSL_LOG( - "PluginMountFolderCall", + "PluginCallbackMountFolderBegin", TraceLoggingValue(WindowsPath, "WindowsPath"), - TraceLoggingValue(LinuxPath, "LinuxPath"), - TraceLoggingValue(ReadOnly, "ReadOnly"), - TraceLoggingValue(Name, "Name"), - TraceLoggingValue(result, "Result")); + TraceLoggingValue(SessionId, "SessionId")); + const auto session = FindSessionByCookie(SessionId); + RETURN_HR_IF(c_pluginSessionNotFound, !session); - return result; -} -CATCH_RETURN(); - -HRESULT ExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) -try -{ - - const auto session = FindSessionByCookie(Session); - RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name); - auto result = session->CreateLinuxProcess(nullptr, Path, Arguments, Socket); + WSL_LOG("PluginCallbackMountFolderEnd", TraceLoggingValue(WindowsPath, "WindowsPath"), TraceLoggingValue(result, "Result")); - WSL_LOG("PluginExecuteBinaryCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); return result; } CATCH_RETURN(); -HRESULT PluginError(LPCWSTR UserMessage) +STDMETHODIMP PluginHostCallbackImpl::ExecuteBinary( + _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket) try { - const auto* context = ExecutionContext::Current(); - THROW_HR_IF(E_INVALIDARG, UserMessage == nullptr); - THROW_HR_IF_MSG( - E_ILLEGAL_METHOD_CALL, context == nullptr || WI_IsFlagClear(context->CurrentContext(), Context::Plugin), "Message: %ls", UserMessage); + RETURN_HR_IF(E_POINTER, Socket == nullptr); + *Socket = nullptr; + RETURN_HR_IF(E_INVALIDARG, Path == nullptr); + RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr); - // Logs when a WSL plugin hits an error and what that error message is - WSL_LOG_TELEMETRY("PluginError", PDT_ProductAndServicePerformance, TraceLoggingValue(UserMessage, "Message")); - - THROW_HR_IF(E_ILLEGAL_STATE_CHANGE, g_pluginErrorMessage.has_value()); - - g_pluginErrorMessage.emplace(UserMessage); - - return S_OK; -} -CATCH_RETURN(); - -HRESULT ExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) -try -{ - THROW_HR_IF(E_INVALIDARG, Distro == nullptr); + WSL_LOG("PluginCallbackExecuteBinaryBegin", TraceLoggingValue(Path, "Path"), TraceLoggingValue(SessionId, "SessionId")); + const auto session = FindSessionByCookie(SessionId); + WSL_LOG( + "PluginCallbackExecuteBinaryFoundSession", + TraceLoggingValue(Path, "Path"), + TraceLoggingValue(session != nullptr, "Found")); + RETURN_HR_IF(c_pluginSessionNotFound, !session); + + // Build NULL-terminated argument array expected by CreateLinuxProcess. + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); - const auto session = FindSessionByCookie(Session); - RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + WSL_LOG("PluginCallbackExecuteBinaryCallingCreateProcess", TraceLoggingValue(Path, "Path")); + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(nullptr, Path, args.data(), &sock); - auto result = session->CreateLinuxProcess(Distro, Path, Arguments, Socket); + WSL_LOG("PluginCallbackExecuteBinaryEnd", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); + if (SUCCEEDED(result)) + { + // Return socket as HANDLE — COM's system_handle marshaling will + // duplicate it into the host process automatically. + *Socket = reinterpret_cast(sock.release()); + } return result; } CATCH_RETURN(); -} -namespace { - -// Opaque wrapper around IWSLCProcess, handed out as WSLCProcessHandle to plugins. -struct WslcProcessWrapper -{ - wil::com_ptr Process; -}; - -wil::com_ptr ResolveWslcSession(WSLCSessionId Session) -{ - auto* mgr = wsl::windows::service::wslc::WSLCSessionManagerImpl::Instance(); - THROW_HR_IF(RPC_E_DISCONNECTED, mgr == nullptr); - - return mgr->FindSession(static_cast(Session)); -} - -} // namespace - -extern "C" { - -HRESULT WSLCMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly) +STDMETHODIMP PluginHostCallbackImpl::ExecuteBinaryInDistribution( + _In_ DWORD SessionId, + _In_ const GUID* DistributionId, + _In_ LPCSTR Path, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _Out_ HANDLE* Socket) try { - // TODO: Once plugins are out of proc, add logic to validate that the mountpoint isn't in use by another plugin. - RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Mountpoint == nullptr); - - auto session = ResolveWslcSession(Session); - auto result = session->MountWindowsFolder(WindowsPath, Mountpoint, ReadOnly); - - WSL_LOG( - "WslcPluginMountFolderCall", - TraceLoggingValue(Session, "SessionId"), - TraceLoggingValue(WindowsPath, "WindowsPath"), - TraceLoggingValue(Mountpoint, "Mountpoint"), - TraceLoggingValue(ReadOnly, "ReadOnly"), - TraceLoggingValue(result, "Result")); + RETURN_HR_IF(E_POINTER, Socket == nullptr); + *Socket = nullptr; + RETURN_HR_IF(E_INVALIDARG, DistributionId == nullptr); + RETURN_HR_IF(E_INVALIDARG, Path == nullptr); + RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr); - return result; -} -CATCH_RETURN(); + const auto session = FindSessionByCookie(SessionId); + RETURN_HR_IF(c_pluginSessionNotFound, !session); -HRESULT WSLCUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint) -try -{ - // TODO: Once plugins are out of proc, add logic to validate that the mountpoint is actually owned by the plugin. - RETURN_HR_IF(E_POINTER, Mountpoint == nullptr); + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); - auto session = ResolveWslcSession(Session); + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(DistributionId, Path, args.data(), &sock); - auto result = session->UnmountWindowsFolder(Mountpoint); + WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - WSL_LOG( - "WslcPluginUnmountFolderCall", - TraceLoggingValue(Session, "SessionId"), - TraceLoggingValue(Mountpoint, "Mountpoint"), - TraceLoggingValue(result, "Result")); + if (SUCCEEDED(result)) + { + *Socket = reinterpret_cast(sock.release()); + } return result; } CATCH_RETURN(); -HRESULT WSLCCreateProcess(WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno) -try -{ - RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr); +// --- PluginManager implementation --- - *Process = nullptr; - if (Errno != nullptr) +PluginManager::~PluginManager() +{ + // m_plugins should have been cleared by Shutdown() while COM was still + // initialized. Releasing GIT cookies and the GIT itself here (at global + // teardown, after CoUninitialize and after the proxy/stub DLL has been + // unloaded) crashes inside the marshaler. If anything is left here, leak + // it on purpose rather than dereferencing a torn-down proxy vtable. + if (!m_plugins.empty()) { - *Errno = 0; - } - - auto session = ResolveWslcSession(Session); - - // Count NULL-terminated arrays. - auto countArray = [](LPCSTR* arr) -> ULONG { - if (arr == nullptr) - { - return 0; - } - ULONG count = 0; - while (arr[count] != nullptr) + LOG_HR_MSG(E_UNEXPECTED, "PluginManager destroyed without Shutdown(); leaking %zu host registrations", m_plugins.size()); + for (auto& e : m_plugins) { - ++count; + // Drop the cookie without revoking — calling GIT after CoUninitialize crashes. + e.hostCookie = 0; + (void)e.callback.Detach(); } - return count; - }; - - WSLCProcessOptions options{}; - options.CommandLine.Values = Arguments; - options.CommandLine.Count = countArray(Arguments); - options.Environment.Values = Env; - options.Environment.Count = countArray(Env); - options.Flags = WSLCProcessFlagsStdin; - - wil::com_ptr process; - int errnoValue = 0; - auto result = session->CreateRootNamespaceProcess(Executable, &options, 0, 0, &process, &errnoValue); - - if (Errno != nullptr) - { - *Errno = errnoValue; + m_plugins.clear(); } - - if (FAILED(result)) + if (m_git) { - WSL_LOG( - "WslcPluginCreateProcessCall", - TraceLoggingValue(Session, "SessionId"), - TraceLoggingValue(Executable, "Executable"), - TraceLoggingValue(result, "Result"), - TraceLoggingValue(errnoValue, "Errno")); - return result; + // Same reasoning: leak the GIT reference rather than releasing it after teardown. + (void)m_git.Detach(); } - auto wrapper = std::make_unique(); - wrapper->Process = std::move(process); - *Process = wrapper.release(); - - WSL_LOG( - "WslcPluginCreateProcessCall", - TraceLoggingValue(Session, "SessionId"), - TraceLoggingValue(Executable, "Executable"), - TraceLoggingValue(*Process, "Process"), - TraceLoggingValue(S_OK, "Result")); - - return S_OK; -} -CATCH_RETURN(); - -HRESULT WSLCProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle) -try -{ - RETURN_HR_IF(E_POINTER, Process == nullptr || Handle == nullptr); - - *Handle = nullptr; - - auto* wrapper = static_cast(Process); - - WSLCFD wslcFd{}; - switch (Fd) + // WSLC session references are COM proxies too. Shutdown() should have + // drained them while COM was still initialized; if any survive, detach + // (leak) them rather than releasing a torn-down proxy. + if (!m_wslcSessionRefs.empty()) { - case WSLCProcessFdStdin: - wslcFd = WSLCFDStdin; - break; - case WSLCProcessFdStdout: - wslcFd = WSLCFDStdout; - break; - case WSLCProcessFdStderr: - wslcFd = WSLCFDStderr; - break; - default: - WSL_LOG( - "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result")); - return E_INVALIDARG; + LOG_HR_MSG(E_UNEXPECTED, "PluginManager destroyed with %zu WSLC session references still registered", m_wslcSessionRefs.size()); + for (auto& [id, ref] : m_wslcSessionRefs) + { + (void)ref.detach(); + } + m_wslcSessionRefs.clear(); } - - WSLCHandle handle{}; - auto result = wrapper->Process->GetStdHandle(wslcFd, &handle); - - WSL_LOG( - "WslcPluginProcessGetFd", - TraceLoggingValue(static_cast(Fd), "Fd"), - TraceLoggingValue(handle.Handle.Socket, "Handle"), - TraceLoggingValue(result, "Result")); - - RETURN_IF_FAILED(result); - WI_ASSERT(handle.Type == WSLCHandleTypeSocket); - - *Handle = handle.Handle.Socket; - return S_OK; -} -CATCH_RETURN(); - -HRESULT WSLCProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent) -try -{ - RETURN_HR_IF(E_POINTER, Process == nullptr || ExitEvent == nullptr); - - *ExitEvent = nullptr; - - auto* wrapper = static_cast(Process); - auto result = wrapper->Process->GetExitEvent(ExitEvent); - - WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result")); - - return result; + m_jobObject.reset(); } -CATCH_RETURN(); -HRESULT WSLCProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode) -try +void PluginManager::Shutdown() { - RETURN_HR_IF(E_POINTER, Process == nullptr || ExitCode == nullptr); - - *ExitCode = -1; - auto* wrapper = static_cast(Process); - - WSLCProcessState state{}; - auto result = wrapper->Process->GetState(&state, ExitCode); - - if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled) + // Must be called while COM is still initialized. Revoking each GIT cookie + // releases the underlying marshaled IWslPluginHost reference, which causes + // the wslpluginhost.exe processes to exit; the job object below makes that + // guaranteed even if a revoke fails. + if (m_git) { - result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE); + for (auto& e : m_plugins) + { + if (e.hostCookie != 0) + { + LOG_IF_FAILED(m_git->RevokeInterfaceFromGlobal(e.hostCookie)); + e.hostCookie = 0; + } + } + m_git.Reset(); } + m_plugins.clear(); - WSL_LOG( - "WslcPluginProcessGetExitCode", - TraceLoggingValue(*ExitCode, "ExitCode"), - TraceLoggingValue(static_cast(state), "State"), - TraceLoggingValue(result, "Result")); - - return result; -} -CATCH_RETURN(); - -void WSLCReleaseProcess(WSLCProcessHandle Process) -{ - if (Process != nullptr) + // Release any remaining WSLC session references while COM is still + // initialized. By this point all sessions should already have been torn + // down (and thus unregistered), but drain defensively so we never release + // these proxies from the destructor after CoUninitialize. { - WSL_LOG("WslcPluginReleaseProcess", TraceLoggingValue(Process, "Process")); - delete static_cast(Process); + std::lock_guard lock(m_wslcSessionRefLock); + m_wslcSessionRefs.clear(); } -} -} // extern "C" - -static constexpr WSLPluginAPIV1 ApiV1 = { - Version, - &MountFolder, - &ExecuteBinary, - &PluginError, - &ExecuteBinaryInDistribution, - &WSLCMountFolder, - &WSLCUnmountFolder, - &WSLCCreateProcess, - &WSLCProcessGetFd, - &WSLCProcessGetExitEvent, - &WSLCProcessGetExitCode, - &WSLCReleaseProcess}; + m_jobObject.reset(); +} void PluginManager::LoadPlugins() { @@ -365,190 +279,531 @@ void PluginManager::LoadPlugins() continue; } - auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e.first.c_str(), path.c_str()); }); - - // Logs when a WSL plugin is loaded, used for evaluating plugin populations + // Record the plugin for deferred activation. The actual COM host process + // is created in EnsureInitialized(), which runs after the service's COM + // initialization is complete (CoInitializeSecurity must happen first). + OutOfProcPlugin plugin{}; + plugin.name = e.first; + plugin.path = path; + m_plugins.emplace_back(std::move(plugin)); + + // Discovery-only event. The plugin DLL is no longer loaded into + // wslservice.exe — actual load happens out-of-process via COM + // activation in EnsureInitialized(). See "PluginLoad" emitted from + // that path for the real load result. WSL_LOG_TELEMETRY( - "PluginLoad", + "PluginDiscovered", PDT_ProductAndServiceUsage, TraceLoggingValue(e.first.c_str(), "Name"), - TraceLoggingValue(path.c_str(), "Path"), - TraceLoggingValue(loadResult, "Result")); + TraceLoggingValue(path.c_str(), "Path")); + } +} + +PluginManager::ScopedComInit::ScopedComInit() : initHr(::CoInitializeEx(nullptr, COINIT_MULTITHREADED)) +{ +} + +PluginManager::ScopedComInit::~ScopedComInit() +{ + if (SUCCEEDED(initHr)) + { + ::CoUninitialize(); + } +} + +PluginManager::ScopedComInit::ScopedComInit(ScopedComInit&& other) noexcept : initHr(other.initHr) +{ + // Suppress uninit in moved-from instance. + other.initHr = RPC_E_CHANGED_MODE; +} + +HRESULT PluginManager::ScopedComInit::Result() const noexcept +{ + return (initHr == RPC_E_CHANGED_MODE) ? S_OK : initHr; +} - if (FAILED(loadResult)) +PluginManager::ScopedComInit PluginManager::EnsureInitialized() +{ + // Join the calling thread to the MTA for the duration of the dispatch. The + // returned guard must outlive any subsequent plugin host calls because + // those are cross-process COM calls that require an initialized apartment, + // and so does the GIT-based proxy acquisition below. + ScopedComInit coInit; + THROW_IF_FAILED(coInit.Result()); + + // Lazily create the process-wide Global Interface Table accessor. The GIT + // itself is a process-global COM-provided singleton; we just need an + // in-process accessor to register and look up cookies. + std::call_once(m_gitOnce, [this]() { + THROW_IF_FAILED(CoCreateInstance(CLSID_StdGlobalInterfaceTable, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&m_git))); + }); + + std::call_once(m_initOnce, [this]() { + for (auto& e : m_plugins) { - // If this plugin reported an error, record it to display it to the user - m_pluginError.emplace(PluginError{e.first, loadResult}); + auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e); }); + + // Canonical "plugin was actually loaded" telemetry — matches the + // semantics of the pre-refactor PluginLoad event (emitted after + // the entry point ran). PluginHostActivation below is the more + // granular event covering the COM activation path specifically. + WSL_LOG_TELEMETRY( + "PluginLoad", + PDT_ProductAndServiceUsage, + TraceLoggingValue(e.name.c_str(), "Name"), + TraceLoggingValue(e.path.c_str(), "Path"), + TraceLoggingValue(loadResult, "Result")); + + WSL_LOG_TELEMETRY( + "PluginHostActivation", + PDT_ProductAndServiceUsage, + TraceLoggingValue(e.name.c_str(), "Name"), + TraceLoggingValue(e.path.c_str(), "Path"), + TraceLoggingValue(loadResult, "Result")); + + if (FAILED(loadResult)) + { + // Treat host-process crashes and benign COM activation races (server is + // shutting down or its exec failed) as non-fatal — the plugin is simply + // unavailable for this session. All other failures, including registration + // errors (REGDB_E_CLASSNOTREG), access denials, and plugin-reported errors + // from Initialize, are treated as fatal plugin load failures so the user + // gets a clear error rather than a silently-disabled plugin. + if (IsHostCrash(loadResult) || loadResult == CO_E_SERVER_EXEC_FAILURE || loadResult == CO_E_SERVER_STOPPING) + { + LOG_HR_MSG(loadResult, "Plugin host activation failed for: '%ls', skipping", e.name.c_str()); + } + else + { + m_pluginError.emplace(PluginError{e.name, loadResult}); + } + } } - } + }); + + return coInit; } -void PluginManager::LoadPlugin(LPCWSTR Name, LPCWSTR ModulePath) +void PluginManager::LoadPlugin(OutOfProcPlugin& plugin) { - // Validate the plugin signature before loading it. - // The handle to the module is kept open after validating the signature so the file can't be written to - // after the signature check. - wil::unique_hfile pluginHandle; - if constexpr (wsl::shared::OfficialBuild) + // One callback object per plugin host so the WSLC process-cookie map is + // isolated per plugin. When the plugin host process is released, the + // callback's refcount drops and any unreleased IWSLCProcess refs are freed. + plugin.callback = Microsoft::WRL::Make(*this); + THROW_IF_NULL_ALLOC(plugin.callback); + + // Activate the plugin host via COM. The LocalServer32 registration causes COM + // to spawn wslpluginhost.exe automatically. + Microsoft::WRL::ComPtr host; + HRESULT activationHr = CoCreateInstance(CLSID_WslPluginHost, nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&host)); + WSL_LOG( + "PluginHostActivation", + TraceLoggingValue(plugin.name.c_str(), "Plugin"), + TraceLoggingValue(plugin.path.c_str(), "Path"), + TraceLoggingValue(activationHr, "CoCreateInstanceResult")); + THROW_IF_FAILED_MSG(activationHr, "Failed to create plugin host for: '%ls'", plugin.path.c_str()); + + // Join the host to our job object before Initialize runs plugin code, so any + // child processes the plugin spawns inherit the job and are killed when the + // service exits. A failure here is non-fatal: the host is still reaped via + // CoReleaseServerProcess on clean shutdown, just not on a service crash. + EnsureJobObjectCreated(); + wil::unique_handle process; + const HRESULT getProcessHr = host->GetProcessHandle(&process); + LOG_IF_FAILED_MSG(getProcessHr, "Failed to get plugin host process handle for: '%ls'", plugin.path.c_str()); + if (SUCCEEDED(getProcessHr)) { - pluginHandle = wsl::windows::common::install::ValidateFileSignature(ModulePath); - WI_ASSERT(pluginHandle.is_valid()); + LOG_IF_WIN32_BOOL_FALSE_MSG( + AssignProcessToJobObject(m_jobObject.get(), process.get()), + "Failed to assign plugin host to job object for: '%ls'", + plugin.path.c_str()); } - LoadedPlugin plugin{}; - plugin.name = Name; + THROW_IF_FAILED_MSG( + host->Initialize(plugin.callback.Get(), plugin.path.c_str(), plugin.name.c_str()), + "Plugin host failed to initialize: '%ls'", + plugin.path.c_str()); + + // Stash the IWslPluginHost in the Global Interface Table. The proxy returned + // by CoCreateInstance is bound to the apartment of this thread (MTA, since + // EnsureInitialized() joined us to the MTA). Hook dispatch can arrive on + // threads in any apartment — in particular NTA-on-MTA, which is what wslservice's + // RPC dispatcher uses when the incoming call is via a WinRT-style interface + // (e.g. IWSLCSession). MTA-bound proxies are NOT callable from NTA, so storing + // the raw proxy here and reusing it cross-apartment fails with RPC_E_WRONG_THREAD. + // Storing in the GIT and re-unmarshaling per call yields an apartment-local + // proxy that works from any apartment. + DWORD cookie = 0; + THROW_IF_FAILED(m_git->RegisterInterfaceInGlobal(host.Get(), __uuidof(IWslPluginHost), &cookie)); + + auto revokeOnFailure = wil::scope_exit([&] { + if (cookie != 0) + { + LOG_IF_FAILED(m_git->RevokeInterfaceFromGlobal(cookie)); + } + }); + + // Drop the raw proxy — future access goes through the GIT to get an + // apartment-local proxy. The GIT keeps the underlying marshaled stream + // alive, so the host process stays running. + host.Reset(); - plugin.module.reset(LoadLibrary(ModulePath)); - THROW_LAST_ERROR_IF_NULL(plugin.module); + plugin.hostCookie = cookie; + cookie = 0; + revokeOnFailure.release(); +} - const WSLPluginAPI_EntryPointV1 entryPoint = - reinterpret_cast(GetProcAddress(plugin.module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1))); +HRESULT PluginManager::AcquireHostProxy(const OutOfProcPlugin& plugin, _COM_Outptr_ IWslPluginHost** host) +{ + *host = nullptr; + if (plugin.hostCookie == 0 || !m_git) + { + return E_NOT_VALID_STATE; + } + return m_git->GetInterfaceFromGlobal(plugin.hostCookie, __uuidof(IWslPluginHost), reinterpret_cast(host)); +} - THROW_LAST_ERROR_IF_NULL(entryPoint); - THROW_IF_FAILED_MSG(entryPoint(&ApiV1, &plugin.hooks), "Error returned by plugin: '%ls'", ModulePath); +void PluginManager::EnsureJobObjectCreated() +{ + std::call_once(m_jobObjectOnce, [this]() { + m_jobObject.reset(CreateJobObjectW(nullptr, nullptr)); + THROW_LAST_ERROR_IF(!m_jobObject); + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{}; + jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + THROW_IF_WIN32_BOOL_FALSE(SetInformationJobObject(m_jobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo))); + }); +} - m_plugins.emplace_back(std::move(plugin)); +std::vector PluginManager::SerializeSid(PSID Sid) +{ + const DWORD sidLength = GetLengthSid(Sid); + std::vector buffer(sidLength); + CopySid(sidLength, buffer.data(), Sid); + return buffer; } void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); - for (const auto& e : m_plugins) + auto sidData = SerializeSid(Session->UserSid); + + for (auto& e : m_plugins) { - if (e.hooks.OnVMStarted != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); - - SlowOperationWatcher slowOperation{"PluginOnVmStarted"}; - ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), e.name.c_str()); + continue; + } + WSL_LOG("PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + + ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnVmStarted"); + + wil::unique_cotaskmem_string errorMessage; + SlowOperationWatcher slowOperation{"PluginOnVmStarted"}; + WSL_LOG("PluginOnVmStartedBeginRpc", TraceLoggingValue(e.name.c_str(), "Plugin")); + HRESULT hr = host->OnVMStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + static_cast(Settings->CustomConfigurationFlags), + &errorMessage); + WSL_LOG("PluginOnVmStartedEndRpc", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(hr, "Result")); + + if (IsHostCrash(hr)) + { + LogPluginHostCrash(e, hr, "OnVmStarted"); + continue; } + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } } -void PluginManager::OnVmStopping(const WSLSessionInformation* Session) const +void PluginManager::OnVmStopping(const WSLSessionInformation* Session) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.OnVMStopping != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + continue; + } + WSL_LOG("PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnVmStopping"); + + const auto result = host->OnVMStopping(Session->SessionId, Session->UserToken, static_cast(sidData.size()), sidData.data()); - const auto result = e.hooks.OnVMStopping(Session); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnVmStopping"); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.OnDistributionStarted != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnDistroStartedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - SlowOperationWatcher slowOperation{"PluginOnDistributionStarted"}; - ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnDistroStartedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnDistributionStarted"); + + wil::unique_cotaskmem_string errorMessage; + SlowOperationWatcher slowOperation{"PluginOnDistributionStarted"}; + HRESULT hr = host->OnDistributionStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version, + &errorMessage); + + if (IsHostCrash(hr)) + { + LogPluginHostCrash(e, hr, "OnDistributionStarted"); + continue; } + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } } -void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) const +void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.OnDistributionStopping != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnDistroStoppingCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionStopping(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnDistroStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionStopping"); + + const auto result = host->OnDistributionStopping( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnDistributionStopping"); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const +void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.OnDistributionRegistered != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnDistributionRegisteredCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionRegistered(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnDistributionRegisteredCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionRegistered"); + + const auto result = host->OnDistributionRegistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnDistributionRegistered"); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const +void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.OnDistributionUnregistered != nullptr) + if (e.hostCookie == 0) { - WSL_LOG( - "PluginOnDistributionUnregisteredCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionUnregistered(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnDistributionUnregisteredCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionUnregistered"); + + const auto result = host->OnDistributionUnregistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnDistributionUnregistered"); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin) +void PluginManager::ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId Session, LPCWSTR Plugin) { - const auto message = std::move(g_pluginErrorMessage); - g_pluginErrorMessage.reset(); // std::move() doesn't clear the previous std::optional - + // If the host process crashed, don't propagate as a fatal plugin error — + // log it and let the caller decide. The plugin is already dead. + if (IsHostCrash(Result)) + { + LOG_HR_MSG(Result, "Plugin host process crashed for plugin: '%ls'", Plugin); + return; + } + if (FAILED(Result)) { - if (message.has_value()) + if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0') { - THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, message->c_str())); + THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, ErrorMessage)); } else { THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginError(Plugin)); } } - else if (message.has_value()) + else if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0') { THROW_HR_MSG(E_ILLEGAL_STATE_CHANGE, "Plugin '%ls' emitted an error message but returned success", Plugin); } } -void PluginManager::ThrowIfFatalPluginError() const +bool PluginManager::IsHostCrash(HRESULT hr) +{ + // Each of these unambiguously indicates the COM server process has gone + // away. RPC_E_CALL_REJECTED is deliberately excluded: it means a busy + // server rejected the call, not that the server died — treating it as a + // crash would silently disable the plugin for the rest of the session. + switch (hr) + { + case RPC_E_DISCONNECTED: + case RPC_E_SERVER_DIED: + case RPC_E_SERVER_DIED_DNE: + case CO_E_OBJNOTCONNECTED: + case HRESULT_FROM_WIN32(RPC_S_SERVER_UNAVAILABLE): + case HRESULT_FROM_WIN32(RPC_S_CALL_FAILED): + case HRESULT_FROM_WIN32(RPC_S_CALL_FAILED_DNE): + return true; + default: + return false; + } +} + +void PluginManager::LogPluginHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage) +{ + LOG_HR_MSG(result, "Plugin host crashed at %hs for: '%ls'", stage, plugin.name.c_str()); + + // Fire telemetry only on first observation per plugin: a dead plugin will + // hit this path on every subsequent VM/distro lifecycle event, and we + // don't want to flood the telemetry channel with duplicates. + if (!plugin.crashTelemetryFired.exchange(true)) + { + // Release any WSLC processes the dead host created so they don't leak + // until service shutdown. + if (plugin.callback) + { + plugin.callback->DrainProcesses(); + } + + WSL_LOG_TELEMETRY( + "PluginHostCrash", + PDT_ProductAndServiceUsage, + TraceLoggingValue(plugin.name.c_str(), "Name"), + TraceLoggingValue(plugin.path.c_str(), "Path"), + TraceLoggingValue(result, "Result"), + TraceLoggingValue(stage, "Stage")); + } +} + +void PluginManager::ThrowIfFatalPluginError() { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); if (!m_pluginError.has_value()) { @@ -565,129 +820,547 @@ void PluginManager::ThrowIfFatalPluginError() const } } +void PluginManager::RegisterWslcSession(ULONG SessionId, IWSLCSessionReference* Reference) +{ + THROW_HR_IF(E_POINTER, Reference == nullptr); + + std::lock_guard lock(m_wslcSessionRefLock); + m_wslcSessionRefs[SessionId] = Reference; +} + +void PluginManager::UnregisterWslcSession(ULONG SessionId) +{ + std::lock_guard lock(m_wslcSessionRefLock); + m_wslcSessionRefs.erase(SessionId); +} + +wil::com_ptr PluginManager::ResolveWslcSession(ULONG SessionId) +{ + wil::com_ptr reference; + { + std::lock_guard lock(m_wslcSessionRefLock); + auto it = m_wslcSessionRefs.find(SessionId); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_wslcSessionRefs.end()); + reference = it->second; + } + + // OpenSession() is called OUTSIDE m_wslcSessionRefLock: the reference is a + // weak handle, so opening it can fail or block and must not hold the lock. + wil::com_ptr session; + const auto hr = reference->OpenSession(&session); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), FAILED(hr) || session == nullptr); + + return session; +} + void PluginManager::OnWslcSessionCreated(const WSLCSessionInformation* Session) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); - for (const auto& e : m_plugins) + auto sidData = SerializeSid(Session->UserSid); + + for (auto& e : m_plugins) { - if (e.hooks.OnSessionCreated != nullptr) + if (e.hostCookie == 0) { - auto result = e.hooks.OnSessionCreated(Session); - WSL_LOG( - "PluginOnWslcSessionCreatedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(Session->DisplayName, "DisplayName"), - TraceLoggingValue(result, "Result")); - - ThrowIfPluginError(result, e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnWslcSessionCreatedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(Session->DisplayName, "DisplayName")); + + ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnWslcSessionCreated"); + + wil::unique_cotaskmem_string errorMessage; + HRESULT hr = host->OnWslcSessionCreated( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data(), + &errorMessage); + + if (IsHostCrash(hr)) + { + LogPluginHostCrash(e, hr, "OnWslcSessionCreated"); + continue; + } + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } } -void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) const +void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); - for (const auto& e : m_plugins) + auto sidData = SerializeSid(Session->UserSid); + + for (auto& e : m_plugins) { - if (e.hooks.OnSessionStopping != nullptr) + if (e.hostCookie == 0) { - const auto result = e.hooks.OnSessionStopping(Session); - WSL_LOG( - "PluginOnWslcSessionStoppingCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(result, "Result")); + continue; + } + WSL_LOG( + "PluginOnWslcSessionStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcSessionStopping"); + + const auto result = host->OnWslcSessionStopping( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data()); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnWslcSessionStopping"); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const +HRESULT PluginManager::OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) try { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.ContainerStarted != nullptr) + if (e.hostCookie == 0) { - // Failure here aborts the container creation. Surface the first error. - const auto result = e.hooks.ContainerStarted(Session, InspectJson); - WSL_LOG( - "PluginOnWslcContainerStartedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(result, "Result")); - - ThrowIfPluginError(result, e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnWslcContainerStartedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId")); + + ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnWslcContainerStarted"); + + // Failure here aborts the container creation. Surface the first error. + wil::unique_cotaskmem_string errorMessage; + const HRESULT hr = host->OnWslcContainerStarted( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data(), + InspectJson, + &errorMessage); + + if (IsHostCrash(hr)) + { + LogPluginHostCrash(e, hr, "OnWslcContainerStarted"); + continue; } + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } return S_OK; } CATCH_RETURN() -void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const +void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.ContainerStopping != nullptr) + if (e.hostCookie == 0) { - - const auto result = e.hooks.ContainerStopping(Session, ContainerId); - WSL_LOG( - "PluginOnWslcContainerStoppingCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(ContainerId, "ContainerId"), - TraceLoggingValue(result, "Result")); - - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnWslcContainerStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(ContainerId, "ContainerId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcContainerStopping"); + + const auto result = host->OnWslcContainerStopping( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data(), + ContainerId); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnWslcContainerStopping"); + continue; + } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const +void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.ImageCreated != nullptr) + if (e.hostCookie == 0) { - const auto result = e.hooks.ImageCreated(Session, InspectJson); - WSL_LOG( - "PluginOnWslcImageCreatedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(result, "Result")); - - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnWslcImageCreatedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcImageCreated"); + + const auto result = host->OnWslcImageCreated( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data(), + InspectJson); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnWslcImageCreated"); + continue; + } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const +void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) { ExecutionContext context(Context::Plugin); + auto coInit = EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); - for (const auto& e : m_plugins) + for (auto& e : m_plugins) { - if (e.hooks.ImageDeleted != nullptr) + if (e.hostCookie == 0) { - const auto result = e.hooks.ImageDeleted(Session, ImageId); - WSL_LOG( - "PluginOnWslcImageDeletedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->SessionId, "SessionId"), - TraceLoggingValue(ImageId, "ImageId"), - TraceLoggingValue(result, "Result")); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnWslcImageDeletedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->SessionId, "SessionId"), + TraceLoggingValue(ImageId, "ImageId")); + + ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnWslcImageDeleted"); + + const auto result = host->OnWslcImageDeleted( + Session->SessionId, + Session->DisplayName, + Session->ApplicationPid, + Session->UserToken, + static_cast(sidData.size()), + sidData.empty() ? nullptr : sidData.data(), + ImageId); + + if (IsHostCrash(result)) + { + LogPluginHostCrash(e, result, "OnWslcImageDeleted"); + continue; + } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + } +} + +// --- IWslPluginHostCallback WSLC implementations (service-side) --- + +DWORD PluginHostCallbackImpl::InsertProcessLocked(wil::com_ptr process) +{ + std::lock_guard lock(m_processLock); + + // Reserve a cookie that's neither 0 nor already in use. Wraparound is fine. + THROW_HR_IF(E_OUTOFMEMORY, m_processes.size() >= std::numeric_limits::max() - 1); + + DWORD cookie = m_nextCookie; + while (cookie == 0 || m_processes.find(cookie) != m_processes.end()) + { + ++cookie; + } + m_nextCookie = cookie + 1; + m_processes.emplace(cookie, std::move(process)); + return cookie; +} + +wil::com_ptr PluginHostCallbackImpl::FindProcess(DWORD cookie) const +{ + std::lock_guard lock(m_processLock); + + auto it = m_processes.find(cookie); + return (it == m_processes.end()) ? nullptr : it->second; +} + +wil::com_ptr PluginHostCallbackImpl::RemoveProcess(DWORD cookie) +{ + std::lock_guard lock(m_processLock); + + auto it = m_processes.find(cookie); + if (it == m_processes.end()) + { + return nullptr; + } + auto process = std::move(it->second); + m_processes.erase(it); + return process; +} + +void PluginHostCallbackImpl::DrainProcesses() noexcept +try +{ + std::unordered_map> processes; + { + std::lock_guard lock(m_processLock); + processes.swap(m_processes); + } + + // Release outside the lock: a process Release() may run teardown that + // re-enters this callback. +} +CATCH_LOG() + +STDMETHODIMP PluginHostCallbackImpl::WslcMountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCSTR Mountpoint, _In_ BOOL ReadOnly) +try +{ + // TODO: Once plugins are out of proc, add logic to validate that the mountpoint isn't in use by another plugin. + RETURN_HR_IF(E_POINTER, WindowsPath == nullptr || Mountpoint == nullptr); + + auto session = m_owner.ResolveWslcSession(SessionId); + + auto result = session->MountWindowsFolder(WindowsPath, Mountpoint, ReadOnly); + + WSL_LOG( + "WslcPluginMountFolderCall", + TraceLoggingValue(SessionId, "SessionId"), + TraceLoggingValue(WindowsPath, "WindowsPath"), + TraceLoggingValue(Mountpoint, "Mountpoint"), + TraceLoggingValue(ReadOnly, "ReadOnly"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcUnmountFolder(_In_ DWORD SessionId, _In_ LPCSTR Mountpoint) +try +{ + RETURN_HR_IF(E_POINTER, Mountpoint == nullptr); + + auto session = m_owner.ResolveWslcSession(SessionId); + auto result = session->UnmountWindowsFolder(Mountpoint); + + WSL_LOG( + "WslcPluginUnmountFolderCall", + TraceLoggingValue(SessionId, "SessionId"), + TraceLoggingValue(Mountpoint, "Mountpoint"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcCreateProcess( + _In_ DWORD SessionId, + _In_ LPCSTR Executable, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _In_ DWORD EnvCount, + _In_reads_opt_(EnvCount) LPCSTR* Environment, + _Out_ DWORD* ProcessCookie, + _Out_ int* Errno) +try +{ + RETURN_HR_IF(E_POINTER, Executable == nullptr || ProcessCookie == nullptr || Errno == nullptr); + *ProcessCookie = 0; + *Errno = 0; + RETURN_HR_IF(E_INVALIDARG, (ArgumentCount > 0 && Arguments == nullptr) || (EnvCount > 0 && Environment == nullptr)); + + auto session = m_owner.ResolveWslcSession(SessionId); + + // Build NULL-terminated argument/env arrays expected by CreateRootNamespaceProcess. + std::vector argsTerminated; + if (Arguments != nullptr) + { + argsTerminated.assign(Arguments, Arguments + ArgumentCount); + } + argsTerminated.push_back(nullptr); + + std::vector envTerminated; + if (Environment != nullptr) + { + envTerminated.assign(Environment, Environment + EnvCount); + envTerminated.push_back(nullptr); + } + + WSLCProcessOptions options{}; + options.CommandLine.Values = argsTerminated.data(); + options.CommandLine.Count = ArgumentCount; + if (!envTerminated.empty()) + { + options.Environment.Values = envTerminated.data(); + options.Environment.Count = EnvCount; + } + options.Flags = WSLCProcessFlagsStdin; + + wil::com_ptr process; + int errnoValue = 0; + auto result = session->CreateRootNamespaceProcess(Executable, &options, 0, 0, &process, &errnoValue); + *Errno = errnoValue; + + if (FAILED(result)) + { + WSL_LOG( + "WslcPluginCreateProcessCall", + TraceLoggingValue(SessionId, "SessionId"), + TraceLoggingValue(Executable, "Executable"), + TraceLoggingValue(result, "Result"), + TraceLoggingValue(errnoValue, "Errno")); + return result; } + + *ProcessCookie = InsertProcessLocked(std::move(process)); + + WSL_LOG( + "WslcPluginCreateProcessCall", + TraceLoggingValue(SessionId, "SessionId"), + TraceLoggingValue(Executable, "Executable"), + TraceLoggingValue(*ProcessCookie, "ProcessCookie"), + TraceLoggingValue(S_OK, "Result")); + + return S_OK; } +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetFd(_In_ DWORD ProcessCookie, _In_ DWORD Fd, _Out_ HANDLE* Handle) +try +{ + RETURN_HR_IF(E_POINTER, Handle == nullptr); + *Handle = nullptr; + + auto process = FindProcess(ProcessCookie); + RETURN_HR_IF(E_INVALIDARG, !process); + + WSLCFD wslcFd{}; + switch (static_cast(Fd)) + { + case WSLCProcessFdStdin: + wslcFd = WSLCFDStdin; + break; + case WSLCProcessFdStdout: + wslcFd = WSLCFDStdout; + break; + case WSLCProcessFdStderr: + wslcFd = WSLCFDStderr; + break; + default: + WSL_LOG( + "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result")); + return E_INVALIDARG; + } + + WSLCHandle handle{}; + auto result = process->GetStdHandle(wslcFd, &handle); + + WSL_LOG( + "WslcPluginProcessGetFd", + TraceLoggingValue(static_cast(Fd), "Fd"), + TraceLoggingValue(handle.Handle.Socket, "Handle"), + TraceLoggingValue(result, "Result")); + + RETURN_IF_FAILED(result); + WI_ASSERT(handle.Type == WSLCHandleTypeSocket); + + // Pass through as HANDLE; COM's system_handle(sh_socket) marshaling will duplicate + // it into the host process which then surfaces it to the plugin. + *Handle = handle.Handle.Socket; + return S_OK; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetExitEvent(_In_ DWORD ProcessCookie, _Out_ HANDLE* ExitEvent) +try +{ + RETURN_HR_IF(E_POINTER, ExitEvent == nullptr); + *ExitEvent = nullptr; + + auto process = FindProcess(ProcessCookie); + RETURN_HR_IF(E_INVALIDARG, !process); + + auto result = process->GetExitEvent(ExitEvent); + + WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetExitCode(_In_ DWORD ProcessCookie, _Out_ int* ExitCode) +try +{ + RETURN_HR_IF(E_POINTER, ExitCode == nullptr); + *ExitCode = -1; + + auto process = FindProcess(ProcessCookie); + RETURN_HR_IF(E_INVALIDARG, !process); + + WSLCProcessState state{}; + auto result = process->GetState(&state, ExitCode); + + if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled) + { + result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE); + } + + WSL_LOG( + "WslcPluginProcessGetExitCode", + TraceLoggingValue(*ExitCode, "ExitCode"), + TraceLoggingValue(static_cast(state), "State"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::WslcReleaseProcess(_In_ DWORD ProcessCookie) +try +{ + auto process = RemoveProcess(ProcessCookie); + WSL_LOG( + "WslcPluginReleaseProcess", + TraceLoggingValue(ProcessCookie, "ProcessCookie"), + TraceLoggingValue(process != nullptr, "Found")); + return S_OK; +} +CATCH_RETURN(); diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h index a99a3e332d..beddd32f4e 100644 --- a/src/windows/service/exe/PluginManager.h +++ b/src/windows/service/exe/PluginManager.h @@ -9,17 +9,118 @@ Module Name: Abstract: This file contains the PluginManager class definition. + Plugins are loaded out-of-process in wslpluginhost.exe via COM + to isolate the service from plugin crashes. --*/ #pragma once #include +#include +#include +#include +#include #include +#include #include #include "WslPluginApi.h" +#include "WslPluginHost.h" +#include "wslc.h" namespace wsl::windows::service { + +class PluginManager; + +// +// IWslPluginHostCallback implementation — lives in the service process and +// handles API calls coming from the plugin host (MountFolder, ExecuteBinary, +// WSLC* APIs etc.). One instance is created per plugin host so that the +// per-plugin WSLC process map (cookie -> IWSLCProcess) is isolated: a plugin +// cannot guess another plugin's cookie, and the map drains automatically when +// the plugin host process goes away. +// +class PluginHostCallbackImpl + : public Microsoft::WRL::RuntimeClass, IWslPluginHostCallback> +{ +public: + explicit PluginHostCallbackImpl(PluginManager& owner) : m_owner(owner) + { + } + ~PluginHostCallbackImpl() override = default; + + PluginHostCallbackImpl(const PluginHostCallbackImpl&) = delete; + PluginHostCallbackImpl& operator=(const PluginHostCallbackImpl&) = delete; + + // WSL plugin API (host -> service) + STDMETHODIMP MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name) override; + + STDMETHODIMP ExecuteBinary( + _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket) override; + + STDMETHODIMP ExecuteBinaryInDistribution( + _In_ DWORD SessionId, + _In_ const GUID* DistributionId, + _In_ LPCSTR Path, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _Out_ HANDLE* Socket) override; + + // WSLC plugin API (host -> service) + STDMETHODIMP WslcMountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCSTR Mountpoint, _In_ BOOL ReadOnly) override; + + STDMETHODIMP WslcUnmountFolder(_In_ DWORD SessionId, _In_ LPCSTR Mountpoint) override; + + STDMETHODIMP WslcCreateProcess( + _In_ DWORD SessionId, + _In_ LPCSTR Executable, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _In_ DWORD EnvCount, + _In_reads_opt_(EnvCount) LPCSTR* Environment, + _Out_ DWORD* ProcessCookie, + _Out_ int* Errno) override; + + STDMETHODIMP WslcProcessGetFd(_In_ DWORD ProcessCookie, _In_ DWORD Fd, _Out_ HANDLE* Handle) override; + + STDMETHODIMP WslcProcessGetExitEvent(_In_ DWORD ProcessCookie, _Out_ HANDLE* ExitEvent) override; + + STDMETHODIMP WslcProcessGetExitCode(_In_ DWORD ProcessCookie, _Out_ int* ExitCode) override; + + STDMETHODIMP WslcReleaseProcess(_In_ DWORD ProcessCookie) override; + + // Release all outstanding process mappings. Called when the plugin host + // crashes so the WSLC processes it created aren't stranded until shutdown. + void DrainProcesses() noexcept; + +private: + // Allocate a new cookie -> process mapping. Loops past 0 and past collisions. + // Throws on exhaustion. + DWORD InsertProcessLocked(wil::com_ptr process); + + // Resolve a cookie to its process under m_processLock; returns nullptr if unknown. + wil::com_ptr FindProcess(DWORD cookie) const; + + // Remove a cookie mapping; returns the removed process (may be null). + wil::com_ptr RemoveProcess(DWORD cookie); + + mutable std::mutex m_processLock; + std::unordered_map> m_processes; + DWORD m_nextCookie{1}; + + // The PluginManager that owns this callback. Used to resolve a WSLC + // SessionId to a live IWSLCSession via the manager's registered session + // reference map (see PluginManager::ResolveWslcSession). + PluginManager& m_owner; +}; + +/// +/// Manages out-of-process plugin hosts (wslpluginhost.exe) via COM activation. +/// Each plugin DLL is loaded in a separate process to isolate the service from +/// plugin crashes. Communication uses IWslPluginHost (service → host) for lifecycle +/// notifications and IWslPluginHostCallback (host → service) for plugin API calls. +/// A job object ensures all hosts are terminated if the service exits unexpectedly. +/// class PluginManager { public: @@ -30,6 +131,7 @@ class PluginManager }; PluginManager() = default; + ~PluginManager(); PluginManager(const PluginManager&) = delete; PluginManager& operator=(const PluginManager&) = delete; @@ -37,37 +139,152 @@ class PluginManager PluginManager& operator=(PluginManager&&) = delete; void LoadPlugins(); + + // Releases all out-of-process plugin host COM proxies and tears down the + // job object that contains the host processes. MUST be called from the + // service shutdown path WHILE COM IS STILL INITIALIZED on this thread + // (i.e. before CoUninitialize). Releasing IWslPluginHost proxies after + // COM has been uninitialized — for example, when this PluginManager is + // destroyed as a global at process exit — crashes inside the marshaler + // because the proxy/stub DLL has been unloaded. + void Shutdown(); + void OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings); - void OnVmStopping(const WSLSessionInformation* Session) const; + void OnVmStopping(const WSLSessionInformation* Session); void OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* distro); - void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro) const; - void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; - void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; + void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro); + void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro); + void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro); - // WSLC notifications. Returning failure from OnSessionCreated/OnContainerStarted causes the - // corresponding operation to be aborted. Other notifications log errors and continue. + // WSLC notifications. Returning failure from OnWslcSessionCreated / OnWslcContainerStarted causes + // the corresponding operation to be aborted. Other notifications log errors and continue. void OnWslcSessionCreated(const WSLCSessionInformation* Session); - void OnWslcSessionStopping(const WSLCSessionInformation* Session) const; - HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson) const; - void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId) const; - void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson) const; - void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId) const; + void OnWslcSessionStopping(const WSLCSessionInformation* Session); + HRESULT OnWslcContainerStarted(const WSLCSessionInformation* Session, LPCSTR InspectJson); + void OnWslcContainerStopping(const WSLCSessionInformation* Session, LPCSTR ContainerId); + void OnWslcImageCreated(const WSLCSessionInformation* Session, LPCSTR InspectJson); + void OnWslcImageDeleted(const WSLCSessionInformation* Session, LPCSTR ImageId); - void ThrowIfFatalPluginError() const; + void ThrowIfFatalPluginError(); -private: - void LoadPlugin(LPCWSTR Name, LPCWSTR Path); - static void ThrowIfPluginError(HRESULT Result, LPCWSTR Plugin); + // WSLC session reference registry. The WSLCSessionManager registers a weak + // IWSLCSessionReference for each live session here (keyed by SessionId) + // before notifying plugins of the session, and unregisters it when the + // session stops or creation is rolled back. Plugin API callbacks that take + // a SessionId (WslcMountFolder/WslcCreateProcess/etc.) resolve the live + // session through this registry instead of reaching back into the + // WSLCSessionManager — this breaks the lock-reentrancy cycle that would + // otherwise occur when an out-of-process plugin calls back on a different + // RPC thread while the manager holds its session lock. + void RegisterWslcSession(ULONG SessionId, IWSLCSessionReference* Reference); + void UnregisterWslcSession(ULONG SessionId); - struct LoadedPlugin + // Resolves a SessionId to a live IWSLCSession. The reference is copied out + // under m_wslcSessionRefLock and OpenSession() is then called OUTSIDE the + // lock (the reference is a weak handle, so OpenSession can fail/block). + // Throws HRESULT_FROM_WIN32(ERROR_NOT_FOUND) if the session is unknown or + // can no longer be opened. + wil::com_ptr ResolveWslcSession(ULONG SessionId); + +private: + struct OutOfProcPlugin { - wil::unique_hmodule module; + // GIT cookie for the IWslPluginHost proxy. Zero means "not loaded". + // Resolved per-call via AcquireHostProxy() because the raw proxy + // returned by CoCreateInstance is apartment-bound to the LoadPlugin + // thread, but hook dispatch can arrive on threads in any apartment + // (MTA, NTA-on-MTA via WinRT-style RPC dispatch, etc.). Storing in + // the Global Interface Table and re-unmarshaling per call yields an + // apartment-local proxy that COM can dispatch from any apartment. + DWORD hostCookie{0}; + Microsoft::WRL::ComPtr callback; std::wstring name; - WSLPluginHooksV1 hooks{}; + std::wstring path; + + // Set the first time a host crash is observed for this plugin during + // runtime (i.e. after successful load). Subsequent crash sites only + // log to ETL and skip the telemetry event, avoiding flood from a + // dead plugin that we'd notify on every distro/VM lifecycle event. + std::atomic crashTelemetryFired{false}; + + OutOfProcPlugin() = default; + OutOfProcPlugin(const OutOfProcPlugin&) = delete; + OutOfProcPlugin& operator=(const OutOfProcPlugin&) = delete; + OutOfProcPlugin(OutOfProcPlugin&& other) noexcept : + hostCookie(std::exchange(other.hostCookie, 0)), + callback(std::move(other.callback)), + name(std::move(other.name)), + path(std::move(other.path)), + crashTelemetryFired(other.crashTelemetryFired.load()) + { + } + OutOfProcPlugin& operator=(OutOfProcPlugin&&) = delete; + }; + + // RAII helper that joins the calling thread to the MTA for the duration of the + // scope. Plugin host activation and per-call dispatch through plugin host proxies + // are both cross-process COM calls that require the calling thread to be COM-init'd, + // and hook dispatch can run on threadpool threads (e.g. _VmIdleTerminate path) that + // haven't called CoInitializeEx. RPC_E_CHANGED_MODE means the thread is already + // initialized to a different apartment (typically STA), which is still fine: the + // existing apartment is preserved and COM dispatches proxy calls accordingly. + struct ScopedComInit + { + HRESULT initHr{RPC_E_CHANGED_MODE}; + + ScopedComInit(); + ~ScopedComInit(); + ScopedComInit(const ScopedComInit&) = delete; + ScopedComInit& operator=(const ScopedComInit&) = delete; + ScopedComInit(ScopedComInit&& other) noexcept; + ScopedComInit& operator=(ScopedComInit&&) = delete; + + HRESULT Result() const noexcept; }; - std::vector m_plugins; + void LoadPlugin(OutOfProcPlugin& plugin); + [[nodiscard]] ScopedComInit EnsureInitialized(); + void EnsureJobObjectCreated(); + + // Returns an apartment-local IWslPluginHost proxy for the given plugin by + // re-unmarshaling from the Global Interface Table. The returned proxy is + // only valid on the current thread / apartment and must not be cached. + // On failure (host process crashed, host released, etc.) returns the + // HRESULT from GIT; caller should treat host-crash HRESULTs via + // IsHostCrash() exactly as if the call itself had failed. + HRESULT AcquireHostProxy(const OutOfProcPlugin& plugin, _COM_Outptr_ IWslPluginHost** host); + + static void ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId session, LPCWSTR Plugin); + static std::vector SerializeSid(PSID Sid); + static bool IsHostCrash(HRESULT hr); + + // Logs a host crash to ETL and fires the PluginHostCrash telemetry event + // at most once per plugin per service lifetime, so a single bad plugin + // does not flood telemetry across every subsequent VM/distro lifecycle. + static void LogPluginHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage); + + std::once_flag m_initOnce; + std::vector m_plugins; std::optional m_pluginError; + + // Global Interface Table used to make IWslPluginHost proxies callable from + // any apartment. Created lazily inside EnsureInitialized() while COM is + // guaranteed to be initialized on the calling thread. + std::once_flag m_gitOnce; + Microsoft::WRL::ComPtr m_git; + + // Job object that automatically terminates all plugin host processes + // when wslservice exits or crashes (JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE). + std::once_flag m_jobObjectOnce; + wil::unique_handle m_jobObject; + + // Maps a WSLC SessionId to a weak reference to the live session. Populated + // by RegisterWslcSession / drained by UnregisterWslcSession and Shutdown. + // The references are COM proxies and MUST be released while COM is still + // initialized (see Shutdown); the destructor detaches (leaks) any survivors + // rather than releasing them after CoUninitialize. + std::mutex m_wslcSessionRefLock; + std::unordered_map> m_wslcSessionRefs; }; } // namespace wsl::windows::service diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp index 40852b5511..f18c3321dc 100644 --- a/src/windows/service/exe/ServiceMain.cpp +++ b/src/windows/service/exe/ServiceMain.cpp @@ -258,6 +258,12 @@ void WslService::ServiceStopped() LxssClientUninitialize(); } + // Release plugin host COM proxies BEFORE CoUninitialize. The IWslPluginHost + // proxies must be torn down while COM is still initialized; otherwise the + // proxy/stub DLL is unloaded and releasing the proxies later (during global + // destruction) crashes inside the marshaler. + g_pluginManager.Shutdown(); + // There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread // isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue. winrt::clear_factory_cache(); diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp index 3966245df5..582a1dc030 100644 --- a/src/windows/service/exe/WSLCSessionManager.cpp +++ b/src/windows/service/exe/WSLCSessionManager.cpp @@ -134,26 +134,41 @@ WSLCSessionManagerImpl::~WSLCSessionManagerImpl() g_managerInstance.store(nullptr); // Terminate all sessions on shutdown. - // Call Terminate() directly rather than going through ForEachSession(), - // which would needlessly resolve weak references and call GetState(). - // Terminate() already handles the "session is gone" case gracefully. - std::lock_guard lock(m_wslcSessionsLock); - for (auto& entry : m_sessions) + // Plugin notifications are dispatched out-of-process to wslpluginhost.exe. + // Move the sessions out under the lock; notify and terminate after release + // so we never invoke a plugin callback while m_wslcSessionsLock is held. + std::vector snapshot; { - NotifySessionStoppingLockHeld(entry); + std::lock_guard lock(m_wslcSessionsLock); + snapshot = std::move(m_sessions); + m_sessions.clear(); + m_persistentSessions.clear(); + } + + for (auto& entry : snapshot) + { + DispatchSessionStopping(entry); LOG_IF_FAILED(entry.Ref->Terminate()); } } -void WSLCSessionManagerImpl::NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept +void WSLCSessionManagerImpl::DispatchSessionStopping(SessionEntry& entry) noexcept try { - if (entry.StoppingNotified) + // Fire OnWslcSessionStopping exactly once. std::exchange guarantees only + // the first caller dispatches even if the entry is observed from multiple + // teardown paths (ForEachSession cleanup, destructor, etc.). + if (std::exchange(entry.StoppingNotified, true)) { return; } - entry.StoppingNotified = true; + // Drop the session from the plugin session reference map once stopping + // completes, even if the notification throws. Unregistered AFTER the + // notification so a plugin's OnWslcSessionStopping handler can still + // resolve the session (e.g. to unmount folders). + auto unregister = wil::scope_exit([&] { g_pluginManager.UnregisterWslcSession(entry.SessionId); }); + WSLCSessionInformation info{}; info.SessionId = static_cast(entry.SessionId); info.DisplayName = entry.DisplayName.c_str(); @@ -168,10 +183,15 @@ void WSLCSessionManagerImpl::CreateSession( _In_ const WSLCSessionSettings* Settings, _In_ WSLCSessionFlags Flags, _In_opt_ IWarningCallback* WarningCallback, _Out_ IWSLCSession** WslcSession) { THROW_HR_IF_NULL(E_POINTER, WslcSession); + *WslcSession = nullptr; auto tokenInfo = GetCallingProcessTokenInfo(); const auto callerToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + // Snapshot before tokenInfo is moved into the SessionEntry below; read by + // the telemetry event at the end of this function. + const bool elevated = tokenInfo.Elevated; + // Resolve display name upfront (for both default and custom sessions). std::wstring resolvedDisplayName; if (Settings == nullptr) @@ -200,126 +220,244 @@ void WSLCSessionManagerImpl::CreateSession( resolvedDisplayName = Settings->DisplayName; } - std::lock_guard lock(m_wslcSessionsLock); + std::vector deadSessions; + std::optional existingResult; - // Check for an existing session first. - auto result = ForEachSession([&](auto& entry, const wil::com_ptr& session) noexcept -> std::optional { + wslutil::StopWatch stopWatch; + std::wstring callerFileName; + HRESULT creationResult = S_OK; + const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent); + + // State for a created-but-unconfirmed session, carried across the unlocked + // OnWslcSessionCreated notification. The session is intentionally NOT added + // to m_sessions until the notification succeeds, so another caller can never + // observe (or be handed) a session that a plugin might still veto. + wil::com_ptr createdRef; + wil::com_ptr createdSession; + Microsoft::WRL::ComPtr createdNotifier; + wil::shared_handle createdToken; + std::vector createdSid; + ULONG createdSessionId = 0; + DWORD createdPid = 0; + wil::unique_handle createdJob; + bool created = false; + + // Set once the created session's outcome (publish / veto / race-loss) has + // been handled. The guard only runs if an unexpected throw escapes before + // then, dropping the ref-map registration and terminating the orphan. + bool outcomeHandled = false; + auto createdCleanup = wil::scope_exit([&] { + if (created && !outcomeHandled) + { + g_pluginManager.UnregisterWslcSession(createdSessionId); + if (createdRef) + { + LOG_IF_FAILED(createdRef->Terminate()); + } + } + }); + + // Matches an existing session by name and, if the caller has access, copies + // it out. Used both for the initial existence check and, after the unlocked + // notification, to detect a concurrent create of the same name. + auto openExisting = [&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional { if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), resolvedDisplayName.c_str())) { return {}; } RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), WI_IsFlagClear(Flags, WSLCSessionFlagsOpenExisting)); - RETURN_IF_FAILED(CheckTokenAccess(entry, tokenInfo)); - RETURN_IF_FAILED(wil::com_copy_to_nothrow(session, WslcSession)); - return S_OK; - }); + }; - if (result.has_value()) { - THROW_IF_FAILED(result.value()); - return; // Existing session was opened. - } - - wslutil::StopWatch stopWatch; + std::lock_guard lock(m_wslcSessionsLock); - // Initialize settings for the default session. - std::unique_ptr defaultSettings; - if (Settings == nullptr) - { - defaultSettings = SessionSettings::Default(callerToken.get(), resolvedDisplayName); - Settings = &defaultSettings->Settings; - } + existingResult = ForEachSessionLockHeld(openExisting, deadSessions); - std::wstring callerFileName; - - HRESULT creationResult = wil::ResultFromException([&]() { - // Get caller info. - const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION); - const ULONG sessionId = m_nextSessionId++; - const DWORD creatorPid = GetProcessId(callerProcess.get()); - - // Query the full image path of the calling process and extract just the file name. - std::wstring callerFilePath; - if (SUCCEEDED_LOG(wil::QueryFullProcessImageNameW(callerProcess.get(), 0, callerFilePath))) + if (!existingResult.has_value()) { - callerFileName = std::filesystem::path(callerFilePath).filename().wstring(); + // Initialize settings for the default session. + std::unique_ptr defaultSettings; + if (Settings == nullptr) + { + defaultSettings = SessionSettings::Default(callerToken.get(), resolvedDisplayName); + Settings = &defaultSettings->Settings; + } + + creationResult = wil::ResultFromException([&]() { + // Get caller info. + const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION); + const ULONG sessionId = m_nextSessionId++; + const DWORD creatorPid = GetProcessId(callerProcess.get()); + + // Query the full image path of the calling process and extract just the file name. + std::wstring callerFilePath; + if (SUCCEEDED_LOG(wil::QueryFullProcessImageNameW(callerProcess.get(), 0, callerFilePath))) + { + callerFileName = std::filesystem::path(callerFilePath).filename().wstring(); + } + + const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + + // Capture a duplicated user token + raw SID so PluginManager can build + // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating. + // The token is shared between the SessionEntry and the WSLCPluginNotifier. + wil::unique_handle dupToken; + THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx( + userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken)); + wil::shared_handle sharedToken{dupToken.release()}; + + const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid); + std::vector storedSid(sidLen); + THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid)); + + // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry. + Microsoft::WRL::ComPtr notifier; + notifier = wil::MakeOrThrow( + g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid)); + + // Create the VM in the SYSTEM service (privileged). + auto vm = Microsoft::WRL::Make(Settings); + + // Launch per-user COM server factory and add it to a fresh per-session job object for crash cleanup. + auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get()); + wil::unique_handle sessionJob = CreateSessionProcessJob(factory.get()); + + auto sessionSettings = CreateSessionSettings(sessionId, callerFileName.c_str(), Settings, resolvedDisplayName.c_str()); + wil::com_ptr session; + wil::com_ptr serviceRef; + THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), WarningCallback, &session, &serviceRef)); + + // Register the session reference so plugin API callbacks (WslcMountFolder etc.) + // can resolve it WITHOUT taking m_wslcSessionsLock during the OnWslcSessionCreated + // notification, even though the plugin may call back on a different RPC thread. The + // session is not yet in m_sessions, so it stays invisible to other callers until + // creation is confirmed below. + g_pluginManager.RegisterWslcSession(sessionId, serviceRef.get()); + + // Carry the created (but unconfirmed) session out to the notification and commit. + createdRef = serviceRef; + createdSession = session; + createdNotifier = notifier; + createdToken = sharedToken; + createdSid = std::move(storedSid); + createdSessionId = sessionId; + createdPid = creatorPid; + createdJob = std::move(sessionJob); + created = true; + }); } + } - const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); - - // Capture a duplicated user token + raw SID so PluginManager can build - // WSLCSessionInformation later (e.g. on shutdown) without re-impersonating. - // The token is shared between the SessionEntry and the WSLCPluginNotifier. - wil::unique_handle dupToken; - THROW_IF_WIN32_BOOL_FALSE(DuplicateTokenEx( - userToken.get(), TOKEN_QUERY | TOKEN_DUPLICATE, nullptr, SecurityImpersonation, TokenImpersonation, &dupToken)); - wil::shared_handle sharedToken{dupToken.release()}; - - const DWORD sidLen = GetLengthSid(tokenInfo.TokenInfo->User.Sid); - std::vector storedSid(sidLen); - THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLen, storedSid.data(), tokenInfo.TokenInfo->User.Sid)); - - // Build the plugin notifier service-side. Lifetime tracked via the SessionEntry. - Microsoft::WRL::ComPtr notifier; - notifier = wil::MakeOrThrow( - g_pluginManager, sessionId, creatorPid, std::wstring(resolvedDisplayName), wil::shared_handle(sharedToken), std::vector(storedSid)); - - // Create the VM in the SYSTEM service (privileged). - auto vm = Microsoft::WRL::Make(Settings); - - // Launch per-user COM server factory and add it to a fresh per-session job object for crash cleanup. - auto factory = wslutil::CreateComServerAsUser(__uuidof(WSLCSessionFactory), userToken.get()); - wil::unique_handle sessionJob = CreateSessionProcessJob(factory.get()); - - const auto sessionSettings = CreateSessionSettings(sessionId, callerFileName.c_str(), Settings, resolvedDisplayName.c_str()); - wil::com_ptr session; - wil::com_ptr serviceRef; - THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), notifier.Get(), WarningCallback, &session, &serviceRef)); - - // Track the session via its service ref, along with metadata and security info. - m_sessions.push_back(SessionEntry{ - std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo), notifier, false, sharedToken, std::move(storedSid), std::move(sessionJob)}); + // Dispatch OnWslcSessionStopping for any dead sessions found during the + // existence check, now that the lock is released. + for (auto& entry : deadSessions) + { + DispatchSessionStopping(entry); + } + deadSessions.clear(); - // For persistent sessions, also hold a strong reference to keep them alive. - const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent); - if (persistent) - { - m_persistentSessions.emplace_back(sessionId, session); - } + if (existingResult.has_value()) + { + THROW_IF_FAILED(existingResult.value()); + return; // Existing session was opened. + } - // Notify plugins that the session was created. A failure here aborts session creation. + if (created) + { + // Notify plugins with the lock RELEASED: a plugin runs arbitrary code and + // may call back into the service (including the externally-activatable + // session manager) on another thread, so holding m_wslcSessionsLock across + // the notification could deadlock. Plugin API callbacks resolve this + // session through the ref-map registered above instead. + WSLCSessionInformation info{}; + info.SessionId = static_cast(createdSessionId); + info.DisplayName = resolvedDisplayName.c_str(); + info.ApplicationPid = createdPid; + info.UserToken = createdToken.get(); + info.UserSid = createdSid.data(); + + HRESULT pluginHr = S_OK; try { - auto& entry = m_sessions.back(); - WSLCSessionInformation info{}; - info.SessionId = static_cast(entry.SessionId); - info.DisplayName = entry.DisplayName.c_str(); - info.ApplicationPid = entry.CreatorPid; - info.UserToken = entry.UserToken.get(); - info.UserSid = entry.UserSid.data(); g_pluginManager.OnWslcSessionCreated(&info); } catch (...) { - const auto error = wil::ResultFromCaughtException(); + pluginHr = wil::ResultFromCaughtException(); + } - // Plugin rejected the session: tear it down before propagating. - m_sessions.back().StoppingNotified = true; // Don't fire stopping for a session that never started successfully. - LOG_IF_FAILED(m_sessions.back().Ref->Terminate()); - m_sessions.pop_back(); + enum class Outcome + { + Commit, + Veto, + RaceLost + } outcome = Outcome::Commit; - auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == sessionId; }); - m_persistentSessions.erase(remove.begin(), remove.end()); + { + std::lock_guard lock(m_wslcSessionsLock); + + if (FAILED(pluginHr)) + { + outcome = Outcome::Veto; + creationResult = pluginHr; + } + else if (auto raceWinner = ForEachSessionLockHeld(openExisting, deadSessions); raceWinner.has_value()) + { + outcome = Outcome::RaceLost; + creationResult = raceWinner.value(); + } + else + { + // Creation confirmed: publish the session and hand it to the + // caller. Reserve first so the inserts can't throw and leave a + // half-published, externally-visible session behind. + m_sessions.reserve(m_sessions.size() + 1); + if (persistent) + { + m_persistentSessions.reserve(m_persistentSessions.size() + 1); + } + + m_sessions.push_back(SessionEntry{ + std::move(createdRef), createdSessionId, createdPid, resolvedDisplayName, std::move(tokenInfo), createdNotifier, false, createdToken, std::move(createdSid), std::move(createdJob)}); + + if (persistent) + { + m_persistentSessions.emplace_back(createdSessionId, createdSession); + } + + *WslcSession = createdSession.detach(); + outcomeHandled = true; + } + } - THROW_HR(error); + // Tear down vetoed / race-lost sessions with the lock released: both + // Terminate() and the stopping notification are out-of-process calls + // that must not run under m_wslcSessionsLock. + if (outcome == Outcome::Veto || outcome == Outcome::RaceLost) + { + // A plugin vetoed the session in OnWslcSessionCreated, or a concurrent + // CreateSession won the name. Either way, fire OnWslcSessionStopping so + // every plugin that already observed OnWslcSessionCreated sees a matching + // stopping (the established veto convention, mirroring the VM path's + // OnVmStarted -> _VmTerminate -> OnVmStopping), then drop and terminate. + SessionEntry teardown{ + createdRef, createdSessionId, createdPid, resolvedDisplayName, std::move(tokenInfo), createdNotifier, false, createdToken, std::move(createdSid)}; + DispatchSessionStopping(teardown); + LOG_IF_FAILED(createdRef->Terminate()); + outcomeHandled = true; } + } - *WslcSession = session.detach(); - }); + // Dispatch OnWslcSessionStopping for any dead sessions found during the + // post-notification re-check, now that the lock is released. + for (auto& entry : deadSessions) + { + DispatchSessionStopping(entry); + } // This telemetry event is used to keep track of session creation performance (via CreationTimeMs) and failure reasons (via Result). WSL_LOG( @@ -330,7 +468,7 @@ void WSLCSessionManagerImpl::CreateSession( TraceLoggingValue(WSL_PACKAGE_VERSION, "wslVersion"), TraceLoggingValue(stopWatch.ElapsedMilliseconds(), "CreationTimeMs"), TraceLoggingValue(creationResult, "Result"), - TraceLoggingValue(tokenInfo.Elevated, "Elevated"), + TraceLoggingValue(elevated, "Elevated"), TraceLoggingValue(static_cast(Flags), "Flags"), TraceLoggingValue(callerFileName.c_str(), "CallerFileName"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); @@ -605,22 +743,4 @@ WSLCSessionManagerImpl* WSLCSessionManagerImpl::Instance() noexcept return g_managerInstance.load(); } -wil::com_ptr WSLCSessionManagerImpl::FindSession(ULONG Id) -{ - wil::com_ptr result; - - ForEachSession([&](SessionEntry& entry, const wil::com_ptr& session) noexcept -> std::optional { - if (entry.SessionId != Id) - { - return std::nullopt; - } - - result = session; - return S_OK; - }); - - THROW_HR_IF_MSG(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), !result, "WSLC session %lu not found", Id); - return result; -} - } // namespace wsl::windows::service::wslc diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h index d48f52e430..a09940bd23 100644 --- a/src/windows/service/exe/WSLCSessionManager.h +++ b/src/windows/service/exe/WSLCSessionManager.h @@ -91,9 +91,6 @@ class WSLCSessionManagerImpl void OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session); void OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session); - // Resolves a session by ID for plugin->API calls. Throws ERROR_NOT_FOUND if no session matches. - wil::com_ptr FindSession(ULONG Id); - static WSLCSessionManagerImpl* Instance() noexcept; private: @@ -104,13 +101,16 @@ class WSLCSessionManagerImpl // Returns true if the name matches a reserved default session prefix. static bool IsReservedSessionName(LPCWSTR Name); - // Iterates over all sessions, cleaning up released sessions. - // The routine receives a SessionEntry& and can return an optional to stop iteration. + // Iterates over all sessions, cleaning up released sessions. The CALLER + // MUST hold m_wslcSessionsLock. Sessions whose backing process is gone are + // moved out of tracking into `DeadSessions` so the caller can dispatch + // their OnWslcSessionStopping plugin notification (via DispatchSessionStopping) + // AFTER releasing the lock — the notification is an out-of-process call and + // must not run under the lock. The routine receives a SessionEntry& and can + // return an optional to stop iteration. template - inline auto ForEachSession(const auto& Routine) + inline auto ForEachSessionLockHeld(const auto& Routine, std::vector& DeadSessions) { - std::lock_guard lock(m_wslcSessionsLock); - // Enforce noexcept: remove_if leaves the container in an unspecified // (partially-moved) state if the predicate throws. Callers must handle // errors via return values, not exceptions. @@ -128,8 +128,24 @@ class WSLCSessionManagerImpl wil::com_ptr lockedSession; if (FAILED_LOG(entry.Ref->OpenSession(&lockedSession))) { - // Session is gone: notify plugins (if not already), then drop persistent reference if any. - NotifySessionStoppingLockHeld(entry); + // Session is gone: move it out for deferred OnWslcSessionStopping + // dispatch (must happen outside the lock) and drop any persistent + // reference. The StoppingNotified flag is intentionally left + // untouched here; DispatchSessionStopping flips it so the + // notification fires exactly once. + try + { + DeadSessions.push_back(std::move(entry)); + } + catch (...) + { + // Couldn't queue it for deferred stopping dispatch: keep it + // tracked (the move above is noexcept, so entry is intact) + // and reap it on a later pass rather than dropping it + // without unregistering. + LOG_CAUGHT_EXCEPTION(); + return false; + } auto remove = std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return e.first == entry.SessionId; }); @@ -166,15 +182,59 @@ class WSLCSessionManagerImpl } [[nodiscard]] wil::unique_handle CreateSessionProcessJob(_In_ IWSLCSessionFactory* Factory); + + // Convenience wrapper that takes m_wslcSessionsLock internally and dispatches + // OnWslcSessionStopping for any dead sessions after releasing the lock. Use + // this from call sites that do NOT already hold the lock. + template + inline auto ForEachSession(const auto& Routine) + { + std::vector deadSessions; + + using TResult = std::conditional_t, nullptr_t, std::optional>; + TResult result{}; + + { + std::lock_guard lock(m_wslcSessionsLock); + if constexpr (std::is_same_v) + { + ForEachSessionLockHeld(Routine, deadSessions); + } + else + { + result = ForEachSessionLockHeld(Routine, deadSessions); + } + } + + // Safe to invoke plugin callbacks now that the lock is released. + for (auto& entry : deadSessions) + { + DispatchSessionStopping(entry); + } + + if constexpr (std::is_same_v) + { + return; + } + else + { + return result; + } + } + WSLCSessionInitSettings CreateSessionSettings( _In_ ULONG SessionId, _In_ LPCWSTR CreatorProcessName, _In_ const WSLCSessionSettings* Settings, _In_ LPCWSTR ResolvedDisplayName); static CallingProcessTokenInfo GetCallingProcessTokenInfo(); static HRESULT CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo); - void NotifySessionStoppingLockHeld(SessionEntry& entry) noexcept; + // Fires OnWslcSessionStopping for `entry` exactly once (guarded by + // entry.StoppingNotified) and then unregisters the session from the plugin + // session reference map. MUST be called WITHOUT m_wslcSessionsLock held — + // the plugin notification is dispatched out-of-process to wslpluginhost.exe. + void DispatchSessionStopping(SessionEntry& entry) noexcept; std::atomic m_nextSessionId{1}; - std::recursive_mutex m_wslcSessionsLock; + std::mutex m_wslcSessionsLock; // All sessions tracked via SessionEntry (which holds weak refs and service-side security info). // Sessions are automatically cleaned up when the underlying session is released. diff --git a/src/windows/service/inc/CMakeLists.txt b/src/windows/service/inc/CMakeLists.txt index 59888c441f..e8cbc1dc24 100644 --- a/src/windows/service/inc/CMakeLists.txt +++ b/src/windows/service/inc/CMakeLists.txt @@ -1,3 +1,5 @@ add_idl(wslserviceidl "wslservice.idl;wslc.idl" "windowsdefs.idl") +add_idl(wslpluginhostidl "WslPluginHost.idl" "") set_target_properties(wslserviceidl PROPERTIES FOLDER windows) +set_target_properties(wslpluginhostidl PROPERTIES FOLDER windows) diff --git a/src/windows/service/inc/WslPluginHost.idl b/src/windows/service/inc/WslPluginHost.idl new file mode 100644 index 0000000000..73d51b101f --- /dev/null +++ b/src/windows/service/inc/WslPluginHost.idl @@ -0,0 +1,260 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslPluginHost.idl + +Abstract: + + This file contains the COM interface definitions for out-of-process + plugin hosting. IWslPluginHost is implemented by the plugin host process + and called by the service. IWslPluginHostCallback is implemented by the + service and called by the plugin host when a plugin invokes API functions. + +--*/ + +import "unknwn.idl"; +import "wtypes.idl"; + +cpp_quote("const GUID CLSID_WslPluginHost = {0x7a1d2c3e, 0x4b5f, 0x6a7d, {0x8e, 0x9f, 0x0a, 0x1b, 0x2c, 0x3d, 0x4e, 0x5f}};") +cpp_quote("#ifdef __cplusplus") +cpp_quote("class DECLSPEC_UUID(\"7a1d2c3e-4b5f-6a7d-8e9f-0a1b2c3d4e5f\") WslPluginHost;") +cpp_quote("#endif") + +// +// IWslPluginHostCallback - implemented by the service, called by the plugin host +// when a plugin invokes WSLPluginAPIV1 functions (MountFolder, ExecuteBinary, etc.) +// + +[ + uuid(A2B3C4D5-E6F7-4890-AB12-CD34EF56A789), + pointer_default(unique), + object +] +interface IWslPluginHostCallback : IUnknown +{ + HRESULT MountFolder( + [in] DWORD SessionId, + [in, string] LPCWSTR WindowsPath, + [in, string] LPCWSTR LinuxPath, + [in] BOOL ReadOnly, + [in, string] LPCWSTR Name); + + HRESULT ExecuteBinary( + [in] DWORD SessionId, + [in, string] LPCSTR Path, + [in] DWORD ArgumentCount, + [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, + [out, system_handle(sh_socket)] HANDLE* Socket); + + HRESULT ExecuteBinaryInDistribution( + [in] DWORD SessionId, + [in] const GUID* DistributionId, + [in, string] LPCSTR Path, + [in] DWORD ArgumentCount, + [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, + [out, system_handle(sh_socket)] HANDLE* Socket); + + // + // WSLC plugin API. WSLCProcessHandle is mapped to a service-allocated + // DWORD cookie that the host wraps in an opaque void* for the plugin. + // + + HRESULT WslcMountFolder( + [in] DWORD SessionId, + [in, string] LPCWSTR WindowsPath, + [in, string] LPCSTR Mountpoint, + [in] BOOL ReadOnly); + + HRESULT WslcUnmountFolder( + [in] DWORD SessionId, + [in, string] LPCSTR Mountpoint); + + HRESULT WslcCreateProcess( + [in] DWORD SessionId, + [in, string] LPCSTR Executable, + [in] DWORD ArgumentCount, + [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, + [in] DWORD EnvCount, + [in, unique, size_is(EnvCount), string] LPCSTR* Environment, + [out] DWORD* ProcessCookie, + [out] int* Errno); + + HRESULT WslcProcessGetFd( + [in] DWORD ProcessCookie, + [in] DWORD Fd, + [out, system_handle(sh_socket)] HANDLE* Handle); + + HRESULT WslcProcessGetExitEvent( + [in] DWORD ProcessCookie, + [out, system_handle(sh_event)] HANDLE* ExitEvent); + + HRESULT WslcProcessGetExitCode( + [in] DWORD ProcessCookie, + [out] int* ExitCode); + + HRESULT WslcReleaseProcess( + [in] DWORD ProcessCookie); +}; + +// +// IWslPluginHost - implemented by the plugin host process, called by the service +// to deliver lifecycle notifications to the plugin. +// + +[ + uuid(B3C4D5E6-F7A8-4901-BC23-DE45FA67B890), + pointer_default(unique), + object +] +interface IWslPluginHost : IUnknown +{ + // + // Initialize the plugin host: load the plugin DLL and call its entry point. + // The Callback interface is used by the plugin to call back into the service. + // + + HRESULT Initialize( + [in] IWslPluginHostCallback* Callback, + [in, string] LPCWSTR PluginPath, + [in, string] LPCWSTR PluginName); + + // + // Returns a handle to this COM server process. Used by the service to add + // the plugin host to a job object for automatic cleanup on service exit. + // + + HRESULT GetProcessHandle( + [out, system_handle(sh_process)] HANDLE* ProcessHandle); + + // + // Lifecycle hook dispatchers - mirror WSLPluginHooksV1. + // UserToken is duplicated into the host process by the service before calling. + // UserSid is serialized as a byte array. + // + + HRESULT OnVMStarted( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] DWORD CustomConfigurationFlags, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnVMStopping( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData); + + HRESULT OnDistributionStarted( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in] ULONGLONG PidNamespace, + [in, unique, string] LPCWSTR PackageFamilyName, + [in] DWORD InitPid, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnDistributionStopping( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in] ULONGLONG PidNamespace, + [in, unique, string] LPCWSTR PackageFamilyName, + [in] DWORD InitPid, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); + + HRESULT OnDistributionRegistered( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in, unique, string] LPCWSTR PackageFamilyName, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); + + HRESULT OnDistributionUnregistered( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in, unique, string] LPCWSTR PackageFamilyName, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); + + // + // WSLC plugin hooks. Mirror WSLPluginHooksV1 WSLC entries. + // UserToken is duplicated into the host by COM's system_handle marshaling. + // UserSid is serialized as a byte array (same convention as WSL hooks). + // + + HRESULT OnWslcSessionCreated( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnWslcSessionStopping( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData); + + HRESULT OnWslcContainerStarted( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData, + [in, string] LPCSTR InspectJson, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnWslcContainerStopping( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData, + [in, string] LPCSTR ContainerId); + + HRESULT OnWslcImageCreated( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData, + [in, string] LPCSTR InspectJson); + + HRESULT OnWslcImageDeleted( + [in] DWORD SessionId, + [in, string] LPCWSTR DisplayName, + [in] DWORD ApplicationPid, + [in, unique, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, unique, size_is(SidSize)] BYTE* SidData, + [in, string] LPCSTR ImageId); +}; diff --git a/src/windows/service/stub/CMakeLists.txt b/src/windows/service/stub/CMakeLists.txt index c0308d6dcd..c8ef86f3f9 100644 --- a/src/windows/service/stub/CMakeLists.txt +++ b/src/windows/service/stub/CMakeLists.txt @@ -3,6 +3,8 @@ set(SOURCES ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslservice_p_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslc_i_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslc_p_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_i_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_p_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.def ${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.rc) @@ -10,6 +12,6 @@ set(SOURCES set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE) add_library(wslserviceproxystub SHARED ${SOURCES}) -add_dependencies(wslserviceproxystub wslserviceidl) +add_dependencies(wslserviceproxystub wslserviceidl wslpluginhostidl) target_link_libraries(wslserviceproxystub ${COMMON_LINK_LIBRARIES}) set_target_properties(wslserviceproxystub PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslinstall/DllMain.cpp b/src/windows/wslinstall/DllMain.cpp index cc42baf018..01d39d4d39 100644 --- a/src/windows/wslinstall/DllMain.cpp +++ b/src/windows/wslinstall/DllMain.cpp @@ -827,7 +827,7 @@ void RegisterLspCategoriesImpl(DWORD flags) const auto installRoot = wsl::windows::common::wslutil::GetMsiPackagePath(); THROW_HR_IF(E_INVALIDARG, !installRoot.has_value()); - for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe", L"wslservice.exe"}) + for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe", L"wslservice.exe"}) { auto executable = installRoot.value() + e; INT error{}; diff --git a/src/windows/wslpluginhost/CMakeLists.txt b/src/windows/wslpluginhost/CMakeLists.txt new file mode 100644 index 0000000000..ce8dc8f770 --- /dev/null +++ b/src/windows/wslpluginhost/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(exe) diff --git a/src/windows/wslpluginhost/exe/CMakeLists.txt b/src/windows/wslpluginhost/exe/CMakeLists.txt new file mode 100644 index 0000000000..fc16ec3bc2 --- /dev/null +++ b/src/windows/wslpluginhost/exe/CMakeLists.txt @@ -0,0 +1,25 @@ +set(SOURCES + main.cpp + main.rc + PluginHost.cpp) + +set(HEADERS + PluginHost.h + resource.h) + +add_executable(wslpluginhost WIN32 ${SOURCES} ${HEADERS}) +add_dependencies(wslpluginhost + wslpluginhostidl + common) + +target_include_directories(wslpluginhost PRIVATE + ${CMAKE_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) + +target_link_libraries(wslpluginhost + ${COMMON_LINK_LIBRARIES} + ${MSI_LINK_LIBRARIES} + common + ole32.lib) + +target_precompile_headers(wslpluginhost REUSE_FROM common) +set_target_properties(wslpluginhost PROPERTIES FOLDER windows) diff --git a/src/windows/wslpluginhost/exe/PluginHost.cpp b/src/windows/wslpluginhost/exe/PluginHost.cpp new file mode 100644 index 0000000000..19bcc62e18 --- /dev/null +++ b/src/windows/wslpluginhost/exe/PluginHost.cpp @@ -0,0 +1,947 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + PluginHost.cpp + +Abstract: + + This file contains the IWslPluginHost COM class implementation. + It loads a plugin DLL in this (host) process and forwards lifecycle + notifications from the service to the plugin, while routing plugin API + callbacks back to the service via IWslPluginHostCallback. + +--*/ + +#include "precomp.h" +#include "PluginHost.h" +#include "install.h" + +using namespace wsl::windows::pluginhost; + +// Defined in main.cpp — part of the COM local server lifecycle. +extern void AddComRef(); +extern void ReleaseComRef(); + +std::atomic wsl::windows::pluginhost::g_pluginHost{nullptr}; + +// Thread ID of the thread currently dispatching a plugin hook. +// Only that thread may call PluginError. Using thread ID instead of +// thread_local to avoid TLS initialization issues across DLL/EXE boundaries. +static std::atomic g_hookThreadId{0}; + +namespace { + +// Plugins may invoke API stubs from worker threads they create themselves, +// which often haven't called CoInitializeEx. Without COM initialization, the +// cross-process m_callback proxy can't be used and calls would fail with +// CO_E_NOTINITIALIZED. This RAII helper joins the calling thread to the MTA +// for the duration of the API call. RPC_E_CHANGED_MODE means the thread is +// already STA-initialized; we leave its apartment unchanged. The supported +// path is an uninitialized worker thread joining the MTA where m_callback was +// marshaled — calling the proxy from a caller-created STA is not supported. +struct ScopedComInitForCallback +{ + HRESULT initHr; + ScopedComInitForCallback() : initHr(::CoInitializeEx(nullptr, COINIT_MULTITHREADED)) + { + } + ~ScopedComInitForCallback() + { + if (SUCCEEDED(initHr)) + { + ::CoUninitialize(); + } + } + ScopedComInitForCallback(const ScopedComInitForCallback&) = delete; + ScopedComInitForCallback& operator=(const ScopedComInitForCallback&) = delete; + + HRESULT Result() const + { + return (initHr == RPC_E_CHANGED_MODE) ? S_OK : initHr; + } +}; + +} // namespace + +// Tracks how many PluginHost instances have ever been constructed in this +// process. Used by main() to distinguish "we were activated and the client +// disconnected normally" from "we sat here for the entire startup window +// without ever being activated by anyone" (orphaned host). +std::atomic wsl::windows::pluginhost::g_activationCount{0}; + +PluginHost::PluginHost() +{ + // Increment the COM server reference count so the process stays alive while + // this instance exists. Pairs with ReleaseComRef() in ~PluginHost(); tying + // both to the object's lifetime guarantees they always balance regardless of + // whether the factory's CopyTo() succeeds or fails. + AddComRef(); + g_activationCount.fetch_add(1, std::memory_order_release); +} + +PluginHost::~PluginHost() +{ + // Clear globally reachable state so late plugin API calls fail with + // E_UNEXPECTED instead of dereferencing freed memory. Release-store + // pairs with the acquire-loads in the Local* API stubs. + PluginHost* expected = this; + g_pluginHost.compare_exchange_strong(expected, nullptr, std::memory_order_acq_rel); + + // Module unloads automatically via wil::unique_hmodule destructor. + + // Decrement the COM server reference count. When it reaches zero, + // the process will exit. Matches AddComRef() in PluginHost::PluginHost(). + ReleaseComRef(); +} + +// --- IWslPluginHost implementation --- + +STDMETHODIMP PluginHost::Initialize(_In_ IWslPluginHostCallback* Callback, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) +try +{ + RETURN_HR_IF(E_INVALIDARG, Callback == nullptr || PluginPath == nullptr || PluginName == nullptr); + RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, m_module.is_valid()); // Already initialized + + m_callback = Callback; + + // Validate the plugin signature before loading it. + // Keep the file handle open to prevent TOCTOU (swap between validation and load). + wil::unique_hfile signatureHandle; + if constexpr (wsl::shared::OfficialBuild) + { + signatureHandle = wsl::windows::common::install::ValidateFileSignature(PluginPath); + } + + m_module.reset(LoadLibrary(PluginPath)); + THROW_LAST_ERROR_IF_NULL(m_module); + signatureHandle.reset(); // Safe to release after LoadLibrary has mapped the DLL + + const auto entryPoint = + reinterpret_cast(GetProcAddress(m_module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1))); + THROW_LAST_ERROR_IF_NULL(entryPoint); + + // Build the API vtable that the plugin will use to call back into the service. + // The function pointers are static methods on this class that route through g_pluginHost. + static const WSLPluginAPIV1 api = { + {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision}, + &LocalMountFolder, + &LocalExecuteBinary, + &LocalPluginError, + &LocalExecuteBinaryInDistribution, + &LocalWslcMountFolder, + &LocalWslcUnmountFolder, + &LocalWslcCreateProcess, + &LocalWslcProcessGetFd, + &LocalWslcProcessGetExitEvent, + &LocalWslcProcessGetExitCode, + &LocalWslcReleaseProcess}; + + // Publish g_pluginHost with release semantics so an acquire-load in a stub + // observes a fully-constructed m_callback / m_hooks. Only publish on success + // so a failed entry point never leaves a dangling pointer for late stub calls. + g_pluginHost.store(this, std::memory_order_release); + HRESULT hr = entryPoint(&api, &m_hooks); + + if (FAILED(hr)) + { + g_pluginHost.store(nullptr, std::memory_order_release); + RETURN_HR_MSG(hr, "Plugin entry point failed: '%ls'", PluginPath); + } + + return S_OK; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::GetProcessHandle(_Out_ HANDLE* ProcessHandle) +try +{ + RETURN_HR_IF(E_POINTER, ProcessHandle == nullptr); + *ProcessHandle = nullptr; + + wil::unique_handle process(OpenProcess(PROCESS_SET_QUOTA | PROCESS_TERMINATE, FALSE, GetCurrentProcessId())); + RETURN_LAST_ERROR_IF_NULL(process); + + // COM's system_handle(sh_process) marshaling will duplicate this into the caller's process. + *ProcessHandle = process.release(); + return S_OK; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnVMStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ DWORD CustomConfigurationFlags, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.OnVMStarted == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + WSLVmCreationSettings settings{}; + settings.CustomConfigurationFlags = static_cast(CustomConfigurationFlags); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnVMStarted(&ctx.info, &settings); + + // If the plugin called PluginError during the hook, return the message. + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData) +try +{ + if (m_hooks.OnVMStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnVMStopping(&ctx.info); + + // Plugin must not call PluginError outside hooks that surface ErrorMessage; + // silently drop any value so it can't poison a later fatal hook. + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.OnDistributionStarted == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + // Preserve nullability of PackageFamilyName/Flavor/Version per the public + // ABI documented in WslPluginApi.h. Coalescing to L"" would silently change + // behavior of plugins that check `if (Distribution->PackageFamilyName)`. + WSLDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PidNamespace = PidNamespace; + distro.PackageFamilyName = PackageFamilyName; + distro.InitPid = InitPid; + distro.Flavor = Flavor; + distro.Version = Version; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionStarted(&ctx.info, &distro); + + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionStopping( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WSLDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PidNamespace = PidNamespace; + distro.PackageFamilyName = PackageFamilyName; + distro.InitPid = InitPid; + distro.Flavor = Flavor; + distro.Version = Version; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionStopping(&ctx.info, &distro); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionRegistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionRegistered == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WslOfflineDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PackageFamilyName = PackageFamilyName; + distro.Flavor = Flavor; + distro.Version = Version; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionRegistered(&ctx.info, &distro); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionUnregistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionUnregistered == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WslOfflineDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PackageFamilyName = PackageFamilyName; + distro.Flavor = Flavor; + distro.Version = Version; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionUnregistered(&ctx.info, &distro); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +// --- Helpers --- + +PluginHost::SessionContext PluginHost::BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData) +{ + SessionContext ctx{}; + ctx.info.SessionId = SessionId; + + // The marshaled UserToken is owned by the RPC stub, which closes it when the + // call returns. Borrow it for the duration of the hook; do not take ownership. + ctx.info.UserToken = UserToken; + + // Reconstruct the SID. Reject malformed inputs at the COM boundary — + // pointer arithmetic on a null pointer (even with size 0) is undefined + // behavior, and a malformed SID would be dereferenced by plugin code. + THROW_HR_IF(E_INVALIDARG, SidData == nullptr || SidSize == 0); + ctx.sidBuffer.assign(SidData, SidData + SidSize); + THROW_HR_IF(E_INVALIDARG, !::IsValidSid(reinterpret_cast(ctx.sidBuffer.data()))); + ctx.info.UserSid = reinterpret_cast(ctx.sidBuffer.data()); + + return ctx; +} + +// --- Static API stubs --- + +HRESULT CALLBACK PluginHost::LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + auto hr = host->m_callback->MountFolder(Session, WindowsPath, LinuxPath, ReadOnly, Name); + return hr; +} + +HRESULT CALLBACK PluginHost::LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_POINTER, Socket == nullptr); + // Initialize the out-param so callers don't observe an uninitialized + // socket value on any error return below. + *Socket = INVALID_SOCKET; + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + // Count arguments (NULL-terminated array) + DWORD count = 0; + if (Arguments != nullptr) + { + for (const LPCSTR* p = Arguments; *p != nullptr; ++p) + { + ++count; + } + } + + HANDLE socketResult = nullptr; + HRESULT hr = host->m_callback->ExecuteBinary(Session, Path, count, Arguments, &socketResult); + + if (SUCCEEDED(hr)) + { + // COM's system_handle marshaling duplicated the socket into our process. + *Socket = reinterpret_cast(socketResult); + } + else if (socketResult != nullptr) + { + if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR) + { + LOG_WIN32(WSAGetLastError()); + } + } + + return hr; +} + +HRESULT CALLBACK PluginHost::LocalPluginError(LPCWSTR UserMessage) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr) + { + // Not on a hook thread — PluginError must only be called + // synchronously from within OnVMStarted/OnDistributionStarted. + return E_ILLEGAL_METHOD_CALL; + } + + RETURN_HR_IF(E_INVALIDARG, UserMessage == nullptr); + RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, GetCurrentThreadId() != g_hookThreadId.load()); + RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, host->m_pluginErrorMessage.has_value()); + + // Store locally — returned to service alongside the hook HRESULT. + host->m_pluginErrorMessage.emplace(UserMessage); + return S_OK; +} + +HRESULT CALLBACK PluginHost::LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_INVALIDARG, Distro == nullptr); + RETURN_HR_IF(E_POINTER, Socket == nullptr); + // Initialize the out-param so callers don't observe an uninitialized + // socket value on any error return below. + *Socket = INVALID_SOCKET; + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + DWORD count = 0; + if (Arguments != nullptr) + { + for (const LPCSTR* p = Arguments; *p != nullptr; ++p) + { + ++count; + } + } + + HANDLE socketResult = nullptr; + HRESULT hr = host->m_callback->ExecuteBinaryInDistribution(Session, Distro, Path, count, Arguments, &socketResult); + + if (SUCCEEDED(hr)) + { + *Socket = reinterpret_cast(socketResult); + } + else if (socketResult != nullptr) + { + if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR) + { + LOG_WIN32(WSAGetLastError()); + } + } + + return hr; +} + +// --- WSLC hook implementations --- + +PluginHost::WslcSessionContext PluginHost::BuildWslcSessionContext( + DWORD SessionId, LPCWSTR DisplayName, DWORD ApplicationPid, HANDLE UserToken, DWORD SidSize, BYTE* SidData) +{ + WslcSessionContext ctx{}; + ctx.info.SessionId = SessionId; + ctx.info.DisplayName = DisplayName; + ctx.info.ApplicationPid = ApplicationPid; + + // UserToken / SidData are unique in the IDL — both may be null in stopping + // paths invoked during session teardown. The marshaled token is owned by the + // RPC stub (closed when the call returns), so borrow it without taking ownership. + ctx.info.UserToken = UserToken; + + if (SidData != nullptr && SidSize > 0) + { + ctx.sidBuffer.assign(SidData, SidData + SidSize); + THROW_HR_IF(E_INVALIDARG, !::IsValidSid(reinterpret_cast(ctx.sidBuffer.data()))); + ctx.info.UserSid = reinterpret_cast(ctx.sidBuffer.data()); + } + + return ctx; +} + +STDMETHODIMP PluginHost::OnWslcSessionCreated( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.OnSessionCreated == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnSessionCreated(&ctx.info); + + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnWslcSessionStopping( + _In_ DWORD SessionId, _In_ LPCWSTR DisplayName, _In_ DWORD ApplicationPid, _In_opt_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_opt_(SidSize) BYTE* SidData) +try +{ + if (m_hooks.OnSessionStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnSessionStopping(&ctx.info); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnWslcContainerStarted( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR InspectJson, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.ContainerStarted == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.ContainerStarted(&ctx.info, InspectJson); + + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnWslcContainerStopping( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR ContainerId) +try +{ + if (m_hooks.ContainerStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.ContainerStopping(&ctx.info, ContainerId); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnWslcImageCreated( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR InspectJson) +try +{ + if (m_hooks.ImageCreated == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.ImageCreated(&ctx.info, InspectJson); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnWslcImageDeleted( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR ImageId) +try +{ + if (m_hooks.ImageDeleted == nullptr) + { + return S_OK; + } + + auto ctx = BuildWslcSessionContext(SessionId, DisplayName, ApplicationPid, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.ImageDeleted(&ctx.info, ImageId); + + m_pluginErrorMessage.reset(); + return hr; +} +CATCH_RETURN(); + +// --- WSLC API stubs --- + +namespace { + +// Opaque wrapper handed to the plugin as WSLCProcessHandle. The DWORD cookie +// identifies the IWSLCProcess held by the service-side PluginHostCallbackImpl. +struct WslcProcessWrapper +{ + DWORD cookie; +}; + +} // namespace + +HRESULT CALLBACK PluginHost::LocalWslcMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_POINTER, Mountpoint == nullptr); + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + return host->m_callback->WslcMountFolder(Session, WindowsPath, Mountpoint, ReadOnly); +} + +HRESULT CALLBACK PluginHost::LocalWslcUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + return host->m_callback->WslcUnmountFolder(Session, Mountpoint); +} + +HRESULT CALLBACK PluginHost::LocalWslcCreateProcess( + WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_POINTER, Process == nullptr); + *Process = nullptr; + + int localErrno = 0; + if (Errno != nullptr) + { + *Errno = 0; + } + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + DWORD argCount = 0; + if (Arguments != nullptr) + { + for (const LPCSTR* p = Arguments; *p != nullptr; ++p) + { + ++argCount; + } + } + + DWORD envCount = 0; + if (Env != nullptr) + { + for (const LPCSTR* p = Env; *p != nullptr; ++p) + { + ++envCount; + } + } + + // Allocate the wrapper before creating the remote process so a throwing + // allocation can't strand a service-side cookie that only WslcReleaseProcess + // frees. Nothing between the remote create and release() below can throw. + auto wrapper = std::make_unique(); + + DWORD cookie = 0; + HRESULT hr = host->m_callback->WslcCreateProcess(Session, Executable, argCount, Arguments, envCount, Env, &cookie, &localErrno); + if (Errno != nullptr) + { + *Errno = localErrno; + } + + if (FAILED(hr)) + { + return hr; + } + + wrapper->cookie = cookie; + *Process = wrapper.release(); + return S_OK; +} + +HRESULT CALLBACK PluginHost::LocalWslcProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_INVALIDARG, Process == nullptr); + RETURN_HR_IF(E_POINTER, Handle == nullptr); + *Handle = nullptr; + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + auto* wrapper = static_cast(Process); + return host->m_callback->WslcProcessGetFd(wrapper->cookie, static_cast(Fd), Handle); +} + +HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_INVALIDARG, Process == nullptr); + RETURN_HR_IF(E_POINTER, ExitEvent == nullptr); + *ExitEvent = nullptr; + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + auto* wrapper = static_cast(Process); + return host->m_callback->WslcProcessGetExitEvent(wrapper->cookie, ExitEvent); +} + +HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode) +{ + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_INVALIDARG, Process == nullptr); + RETURN_HR_IF(E_POINTER, ExitCode == nullptr); + + ScopedComInitForCallback coInit; + RETURN_IF_FAILED(coInit.Result()); + + auto* wrapper = static_cast(Process); + return host->m_callback->WslcProcessGetExitCode(wrapper->cookie, ExitCode); +} + +void CALLBACK PluginHost::LocalWslcReleaseProcess(WSLCProcessHandle Process) +{ + if (Process == nullptr) + { + return; + } + + std::unique_ptr wrapper{static_cast(Process)}; + + auto* host = g_pluginHost.load(std::memory_order_acquire); + if (host == nullptr || host->m_callback == nullptr) + { + return; + } + + ScopedComInitForCallback coInit; + if (FAILED(coInit.Result())) + { + return; + } + + LOG_IF_FAILED(host->m_callback->WslcReleaseProcess(wrapper->cookie)); +} \ No newline at end of file diff --git a/src/windows/wslpluginhost/exe/PluginHost.h b/src/windows/wslpluginhost/exe/PluginHost.h new file mode 100644 index 0000000000..b7381c9d4a --- /dev/null +++ b/src/windows/wslpluginhost/exe/PluginHost.h @@ -0,0 +1,222 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + PluginHost.h + +Abstract: + + This file contains the COM class that implements IWslPluginHost. + It loads a plugin DLL and dispatches lifecycle notifications to it, + forwarding plugin API callbacks to the service via IWslPluginHostCallback. + +--*/ + +#pragma once + +#include "WslPluginApi.h" +#include "WslPluginHost.h" + +namespace wsl::windows::pluginhost { + +class PluginHost : public Microsoft::WRL::RuntimeClass, IWslPluginHost> +{ +public: + PluginHost(); + ~PluginHost(); + + PluginHost(const PluginHost&) = delete; + PluginHost& operator=(const PluginHost&) = delete; + + // IWslPluginHost + STDMETHODIMP Initialize(_In_ IWslPluginHostCallback* Callback, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) override; + STDMETHODIMP GetProcessHandle(_Out_ HANDLE* ProcessHandle) override; + + STDMETHODIMP OnVMStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ DWORD CustomConfigurationFlags, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData) override; + + STDMETHODIMP OnDistributionStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnDistributionStopping( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + + STDMETHODIMP OnDistributionRegistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + + STDMETHODIMP OnDistributionUnregistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + + // + // WSLC plugin hooks. UserToken and SidData are marshaled the same way as + // the existing WSL hooks. The string DisplayName plus ApplicationPid make + // up the rest of the WSLCSessionInformation struct. + // + + STDMETHODIMP OnWslcSessionCreated( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnWslcSessionStopping( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData) override; + + STDMETHODIMP OnWslcContainerStarted( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR InspectJson, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnWslcContainerStopping( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR ContainerId) override; + + STDMETHODIMP OnWslcImageCreated( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR InspectJson) override; + + STDMETHODIMP OnWslcImageDeleted( + _In_ DWORD SessionId, + _In_ LPCWSTR DisplayName, + _In_ DWORD ApplicationPid, + _In_opt_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_opt_(SidSize) BYTE* SidData, + _In_ LPCSTR ImageId) override; + +private: + // Build a WSLSessionInformation struct from the marshaled parameters. + // The returned struct and its SID allocation are valid for the lifetime of the wil::unique_handle. + struct SessionContext + { + WSLSessionInformation info{}; + std::vector sidBuffer; + }; + + SessionContext BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData); + + // WSLC counterpart — builds a WSLCSessionInformation. DisplayName is held + // by the caller (the IDL marshaled buffer lives for the call), so the info + // struct's LPCWSTR DisplayName points at it directly. UserToken/SidData are + // optional in the WSLC hooks (NotifySessionStopping is invoked without an + // application token in dtor paths). + struct WslcSessionContext + { + WSLCSessionInformation info{}; + std::vector sidBuffer; + }; + + WslcSessionContext BuildWslcSessionContext(DWORD SessionId, LPCWSTR DisplayName, DWORD ApplicationPid, HANDLE UserToken, DWORD SidSize, BYTE* SidData); + + // Local stubs for the WSLPluginAPIV1 function pointers. + // These forward calls to the service via m_callback. + static HRESULT CALLBACK LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name); + static HRESULT CALLBACK LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); + static HRESULT CALLBACK LocalPluginError(LPCWSTR UserMessage); + static HRESULT CALLBACK LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); + + // WSLC API stubs. WSLCProcessHandle is a heap-allocated wrapper around the + // DWORD process cookie owned by the service-side PluginHostCallbackImpl. + static HRESULT CALLBACK LocalWslcMountFolder(WSLCSessionId Session, LPCWSTR WindowsPath, LPCSTR Mountpoint, BOOL ReadOnly); + static HRESULT CALLBACK LocalWslcUnmountFolder(WSLCSessionId Session, LPCSTR Mountpoint); + static HRESULT CALLBACK LocalWslcCreateProcess( + WSLCSessionId Session, LPCSTR Executable, LPCSTR* Arguments, LPCSTR* Env, WSLCProcessHandle* Process, int* Errno); + static HRESULT CALLBACK LocalWslcProcessGetFd(WSLCProcessHandle Process, WSLCProcessFd Fd, HANDLE* Handle); + static HRESULT CALLBACK LocalWslcProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent); + static HRESULT CALLBACK LocalWslcProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode); + static void CALLBACK LocalWslcReleaseProcess(WSLCProcessHandle Process); + + wil::unique_hmodule m_module; + WSLPluginHooksV1 m_hooks{}; + Microsoft::WRL::ComPtr m_callback; + + // Serializes hook dispatch so m_pluginErrorMessage and g_hookThreadId + // are not raced when multiple sessions call hooks concurrently (MTA). + std::mutex m_hookLock; + + // Error message captured by LocalPluginError during hook execution + std::optional m_pluginErrorMessage; +}; + +// Process-wide pointer to the single PluginHost instance. Safe because +// REGCLS_SINGLEUSE guarantees one PluginHost per wslpluginhost.exe process. +// This allows plugin DLLs to call API functions from any thread, not just +// the thread dispatching the current hook. Atomic so concurrent stub calls +// from plugin worker threads observe a coherent value during ctor/dtor. +extern std::atomic g_pluginHost; + +// Number of PluginHost instances ever constructed in this process. Used by +// main() to detect orphan hosts that COM-activated wslpluginhost.exe but +// never followed through with a successful CreateInstance. +extern std::atomic g_activationCount; + +} // namespace wsl::windows::pluginhost diff --git a/src/windows/wslpluginhost/exe/main.cpp b/src/windows/wslpluginhost/exe/main.cpp new file mode 100644 index 0000000000..bb2a708452 --- /dev/null +++ b/src/windows/wslpluginhost/exe/main.cpp @@ -0,0 +1,139 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.cpp + +Abstract: + + This file contains the entry point for wslpluginhost.exe. + This process acts as a COM local server that loads a single WSL plugin DLL + in an isolated process, preventing a buggy or malicious plugin from crashing + the main WSL service. + + The host is activated through COM local-server activation. It registers its + COM class factory, serves activation requests, and remains alive until all + COM server-process references are released, at which point it exits. + +--*/ + +#include "precomp.h" +#include "PluginHost.h" +#include "WslPluginHost.h" + +using namespace Microsoft::WRL; + +static wil::unique_event g_exitEvent(wil::EventOptions::ManualReset); + +void AddComRef() +{ + CoAddRefServerProcess(); +} + +void ReleaseComRef() +{ + if (CoReleaseServerProcess() == 0) + { + g_exitEvent.SetEvent(); + } +} + +class PluginHostFactory : public RuntimeClass, IClassFactory> +{ +public: + STDMETHODIMP CreateInstance(_In_opt_ IUnknown* pUnkOuter, _In_ REFIID riid, _Outptr_ void** ppCreated) override + try + { + RETURN_HR_IF_NULL(E_POINTER, ppCreated); + *ppCreated = nullptr; + RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr); + + auto host = Make(); + RETURN_IF_NULL_ALLOC(host); + + // The PluginHost ctor/dtor pair manages the process keep-alive ref; + // no manual AddComRef/ReleaseComRef needed here. If CopyTo fails, the + // local ComPtr destructor releases the only reference, which destroys + // the PluginHost and decrements the keep-alive count. + RETURN_IF_FAILED(host.CopyTo(riid, ppCreated)); + return S_OK; + } + CATCH_RETURN(); + + STDMETHODIMP LockServer(BOOL lock) noexcept override + { + if (lock) + { + AddComRef(); + } + else + { + ReleaseComRef(); + } + return S_OK; + } +}; + +int WINAPI wWinMain(_In_ HINSTANCE, _In_opt_ HINSTANCE, _In_ LPWSTR, _In_ int) +try +{ + wsl::windows::common::wslutil::ConfigureCrt(); + wsl::windows::common::wslutil::InitializeWil(); + + // Initialize logging. + WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild); + auto cleanupTracing = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [] { WslTraceLoggingUninitialize(); }); + + // Harden the process before loading any third-party plugin code. Match the + // mitigation set applied by the other WSL COM server processes + // (wslservice.exe / wslcsession.exe). + wsl::windows::common::security::ApplyProcessMitigationPolicies(); + + auto coInit = wil::CoInitializeEx(COINIT_MULTITHREADED); + wsl::windows::common::wslutil::CoInitializeSecurity(); + + // Initialize Winsock — plugins receive sockets from ExecuteBinary and need + // Winsock to be initialized for recv/send/closesocket to work. + WSADATA wsaData{}; + THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &wsaData)); + auto cleanupWinsock = wil::scope_exit([] { WSACleanup(); }); + + // Register the class factory so the service can CoCreateInstance on us. + DWORD cookie = 0; + auto factory = Make(); + THROW_IF_NULL_ALLOC(factory); + + THROW_IF_FAILED(::CoRegisterClassObject(CLSID_WslPluginHost, factory.Get(), CLSCTX_LOCAL_SERVER, REGCLS_SINGLEUSE, &cookie)); + + auto revokeOnExit = wil::scope_exit([&]() { ::CoRevokeClassObject(cookie); }); + + // Bounded shutdown for orphaned hosts: if COM activates wslpluginhost.exe + // but no client ever successfully creates an instance (e.g., the service + // crashes between launch and CreateInstance, or activation is abandoned + // by COM), exit instead of blocking on g_exitEvent forever. Once at least + // one PluginHost has been constructed, AddComRef/ReleaseComRef govern + // shutdown and we wait indefinitely for that ref count to drop to 0. + constexpr DWORD c_startupTimeoutMs = 60'000; + const DWORD waitResult = ::WaitForSingleObject(g_exitEvent.get(), c_startupTimeoutMs); + if (waitResult == WAIT_TIMEOUT && wsl::windows::pluginhost::g_activationCount.load(std::memory_order_acquire) == 0) + { + LOG_HR_MSG( + HRESULT_FROM_WIN32(ERROR_TIMEOUT), + "wslpluginhost.exe startup timeout: no client activated the host within %u ms; exiting", + c_startupTimeoutMs); + return 1; + } + + // Either the event was signaled, or a host was activated. Wait for the + // exit signal (which fires when CoReleaseServerProcess returns 0). + g_exitEvent.wait(); + + return 0; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION(); + return 1; +} diff --git a/src/windows/wslpluginhost/exe/main.rc b/src/windows/wslpluginhost/exe/main.rc new file mode 100644 index 0000000000..84d6b8b8c6 --- /dev/null +++ b/src/windows/wslpluginhost/exe/main.rc @@ -0,0 +1,25 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.rc + +Abstract: + + This file contains resources for wslpluginhost. + +--*/ + +#include +#include "resource.h" +#include "wslversioninfo.h" + +#define VER_INTERNALNAME_STR "wslpluginhost.exe" +#define VER_ORIGINALFILENAME_STR "wslpluginhost.exe" + +#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux" +ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\..\Images\wsl.ico" + +#include diff --git a/src/windows/wslpluginhost/exe/resource.h b/src/windows/wslpluginhost/exe/resource.h new file mode 100644 index 0000000000..355437a443 --- /dev/null +++ b/src/windows/wslpluginhost/exe/resource.h @@ -0,0 +1,15 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + resource.h + +Abstract: + + This file contains resource declarations for wslpluginhost.exe + +--*/ + +#define ID_ICON 1 diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index f5aa3b1a0e..ebc941782e 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -831,6 +831,7 @@ void CreateWerReports() L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", + L"wslpluginhost.exe", L"wslservice.exe", L"wslg.exe", L"vmcompute.exe", @@ -1329,6 +1330,48 @@ void StopWslService() StopService(service.get()); } +DWORD GetWslServicePid() +{ + const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)}; + VERIFY_IS_NOT_NULL(manager); + + const wil::unique_schandle service{OpenService(manager.get(), L"wslservice", SERVICE_QUERY_STATUS | SERVICE_START)}; + VERIFY_IS_NOT_NULL(service); + + auto [state, pid] = GetServiceState(service.get()); + if (state == SERVICE_STOPPED) + { + // The service is on-demand; start it so we can capture a stable PID. + if (!StartService(service.get(), 0, nullptr)) + { + const auto error = GetLastError(); + VERIFY_IS_TRUE(error == ERROR_SERVICE_ALREADY_RUNNING); + } + } + + if (state != SERVICE_RUNNING) + { + WaitForServiceState(service.get(), SERVICE_RUNNING, 0); + std::tie(state, pid) = GetServiceState(service.get()); + } + + VERIFY_ARE_EQUAL(static_cast(SERVICE_RUNNING), state); + return pid; +} + +DWORD GetWslServiceRunningPid() +{ + const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)}; + VERIFY_IS_NOT_NULL(manager); + + const wil::unique_schandle service{OpenService(manager.get(), L"wslservice", SERVICE_QUERY_STATUS)}; + VERIFY_IS_NOT_NULL(service); + + auto [state, pid] = GetServiceState(service.get()); + VERIFY_ARE_EQUAL(static_cast(SERVICE_RUNNING), state); + return pid; +} + wil::unique_handle GetNonElevatedToken(TOKEN_TYPE Type) { auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS); diff --git a/test/windows/Common.h b/test/windows/Common.h index ad2706ebf2..611900d95c 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -120,6 +120,7 @@ using namespace std::chrono_literals; TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"WslServiceProxyStub.dll") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslhost.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslrelay.exe") \ + TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslpluginhost.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslconfig.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wsl.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslg.exe") \ @@ -501,6 +502,13 @@ std::vector LxssSplitString(_In_ const std::wstring& string, _In_ void RestartWslService(); +DWORD GetWslServicePid(); + +// Returns the PID of wslservice, asserting it is already RUNNING. Unlike +// GetWslServicePid it never starts the service, so a crashed/stopped service +// fails at capture instead of being silently restarted with a new PID. +DWORD GetWslServiceRunningPid(); + wil::unique_handle GetNonElevatedToken(TOKEN_TYPE Type = TokenPrimary); std::wstring LxssWriteWslConfig(const std::wstring& Content); diff --git a/test/windows/InstallerTests.cpp b/test/windows/InstallerTests.cpp index 91d9c699c1..13f77be5c4 100644 --- a/test/windows/InstallerTests.cpp +++ b/test/windows/InstallerTests.cpp @@ -836,7 +836,7 @@ class InstallerTests return flags; }; - const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe"}; + const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe"}; for (const auto& e : executables) { auto fullPath = installPath.value() + e; diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp index 61b28a74e2..2cd5f03850 100644 --- a/test/windows/PluginTests.cpp +++ b/test/windows/PluginTests.cpp @@ -34,6 +34,17 @@ class PluginTests std::wstring pluginDll; std::optional config; + // Returns true if the file does not exist, cannot be stat'd, or is zero + // bytes. Uses the non-throwing std::filesystem overloads so racy file + // deletion / access denials between the existence check and size query + // can't surface as exceptions out of test code. + static bool LogFileAbsentOrEmpty(const std::filesystem::path& path) + { + std::error_code ec; + const auto size = std::filesystem::file_size(path, ec); + return static_cast(ec) || size == 0; + } + WSL_TEST_CLASS(PluginTests) TEST_CLASS_SETUP(TestClassSetup) @@ -340,11 +351,13 @@ class PluginTests WSL1_TEST_METHOD(SuccessWSL1) { - constexpr auto ExpectedOutput = LR"(Plugin loaded. TestMode=1)"; - + // Plugins are not loaded for WSL1-only sessions (no VM, no plugin hooks). + // Verify the plugin log file is absent/empty to assert no plugin code ran. ConfigurePlugin(PluginTestType::Success); StartWsl(0); - ValidateLogFile(ExpectedOutput); + + VERIFY_IS_TRUE( + LogFileAbsentOrEmpty(logFile), std::format(L"Expected plugin log file '{}' to be absent or empty for WSL1", logFile).c_str()); } WSL2_TEST_METHOD(LoadFailureFatalWSL2) @@ -363,13 +376,14 @@ class PluginTests WSL1_TEST_METHOD(LoadFailureNonFatalWSL1) { - constexpr auto ExpectedOutput = - LR"(Plugin loaded. TestMode=2 - OnLoad: E_UNEXPECTED)"; - + // Plugins are not loaded for WSL1-only sessions, so a plugin that + // would fail to load on WSL2 has no effect on WSL1. Assert the plugin + // log file is absent/empty to confirm no plugin code ran. ConfigurePlugin(PluginTestType::FailToLoad); StartWsl(0); - ValidateLogFile(ExpectedOutput); + + VERIFY_IS_TRUE( + LogFileAbsentOrEmpty(logFile), std::format(L"Expected plugin log file '{}' to be absent or empty for WSL1", logFile).c_str()); } WSL2_TEST_METHOD(VmStartFailure) @@ -598,6 +612,7 @@ class PluginTests StartWsl(0); ValidateLogFile(ExpectedOutput); } + static wil::com_ptr OpenWslcSessionManager() { wil::com_ptr sessionManager; @@ -744,7 +759,8 @@ class PluginTests constexpr auto ExpectedOutput = LR"(Plugin loaded. TestMode=19 WSLC Session created, name=plugin-wslc-rejected, id=*, pid=*, token=set, sid=set - OnWslcSessionCreated: ERROR_ACCESS_DENIED)"; + OnWslcSessionCreated: ERROR_ACCESS_DENIED + WSLC Session stopping, name=plugin-wslc-rejected, id=*)"; ValidateLogFile(ExpectedOutput); } @@ -776,6 +792,150 @@ class PluginTests ValidateLogFile(ExpectedOutput); } + // --- PR #40120 (out-of-process plugin host) coverage --- + // + // These tests validate the new isolation and locking behavior: + // * HostCrashIsolation — host process crash is non-fatal (IsHostCrash). + // * ConcurrentCallbacks — concurrent shared_lock readers on m_callbackLock. + // * AsyncApiCallFromWorker — cross-apartment plugin API call from a non-hook thread. + // * CallbacksDuringTerminationDoNotCrash — exclusive m_callbackLock drains in-flight + // callbacks before m_utilityVm.reset(). + + WSL2_TEST_METHOD(HostCrashIsolation) + { + ConfigurePlugin(PluginTestType::HostCrash); + + // wslservice is on-demand; the first StartWsl below starts it. Capture + // its PID afterward, run more commands, and confirm the PID hasn't + // changed (i.e. the plugin host crash did not take the service down). + StartWsl(0); + const DWORD pidBefore = GetWslServiceRunningPid(); + VERIFY_IS_TRUE(pidBefore != 0); + + // Shut down the VM. _VmTerminate will call OnVmStopping against the + // dead host; IsHostCrash treats RPC_E_DISCONNECTED-style errors as + // non-fatal (logged + skipped) so this must not fail. + WslShutdown(); + + // Service must accept new work after the host crash. + StartWsl(0); + + const DWORD pidAfter = GetWslServiceRunningPid(); + VERIFY_ARE_EQUAL(pidBefore, pidAfter); + + // Plugin host is loaded once via std::call_once and is NOT re-activated + // after a crash. Verify by counting "Plugin loaded" occurrences: should + // be exactly one across both wsl invocations. + StopWslService(); + + std::wifstream file(logFile); + const auto fileContent = std::wstring{std::istreambuf_iterator(file), {}}; + LogInfo("Logfile: %ls", fileContent.c_str()); + + auto countOccurrences = [&](const std::wstring& needle) { + size_t count = 0; + size_t pos = 0; + while ((pos = fileContent.find(needle, pos)) != std::wstring::npos) + { + ++count; + pos += needle.size(); + } + return count; + }; + + VERIFY_ARE_EQUAL(static_cast(1), countOccurrences(L"Plugin loaded. TestMode=")); + VERIFY_ARE_EQUAL(static_cast(1), countOccurrences(L"Crashing host")); + } + + WSL2_TEST_METHOD(ConcurrentCallbacks) + { + // The hook spawns 4 threads behind two gates: the first ensures all + // workers are spawned, the second rendezvouses them at the callback + // boundary so maxConcurrent deterministically reaches 4 (proving the + // plugin issues 4 callbacks concurrently). Per-thread logging is + // intentionally suppressed; only the deterministic summary line is + // asserted. Lifecycle pre/post lines validate the hook itself ran + // to completion. + constexpr auto ExpectedOutput = + LR"(Plugin loaded. TestMode=23 + VM created (settings->CustomConfigurationFlags=0) + Concurrent callbacks complete: success=4 failures=0 maxConcurrent=4 + Distribution started, name=test_distro, package=, PidNs=*, InitPid=*, Flavor=debian, Version=13 + Distribution Stopping, name=test_distro, package=, PidNs=*, Flavor=debian, Version=13 + VM Stopping)"; + + ConfigurePlugin(PluginTestType::ConcurrentApiCalls); + StartWsl(0); + ValidateLogFile(ExpectedOutput); + } + + WSL2_TEST_METHOD(AsyncApiCallFromWorker) + { + // The worker thread is created in OnDistroStarted, sleeps briefly to + // ensure it runs after the hook has returned, then calls + // ExecuteBinaryInDistribution. It's joined unconditionally in + // OnDistroStopping (which defers its own "Distribution Stopping" log + // until after the join), so the worker-output line is guaranteed to + // precede "Distribution Stopping" in the log. + // + // The wsl command sleeps for 1s so the distro is alive long enough + // for the post-hook worker call to land before shutdown. + constexpr auto ExpectedOutput = + LR"(Plugin loaded. TestMode=24 + VM created (settings->CustomConfigurationFlags=0) + Distribution started, name=test_distro, package=, PidNs=*, InitPid=*, Flavor=debian, Version=13 + Async worker output: hello-from-worker + Distribution Stopping, name=test_distro, package=, PidNs=*, Flavor=debian, Version=13 + VM Stopping)"; + + ConfigurePlugin(PluginTestType::AsyncApiCall); + + auto [output, error] = LxsstuLaunchWslAndCaptureOutput(L"sh -c \"sleep 1; echo -n OK\"", 0); + VERIFY_ARE_EQUAL(output, L"OK"); + + ValidateLogFile(ExpectedOutput); + } + + WSL2_TEST_METHOD(CallbacksDuringTerminationDoNotCrash) + { + // Drain test: 4 workers loop ExecuteBinaryInDistribution (with /bin/true, + // sub-ms callback) while the distro is alive. They keep calling across + // OnDistroStopping and _VmTerminate; the exclusive m_callbackLock acquire + // in _VmTerminate must drain in-flight callbacks before resetting + // m_utilityVm. After OnVmStopping signals wind-down, workers run a bounded + // number of further iterations and exit (see Plugin.cpp), so termination + // is deterministic and no worker can revive against a later VM. + // + // The post-shutdown StartWsl below triggers a second OnDistroStarted that + // joins the finished workers, so this test needs no fixed sleep. + // + // Scope: + // - Validates: dual-lock invariant under racing callbacks; drain works + // when callbacks complete in sub-ms; service survives the race. + // - Does NOT validate: drain semantics when a callback is genuinely + // stuck (e.g. service-side CreateLinuxProcess waiting on a hung + // Linux init). That requires cancellation plumbing through + // WslCoreInstance::CreateLinuxProcess and is tracked separately. + ConfigurePlugin(PluginTestType::CallbackDuringTermination); + + // Use a 1s sleep so workers ramp up while the distro is alive. + auto [output, error] = LxsstuLaunchWslAndCaptureOutput(L"sh -c \"sleep 1; echo -n OK\"", 0); + VERIFY_ARE_EQUAL(output, L"OK"); + + const DWORD pidBefore = GetWslServiceRunningPid(); + VERIFY_IS_TRUE(pidBefore != 0); + + // Trigger VM termination with workers still running. + WslShutdown(); + + // Subsequent WSL command must still succeed — service survived. Its + // OnDistroStarted also joins the wound-down workers. + StartWsl(0); + + const DWORD pidAfter = GetWslServiceRunningPid(); + VERIFY_ARE_EQUAL(pidBefore, pidAfter); + } + // This test must run last so it doesn't break test cases that depends on plugin signature. WSL2_TEST_METHOD(InvalidPluginSignature) { diff --git a/test/windows/PluginTests.h b/test/windows/PluginTests.h index c9a779031c..9fb68a00ab 100644 --- a/test/windows/PluginTests.h +++ b/test/windows/PluginTests.h @@ -41,7 +41,11 @@ enum class PluginTestType WslcSuccess, WslcSessionRejected, WslcContainerRejected, - WslcImagePull + WslcImagePull, + HostCrash, + ConcurrentApiCalls, + AsyncApiCall, + CallbackDuringTermination }; constexpr auto c_testType = L"TestType"; diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp index 17087e6edc..1e4b77df68 100644 --- a/test/windows/testplugin/Plugin.cpp +++ b/test/windows/testplugin/Plugin.cpp @@ -16,6 +16,9 @@ Module Name: #include "WslPluginApi.h" #include "wslc_schema.h" +#include +#include + #include "PluginTests.h" using namespace wsl::windows::common::registry; @@ -31,6 +34,47 @@ PluginTestType g_testType = PluginTestType::Invalid; std::optional g_previousInitPid; +// Serializes writes to g_logfile from multiple threads in modes that spawn +// worker threads (ConcurrentApiCalls, AsyncApiCall, CallbackDuringTermination). +// Hook-thread writes that don't overlap with worker writes don't need to take +// this — but it's harmless to do so. +std::mutex g_logMutex; + +void LogLine(const std::string& line) +{ + std::lock_guard guard{g_logMutex}; + g_logfile << line << std::endl; +} + +// State for AsyncApiCall: worker thread launched in OnDistroStarted, joined +// in OnDistroStopping. The promise carries the result so the hook can log it +// after the join. The future is retrieved exactly once (in OnDistroStarted) +// and consumed in OnDistroStopping — std::promise::get_future() can only be +// called once per promise instance. +std::optional g_asyncWorker; +std::optional> g_asyncWorkerResult; +std::future g_asyncWorkerFuture; +std::string g_asyncWorkerOutput; + +// State for CallbackDuringTermination. Workers loop ExecuteBinaryInDistribution +// while the distro is alive. OnVmStopping (which fires before _VmTerminate's +// exclusive-lock drain) sets g_drainWindDown; workers then keep racing the drain +// for a bounded number of iterations and exit. Exiting on a fixed count (rather +// than on a post-reset failure) keeps shutdown deterministic and prevents a +// worker from "reviving" against a VM that a later StartWsl creates with the +// same session/distro IDs. +// +// g_drainWorkersStarted prevents the post-shutdown StartWsl (which the test uses +// to verify the service survived) from spawning a fresh batch; that second +// OnDistroStarted instead joins the finished workers. Threads are kept joinable +// in g_drainWorkers so the test needs no fixed sleep. +constexpr int c_drainWindDownIterations = 100; +std::atomic g_drainSuccess{0}; +std::atomic g_drainFailures{0}; +std::atomic g_drainWindDown{false}; +std::atomic g_drainWorkersStarted{false}; +std::vector g_drainWorkers; + std::vector ReadFromSocket(SOCKET socket) { // Simplified error handling for the sake of the demo. @@ -123,7 +167,7 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet } result = g_api->ExecuteBinary(0xcafe, arguments[0], arguments.data(), &socket); - if (result != RPC_E_DISCONNECTED) + if (result != HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) { g_logfile << "Unexpected error for ExecuteBinary(): " << result << std::endl; return E_ABORT; @@ -182,12 +226,133 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet return S_OK; } + else if (g_testType == PluginTestType::HostCrash) + { + // Validate plugin host crash isolation. Forcefully exit the host + // process so the COM RPC returns one of the HRESULTs in IsHostCrash + // (RPC_E_DISCONNECTED / RPC_E_SERVER_DIED / ...). The service should + // treat this as non-fatal and continue. + LogLine("Crashing host"); + g_logfile.flush(); + TerminateProcess(GetCurrentProcess(), 1); + // Unreachable. + return E_UNEXPECTED; + } + else if (g_testType == PluginTestType::ConcurrentApiCalls) + { + // Validate concurrent service-side callbacks under the new + // m_callbackLock (shared_mutex). N threads call MountFolder + + // ExecuteBinary in parallel via a start-gate so the shared_lock has + // multiple readers in flight at once. + // + // maxConcurrent records how many workers are simultaneously at the + // callback boundary (via a second rendezvous). Reaching N proves the + // plugin issues N callbacks concurrently. It does NOT prove the service + // executes them in parallel — a black-box plugin can't observe whether + // m_callbackLock is shared or exclusive, since either way the RPCs + // simply appear in flight. This is the strongest honest assertion here. + constexpr int N = 4; + + std::filesystem::path modulePath = wil::GetModuleFileNameW(wil::GetModuleInstanceHandle()).get(); + const auto mountSource = modulePath.parent_path().wstring(); + + std::mutex gateMutex; + std::condition_variable gateCv; + int arrived = 0; + bool released = false; + int inFlight = 0; + int maxConcurrent = 0; + + std::atomic successes{0}; + std::atomic failures{0}; + + const auto worker = [&](int index) { + // Gate 1: wait for all workers to be spawned so they overlap. + { + std::unique_lock lock{gateMutex}; + ++arrived; + if (arrived == N) + { + released = true; + gateCv.notify_all(); + } + else + { + gateCv.wait(lock, [&]() { return released; }); + } + } + + // Gate 2: rendezvous right before the callbacks so all N are at the + // boundary at once, making maxConcurrent == N deterministic. + { + std::unique_lock lock{gateMutex}; + ++inFlight; + maxConcurrent = std::max(maxConcurrent, inFlight); + if (inFlight == N) + { + gateCv.notify_all(); + } + else + { + gateCv.wait(lock, [&]() { return inFlight == N; }); + } + } + + const auto linuxPath = L"/test-plugin/concurrent-" + std::to_wstring(index); + const auto mountName = L"test-plugin-concurrent-" + std::to_wstring(index); + HRESULT hr = g_api->MountFolder(Session->SessionId, mountSource.c_str(), linuxPath.c_str(), true, mountName.c_str()); + if (FAILED(hr)) + { + ++failures; + return; + } + + wil::unique_socket socket; + std::vector args = {"/bin/true", nullptr}; + hr = g_api->ExecuteBinary(Session->SessionId, args[0], args.data(), &socket); + if (FAILED(hr)) + { + ++failures; + return; + } + + ++successes; + }; + + std::vector threads; + threads.reserve(N); + for (int i = 0; i < N; ++i) + { + threads.emplace_back(worker, i); + } + for (auto& t : threads) + { + t.join(); + } + + LogLine( + "Concurrent callbacks complete: success=" + std::to_string(successes.load()) + + " failures=" + std::to_string(failures.load()) + " maxConcurrent=" + std::to_string(maxConcurrent)); + + if (failures.load() != 0) + { + return E_FAIL; + } + } return S_OK; } HRESULT OnVmStopping(const WSLSessionInformation* Session) { + if (g_testType == PluginTestType::CallbackDuringTermination) + { + // Signal drain workers to begin a bounded wind-down. Fires before + // _VmTerminate's exclusive m_callbackLock acquire, so workers keep + // racing the drain for a fixed number of iterations before exiting. + g_drainWindDown = true; + } + g_logfile << "VM Stopping" << std::endl; if (g_testType == PluginTestType::FailToStopVm) @@ -283,16 +448,168 @@ HRESULT OnDistroStarted(const WSLSessionInformation* Session, const WSLDistribut g_logfile << "Invalid distro launch returned: " << g_api->ExecuteBinaryInDistribution(Session->SessionId, &guid, arguments[0], arguments.data(), &socket) << std::endl; } + else if (g_testType == PluginTestType::AsyncApiCall) + { + // Validate plugin API calls from a worker thread that outlives the + // hook. The worker thread is joined in OnDistroStopping — joining is + // unconditional (no timeout) because letting the worker outlive + // g_pluginHost (cleared in ~PluginHost) would dereference freed memory. + g_asyncWorkerOutput.clear(); + g_asyncWorkerResult.emplace(); + g_asyncWorkerFuture = g_asyncWorkerResult->get_future(); + + const DWORD sessionId = Session->SessionId; + const GUID distroId = Distribution->Id; + + g_asyncWorker.emplace([sessionId, distroId]() { + // Sleep briefly so the call is guaranteed to happen after the + // hook has returned — exercises the cross-apartment callback + // path from a non-hook thread that hasn't called CoInitializeEx. + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + wil::unique_socket socket; + std::vector args = {"/bin/echo", "hello-from-worker", nullptr}; + const HRESULT hr = g_api->ExecuteBinaryInDistribution(sessionId, &distroId, args[0], args.data(), &socket); + + if (SUCCEEDED(hr)) + { + const auto output = ReadFromSocket(socket.get()); + std::string captured(output.begin(), output.end()); + // Strip trailing newline added by /bin/echo so the log line + // doesn't get split when ValidateLogFile splits on '\n'. + while (!captured.empty() && (captured.back() == '\n' || captured.back() == '\r')) + { + captured.pop_back(); + } + std::lock_guard guard{g_logMutex}; + g_asyncWorkerOutput = std::move(captured); + } + + g_asyncWorkerResult->set_value(hr); + }); + } + else if (g_testType == PluginTestType::CallbackDuringTermination) + { + // Validate that the new exclusive m_callbackLock acquire in + // _VmTerminate drains in-flight callbacks before m_utilityVm.reset(). + // Workers keep calling into the service across OnDistroStopping / + // _VmTerminate, then wind down deterministically (see globals above). + // + // Scope: this test exercises only the *happy-path* drain — the + // callback (/bin/true) returns in sub-millisecond, so workers are + // almost always between iterations when the exclusive lock is + // acquired. It is *not* a regression test for the hung-callback + // case, where a service-side callback is stuck inside CreateLinuxProcess + // waiting on a non-responsive Linux init; that scenario requires + // termination-event plumbing through WslCoreInstance::CreateLinuxProcess + // and is tracked separately. + constexpr int N = 4; + + // Spawn at most once. The post-shutdown StartWsl in + // CallbacksDuringTerminationDoNotCrash triggers another + // OnDistroStarted; join the (already wound-down) workers there instead + // of starting a fresh batch. + if (g_drainWorkersStarted.exchange(true)) + { + for (auto& t : g_drainWorkers) + { + if (t.joinable()) + { + t.join(); + } + } + g_drainWorkers.clear(); + return S_OK; + } + + g_drainSuccess = 0; + g_drainFailures = 0; + g_drainWindDown = false; + + const DWORD sessionId = Session->SessionId; + const GUID distroId = Distribution->Id; + + for (int i = 0; i < N; ++i) + { + g_drainWorkers.emplace_back([sessionId, distroId]() { + int windDownRemaining = -1; + while (true) + { + wil::unique_socket socket; + std::vector args = {"/bin/true", nullptr}; + const HRESULT hr = g_api->ExecuteBinaryInDistribution(sessionId, &distroId, args[0], args.data(), &socket); + if (SUCCEEDED(hr)) + { + ++g_drainSuccess; + } + else + { + ++g_drainFailures; + } + + if (g_drainWindDown.load()) + { + if (windDownRemaining < 0) + { + windDownRemaining = c_drainWindDownIterations; + } + if (windDownRemaining-- == 0) + { + return; + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + } return S_OK; } HRESULT OnDistroStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) { - g_logfile << "Distribution Stopping, name=" << wsl::shared::string::WideToMultiByte(Distribution->Name) - << ", package=" << wsl::shared::string::WideToMultiByte(Distribution->PackageFamilyName) - << ", PidNs=" << Distribution->PidNamespace << ", Flavor=" << wsl::shared::string::WideToMultiByte(Distribution->Flavor) - << ", Version=" << wsl::shared::string::WideToMultiByte(Distribution->Version) << std::endl; + // For AsyncApiCall we defer the "Distribution Stopping" line until after + // the worker thread has been joined, so the worker's "Async worker output" + // line is guaranteed to appear before it in the log. + auto logDistroStopping = [&]() { + g_logfile << "Distribution Stopping, name=" << wsl::shared::string::WideToMultiByte(Distribution->Name) + << ", package=" << wsl::shared::string::WideToMultiByte(Distribution->PackageFamilyName) + << ", PidNs=" << Distribution->PidNamespace << ", Flavor=" << wsl::shared::string::WideToMultiByte(Distribution->Flavor) + << ", Version=" << wsl::shared::string::WideToMultiByte(Distribution->Version) << std::endl; + }; + + if (g_testType == PluginTestType::AsyncApiCall) + { + if (g_asyncWorker.has_value()) + { + // Unconditional join — letting the worker outlive g_pluginHost + // (cleared in ~PluginHost) would dereference freed memory. + g_asyncWorker->join(); + g_asyncWorker.reset(); + + HRESULT workerHr = S_OK; + if (g_asyncWorkerFuture.valid()) + { + workerHr = g_asyncWorkerFuture.get(); + g_asyncWorkerResult.reset(); + } + + if (SUCCEEDED(workerHr)) + { + LogLine("Async worker output: " + g_asyncWorkerOutput); + } + else + { + LogLine("Async worker failed: " + std::to_string(workerHr)); + } + } + + logDistroStopping(); + return S_OK; + } + + logDistroStopping(); if (g_testType == PluginTestType::FailToStopDistro) { @@ -550,7 +867,7 @@ EXTERN_C __declspec(dllexport) HRESULT WSLPLUGINAPI_ENTRYPOINTV1(const WSLPlugin THROW_HR_IF(E_UNEXPECTED, !g_logfile); g_testType = static_cast(ReadDword(key.get(), nullptr, c_testType, static_cast(PluginTestType::Invalid))); - THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::WslcImagePull)); + THROW_HR_IF(E_INVALIDARG, static_cast(g_testType) <= 0 || static_cast(g_testType) > static_cast(PluginTestType::CallbackDuringTermination)); g_logfile << "Plugin loaded. TestMode=" << static_cast(g_testType) << std::endl; g_api = Api; From 13b4be865e3c5a8c84067054c77d96a685c69572 Mon Sep 17 00:00:00 2001 From: Ben Hillis Date: Mon, 8 Jun 2026 13:50:21 -0700 Subject: [PATCH 2/3] Address PR feedback: out-of-process plugin host - Pass the kill-on-close job object into IWslPluginHost::Initialize and drop GetProcessHandle so the host assigns itself to the job before running any plugin code. - Return an IWSLCProcess reference from WslcCreateProcess instead of opaque process cookies; remove the cookie bookkeeping methods. - Make plugin host crashes fatal: a crash during a start/veto hook (or at load) blocks the WSL operation with a recorded, plugin-named error, matching pre-refactor behavior. Latch the first crash so subsequent operations fail fast with one consistent error instead of repeatedly driving a dead host (m_pluginError guarded by m_pluginErrorLock). Teardown hooks latch but stay non-fatal. Re-activation/non-fatal resilience is a follow-up. - Drop the sidData empty-guards so WSL and WSLC hooks pass the serialized SID consistently. - Rework the HostCrash test to assert the new fatal behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/windows/service/exe/PluginManager.cpp | 338 +++++------------- src/windows/service/exe/PluginManager.h | 57 ++- .../service/exe/WSLCSessionManager.cpp | 11 +- src/windows/service/exe/WSLCSessionManager.h | 16 +- src/windows/service/inc/WslPluginHost.idl | 39 +- src/windows/wslpluginhost/exe/PluginHost.cpp | 115 +++--- src/windows/wslpluginhost/exe/PluginHost.h | 7 +- test/windows/PluginTests.cpp | 60 ++-- test/windows/testplugin/Plugin.cpp | 9 +- 9 files changed, 254 insertions(+), 398 deletions(-) diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp index 0948ae443b..82a1aa8a2f 100644 --- a/src/windows/service/exe/PluginManager.cpp +++ b/src/windows/service/exe/PluginManager.cpp @@ -28,11 +28,13 @@ using wsl::windows::service::PluginHostCallbackImpl; using wsl::windows::service::PluginManager; // Acquire an apartment-local IWslPluginHost proxy for `plugin` (named `host`). -// On a host-process crash, log it and `continue` the surrounding loop. On any -// other failure to acquire (which would indicate a fundamental COM problem, -// not a plugin-reported issue), log the HRESULT and `continue` so a single -// busted plugin does not break the iteration for the others. Use only inside -// the per-plugin loops in PluginManager hook methods. +// On a host-process crash, latch a fatal plugin error and `continue` the +// surrounding loop: teardown/notification hooks must not block WSL, but the +// latch makes the next start operation fail fast instead of repeatedly driving +// a dead host. On any other failure to acquire (which would indicate a +// fundamental COM problem, not a plugin-reported issue), log the HRESULT and +// `continue` so a single busted plugin does not break the iteration for the +// others. Use only inside the per-plugin loops in PluginManager hook methods. #define ACQUIRE_PLUGIN_HOST_OR_CONTINUE(plugin, host, stage) \ Microsoft::WRL::ComPtr host; \ { \ @@ -41,7 +43,7 @@ using wsl::windows::service::PluginManager; { \ if (IsHostCrash(_acqHr)) \ { \ - LogPluginHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ + LatchHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ } \ else \ { \ @@ -53,9 +55,11 @@ using wsl::windows::service::PluginManager; // Same as ACQUIRE_PLUGIN_HOST_OR_CONTINUE, but for hook methods that surface a // plugin error to abort the guarded operation (e.g. OnVmStarted). A host crash -// is still isolated (logged + `continue`), but any other acquisition failure is -// thrown so it is surfaced exactly like a plugin-reported error would be, rather -// than silently allowing the operation to proceed without consulting the plugin. +// is fatal: it is latched and thrown as a fatal plugin error, matching the +// pre-refactor behavior where an in-process plugin crash brought down WSL. Any +// other acquisition failure is likewise thrown so it is surfaced exactly like a +// plugin-reported error would be, rather than silently allowing the operation to +// proceed without consulting the plugin. #define ACQUIRE_PLUGIN_HOST_OR_THROW(plugin, host, stage) \ Microsoft::WRL::ComPtr host; \ { \ @@ -64,8 +68,7 @@ using wsl::windows::service::PluginManager; { \ if (IsHostCrash(_acqHr)) \ { \ - LogPluginHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ - continue; \ + ThrowHostCrash((plugin), _acqHr, stage "/AcquireHostProxy"); \ } \ THROW_HR_MSG(_acqHr, "Failed to acquire plugin host proxy for: '%ls'", (plugin).name.c_str()); \ } \ @@ -363,20 +366,13 @@ PluginManager::ScopedComInit PluginManager::EnsureInitialized() if (FAILED(loadResult)) { - // Treat host-process crashes and benign COM activation races (server is - // shutting down or its exec failed) as non-fatal — the plugin is simply - // unavailable for this session. All other failures, including registration - // errors (REGDB_E_CLASSNOTREG), access denials, and plugin-reported errors - // from Initialize, are treated as fatal plugin load failures so the user - // gets a clear error rather than a silently-disabled plugin. - if (IsHostCrash(loadResult) || loadResult == CO_E_SERVER_EXEC_FAILURE || loadResult == CO_E_SERVER_STOPPING) - { - LOG_HR_MSG(loadResult, "Plugin host activation failed for: '%ls', skipping", e.name.c_str()); - } - else - { - m_pluginError.emplace(PluginError{e.name, loadResult}); - } + // Any load failure is fatal: the plugin is recorded so that + // subsequent operations block with a clear error rather than a + // silently-disabled plugin. This includes host-process crashes + // and benign-looking COM activation failures (the server is + // shutting down or its exec failed) — matching the pre-refactor + // behavior where a plugin that failed to load blocked WSL. + m_pluginError.emplace(PluginError{e.name, loadResult}); } } }); @@ -403,24 +399,19 @@ void PluginManager::LoadPlugin(OutOfProcPlugin& plugin) TraceLoggingValue(activationHr, "CoCreateInstanceResult")); THROW_IF_FAILED_MSG(activationHr, "Failed to create plugin host for: '%ls'", plugin.path.c_str()); - // Join the host to our job object before Initialize runs plugin code, so any - // child processes the plugin spawns inherit the job and are killed when the - // service exits. A failure here is non-fatal: the host is still reaped via - // CoReleaseServerProcess on clean shutdown, just not on a service crash. + // Create the job object before initializing the host so we can hand it to + // Initialize. The host assigns itself to the job before running any plugin + // code, so any child processes the plugin spawns inherit the job and are + // killed when the service exits. Job assignment is fatal in the host: a host + // that isn't in the job would escape the kill-on-close guarantee, so a + // failure surfaces here (alongside activation and plugin entry-point errors) + // and the host process exits when its proxy is released on unwind. + // system_handle(sh_job) marshaling duplicates the job handle into the host + // for the duration of the Initialize call. EnsureJobObjectCreated(); - wil::unique_handle process; - const HRESULT getProcessHr = host->GetProcessHandle(&process); - LOG_IF_FAILED_MSG(getProcessHr, "Failed to get plugin host process handle for: '%ls'", plugin.path.c_str()); - if (SUCCEEDED(getProcessHr)) - { - LOG_IF_WIN32_BOOL_FALSE_MSG( - AssignProcessToJobObject(m_jobObject.get(), process.get()), - "Failed to assign plugin host to job object for: '%ls'", - plugin.path.c_str()); - } THROW_IF_FAILED_MSG( - host->Initialize(plugin.callback.Get(), plugin.path.c_str(), plugin.name.c_str()), + host->Initialize(plugin.callback.Get(), m_jobObject.get(), plugin.path.c_str(), plugin.name.c_str()), "Plugin host failed to initialize: '%ls'", plugin.path.c_str()); @@ -514,8 +505,7 @@ void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLV if (IsHostCrash(hr)) { - LogPluginHostCrash(e, hr, "OnVmStarted"); - continue; + ThrowHostCrash(e, hr, "OnVmStarted"); } ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); @@ -543,7 +533,7 @@ void PluginManager::OnVmStopping(const WSLSessionInformation* Session) if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnVmStopping"); + LatchHostCrash(e, result, "OnVmStopping"); continue; } @@ -590,8 +580,7 @@ void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, if (IsHostCrash(hr)) { - LogPluginHostCrash(e, hr, "OnDistributionStarted"); - continue; + ThrowHostCrash(e, hr, "OnDistributionStarted"); } ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); @@ -634,7 +623,7 @@ void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnDistributionStopping"); + LatchHostCrash(e, result, "OnDistributionStopping"); continue; } @@ -676,7 +665,7 @@ void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Sessio if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnDistributionRegistered"); + LatchHostCrash(e, result, "OnDistributionRegistered"); continue; } @@ -718,7 +707,7 @@ void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Sess if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnDistributionUnregistered"); + LatchHostCrash(e, result, "OnDistributionUnregistered"); continue; } @@ -728,14 +717,6 @@ void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Sess void PluginManager::ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId Session, LPCWSTR Plugin) { - // If the host process crashed, don't propagate as a fatal plugin error — - // log it and let the caller decide. The plugin is already dead. - if (IsHostCrash(Result)) - { - LOG_HR_MSG(Result, "Plugin host process crashed for plugin: '%ls'", Plugin); - return; - } - if (FAILED(Result)) { if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0') @@ -774,22 +755,17 @@ bool PluginManager::IsHostCrash(HRESULT hr) } } -void PluginManager::LogPluginHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage) +void PluginManager::LatchHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage) { LOG_HR_MSG(result, "Plugin host crashed at %hs for: '%ls'", stage, plugin.name.c_str()); // Fire telemetry only on first observation per plugin: a dead plugin will // hit this path on every subsequent VM/distro lifecycle event, and we - // don't want to flood the telemetry channel with duplicates. + // don't want to flood the telemetry channel with duplicates. Any WSLC + // processes the dead host created are released automatically by COM when + // the host process exits, so there is nothing to drain here. if (!plugin.crashTelemetryFired.exchange(true)) { - // Release any WSLC processes the dead host created so they don't leak - // until service shutdown. - if (plugin.callback) - { - plugin.callback->DrainProcesses(); - } - WSL_LOG_TELEMETRY( "PluginHostCrash", PDT_ProductAndServiceUsage, @@ -798,6 +774,27 @@ void PluginManager::LogPluginHostCrash(OutOfProcPlugin& plugin, HRESULT result, TraceLoggingValue(result, "Result"), TraceLoggingValue(stage, "Stage")); } + + // Latch a fatal plugin error so subsequent operations fail fast with a + // single consistent message instead of repeatedly driving a dead host that + // is never re-activated for this service lifetime. The first crash wins; a + // load-time failure already recorded in m_pluginError is left untouched. + auto lock = m_pluginErrorLock.lock_exclusive(); + if (!m_pluginError.has_value()) + { + m_pluginError.emplace(PluginError{plugin.name, result}); + } +} + +void PluginManager::ThrowHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage) +{ + // Record the crash (telemetry + fatal latch), then throw a fatal plugin + // error so the guarded start/veto operation (VM/distro/session/container + // creation) is aborted. The HRESULT is whichever RPC/CO_E_* code COM + // surfaced for the dead host; it is reported the same way a plugin-returned + // fatal error would be. + LatchHostCrash(plugin, result, stage); + THROW_HR_WITH_USER_ERROR(result, wsl::shared::Localization::MessageFatalPluginError(plugin.name.c_str())); } void PluginManager::ThrowIfFatalPluginError() @@ -805,18 +802,26 @@ void PluginManager::ThrowIfFatalPluginError() ExecutionContext context(Context::Plugin); auto coInit = EnsureInitialized(); - if (!m_pluginError.has_value()) + // m_pluginError can be set at load time (single-threaded, under m_initOnce) + // or latched at runtime from a hook on any RPC thread when a host crash is + // observed, so read it under the lock and act on a local copy. + std::optional error; + { + auto lock = m_pluginErrorLock.lock_shared(); + error = m_pluginError; + } + + if (!error.has_value()) { return; } - else if (m_pluginError->error == WSL_E_PLUGIN_REQUIRES_UPDATE) + else if (error->error == WSL_E_PLUGIN_REQUIRES_UPDATE) { - THROW_HR_WITH_USER_ERROR( - WSL_E_PLUGIN_REQUIRES_UPDATE, wsl::shared::Localization::MessagePluginRequiresUpdate(m_pluginError->plugin)); + THROW_HR_WITH_USER_ERROR(WSL_E_PLUGIN_REQUIRES_UPDATE, wsl::shared::Localization::MessagePluginRequiresUpdate(error->plugin)); } else { - THROW_HR_WITH_USER_ERROR(m_pluginError->error, wsl::shared::Localization::MessageFatalPluginError(m_pluginError->plugin)); + THROW_HR_WITH_USER_ERROR(error->error, wsl::shared::Localization::MessageFatalPluginError(error->plugin)); } } @@ -881,13 +886,12 @@ void PluginManager::OnWslcSessionCreated(const WSLCSessionInformation* Session) Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data(), + sidData.data(), &errorMessage); if (IsHostCrash(hr)) { - LogPluginHostCrash(e, hr, "OnWslcSessionCreated"); - continue; + ThrowHostCrash(e, hr, "OnWslcSessionCreated"); } ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); @@ -920,11 +924,11 @@ void PluginManager::OnWslcSessionStopping(const WSLCSessionInformation* Session) Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data()); + sidData.data()); if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnWslcSessionStopping"); + LatchHostCrash(e, result, "OnWslcSessionStopping"); continue; } @@ -961,14 +965,13 @@ try Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data(), + sidData.data(), InspectJson, &errorMessage); if (IsHostCrash(hr)) { - LogPluginHostCrash(e, hr, "OnWslcContainerStarted"); - continue; + ThrowHostCrash(e, hr, "OnWslcContainerStarted"); } ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); @@ -1004,12 +1007,12 @@ void PluginManager::OnWslcContainerStopping(const WSLCSessionInformation* Sessio Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data(), + sidData.data(), ContainerId); if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnWslcContainerStopping"); + LatchHostCrash(e, result, "OnWslcContainerStopping"); continue; } @@ -1043,12 +1046,12 @@ void PluginManager::OnWslcImageCreated(const WSLCSessionInformation* Session, LP Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data(), + sidData.data(), InspectJson); if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnWslcImageCreated"); + LatchHostCrash(e, result, "OnWslcImageCreated"); continue; } @@ -1083,12 +1086,12 @@ void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LP Session->ApplicationPid, Session->UserToken, static_cast(sidData.size()), - sidData.empty() ? nullptr : sidData.data(), + sidData.data(), ImageId); if (IsHostCrash(result)) { - LogPluginHostCrash(e, result, "OnWslcImageDeleted"); + LatchHostCrash(e, result, "OnWslcImageDeleted"); continue; } @@ -1098,59 +1101,6 @@ void PluginManager::OnWslcImageDeleted(const WSLCSessionInformation* Session, LP // --- IWslPluginHostCallback WSLC implementations (service-side) --- -DWORD PluginHostCallbackImpl::InsertProcessLocked(wil::com_ptr process) -{ - std::lock_guard lock(m_processLock); - - // Reserve a cookie that's neither 0 nor already in use. Wraparound is fine. - THROW_HR_IF(E_OUTOFMEMORY, m_processes.size() >= std::numeric_limits::max() - 1); - - DWORD cookie = m_nextCookie; - while (cookie == 0 || m_processes.find(cookie) != m_processes.end()) - { - ++cookie; - } - m_nextCookie = cookie + 1; - m_processes.emplace(cookie, std::move(process)); - return cookie; -} - -wil::com_ptr PluginHostCallbackImpl::FindProcess(DWORD cookie) const -{ - std::lock_guard lock(m_processLock); - - auto it = m_processes.find(cookie); - return (it == m_processes.end()) ? nullptr : it->second; -} - -wil::com_ptr PluginHostCallbackImpl::RemoveProcess(DWORD cookie) -{ - std::lock_guard lock(m_processLock); - - auto it = m_processes.find(cookie); - if (it == m_processes.end()) - { - return nullptr; - } - auto process = std::move(it->second); - m_processes.erase(it); - return process; -} - -void PluginHostCallbackImpl::DrainProcesses() noexcept -try -{ - std::unordered_map> processes; - { - std::lock_guard lock(m_processLock); - processes.swap(m_processes); - } - - // Release outside the lock: a process Release() may run teardown that - // re-enters this callback. -} -CATCH_LOG() - STDMETHODIMP PluginHostCallbackImpl::WslcMountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCSTR Mountpoint, _In_ BOOL ReadOnly) try { @@ -1198,12 +1148,12 @@ STDMETHODIMP PluginHostCallbackImpl::WslcCreateProcess( _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _In_ DWORD EnvCount, _In_reads_opt_(EnvCount) LPCSTR* Environment, - _Out_ DWORD* ProcessCookie, + _COM_Outptr_ IWSLCProcess** Process, _Out_ int* Errno) try { - RETURN_HR_IF(E_POINTER, Executable == nullptr || ProcessCookie == nullptr || Errno == nullptr); - *ProcessCookie = 0; + RETURN_HR_IF(E_POINTER, Executable == nullptr || Process == nullptr || Errno == nullptr); + *Process = nullptr; *Errno = 0; RETURN_HR_IF(E_INVALIDARG, (ArgumentCount > 0 && Arguments == nullptr) || (EnvCount > 0 && Environment == nullptr)); @@ -1250,117 +1200,17 @@ try return result; } - *ProcessCookie = InsertProcessLocked(std::move(process)); + // Hand the IWSLCProcess back to the host. COM marshals it across the host + // boundary; the host then calls it directly (GetStdHandle/GetExitEvent/ + // GetState) and owns its lifetime, so the service keeps no per-process state. + *Process = process.detach(); WSL_LOG( "WslcPluginCreateProcessCall", TraceLoggingValue(SessionId, "SessionId"), TraceLoggingValue(Executable, "Executable"), - TraceLoggingValue(*ProcessCookie, "ProcessCookie"), TraceLoggingValue(S_OK, "Result")); return S_OK; } CATCH_RETURN(); - -STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetFd(_In_ DWORD ProcessCookie, _In_ DWORD Fd, _Out_ HANDLE* Handle) -try -{ - RETURN_HR_IF(E_POINTER, Handle == nullptr); - *Handle = nullptr; - - auto process = FindProcess(ProcessCookie); - RETURN_HR_IF(E_INVALIDARG, !process); - - WSLCFD wslcFd{}; - switch (static_cast(Fd)) - { - case WSLCProcessFdStdin: - wslcFd = WSLCFDStdin; - break; - case WSLCProcessFdStdout: - wslcFd = WSLCFDStdout; - break; - case WSLCProcessFdStderr: - wslcFd = WSLCFDStderr; - break; - default: - WSL_LOG( - "WslcPluginProcessGetFd", TraceLoggingValue(static_cast(Fd), "Fd"), TraceLoggingValue(E_INVALIDARG, "Result")); - return E_INVALIDARG; - } - - WSLCHandle handle{}; - auto result = process->GetStdHandle(wslcFd, &handle); - - WSL_LOG( - "WslcPluginProcessGetFd", - TraceLoggingValue(static_cast(Fd), "Fd"), - TraceLoggingValue(handle.Handle.Socket, "Handle"), - TraceLoggingValue(result, "Result")); - - RETURN_IF_FAILED(result); - WI_ASSERT(handle.Type == WSLCHandleTypeSocket); - - // Pass through as HANDLE; COM's system_handle(sh_socket) marshaling will duplicate - // it into the host process which then surfaces it to the plugin. - *Handle = handle.Handle.Socket; - return S_OK; -} -CATCH_RETURN(); - -STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetExitEvent(_In_ DWORD ProcessCookie, _Out_ HANDLE* ExitEvent) -try -{ - RETURN_HR_IF(E_POINTER, ExitEvent == nullptr); - *ExitEvent = nullptr; - - auto process = FindProcess(ProcessCookie); - RETURN_HR_IF(E_INVALIDARG, !process); - - auto result = process->GetExitEvent(ExitEvent); - - WSL_LOG("WslcPluginProcessGetExitEvent", TraceLoggingValue(*ExitEvent, "ExitEvent"), TraceLoggingValue(result, "Result")); - - return result; -} -CATCH_RETURN(); - -STDMETHODIMP PluginHostCallbackImpl::WslcProcessGetExitCode(_In_ DWORD ProcessCookie, _Out_ int* ExitCode) -try -{ - RETURN_HR_IF(E_POINTER, ExitCode == nullptr); - *ExitCode = -1; - - auto process = FindProcess(ProcessCookie); - RETURN_HR_IF(E_INVALIDARG, !process); - - WSLCProcessState state{}; - auto result = process->GetState(&state, ExitCode); - - if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled) - { - result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE); - } - - WSL_LOG( - "WslcPluginProcessGetExitCode", - TraceLoggingValue(*ExitCode, "ExitCode"), - TraceLoggingValue(static_cast(state), "State"), - TraceLoggingValue(result, "Result")); - - return result; -} -CATCH_RETURN(); - -STDMETHODIMP PluginHostCallbackImpl::WslcReleaseProcess(_In_ DWORD ProcessCookie) -try -{ - auto process = RemoveProcess(ProcessCookie); - WSL_LOG( - "WslcPluginReleaseProcess", - TraceLoggingValue(ProcessCookie, "ProcessCookie"), - TraceLoggingValue(process != nullptr, "Found")); - return S_OK; -} -CATCH_RETURN(); diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h index beddd32f4e..17c1b23ba5 100644 --- a/src/windows/service/exe/PluginManager.h +++ b/src/windows/service/exe/PluginManager.h @@ -35,10 +35,10 @@ class PluginManager; // // IWslPluginHostCallback implementation — lives in the service process and // handles API calls coming from the plugin host (MountFolder, ExecuteBinary, -// WSLC* APIs etc.). One instance is created per plugin host so that the -// per-plugin WSLC process map (cookie -> IWSLCProcess) is isolated: a plugin -// cannot guess another plugin's cookie, and the map drains automatically when -// the plugin host process goes away. +// WSLC* APIs etc.). WslcCreateProcess returns a marshaled IWSLCProcess directly +// to the host, which calls it (GetStdHandle/GetExitEvent/GetState) without any +// service-side process bookkeeping; the remote process is released by COM when +// the host releases it or the host process exits. // class PluginHostCallbackImpl : public Microsoft::WRL::RuntimeClass, IWslPluginHostCallback> @@ -78,36 +78,10 @@ class PluginHostCallbackImpl _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _In_ DWORD EnvCount, _In_reads_opt_(EnvCount) LPCSTR* Environment, - _Out_ DWORD* ProcessCookie, + _COM_Outptr_ IWSLCProcess** Process, _Out_ int* Errno) override; - STDMETHODIMP WslcProcessGetFd(_In_ DWORD ProcessCookie, _In_ DWORD Fd, _Out_ HANDLE* Handle) override; - - STDMETHODIMP WslcProcessGetExitEvent(_In_ DWORD ProcessCookie, _Out_ HANDLE* ExitEvent) override; - - STDMETHODIMP WslcProcessGetExitCode(_In_ DWORD ProcessCookie, _Out_ int* ExitCode) override; - - STDMETHODIMP WslcReleaseProcess(_In_ DWORD ProcessCookie) override; - - // Release all outstanding process mappings. Called when the plugin host - // crashes so the WSLC processes it created aren't stranded until shutdown. - void DrainProcesses() noexcept; - private: - // Allocate a new cookie -> process mapping. Loops past 0 and past collisions. - // Throws on exhaustion. - DWORD InsertProcessLocked(wil::com_ptr process); - - // Resolve a cookie to its process under m_processLock; returns nullptr if unknown. - wil::com_ptr FindProcess(DWORD cookie) const; - - // Remove a cookie mapping; returns the removed process (may be null). - wil::com_ptr RemoveProcess(DWORD cookie); - - mutable std::mutex m_processLock; - std::unordered_map> m_processes; - DWORD m_nextCookie{1}; - // The PluginManager that owns this callback. Used to resolve a WSLC // SessionId to a live IWSLCSession via the manager's registered session // reference map (see PluginManager::ResolveWslcSession). @@ -258,13 +232,26 @@ class PluginManager static std::vector SerializeSid(PSID Sid); static bool IsHostCrash(HRESULT hr); - // Logs a host crash to ETL and fires the PluginHostCrash telemetry event - // at most once per plugin per service lifetime, so a single bad plugin - // does not flood telemetry across every subsequent VM/distro lifecycle. - static void LogPluginHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage); + // Records an observed host-process crash: logs it to ETL, fires the + // PluginHostCrash telemetry event at most once per plugin per service + // lifetime (so a single bad plugin does not flood telemetry across every + // subsequent VM/distro lifecycle), and latches a fatal plugin error so + // later operations fail fast. The host is not re-activated for the rest of + // the service lifetime. + void LatchHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage); + + // LatchHostCrash, then throws the latched error as a fatal plugin error to + // abort a start/veto operation. Use from start hooks (OnVmStarted, etc.); + // teardown hooks latch but cannot block, so they call LatchHostCrash + skip. + [[noreturn]] void ThrowHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage); std::once_flag m_initOnce; std::vector m_plugins; + + // Guards m_pluginError. The error is written once at load time under + // m_initOnce and may later be latched from a hook on any RPC thread when a + // host crash is observed, so all accesses outside the initial load take it. + wil::srwlock m_pluginErrorLock; std::optional m_pluginError; // Global Interface Table used to make IWslPluginHost proxies callable from diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp index 582a1dc030..e1a8531036 100644 --- a/src/windows/service/exe/WSLCSessionManager.cpp +++ b/src/windows/service/exe/WSLCSessionManager.cpp @@ -422,7 +422,16 @@ void WSLCSessionManagerImpl::CreateSession( } m_sessions.push_back(SessionEntry{ - std::move(createdRef), createdSessionId, createdPid, resolvedDisplayName, std::move(tokenInfo), createdNotifier, false, createdToken, std::move(createdSid), std::move(createdJob)}); + std::move(createdRef), + createdSessionId, + createdPid, + resolvedDisplayName, + std::move(tokenInfo), + createdNotifier, + false, + createdToken, + std::move(createdSid), + std::move(createdJob)}); if (persistent) { diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h index a09940bd23..ab8ae387cf 100644 --- a/src/windows/service/exe/WSLCSessionManager.h +++ b/src/windows/service/exe/WSLCSessionManager.h @@ -118,7 +118,7 @@ class WSLCSessionManagerImpl std::is_nothrow_invocable_v&>, "ForEachSession routine must be noexcept to preserve container invariants during remove_if"); - using TResult = std::conditional_t, nullptr_t, std::optional>; + using TResult = std::conditional_t, std::nullptr_t, std::optional>; TResult result{}; auto each = [&](SessionEntry& entry) { @@ -133,16 +133,20 @@ class WSLCSessionManagerImpl // reference. The StoppingNotified flag is intentionally left // untouched here; DispatchSessionStopping flips it so the // notification fires exactly once. + // Reserve the slot up front so push_back cannot reallocate (and + // therefore cannot throw) after `entry` is moved-from. Combined + // with SessionEntry's noexcept move, this guarantees `entry` is + // only consumed once the destination slot is known to exist. + DeadSessions.reserve(DeadSessions.size() + 1); try { DeadSessions.push_back(std::move(entry)); } catch (...) { - // Couldn't queue it for deferred stopping dispatch: keep it - // tracked (the move above is noexcept, so entry is intact) - // and reap it on a later pass rather than dropping it - // without unregistering. + // Defensive: if queuing for the deferred stopping dispatch + // still fails, keep the session tracked and reap it on a + // later pass rather than dropping it without unregistering. LOG_CAUGHT_EXCEPTION(); return false; } @@ -191,7 +195,7 @@ class WSLCSessionManagerImpl { std::vector deadSessions; - using TResult = std::conditional_t, nullptr_t, std::optional>; + using TResult = std::conditional_t, std::nullptr_t, std::optional>; TResult result{}; { diff --git a/src/windows/service/inc/WslPluginHost.idl b/src/windows/service/inc/WslPluginHost.idl index 73d51b101f..81abaf5df1 100644 --- a/src/windows/service/inc/WslPluginHost.idl +++ b/src/windows/service/inc/WslPluginHost.idl @@ -17,6 +17,7 @@ Abstract: import "unknwn.idl"; import "wtypes.idl"; +import "wslc.idl"; cpp_quote("const GUID CLSID_WslPluginHost = {0x7a1d2c3e, 0x4b5f, 0x6a7d, {0x8e, 0x9f, 0x0a, 0x1b, 0x2c, 0x3d, 0x4e, 0x5f}};") cpp_quote("#ifdef __cplusplus") @@ -58,8 +59,12 @@ interface IWslPluginHostCallback : IUnknown [out, system_handle(sh_socket)] HANDLE* Socket); // - // WSLC plugin API. WSLCProcessHandle is mapped to a service-allocated - // DWORD cookie that the host wraps in an opaque void* for the plugin. + // WSLC plugin API. WslcCreateProcess returns a marshaled IWSLCProcess that + // the host wraps in an opaque WSLCProcessHandle for the plugin and calls + // directly (GetStdHandle/GetExitEvent/GetState), so there is no service-side + // process bookkeeping: lifetime is owned by the host via COM reference + // counting and the remote process is released when the host releases it (or + // when the host process exits). // HRESULT WslcMountFolder( @@ -79,24 +84,8 @@ interface IWslPluginHostCallback : IUnknown [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, [in] DWORD EnvCount, [in, unique, size_is(EnvCount), string] LPCSTR* Environment, - [out] DWORD* ProcessCookie, + [out] IWSLCProcess** Process, [out] int* Errno); - - HRESULT WslcProcessGetFd( - [in] DWORD ProcessCookie, - [in] DWORD Fd, - [out, system_handle(sh_socket)] HANDLE* Handle); - - HRESULT WslcProcessGetExitEvent( - [in] DWORD ProcessCookie, - [out, system_handle(sh_event)] HANDLE* ExitEvent); - - HRESULT WslcProcessGetExitCode( - [in] DWORD ProcessCookie, - [out] int* ExitCode); - - HRESULT WslcReleaseProcess( - [in] DWORD ProcessCookie); }; // @@ -114,21 +103,17 @@ interface IWslPluginHost : IUnknown // // Initialize the plugin host: load the plugin DLL and call its entry point. // The Callback interface is used by the plugin to call back into the service. + // JobObject is the service's job object; the host assigns itself to it before + // running any plugin code so that processes spawned by the plugin inherit the + // job and are terminated if the service exits unexpectedly. // HRESULT Initialize( [in] IWslPluginHostCallback* Callback, + [in, system_handle(sh_job)] HANDLE JobObject, [in, string] LPCWSTR PluginPath, [in, string] LPCWSTR PluginName); - // - // Returns a handle to this COM server process. Used by the service to add - // the plugin host to a job object for automatic cleanup on service exit. - // - - HRESULT GetProcessHandle( - [out, system_handle(sh_process)] HANDLE* ProcessHandle); - // // Lifecycle hook dispatchers - mirror WSLPluginHooksV1. // UserToken is duplicated into the host process by the service before calling. diff --git a/src/windows/wslpluginhost/exe/PluginHost.cpp b/src/windows/wslpluginhost/exe/PluginHost.cpp index 19bcc62e18..b451ea090c 100644 --- a/src/windows/wslpluginhost/exe/PluginHost.cpp +++ b/src/windows/wslpluginhost/exe/PluginHost.cpp @@ -39,9 +39,10 @@ namespace { // cross-process m_callback proxy can't be used and calls would fail with // CO_E_NOTINITIALIZED. This RAII helper joins the calling thread to the MTA // for the duration of the API call. RPC_E_CHANGED_MODE means the thread is -// already STA-initialized; we leave its apartment unchanged. The supported -// path is an uninitialized worker thread joining the MTA where m_callback was -// marshaled — calling the proxy from a caller-created STA is not supported. +// already STA-initialized; we leave its apartment unchanged and still proceed, +// relying on COM cross-apartment marshaling to dispatch the call to the MTA +// where m_callback was marshaled. The primary (and tested) path is an +// uninitialized worker thread joining the MTA directly. struct ScopedComInitForCallback { HRESULT initHr; @@ -99,12 +100,24 @@ PluginHost::~PluginHost() // --- IWslPluginHost implementation --- -STDMETHODIMP PluginHost::Initialize(_In_ IWslPluginHostCallback* Callback, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) +STDMETHODIMP PluginHost::Initialize(_In_ IWslPluginHostCallback* Callback, _In_ HANDLE JobObject, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) try { - RETURN_HR_IF(E_INVALIDARG, Callback == nullptr || PluginPath == nullptr || PluginName == nullptr); + RETURN_HR_IF(E_INVALIDARG, Callback == nullptr || JobObject == nullptr || PluginPath == nullptr || PluginName == nullptr); RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, m_module.is_valid()); // Already initialized + // Join the service's job object before loading or running any plugin code, so that + // any child processes the plugin spawns inherit the job and are terminated when the + // service exits (the service sets JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE). The host runs + // as SYSTEM, so it can assign itself. JobObject is a duplicate owned by the marshaler + // for the duration of this call and is freed on return; the assignment persists. + // A failure here is fatal: a host that isn't in the job escapes the kill-on-close + // guarantee and would be orphaned (along with any children it spawns) when the + // service closes the job. Failing Initialize makes the service release this host's + // proxy, which exits the host process before any plugin code runs. + RETURN_IF_WIN32_BOOL_FALSE_MSG( + AssignProcessToJobObject(JobObject, GetCurrentProcess()), "Failed to assign plugin host to job object: '%ls'", PluginName); + m_callback = Callback; // Validate the plugin signature before loading it. @@ -155,21 +168,6 @@ try } CATCH_RETURN(); -STDMETHODIMP PluginHost::GetProcessHandle(_Out_ HANDLE* ProcessHandle) -try -{ - RETURN_HR_IF(E_POINTER, ProcessHandle == nullptr); - *ProcessHandle = nullptr; - - wil::unique_handle process(OpenProcess(PROCESS_SET_QUOTA | PROCESS_TERMINATE, FALSE, GetCurrentProcessId())); - RETURN_LAST_ERROR_IF_NULL(process); - - // COM's system_handle(sh_process) marshaling will duplicate this into the caller's process. - *ProcessHandle = process.release(); - return S_OK; -} -CATCH_RETURN(); - STDMETHODIMP PluginHost::OnVMStarted( _In_ DWORD SessionId, _In_ HANDLE UserToken, @@ -766,11 +764,13 @@ CATCH_RETURN(); namespace { -// Opaque wrapper handed to the plugin as WSLCProcessHandle. The DWORD cookie -// identifies the IWSLCProcess held by the service-side PluginHostCallbackImpl. +// Opaque wrapper handed to the plugin as WSLCProcessHandle. It owns the +// IWSLCProcess COM proxy marshaled from wslcsession (via the service), so the +// host calls the process interface directly and the remote process is released +// when the plugin releases this wrapper (or when the host process exits). struct WslcProcessWrapper { - DWORD cookie; + wil::com_ptr process; }; } // namespace @@ -845,12 +845,12 @@ HRESULT CALLBACK PluginHost::LocalWslcCreateProcess( } // Allocate the wrapper before creating the remote process so a throwing - // allocation can't strand a service-side cookie that only WslcReleaseProcess + // allocation can't strand a remote process that only WslcReleaseProcess // frees. Nothing between the remote create and release() below can throw. auto wrapper = std::make_unique(); - DWORD cookie = 0; - HRESULT hr = host->m_callback->WslcCreateProcess(Session, Executable, argCount, Arguments, envCount, Env, &cookie, &localErrno); + HRESULT hr = + host->m_callback->WslcCreateProcess(Session, Executable, argCount, Arguments, envCount, Env, wrapper->process.put(), &localErrno); if (Errno != nullptr) { *Errno = localErrno; @@ -861,7 +861,6 @@ HRESULT CALLBACK PluginHost::LocalWslcCreateProcess( return hr; } - wrapper->cookie = cookie; *Process = wrapper.release(); return S_OK; } @@ -882,7 +881,33 @@ HRESULT CALLBACK PluginHost::LocalWslcProcessGetFd(WSLCProcessHandle Process, WS RETURN_IF_FAILED(coInit.Result()); auto* wrapper = static_cast(Process); - return host->m_callback->WslcProcessGetFd(wrapper->cookie, static_cast(Fd), Handle); + RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr); + + WSLCFD wslcFd{}; + switch (Fd) + { + case WSLCProcessFdStdin: + wslcFd = WSLCFDStdin; + break; + case WSLCProcessFdStdout: + wslcFd = WSLCFDStdout; + break; + case WSLCProcessFdStderr: + wslcFd = WSLCFDStderr; + break; + default: + return E_INVALIDARG; + } + + WSLCHandle handle{}; + RETURN_IF_FAILED(wrapper->process->GetStdHandle(wslcFd, &handle)); + + WI_ASSERT(handle.Type == WSLCHandleTypeSocket); + + // Pass through as HANDLE; COM's system_handle(sh_socket) marshaling already + // duplicated it into this process. + *Handle = handle.Handle.Socket; + return S_OK; } HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitEvent(WSLCProcessHandle Process, HANDLE* ExitEvent) @@ -901,7 +926,8 @@ HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitEvent(WSLCProcessHandle Proc RETURN_IF_FAILED(coInit.Result()); auto* wrapper = static_cast(Process); - return host->m_callback->WslcProcessGetExitEvent(wrapper->cookie, ExitEvent); + RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr); + return wrapper->process->GetExitEvent(ExitEvent); } HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitCode(WSLCProcessHandle Process, int* ExitCode) @@ -914,34 +940,37 @@ HRESULT CALLBACK PluginHost::LocalWslcProcessGetExitCode(WSLCProcessHandle Proce RETURN_HR_IF(E_INVALIDARG, Process == nullptr); RETURN_HR_IF(E_POINTER, ExitCode == nullptr); + *ExitCode = -1; ScopedComInitForCallback coInit; RETURN_IF_FAILED(coInit.Result()); auto* wrapper = static_cast(Process); - return host->m_callback->WslcProcessGetExitCode(wrapper->cookie, ExitCode); -} + RETURN_HR_IF(E_INVALIDARG, wrapper->process == nullptr); -void CALLBACK PluginHost::LocalWslcReleaseProcess(WSLCProcessHandle Process) -{ - if (Process == nullptr) + WSLCProcessState state{}; + auto result = wrapper->process->GetState(&state, ExitCode); + if (SUCCEEDED(result) && state != WslcProcessStateExited && state != WslcProcessStateSignalled) { - return; + result = HRESULT_FROM_WIN32(ERROR_INVALID_STATE); } - std::unique_ptr wrapper{static_cast(Process)}; + return result; +} - auto* host = g_pluginHost.load(std::memory_order_acquire); - if (host == nullptr || host->m_callback == nullptr) +void CALLBACK PluginHost::LocalWslcReleaseProcess(WSLCProcessHandle Process) +{ + if (Process == nullptr) { return; } + // Initialize COM before taking ownership: destroying the wrapper releases + // the IWSLCProcess proxy, which marshals a Release back to wslcsession and + // needs COM initialized on this thread. coInit is declared first so it + // outlives the wrapper (reverse destruction order). ScopedComInitForCallback coInit; - if (FAILED(coInit.Result())) - { - return; - } + LOG_IF_FAILED(coInit.Result()); - LOG_IF_FAILED(host->m_callback->WslcReleaseProcess(wrapper->cookie)); + std::unique_ptr wrapper{static_cast(Process)}; } \ No newline at end of file diff --git a/src/windows/wslpluginhost/exe/PluginHost.h b/src/windows/wslpluginhost/exe/PluginHost.h index b7381c9d4a..d092917680 100644 --- a/src/windows/wslpluginhost/exe/PluginHost.h +++ b/src/windows/wslpluginhost/exe/PluginHost.h @@ -31,8 +31,7 @@ class PluginHost : public Microsoft::WRL::RuntimeClass(file), {}}; LogInfo("Logfile: %ls", fileContent.c_str()); - auto countOccurrences = [&](const std::wstring& needle) { - size_t count = 0; - size_t pos = 0; - while ((pos = fileContent.find(needle, pos)) != std::wstring::npos) - { - ++count; - pos += needle.size(); - } - return count; - }; - - VERIFY_ARE_EQUAL(static_cast(1), countOccurrences(L"Plugin loaded. TestMode=")); - VERIFY_ARE_EQUAL(static_cast(1), countOccurrences(L"Crashing host")); + VERIFY_IS_TRUE( + fileContent.find(L"Crashing host") != std::wstring::npos, + std::format(L"Expected the plugin to reach the crash point, log: '{}'", fileContent).c_str()); } WSL2_TEST_METHOD(ConcurrentCallbacks) diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp index 1e4b77df68..94c20d35fc 100644 --- a/test/windows/testplugin/Plugin.cpp +++ b/test/windows/testplugin/Plugin.cpp @@ -228,10 +228,11 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet } else if (g_testType == PluginTestType::HostCrash) { - // Validate plugin host crash isolation. Forcefully exit the host - // process so the COM RPC returns one of the HRESULTs in IsHostCrash - // (RPC_E_DISCONNECTED / RPC_E_SERVER_DIED / ...). The service should - // treat this as non-fatal and continue. + // Validate plugin host crash handling. Forcefully exit the host process + // so the COM RPC returns one of the HRESULTs in IsHostCrash + // (RPC_E_DISCONNECTED / RPC_E_SERVER_DIED / ...). The service treats a + // crash during a veto hook as a fatal plugin error and aborts the + // operation. LogLine("Crashing host"); g_logfile.flush(); TerminateProcess(GetCurrentProcess(), 1); From cdb7b1897881bf2f37d5cdfa8420f7050d7c6ed3 Mon Sep 17 00:00:00 2001 From: Ben Hillis Date: Wed, 10 Jun 2026 13:59:46 -0700 Subject: [PATCH 3/3] Implement ThreadedPluginAPI (PluginCallPump) alternative for out-of-process plugin host Alternative to the m_callbackLock design in PR #40120, per @OneBlue's review. Instead of guarding plugin callbacks with a new shared_mutex (m_callbackLock) and a session-ref registry, this runs each outbound notification (host->On...) on a worker thread and pumps the plugin's service-side API calls back onto the original notifying thread, which already holds the session's recursive m_instanceLock. This reproduces in-process re-entrancy and lets us delete the synchronization the PR added. - New PluginCallPump primitive (Run/Invoke): the worker makes the COM notification call; the notifying thread pumps queued callback work and runs it under m_instanceLock. - PluginManager: WSL notifications (OnVm*/OnDistribution*) route via RunHostNotification; WSL callbacks (MountFolder/ExecuteBinary[InDistribution]) route via a hybrid InvokeOnWslPump - pump when a hook is in flight, direct (own m_instanceLock) when not, preserving async out-of-hook callbacks. - Revert m_callbackLock in LxssUserSessionImpl; MountRootNamespaceFolder/CreateLinuxProcess go back to the recursive m_instanceLock. Add LxssUserSessionImpl::TryInvokeUnderInstanceLock (timed try_lock_for on m_instanceLock) used by the direct path. - Deadlock fix: a callback racing the pre-notification window (lock held, pump not yet registered) must not block on m_instanceLock - InvokeOnWslPump uses a timed acquire and re-checks for a pump, so the racing callback is routed onto the pump instead of dead-blocking. - Hardening (from multi-model review): PluginCallPump::Invoke reports ran-vs-stopped out-of-band (no RPC_E_DISCONNECTED sentinel overload, avoids double-execution); the pump loop polls workerDone each iteration so callback traffic cannot starve the stop path; Run is exception-safe so std::thread creation failure never escapes into teardown hooks. - WSLC session-ref registry left as-is (documented follow-up): its lock is non-recursive, notifications already fire outside it, and it is entangled with the create-veto protocol. - Update test/plugin comments to describe the pump model (assertions unchanged and still pass). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/windows/service/exe/CMakeLists.txt | 2 + src/windows/service/exe/LxssUserSession.cpp | 85 ++--- src/windows/service/exe/LxssUserSession.h | 56 ++- src/windows/service/exe/PluginCallPump.cpp | 147 ++++++++ src/windows/service/exe/PluginCallPump.h | 85 +++++ src/windows/service/exe/PluginManager.cpp | 361 +++++++++++++------- src/windows/service/exe/PluginManager.h | 42 +++ test/windows/PluginTests.cpp | 32 +- test/windows/testplugin/Plugin.cpp | 43 +-- 9 files changed, 613 insertions(+), 240 deletions(-) create mode 100644 src/windows/service/exe/PluginCallPump.cpp create mode 100644 src/windows/service/exe/PluginCallPump.h diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index 841883579e..d5cdb06cb9 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -7,6 +7,7 @@ set(SOURCES LxssHttpProxy.cpp LxssInstance.cpp PluginManager.cpp + PluginCallPump.cpp ServiceMain.cpp BridgedNetworking.cpp GnsRpcServer.cpp @@ -37,6 +38,7 @@ set(HEADERS LxssIptables.h LxssHttpProxy.h PluginManager.h + PluginCallPump.h LxssInstance.h BridgedNetworking.h GnsRpcServer.h diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp index ac6f5ed3d3..578ebd1ac7 100644 --- a/src/windows/service/exe/LxssUserSession.cpp +++ b/src/windows/service/exe/LxssUserSession.cpp @@ -2635,18 +2635,13 @@ std::shared_ptr LxssUserSessionImpl::_CreateInstance(_In_op registration.Write(Property::OsVersion, distributionInfo->Version); } - // This needs to be done before plugins are notified because they might try to run a command inside the distribution. - { - std::unique_lock callbackLock(m_callbackLock); - m_runningInstances[registration.Id()] = instance; - } + // This needs to be done before plugins are notifed because they might try to run a command inside the distribution. + m_runningInstances[registration.Id()] = instance; if (version == LXSS_WSL_VERSION_2) { - auto cleanupOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { - std::unique_lock callbackLock(m_callbackLock); - m_runningInstances.erase(registration.Id()); - }); + auto cleanupOnFailure = + wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { m_runningInstances.erase(registration.Id()); }); m_pluginManager.OnDistributionStarted(&m_session, instance->DistributionInformation()); cleanupOnFailure.release(); } @@ -2882,13 +2877,7 @@ void LxssUserSessionImpl::_CreateVm() m_vmId.store(vmId); // Create the utility VM and register for callbacks. - // Publish m_utilityVm under m_callbackLock exclusive to honor the dual-lock - // invariant for mutations of m_utilityVm; this is uncontended here because - // no plugin callbacks can race against initial creation. - { - std::unique_lock callbackLock(m_callbackLock); - m_utilityVm = WslCoreVm::Create(m_userToken, std::move(config), vmId); - } + m_utilityVm = WslCoreVm::Create(m_userToken, std::move(config), vmId); if (m_httpProxyStateTracker) { @@ -3619,26 +3608,17 @@ bool LxssUserSessionImpl::_TerminateInstanceInternal(_In_ LPCGUID DistroGuid, _I m_pluginManager.OnDistributionStopping(&m_session, wslcoreInstance->DistributionInformation()); } - m_lifetimeManager.RemoveCallback(clientKey); + instance->second->Stop(); - // Stop the instance and remove it from m_runningInstances atomically - // under m_callbackLock. This prevents plugin callbacks (which hold - // m_callbackLock shared) from finding a stopped-but-still-listed - // instance between Stop() and erase. - ULONG clientId; + const auto clientId = instance->second->GetClientId(); { - std::unique_lock callbackLock(m_callbackLock); - - instance->second->Stop(); - clientId = instance->second->GetClientId(); + auto lock = m_terminatedInstanceLock.lock_exclusive(); + m_terminatedInstances.push_back(std::move(instance->second)); + } - { - auto lock = m_terminatedInstanceLock.lock_exclusive(); - m_terminatedInstances.push_back(std::move(instance->second)); - } + m_lifetimeManager.RemoveCallback(clientKey); - m_runningInstances.erase(instance); - } + m_runningInstances.erase(instance); // If the instance that was terminated was a WSL2 instance, // check if the VM is now idle. @@ -3666,21 +3646,28 @@ void LxssUserSessionImpl::_UpdateInit(_In_ const LXSS_DISTRO_CONFIGURATION& Conf HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name) { - // Shared lock prevents _VmTerminate from destroying the VM while we use it. - // Do NOT acquire m_instanceLock — callbacks arrive on a different COM thread - // from the notification thread that holds m_instanceLock. - std::shared_lock lock(m_callbackLock); + std::lock_guard lock(m_instanceLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); m_utilityVm->MountRootNamespaceFolder(HostPath, GuestPath, ReadOnly, Name); return S_OK; } +bool LxssUserSessionImpl::TryInvokeUnderInstanceLock(std::chrono::milliseconds Timeout, const std::function& Work, _Out_ HRESULT& Result) +{ + std::unique_lock lock(m_instanceLock, std::defer_lock); + if (!lock.try_lock_for(Timeout)) + { + return false; + } + + Result = Work(); + return true; +} + HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* Socket) { - // Shared lock prevents _VmTerminate from destroying the VM or instances - // while we use them. See MountRootNamespaceFolder for rationale. - std::shared_lock lock(m_callbackLock); + std::lock_guard lock(m_instanceLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); if (Distro == nullptr) @@ -3689,16 +3676,9 @@ HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In } else { - // Look up the running instance directly instead of calling _RunningInstance, - // which accesses m_lockedDistributions (guarded only by m_instanceLock). - // m_runningInstances is safe to read under m_callbackLock (shared). - // The _EnsureNotLocked check is unnecessary here: _ConversionBegin removes - // a distribution from m_runningInstances before adding it to m_lockedDistributions, - // so a locked distribution will never be found in this lookup. - const auto instance = m_runningInstances.find(*Distro); - THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, instance == m_runningInstances.end()); - - const auto distro = instance->second; + const auto distro = _RunningInstance(Distro); + THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, !distro); + const auto wsl2Distro = dynamic_cast(distro.get()); THROW_HR_IF(WSL_E_WSL2_NEEDED, !wsl2Distro); @@ -3936,12 +3916,7 @@ void LxssUserSessionImpl::_VmTerminate() m_telemetryThread.join(); } - // Acquire exclusive callback lock to wait for any in-flight plugin callbacks - // (MountRootNamespaceFolder, CreateLinuxProcess) to complete before destroying the VM. - { - std::unique_lock callbackLock(m_callbackLock); - m_utilityVm.reset(); - } + m_utilityVm.reset(); m_vmId.store(GUID_NULL); // Reset the user's token since its lifetime is tied to the VM. diff --git a/src/windows/service/exe/LxssUserSession.h b/src/windows/service/exe/LxssUserSession.h index c938f223fb..fbeac06923 100644 --- a/src/windows/service/exe/LxssUserSession.h +++ b/src/windows/service/exe/LxssUserSession.h @@ -310,10 +310,6 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce7e") LxssUserSession /// class LxssUserSessionImpl { - // Plugin callbacks arrive on a different COM RPC thread and use m_callbackLock - // (shared) instead of m_instanceLock to access m_utilityVm and m_runningInstances. - friend class wsl::windows::service::PluginHostCallbackImpl; - public: LxssUserSessionImpl(_In_ PSID userSid, _In_ DWORD sessionId, _Inout_ wsl::windows::service::PluginManager& pluginManager); virtual ~LxssUserSessionImpl(); @@ -367,6 +363,11 @@ class LxssUserSessionImpl /// void ClearDiskStateInRegistry(_In_opt_ LPCWSTR Disk); + /// + /// Start a process in the root namespace or in a user distribution. + /// + HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); + /// /// Enumerates registered distributions, optionally including ones that are /// currently being registered, unregistered, or converted. @@ -442,6 +443,20 @@ class LxssUserSessionImpl HRESULT MoveDistribution(_In_ LPCGUID DistroGuid, _In_ LPCWSTR Location); + HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); + + /// + /// Attempts to run Work while holding m_instanceLock, acquired with the + /// given timeout. Returns true and sets Result to Work()'s HRESULT if the + /// lock was acquired and Work ran; returns false (without running Work) if + /// the lock could not be acquired within Timeout. Used by the plugin + /// callback pump to run an out-of-hook callback directly without blocking + /// indefinitely on the instance lock — a notification thread may hold it in + /// its pre-notification phase and later wait for this very callback, which + /// would deadlock if we blocked on the lock unconditionally. + /// + bool TryInvokeUnderInstanceLock(std::chrono::milliseconds Timeout, const std::function& Work, _Out_ HRESULT& Result); + /// /// Registers a distribution. /// @@ -530,18 +545,6 @@ class LxssUserSessionImpl static CreateLxProcessContext s_GetCreateProcessContext(_In_ const GUID& DistroGuid, _In_ bool SystemDistro); private: - /// - /// Plugin callback methods — called from PluginHostCallbackImpl on a COM RPC - /// thread during plugin notifications. These acquire m_callbackLock (shared) - /// instead of m_instanceLock, preventing _VmTerminate from destroying the VM - /// while a callback is in-flight. Access is restricted via friend declaration. - /// - _Requires_lock_not_held_(m_instanceLock) - HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); - - _Requires_lock_not_held_(m_instanceLock) - HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); - /// /// Adds a distro to the list of converting distros. /// @@ -803,9 +806,7 @@ class LxssUserSessionImpl std::recursive_timed_mutex m_instanceLock; /// - /// Contains the currently running instances. - /// Reads guarded by m_instanceLock OR m_callbackLock (shared). - /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). + /// Contains the currently running utility VM's. /// _Guarded_by_(m_instanceLock) std::map, wsl::windows::common::helpers::GuidLess> m_runningInstances; @@ -822,23 +823,8 @@ class LxssUserSessionImpl /// /// The running utility vm for WSL2 distributions. - /// Reads guarded by m_instanceLock OR m_callbackLock (shared). - /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). - /// - _Guarded_by_(m_instanceLock) std::unique_ptr m_utilityVm; - - /// - /// Reader-writer lock protecting m_utilityVm and m_runningInstances for - /// plugin callbacks. Callbacks take a shared (read) lock; _VmTerminate and - /// instance mutations take an exclusive (write) lock. /// - /// Mutations of m_runningInstances/m_utilityVm require BOTH m_instanceLock - /// AND m_callbackLock (exclusive). Reads are safe under either lock alone. - /// - /// Lock ordering: m_instanceLock → m_callbackLock (never reverse). - /// Callbacks must NEVER acquire m_instanceLock (deadlock with notification thread). - /// - std::shared_mutex m_callbackLock; + _Guarded_by_(m_instanceLock) std::unique_ptr m_utilityVm; std::atomic m_vmId{GUID_NULL}; diff --git a/src/windows/service/exe/PluginCallPump.cpp b/src/windows/service/exe/PluginCallPump.cpp new file mode 100644 index 0000000000..2bac10e9e9 --- /dev/null +++ b/src/windows/service/exe/PluginCallPump.cpp @@ -0,0 +1,147 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include "precomp.h" +#include "PluginCallPump.h" +#include + +using wsl::windows::service::PluginCallPump; + +PluginCallPump::PluginCallPump() = default; + +void PluginCallPump::DrainQueue() +{ + for (;;) + { + Call* call = nullptr; + { + auto lock = m_lock.lock_exclusive(); + if (m_queue.empty()) + { + break; + } + call = m_queue.front(); + m_queue.pop_front(); + } + + // Run the plugin's service-side work on the pump (notification) thread, + // which still holds the notifying lock — recursive locks re-enter here. + // The closures already translate exceptions to HRESULTs via CATCH_RETURN, + // but guard defensively so a throw can never skip done.SetEvent() and + // strand the waiting RPC thread. + try + { + call->result = call->work ? call->work() : E_UNEXPECTED; + } + catch (...) + { + call->result = wil::ResultFromCaughtException(); + } + + call->done.SetEvent(); + } +} + +HRESULT PluginCallPump::Run(const std::function& Notification) +try +{ + HRESULT notificationResult = E_FAIL; + + // The worker makes the outbound cross-process notification call. It runs on + // its own thread so THIS thread is free to pump the plugin's callbacks while + // the notification is in flight. (Thread creation can throw std::system_error + // under resource exhaustion; the surrounding try/CATCH_RETURN converts that to + // an HRESULT so it never escapes into a non-throwing teardown notification.) + wil::unique_event workerDone(wil::EventOptions::ManualReset); + std::thread worker([&]() { + // Guard defensively: a throw escaping the thread's top-level function + // would call std::terminate() and crash the service, and would skip + // workerDone.SetEvent() — stranding the pumping thread. Translate to an + // HRESULT and always signal completion (mirrors DrainQueue). + try + { + notificationResult = Notification(); + } + catch (...) + { + notificationResult = wil::ResultFromCaughtException(); + } + + workerDone.SetEvent(); + }); + + auto join = wil::scope_exit([&]() { + if (worker.joinable()) + { + worker.join(); + } + }); + + const HANDLE waits[] = {m_callAvailable.get(), workerDone.get()}; + for (;;) + { + const DWORD wait = ::WaitForMultipleObjects(ARRAYSIZE(waits), waits, FALSE, INFINITE); + + // A kernel wait failure must never spin this thread. Stop the pump and + // fail any queued/future calls so their RPC threads aren't stranded. + FAIL_FAST_LAST_ERROR_IF(wait == WAIT_FAILED); + + // Drain regardless of which handle woke us: the auto-reset event + // coalesces multiple enqueues into a single signal, and a final call may + // race in just as the worker completes. + DrainQueue(); + + // Check worker completion independently of the wait result. + // WaitForMultipleObjects reports the LOWEST signaled index, so a steady + // stream of callbacks on m_callAvailable (index 0) would otherwise starve + // the workerDone (index 1) branch and keep this thread pumping (and the + // notifying lock held) forever. workerDone is manual-reset, so this is a + // cheap non-consuming poll. + if (::WaitForSingleObject(workerDone.get(), 0) == WAIT_OBJECT_0) + { + // Worker (notification) finished. Stop accepting further work and + // fail any call that raced in after this point so its RPC thread is + // never stranded, then drain whatever was already queued. + { + auto lock = m_lock.lock_exclusive(); + m_stopped = true; + } + DrainQueue(); + break; + } + } + + return notificationResult; +} +CATCH_RETURN() + +bool PluginCallPump::Invoke(std::function Work, _Out_ HRESULT& Result) +{ + Call call; + call.work = std::move(Work); + + const HRESULT createHr = call.done.create(); + if (FAILED(createHr)) + { + // Could not create the completion event (resource exhaustion). Surface + // the failure as the executed result rather than reporting "not run": + // retrying on the direct path would not help and risks double-execution. + Result = createHr; + return true; + } + + { + auto lock = m_lock.lock_exclusive(); + if (m_stopped) + { + // The notification already returned; there is no pump thread left to + // run this. Report "not run" so the caller executes it directly. + return false; + } + m_queue.push_back(&call); + } + + m_callAvailable.SetEvent(); + call.done.wait(); + Result = call.result; + return true; +} diff --git a/src/windows/service/exe/PluginCallPump.h b/src/windows/service/exe/PluginCallPump.h new file mode 100644 index 0000000000..2128fdf80b --- /dev/null +++ b/src/windows/service/exe/PluginCallPump.h @@ -0,0 +1,85 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace wsl::windows::service { + +// +// PluginCallPump implements the threaded-callback model for out-of-process +// plugin notifications. +// +// Problem: a plugin lifecycle notification (OnVMStarted, OnDistributionStarted, +// ...) is an outbound cross-process COM call. While the service thread is +// blocked inside that call, the plugin may call back into the service +// (MountFolder, ExecuteBinary, ...). That callback arrives on a *different* COM +// RPC thread, so it cannot re-enter the locks held by the notifying thread +// (m_instanceLock etc.) without a second, parallel locking scheme. +// +// Solution (matches the in-process model): make the outbound notification on a +// worker thread and "pump" the plugin's service-side API calls back onto the +// ORIGINAL notifying thread. Because the work then runs on the lock-holding +// thread, recursive locks (std::recursive_timed_mutex m_instanceLock) re-enter +// exactly as they did when plugins were loaded in-process — no second lock +// (m_callbackLock) and no out-of-band session registry are needed. +// +// Notifying thread: pump.Run([&]{ return host->OnVMStarted(...); }); +// RPC callback: return pump.Invoke([&]{ return session->Mount...(); }); +// +// A single pump instance services one outbound notification at a time. The +// pump thread is the single consumer; any number of RPC threads may Invoke(). +// +class PluginCallPump +{ +public: + PluginCallPump(); + ~PluginCallPump() = default; + + PluginCallPump(const PluginCallPump&) = delete; + PluginCallPump& operator=(const PluginCallPump&) = delete; + PluginCallPump(PluginCallPump&&) = delete; + PluginCallPump& operator=(PluginCallPump&&) = delete; + + // Runs `Notification` (the outbound host->On... COM call) on a dedicated + // worker thread and pumps queued Invoke() calls on the CALLING thread until + // the worker completes. The worker performs its own COM initialization, so + // `Notification` should acquire the apartment-local host proxy itself. + // + // Returns the HRESULT returned by `Notification`. Any service-side work + // requested by the plugin runs on the calling (lock-holding) thread, so + // recursive locks behave exactly as in the in-process model. + HRESULT Run(const std::function& Notification); + + // Called from a COM RPC callback thread. Marshals `Work` to the pump thread, + // blocks until it has executed there, and reports its HRESULT via `Result`. + // Returns true if `Work` was executed (`Result` is set); returns false if the + // pump is no longer running (the notification already returned), in which case + // `Work` was NOT run and the caller must run it itself. Reporting "not run" + // out-of-band (rather than via a sentinel HRESULT) lets `Work` legitimately + // return any HRESULT, including RPC_E_DISCONNECTED, without ambiguity. + bool Invoke(std::function Work, _Out_ HRESULT& Result); + +private: + struct Call + { + std::function work; + HRESULT result{E_FAIL}; + wil::unique_event_nothrow done; + }; + + // Runs every currently-queued call on the calling (pump) thread. + void DrainQueue(); + + wil::srwlock m_lock; + _Guarded_by_(m_lock) std::deque m_queue; + _Guarded_by_(m_lock) bool m_stopped { false }; + + // Auto-reset event signaled whenever a call is enqueued. + wil::unique_event m_callAvailable{wil::EventOptions::None}; +}; + +} // namespace wsl::windows::service diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp index 82a1aa8a2f..1b31d19555 100644 --- a/src/windows/service/exe/PluginManager.cpp +++ b/src/windows/service/exe/PluginManager.cpp @@ -89,18 +89,20 @@ try { RETURN_HR_IF(E_INVALIDARG, WindowsPath == nullptr || LinuxPath == nullptr || Name == nullptr); - WSL_LOG( - "PluginCallbackMountFolderBegin", - TraceLoggingValue(WindowsPath, "WindowsPath"), - TraceLoggingValue(SessionId, "SessionId")); - const auto session = FindSessionByCookie(SessionId); - RETURN_HR_IF(c_pluginSessionNotFound, !session); + // Marshal the work onto the notifying thread so MountRootNamespaceFolder runs + // under the session's recursive m_instanceLock (see InvokeOnWslPump). The + // captured pointers reference the caller's stack frame, which stays alive + // because InvokeOnWslPump blocks until the work has run. + return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT { + const auto session = FindSessionByCookie(SessionId); + RETURN_HR_IF(c_pluginSessionNotFound, !session); - auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name); + auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name); - WSL_LOG("PluginCallbackMountFolderEnd", TraceLoggingValue(WindowsPath, "WindowsPath"), TraceLoggingValue(result, "Result")); - - return result; + WSL_LOG( + "PluginCallbackMountFolderEnd", TraceLoggingValue(WindowsPath, "WindowsPath"), TraceLoggingValue(result, "Result")); + return result; + }); } CATCH_RETURN(); @@ -113,36 +115,34 @@ try RETURN_HR_IF(E_INVALIDARG, Path == nullptr); RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr); - WSL_LOG("PluginCallbackExecuteBinaryBegin", TraceLoggingValue(Path, "Path"), TraceLoggingValue(SessionId, "SessionId")); - const auto session = FindSessionByCookie(SessionId); - WSL_LOG( - "PluginCallbackExecuteBinaryFoundSession", - TraceLoggingValue(Path, "Path"), - TraceLoggingValue(session != nullptr, "Found")); - RETURN_HR_IF(c_pluginSessionNotFound, !session); + return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT { + WSL_LOG("PluginCallbackExecuteBinaryBegin", TraceLoggingValue(Path, "Path"), TraceLoggingValue(SessionId, "SessionId")); + const auto session = FindSessionByCookie(SessionId); + RETURN_HR_IF(c_pluginSessionNotFound, !session); - // Build NULL-terminated argument array expected by CreateLinuxProcess. - std::vector args; - if (Arguments != nullptr) - { - args.assign(Arguments, Arguments + ArgumentCount); - } - args.push_back(nullptr); + // Build NULL-terminated argument array expected by CreateLinuxProcess. + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); - WSL_LOG("PluginCallbackExecuteBinaryCallingCreateProcess", TraceLoggingValue(Path, "Path")); - wil::unique_socket sock; - auto result = session->CreateLinuxProcess(nullptr, Path, args.data(), &sock); + WSL_LOG("PluginCallbackExecuteBinaryCallingCreateProcess", TraceLoggingValue(Path, "Path")); + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(nullptr, Path, args.data(), &sock); - WSL_LOG("PluginCallbackExecuteBinaryEnd", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); + WSL_LOG("PluginCallbackExecuteBinaryEnd", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - if (SUCCEEDED(result)) - { - // Return socket as HANDLE — COM's system_handle marshaling will - // duplicate it into the host process automatically. - *Socket = reinterpret_cast(sock.release()); - } + if (SUCCEEDED(result)) + { + // Return socket as HANDLE — COM's system_handle marshaling will + // duplicate it into the host process automatically. + *Socket = reinterpret_cast(sock.release()); + } - return result; + return result; + }); } CATCH_RETURN(); @@ -161,27 +161,29 @@ try RETURN_HR_IF(E_INVALIDARG, Path == nullptr); RETURN_HR_IF(E_INVALIDARG, ArgumentCount > 0 && Arguments == nullptr); - const auto session = FindSessionByCookie(SessionId); - RETURN_HR_IF(c_pluginSessionNotFound, !session); + return m_owner.InvokeOnWslPump(SessionId, [=]() -> HRESULT { + const auto session = FindSessionByCookie(SessionId); + RETURN_HR_IF(c_pluginSessionNotFound, !session); - std::vector args; - if (Arguments != nullptr) - { - args.assign(Arguments, Arguments + ArgumentCount); - } - args.push_back(nullptr); + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); - wil::unique_socket sock; - auto result = session->CreateLinuxProcess(DistributionId, Path, args.data(), &sock); + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(DistributionId, Path, args.data(), &sock); - WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); + WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - if (SUCCEEDED(result)) - { - *Socket = reinterpret_cast(sock.release()); - } + if (SUCCEEDED(result)) + { + *Socket = reinterpret_cast(sock.release()); + } - return result; + return result; + }); } CATCH_RETURN(); @@ -470,10 +472,135 @@ std::vector PluginManager::SerializeSid(PSID Sid) { const DWORD sidLength = GetLengthSid(Sid); std::vector buffer(sidLength); - CopySid(sidLength, buffer.data(), Sid); + THROW_IF_WIN32_BOOL_FALSE(CopySid(sidLength, buffer.data(), Sid)); return buffer; } +// Registers the pump servicing the in-flight WSL notification for a session, so +// that plugin callbacks arriving on RPC threads can be marshaled onto the +// notifying thread (see InvokeOnWslPump). Invariant: a session has at most one +// WSL notification in flight at a time — notifications for a session are issued +// serially by that session, and the callbacks marshaled by the pump +// (Mount/CreateLinuxProcess) never trigger further session notifications — so a +// single SessionId-keyed entry never needs to nest. Register/Unregister are +// bracketed by a scope_exit in RunHostNotification. +void PluginManager::RegisterWslPump(ULONG SessionId, const std::shared_ptr& Pump) +{ + auto lock = m_wslPumpLock.lock_exclusive(); + m_wslPumps[SessionId] = Pump; +} + +void PluginManager::UnregisterWslPump(ULONG SessionId) +{ + auto lock = m_wslPumpLock.lock_exclusive(); + m_wslPumps.erase(SessionId); +} + +HRESULT PluginManager::InvokeOnWslPump(ULONG SessionId, std::function Work) +{ + RETURN_HR_IF(E_UNEXPECTED, !Work); + + // Granularity for the direct path's timed lock acquisition (see below). + constexpr auto c_directLockPollInterval = std::chrono::milliseconds(25); + + for (;;) + { + // 1. If a notification hook for this session is in flight, marshal the + // work onto the notifying thread (which holds the session's recursive + // m_instanceLock) via that hook's pump. Copy the shared_ptr out under + // the lock and release it before the (blocking) Invoke, so the + // registry lock is never held across a callback. + std::shared_ptr pump; + { + auto lock = m_wslPumpLock.lock_shared(); + const auto it = m_wslPumps.find(SessionId); + if (it != m_wslPumps.end()) + { + pump = it->second; + } + } + + if (pump) + { + // Pass Work by copy (not moved): if the pump has already stopped we + // fall through to the direct path below and still need it. Invoke + // reports out-of-band whether it actually ran the work, so a real + // HRESULT from Work (even RPC_E_DISCONNECTED) is never mistaken for + // "pump stopped" — which would otherwise double-run the work. + HRESULT hr = E_FAIL; + if (pump->Invoke(Work, hr)) + { + return hr; + } + + // The hook returned (the pump stopped) between our lookup and the + // Invoke, but the entry is not yet unregistered. The notifying thread + // is no longer waiting on this callback, so fall through and run the + // work directly. The timed lock acquisition below backs off until that + // thread releases m_instanceLock, so a live call is never spuriously + // failed. + } + + // 2. No hook is in flight for this session (or its pump just stopped). + // Run the work directly on this RPC thread, like an in-process plugin + // calling the API from one of its own worker threads. We must NOT just + // block acquiring m_instanceLock: a notification thread may already + // hold it (in its pre-notification phase, before registering its pump) + // and then wait for THIS callback - e.g. a hook that joins the worker + // thread that issued it. Blocking would deadlock (the holder waits for + // us; we wait for the holder). Instead acquire with a short timeout; on + // failure, loop back and re-check for a pump, which the notifying + // thread registers before it can wait on us, so step 1 then routes the + // work onto it. + const auto session = FindSessionByCookie(SessionId); + if (!session) + { + // Unknown session and no hook in flight: run Work() so it surfaces + // the appropriate "session not found" failure. + return Work(); + } + + HRESULT result = E_FAIL; + if (session->TryInvokeUnderInstanceLock(c_directLockPollInterval, Work, result)) + { + return result; + } + + // The instance lock is held by another thread that has not registered a + // pump (most likely a session operation in its pre-notification phase). + // Loop: re-check for a pump or retry the direct acquisition. + } +} + +HRESULT PluginManager::RunHostNotification(OutOfProcPlugin& Plugin, ULONG SessionId, const std::function& Notify) +{ + auto pump = std::make_shared(); + RegisterWslPump(SessionId, pump); + auto unregister = wil::scope_exit([&]() { UnregisterWslPump(SessionId); }); + + return pump->Run([&]() -> HRESULT { + // Runs on the pump's worker thread. The host proxy is apartment-local, + // so acquire it here (and COM-init this thread) rather than on the + // notifying thread. + ScopedComInit init; + RETURN_IF_FAILED(init.Result()); + Microsoft::WRL::ComPtr host; + const HRESULT acquire = AcquireHostProxy(Plugin, &host); + if (FAILED(acquire)) + { + // Surface the acquire failure as the notification result; the caller + // routes it through IsHostCrash exactly like a call-time failure. + if (!IsHostCrash(acquire)) + { + LOG_HR_MSG(acquire, "Failed to acquire plugin host proxy for: '%ls'", Plugin.name.c_str()); + } + return acquire; + } + + return Notify(host.Get()); + }); +} + void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings) { ExecutionContext context(Context::Plugin); @@ -489,18 +616,18 @@ void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLV } WSL_LOG("PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); - ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnVmStarted"); - wil::unique_cotaskmem_string errorMessage; SlowOperationWatcher slowOperation{"PluginOnVmStarted"}; WSL_LOG("PluginOnVmStartedBeginRpc", TraceLoggingValue(e.name.c_str(), "Plugin")); - HRESULT hr = host->OnVMStarted( - Session->SessionId, - Session->UserToken, - static_cast(sidData.size()), - sidData.data(), - static_cast(Settings->CustomConfigurationFlags), - &errorMessage); + const HRESULT hr = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnVMStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + static_cast(Settings->CustomConfigurationFlags), + &errorMessage); + }); WSL_LOG("PluginOnVmStartedEndRpc", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(hr, "Result")); if (IsHostCrash(hr)) @@ -527,9 +654,9 @@ void PluginManager::OnVmStopping(const WSLSessionInformation* Session) } WSL_LOG("PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); - ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnVmStopping"); - - const auto result = host->OnVMStopping(Session->SessionId, Session->UserToken, static_cast(sidData.size()), sidData.data()); + const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnVMStopping(Session->SessionId, Session->UserToken, static_cast(sidData.size()), sidData.data()); + }); if (IsHostCrash(result)) { @@ -560,23 +687,23 @@ void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, TraceLoggingValue(Session->UserSid, "Sid"), TraceLoggingValue(Distribution->Id, "DistributionId")); - ACQUIRE_PLUGIN_HOST_OR_THROW(e, host, "OnDistributionStarted"); - wil::unique_cotaskmem_string errorMessage; SlowOperationWatcher slowOperation{"PluginOnDistributionStarted"}; - HRESULT hr = host->OnDistributionStarted( - Session->SessionId, - Session->UserToken, - static_cast(sidData.size()), - sidData.data(), - &Distribution->Id, - Distribution->Name, - Distribution->PidNamespace, - Distribution->PackageFamilyName, - Distribution->InitPid, - Distribution->Flavor, - Distribution->Version, - &errorMessage); + const HRESULT hr = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnDistributionStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version, + &errorMessage); + }); if (IsHostCrash(hr)) { @@ -606,20 +733,20 @@ void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, TraceLoggingValue(Session->UserSid, "Sid"), TraceLoggingValue(Distribution->Id, "DistributionId")); - ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionStopping"); - - const auto result = host->OnDistributionStopping( - Session->SessionId, - Session->UserToken, - static_cast(sidData.size()), - sidData.data(), - &Distribution->Id, - Distribution->Name, - Distribution->PidNamespace, - Distribution->PackageFamilyName, - Distribution->InitPid, - Distribution->Flavor, - Distribution->Version); + const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnDistributionStopping( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version); + }); if (IsHostCrash(result)) { @@ -650,18 +777,18 @@ void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Sessio TraceLoggingValue(Session->UserSid, "Sid"), TraceLoggingValue(Distribution->Id, "DistributionId")); - ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionRegistered"); - - const auto result = host->OnDistributionRegistered( - Session->SessionId, - Session->UserToken, - static_cast(sidData.size()), - sidData.data(), - &Distribution->Id, - Distribution->Name, - Distribution->PackageFamilyName, - Distribution->Flavor, - Distribution->Version); + const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnDistributionRegistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + }); if (IsHostCrash(result)) { @@ -692,18 +819,18 @@ void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Sess TraceLoggingValue(Session->UserSid, "Sid"), TraceLoggingValue(Distribution->Id, "DistributionId")); - ACQUIRE_PLUGIN_HOST_OR_CONTINUE(e, host, "OnDistributionUnregistered"); - - const auto result = host->OnDistributionUnregistered( - Session->SessionId, - Session->UserToken, - static_cast(sidData.size()), - sidData.data(), - &Distribution->Id, - Distribution->Name, - Distribution->PackageFamilyName, - Distribution->Flavor, - Distribution->Version); + const HRESULT result = RunHostNotification(e, Session->SessionId, [&](IWslPluginHost* host) { + return host->OnDistributionUnregistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + }); if (IsHostCrash(result)) { diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h index 17c1b23ba5..ecf1744a08 100644 --- a/src/windows/service/exe/PluginManager.h +++ b/src/windows/service/exe/PluginManager.h @@ -27,6 +27,7 @@ Module Name: #include "WslPluginApi.h" #include "WslPluginHost.h" #include "wslc.h" +#include "PluginCallPump.h" namespace wsl::windows::service { @@ -160,6 +161,20 @@ class PluginManager // can no longer be opened. wil::com_ptr ResolveWslcSession(ULONG SessionId); + // Routes a WSL-session plugin API callback (MountFolder / ExecuteBinary / + // ExecuteBinaryInDistribution) so it runs with in-process semantics. If a + // notification hook for SessionId is in flight, the work is marshaled onto + // the notifying thread (which holds the session's recursive m_instanceLock) + // via that hook's PluginCallPump — so out-of-process callbacks re-enter the + // lock exactly as in-process plugins did, and no second (m_callbackLock) + // lock is needed. Otherwise the work runs directly on the calling RPC thread + // (acquiring m_instanceLock itself via a timed try-acquire that re-checks for + // a pump on contention, so it can never deadlock against a notification + // thread that holds m_instanceLock and later waits for this callback). This + // supports plugin API calls made from a plugin's own worker threads outside + // any hook. + HRESULT InvokeOnWslPump(ULONG SessionId, std::function Work); + private: struct OutOfProcPlugin { @@ -245,6 +260,23 @@ class PluginManager // teardown hooks latch but cannot block, so they call LatchHostCrash + skip. [[noreturn]] void ThrowHostCrash(OutOfProcPlugin& plugin, HRESULT result, const char* stage); + // Registers/unregisters the active PluginCallPump for a WSL session while a + // notification (OnVmStarted, etc.) is in flight. Keyed by the session cookie + // that is handed to the plugin host and echoed back on callbacks. Plugin + // notifications for a given session are serialized by m_instanceLock, so at + // most one pump is registered per SessionId at a time. + void RegisterWslPump(ULONG SessionId, const std::shared_ptr& Pump); + void UnregisterWslPump(ULONG SessionId); + + // Drives one outbound plugin-host notification (OnVMStarted, etc.) through a + // PluginCallPump registered under SessionId. The pump runs `Notify` on a + // worker thread (which acquires the apartment-local host proxy and performs + // its own COM init) while THIS thread pumps the plugin's API callbacks — so + // they execute back here, under the session's recursive m_instanceLock. + // Returns the HRESULT from `Notify`, or the proxy-acquire failure (the + // caller routes it through IsHostCrash exactly as before). + HRESULT RunHostNotification(OutOfProcPlugin& Plugin, ULONG SessionId, const std::function& Notify); + std::once_flag m_initOnce; std::vector m_plugins; @@ -272,6 +304,16 @@ class PluginManager // rather than releasing them after CoUninitialize. std::mutex m_wslcSessionRefLock; std::unordered_map> m_wslcSessionRefs; + + // Active WSL-session notification pumps, keyed by session cookie. Populated + // for the duration of an out-of-process notification call so that plugin API + // callbacks (which arrive on a different RPC thread) can be marshaled back + // onto the notifying thread. Held by shared_ptr so InvokeOnWslPump can copy + // out a stable reference under the lock and then release the lock before the + // (blocking) Invoke — the lock is never held across a callback. See + // InvokeOnWslPump. + wil::srwlock m_wslPumpLock; + _Guarded_by_(m_wslPumpLock) std::unordered_map> m_wslPumps; }; } // namespace wsl::windows::service diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp index 3726eceaf5..3fb16b9166 100644 --- a/test/windows/PluginTests.cpp +++ b/test/windows/PluginTests.cpp @@ -794,12 +794,18 @@ class PluginTests // --- PR #40120 (out-of-process plugin host) coverage --- // - // These tests validate the new isolation and locking behavior: + // These tests validate the isolation and callback model: // * HostCrashIsFatal — host process crash aborts the guarded operation (fatal). - // * ConcurrentCallbacks — concurrent shared_lock readers on m_callbackLock. - // * AsyncApiCallFromWorker — cross-apartment plugin API call from a non-hook thread. - // * CallbacksDuringTerminationDoNotCrash — exclusive m_callbackLock drains in-flight - // callbacks before m_utilityVm.reset(). + // * ConcurrentCallbacks — many plugin threads issue API callbacks during a hook. + // * AsyncApiCallFromWorker — plugin API call from a worker thread OUTSIDE any hook. + // * CallbacksDuringTerminationDoNotCrash — callbacks racing VM teardown fail gracefully, never crash. + // + // Callback model (PluginCallPump): a plugin API callback that arrives WHILE a + // notification hook is in flight is marshaled back onto the notifying thread + // (which holds the session's recursive m_instanceLock), reproducing in-process + // re-entrancy; a callback that arrives with no hook in flight runs directly on + // the RPC thread (taking m_instanceLock itself). There is no separate callback + // lock — m_instanceLock alone serializes callbacks against VM teardown. WSL2_TEST_METHOD(HostCrashIsFatal) { @@ -892,18 +898,20 @@ class PluginTests { // Drain test: 4 workers loop ExecuteBinaryInDistribution (with /bin/true, // sub-ms callback) while the distro is alive. They keep calling across - // OnDistroStopping and _VmTerminate; the exclusive m_callbackLock acquire - // in _VmTerminate must drain in-flight callbacks before resetting - // m_utilityVm. After OnVmStopping signals wind-down, workers run a bounded - // number of further iterations and exit (see Plugin.cpp), so termination - // is deterministic and no worker can revive against a later VM. + // OnDistroStopping and _VmTerminate. Because callbacks run under (or block + // on) the session's recursive m_instanceLock, _VmTerminate's m_utilityVm + // reset is naturally serialized against them: a racing callback either + // runs before the reset (valid VM) or after it (returns E_NOT_VALID_STATE). + // Either way the service must not crash. After OnVmStopping signals + // wind-down, workers run a bounded number of further iterations and exit + // (see Plugin.cpp), so termination is deterministic and no worker can + // revive against a later VM. // // The post-shutdown StartWsl below triggers a second OnDistroStarted that // joins the finished workers, so this test needs no fixed sleep. // // Scope: - // - Validates: dual-lock invariant under racing callbacks; drain works - // when callbacks complete in sub-ms; service survives the race. + // - Validates: callbacks racing teardown never crash; service survives. // - Does NOT validate: drain semantics when a callback is genuinely // stuck (e.g. service-side CreateLinuxProcess waiting on a hung // Linux init). That requires cancellation plumbing through diff --git a/test/windows/testplugin/Plugin.cpp b/test/windows/testplugin/Plugin.cpp index 94c20d35fc..30ea8e88a0 100644 --- a/test/windows/testplugin/Plugin.cpp +++ b/test/windows/testplugin/Plugin.cpp @@ -241,17 +241,17 @@ HRESULT OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSet } else if (g_testType == PluginTestType::ConcurrentApiCalls) { - // Validate concurrent service-side callbacks under the new - // m_callbackLock (shared_mutex). N threads call MountFolder + - // ExecuteBinary in parallel via a start-gate so the shared_lock has - // multiple readers in flight at once. + // Validate service-side callbacks issued by multiple plugin threads + // during a hook. N threads call MountFolder + ExecuteBinary via a + // start-gate so the RPCs are all in flight at once. // // maxConcurrent records how many workers are simultaneously at the - // callback boundary (via a second rendezvous). Reaching N proves the - // plugin issues N callbacks concurrently. It does NOT prove the service - // executes them in parallel — a black-box plugin can't observe whether - // m_callbackLock is shared or exclusive, since either way the RPCs - // simply appear in flight. This is the strongest honest assertion here. + // plugin-side callback boundary (via a second rendezvous). Reaching N + // proves the plugin issues N callbacks concurrently. It does NOT prove + // the service executes them in parallel: with the PluginCallPump the + // service marshals these onto the single notifying thread and runs them + // serially. The test asserts only that all N succeed, which is the + // strongest honest black-box assertion. constexpr int N = 4; std::filesystem::path modulePath = wil::GetModuleFileNameW(wil::GetModuleInstanceHandle()).get(); @@ -349,8 +349,8 @@ HRESULT OnVmStopping(const WSLSessionInformation* Session) if (g_testType == PluginTestType::CallbackDuringTermination) { // Signal drain workers to begin a bounded wind-down. Fires before - // _VmTerminate's exclusive m_callbackLock acquire, so workers keep - // racing the drain for a fixed number of iterations before exiting. + // _VmTerminate resets m_utilityVm, so workers keep racing teardown for + // a fixed number of iterations before exiting. g_drainWindDown = true; } @@ -491,19 +491,20 @@ HRESULT OnDistroStarted(const WSLSessionInformation* Session, const WSLDistribut } else if (g_testType == PluginTestType::CallbackDuringTermination) { - // Validate that the new exclusive m_callbackLock acquire in - // _VmTerminate drains in-flight callbacks before m_utilityVm.reset(). + // Validate that callbacks racing VM teardown never crash the service. // Workers keep calling into the service across OnDistroStopping / - // _VmTerminate, then wind down deterministically (see globals above). + // _VmTerminate; each callback runs under (or blocks on) the session's + // recursive m_instanceLock, so it is naturally serialized against + // m_utilityVm.reset() and fails gracefully if it lands after teardown. + // Workers then wind down deterministically (see globals above). // - // Scope: this test exercises only the *happy-path* drain — the + // Scope: this test exercises only the *happy-path* race — the // callback (/bin/true) returns in sub-millisecond, so workers are - // almost always between iterations when the exclusive lock is - // acquired. It is *not* a regression test for the hung-callback - // case, where a service-side callback is stuck inside CreateLinuxProcess - // waiting on a non-responsive Linux init; that scenario requires - // termination-event plumbing through WslCoreInstance::CreateLinuxProcess - // and is tracked separately. + // almost always between iterations when teardown runs. It is *not* a + // regression test for the hung-callback case, where a service-side + // callback is stuck inside CreateLinuxProcess waiting on a non-responsive + // Linux init; that scenario requires termination-event plumbing through + // WslCoreInstance::CreateLinuxProcess and is tracked separately. constexpr int N = 4; // Spawn at most once. The post-shutdown StartWsl in