diff --git a/localization/strings/en-US/Resources.resw b/localization/strings/en-US/Resources.resw index 5a8097efca..e99d2e59f8 100644 --- a/localization/strings/en-US/Resources.resw +++ b/localization/strings/en-US/Resources.resw @@ -1054,6 +1054,16 @@ Falling back to NAT networking. Update + + WSL Update + {Locked="WSL"}Product names should not be translated + + + An update for WSL is ready to install, but it is currently in use. + +Select Yes to shutdown WSL and install now. Select No to install on the next reboot. + {Locked="WSL"}Product names should not be translated + See Docs diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt index fafd201352..4c341c6031 100644 --- a/src/windows/common/CMakeLists.txt +++ b/src/windows/common/CMakeLists.txt @@ -53,7 +53,8 @@ set(SOURCES wslutil.cpp install.cpp WSLCUserSettings.cpp - notifications.cpp) + notifications.cpp + WslActivityMarker.cpp) set(HEADERS ../../../generated/Localization.h @@ -140,6 +141,7 @@ set(HEADERS EnumVariantMap.h WSLCUserSettings.h WSLCSessionDefaults.h + WslActivityMarker.h ) add_library(common STATIC ${SOURCES} ${HEADERS}) diff --git a/src/windows/common/WslActivityMarker.cpp b/src/windows/common/WslActivityMarker.cpp new file mode 100644 index 0000000000..14edfb59ff --- /dev/null +++ b/src/windows/common/WslActivityMarker.cpp @@ -0,0 +1,60 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslActivityMarker.cpp + +Abstract: + + This file contains the implementation for tracking whether WSL is in use. + +--*/ + +#include "precomp.h" +#include "WslActivityMarker.h" + +namespace { + +constexpr auto c_activityObjectName = L"Global\\WslActive"; + +wil::srwlock g_activityLock; +_Guarded_by_(g_activityLock) size_t g_activityCount = 0; +_Guarded_by_(g_activityLock) wil::unique_handle g_activityEvent; + +} // namespace + +namespace wsl::windows::common { + +WslActivityMarker::WslActivityMarker() noexcept +{ + auto lock = g_activityLock.lock_exclusive(); + + g_activityCount++; + + if (!g_activityEvent) + { + g_activityEvent.reset(CreateEventW(nullptr, TRUE, FALSE, c_activityObjectName)); + LOG_LAST_ERROR_IF_MSG(!g_activityEvent, "Failed to create WSL activity object"); + } +} + +WslActivityMarker::~WslActivityMarker() noexcept +{ + auto lock = g_activityLock.lock_exclusive(); + + g_activityCount--; + if (g_activityCount == 0) + { + g_activityEvent.reset(); + } +} + +bool WslActivityMarker::IsWslActive() noexcept +{ + wil::unique_handle event{OpenEventW(SYNCHRONIZE, FALSE, c_activityObjectName)}; + return event != nullptr; +} + +} // namespace wsl::windows::common diff --git a/src/windows/common/WslActivityMarker.h b/src/windows/common/WslActivityMarker.h new file mode 100644 index 0000000000..a6adad3043 --- /dev/null +++ b/src/windows/common/WslActivityMarker.h @@ -0,0 +1,31 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslActivityMarker.h + +Abstract: + + This file contains declarations for tracking whether WSL is in use. + +--*/ + +#pragma once + +namespace wsl::windows::common { + +class WslActivityMarker +{ +public: + WslActivityMarker() noexcept; + ~WslActivityMarker() noexcept; + + NON_COPYABLE(WslActivityMarker); + NON_MOVABLE(WslActivityMarker); + + static bool IsWslActive() noexcept; +}; + +} // namespace wsl::windows::common diff --git a/src/windows/service/exe/LxssCreateProcess.h b/src/windows/service/exe/LxssCreateProcess.h index 8471e7b408..fb2d3844ef 100644 --- a/src/windows/service/exe/LxssCreateProcess.h +++ b/src/windows/service/exe/LxssCreateProcess.h @@ -16,6 +16,7 @@ Module Name: #include "SocketChannel.h" #include "WslPluginApi.h" +#include "WslActivityMarker.h" // Macro to test if Windows interop is enabled. #define LXSS_INTEROP_FLAGS (LXSS_DISTRO_FLAGS_ENABLE_DRIVE_MOUNTING | LXSS_DISTRO_FLAGS_ENABLE_INTEROP) @@ -163,4 +164,6 @@ class LxssRunningInstance private: int m_idleTimeout; + + wsl::windows::common::WslActivityMarker m_activityMarker; }; diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp index 3966245df5..8b1f357065 100644 --- a/src/windows/service/exe/WSLCSessionManager.cpp +++ b/src/windows/service/exe/WSLCSessionManager.cpp @@ -282,7 +282,17 @@ void WSLCSessionManagerImpl::CreateSession( // Track the session via its service ref, along with metadata and security info. m_sessions.push_back(SessionEntry{ - std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo), notifier, false, sharedToken, std::move(storedSid), std::move(sessionJob)}); + std::move(serviceRef), + sessionId, + creatorPid, + resolvedDisplayName, + std::move(tokenInfo), + notifier, + false, + sharedToken, + std::move(storedSid), + std::move(sessionJob), + std::make_unique()}); // For persistent sessions, also hold a strong reference to keep them alive. const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent); diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h index d48f52e430..9ea10def66 100644 --- a/src/windows/service/exe/WSLCSessionManager.h +++ b/src/windows/service/exe/WSLCSessionManager.h @@ -34,6 +34,7 @@ Module Name: #include "wslc.h" #include "COMImplClass.h" #include "wslutil.h" +#include "WslActivityMarker.h" #include #include #include @@ -70,6 +71,8 @@ struct SessionEntry std::vector UserSid; wil::unique_handle JobObject; + + std::unique_ptr ActivityMarker; }; class WSLCSessionManagerImpl diff --git a/src/windows/wslinstaller/exe/CMakeLists.txt b/src/windows/wslinstaller/exe/CMakeLists.txt index 18c8f6074f..f7e386b3c1 100644 --- a/src/windows/wslinstaller/exe/CMakeLists.txt +++ b/src/windows/wslinstaller/exe/CMakeLists.txt @@ -17,6 +17,7 @@ set_target_properties(wslinstaller PROPERTIES LINK_FLAGS "/merge:minATL=.rdata / target_link_libraries(wslinstaller ${COMMON_LINK_LIBRARIES} ${MSI_LINK_LIBRARIES} + Wtsapi32.lib common legacy_stdio_definitions) diff --git a/src/windows/wslinstaller/exe/WslInstaller.cpp b/src/windows/wslinstaller/exe/WslInstaller.cpp index 7ac796d6c9..dcf6b4701e 100644 --- a/src/windows/wslinstaller/exe/WslInstaller.cpp +++ b/src/windows/wslinstaller/exe/WslInstaller.cpp @@ -15,6 +15,8 @@ Module Name: #include "precomp.h" #include "install.h" #include "WslInstaller.h" +#include "WslActivityMarker.h" +#include extern wil::unique_event g_stopEvent; @@ -164,6 +166,74 @@ std::pair IsUpdateNeeded() } } +static bool DeferUpdatePromptEnabled() +try +{ + const auto key = wsl::windows::common::registry::OpenLxssMachineKey(KEY_READ); + + auto value = wsl::windows::common::registry::ReadDword(key.get(), L"MSI", L"EnableMsixDeferUpdatePrompt", 1); + + return value == 1; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION(); + return true; +} + +static bool PromptUserToUpgradeWhileActive() +try +{ + // For test automation purpose + if (!DeferUpdatePromptEnabled()) + { + wsl::windows::common::install::WriteInstallLog("DeferUpdatePrompt is disabled"); + return false; + } + + constexpr DWORD c_upgradePromptTimeoutSeconds = 60; + + const DWORD sessionId = WTSGetActiveConsoleSessionId(); + if (sessionId == 0xFFFFFFFF) + { + wsl::windows::common::install::WriteInstallLog("No active console session"); + return false; + } + + auto title = wsl::shared::Localization::MessageUpgradeWhileActiveTitle(); + auto message = wsl::shared::Localization::MessageUpgradeWhileActivePrompt(); + + DWORD response = 0; + if (!WTSSendMessageW( + WTS_CURRENT_SERVER_HANDLE, + sessionId, + title.data(), + static_cast(title.size() * sizeof(wchar_t)), + message.data(), + static_cast(message.size() * sizeof(wchar_t)), + MB_YESNO | MB_ICONQUESTION | MB_TOPMOST, + c_upgradePromptTimeoutSeconds, + &response, + TRUE)) + { + LOG_LAST_ERROR_MSG("WTSSendMessageW failed"); + return false; + } + + if (response != IDYES) + { + wsl::windows::common::install::WriteInstallLog(std::format("User declined upgrade prompt (response {})", response)); + return false; + } + + return true; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION(); + return false; +} + std::shared_ptr LaunchInstall() { static wil::srwlock mutex; @@ -177,6 +247,12 @@ std::shared_ptr LaunchInstall() return {}; } + if (wsl::windows::common::WslActivityMarker::IsWslActive() && !PromptUserToUpgradeWhileActive()) + { + wsl::windows::common::install::WriteInstallLog("WSL is active; deferring MSI upgrade until WSL is idle."); + return {}; + } + wsl::windows::common::install::WriteInstallLog(std::format("Starting upgrade via WslInstaller. Previous version: {}", existingVersion)); // Return an existing install if any diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index 65caf40ddf..e802ec2645 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -2476,6 +2476,20 @@ std::optional GetDistributionId(LPCWSTR Name) return {}; } +LxssDistributionState GetDistributionState(LPCWSTR Name) +{ + wsl::windows::common::SvcComm service; + for (const auto& e : service.EnumerateDistributions()) + { + if (wsl::shared::string::IsEqual(e.DistroName, Name)) + { + return e.State; + } + } + + return LxssDistributionStateInvalid; +} + wil::unique_hkey OpenDistributionKey(LPCWSTR Name) { const auto id = GetDistributionId(Name); diff --git a/test/windows/Common.h b/test/windows/Common.h index d758567ca7..43eaea5e03 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -617,6 +617,7 @@ void StopWslService(); std::optional GetDistributionId(LPCWSTR Name); wil::unique_hkey OpenDistributionKey(LPCWSTR Name); +LxssDistributionState GetDistributionState(LPCWSTR Name = LXSS_DISTRO_NAME_TEST_L); void ValidateOutput(LPCWSTR CommandLine, const std::wstring& ExpectedOutput, const std::wstring& ExpectedWarnings = L"", int ExitCode = -1); diff --git a/test/windows/InstallerTests.cpp b/test/windows/InstallerTests.cpp index 91d9c699c1..6e2d09b0a8 100644 --- a/test/windows/InstallerTests.cpp +++ b/test/windows/InstallerTests.cpp @@ -119,7 +119,7 @@ class InstallerTests try { - wsl::shared::retry::RetryWithTimeout(pred, std::chrono::hours(3), std::chrono::minutes(2)); + wsl::shared::retry::RetryWithTimeout(pred, std::chrono::seconds(3), std::chrono::minutes(2)); } catch (...) { @@ -694,6 +694,43 @@ class InstallerTests output); } + TEST_METHOD(MsixUpgradeDefer) + { + InstallMsi(); + VERIFY_IS_TRUE(IsMsiPackageInstalled()); + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { InstallMsi(); }); + + RegistryKeyChange changeVersion( + HKEY_LOCAL_MACHINE, L"Software\\Microsoft\\Windows\\CurrentVersion\\Lxss\\MSI", L"Version", L"1.0.0"); + + RegistryKeyChange disablePrompt( + HKEY_LOCAL_MACHINE, L"Software\\Microsoft\\Windows\\CurrentVersion\\Lxss\\MSI", L"EnableMsixDeferUpdatePrompt", 0); + + const auto commandLine = LxssGenerateWslCommandLine(L"sleep infinity"); + wsl::windows::common::SubProcess process(nullptr, commandLine.c_str()); + auto processHandle = process.Start(); + + wsl::shared::retry::RetryWithTimeout( + []() { THROW_HR_IF(E_ABORT, GetDistributionState() != LxssDistributionStateRunning); }, + std::chrono::seconds(1), + std::chrono::seconds(30)); + + // Cannot redeploy directly even with ForceUpdateFromAnyVersion set. + UninstallMsix(); + VERIFY_IS_FALSE(IsMsixInstalled()); + + InstallMsix(); + VERIFY_IS_TRUE(IsMsixInstalled()); + + WaitForInstallerServiceStop(); + + const auto key = wsl::windows::common::registry::OpenLxssMachineKey(); + VERIFY_ARE_EQUAL(wsl::windows::common::registry::ReadString(key.get(), L"MSI", L"Version"), L"1.0.0"); + + VERIFY_ARE_EQUAL(WaitForSingleObject(processHandle.get(), 0), static_cast(WAIT_TIMEOUT)); + } + TEST_METHOD(WslUpdateNoNewVersion) { constexpr auto endpoint = L"http://127.0.0.1:12345/"; diff --git a/test/windows/UnitTests.cpp b/test/windows/UnitTests.cpp index c8aa0791d0..d5f3516410 100644 --- a/test/windows/UnitTests.cpp +++ b/test/windows/UnitTests.cpp @@ -442,7 +442,9 @@ class UnitTests // Wait for the distro to exit. VERIFY_NO_THROW(wsl::shared::retry::RetryWithTimeout( - [&]() { THROW_HR_IF(E_ABORT, GetDistroState() == LxssDistributionStateRunning); }, std::chrono::seconds(1), std::chrono::seconds(30))); + [&]() { THROW_HR_IF(E_ABORT, GetDistributionState() == LxssDistributionStateRunning); }, + std::chrono::seconds(1), + std::chrono::seconds(30))); // Verify that a new WSL command succeeds (the distro restarts cleanly). auto [out, err] = LxsstuLaunchWslAndCaptureOutput(L"echo hello"); @@ -6363,21 +6365,6 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", } } - static LxssDistributionState GetDistroState() - { - wsl::windows::common::SvcComm service; - - for (const auto& e : service.EnumerateDistributions()) - { - if (wsl::shared::string::IsEqual(e.DistroName, LXSS_DISTRO_NAME_TEST_L)) - { - return e.State; - } - } - - return LxssDistributionStateInvalid; - } - TEST_METHOD(DistroTimeout) { WslConfigChange config(LxssGenerateTestConfig() + L"[general]\ninstanceIdleTimeout=-1"); @@ -6388,7 +6375,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", VERIFY_ARE_EQUAL(LxsstuLaunchWsl(L"echo OK"), 0L); std::this_thread::sleep_for(std::chrono::seconds(20)); - VERIFY_ARE_EQUAL(GetDistroState(), LxssDistributionStateRunning); + VERIFY_ARE_EQUAL(GetDistributionState(), LxssDistributionStateRunning); } // Validate that distributions time out when timeout value is > 0 @@ -6402,7 +6389,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", unsigned long iterations = 0; while (std::chrono::steady_clock::now() < deadline) { - if (GetDistroState() == LxssDistributionStateInstalled) + if (GetDistributionState() == LxssDistributionStateInstalled) { LogInfo("Distribution stopped after %lu iterations", iterations); return; @@ -6412,7 +6399,7 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", iterations++; } - LogError("Distribution failed to time out after %lu iterations. State: %i", iterations, GetDistroState()); + LogError("Distribution failed to time out after %lu iterations. State: %i", iterations, GetDistributionState()); VERIFY_FAIL(); } }