diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index 03decb3e1..320193744 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -368,14 +368,45 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl( Filename = Url.substr(lastSlash + 1); } - const auto downloadFolder = - winrt::Windows::Storage::StorageFolder::GetFolderFromPathAsync(std::filesystem::temp_directory_path().wstring()).get(); - - const auto file = - downloadFolder.CreateFileAsync(Filename, winrt::Windows::Storage::CreationCollisionOption::GenerateUniqueName).get(); - auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { file.DeleteAsync().get(); }); + // GetFolderFromPathAsync won't work if the folder is hidden or system. + auto downloadFolderPath = std::filesystem::temp_directory_path(); + auto filenameStem = std::filesystem::path(Filename).stem().wstring(); + auto filenameExtension = std::filesystem::path(Filename).extension().wstring(); + std::wstring filePath{}; + winrt::Windows::Storage::Streams::IRandomAccessStream outputStream{}; + for (int suffix = 1; outputStream == nullptr; suffix++) + { + if (suffix == 1) + { + filePath = (downloadFolderPath / Filename).wstring(); + } + else + { + filePath = (downloadFolderPath / std::format(L"{} ({}){}", filenameStem, suffix, filenameExtension)).wstring(); + } + try + { + outputStream = winrt::Windows::Storage::Streams::FileRandomAccessStream::OpenAsync( + filePath, + winrt::Windows::Storage::FileAccessMode::ReadWrite, + winrt::Windows::Storage::StorageOpenOptions::None, + winrt::Windows::Storage::Streams::FileOpenDisposition::CreateNew) + .get(); + } + catch (...) + { + if (wil::ResultFromCaughtException() != HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)) + { + throw; + } + } + } - const auto outputStream = file.OpenAsync(winrt::Windows::Storage::FileAccessMode::ReadWrite).get().GetOutputStreamAt(0); + auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + outputStream.Close(); + std::error_code ec; + std::filesystem::remove(filePath, ec); + }); // By default downloaded files are cached in %appdata%/local/packages/{package-family}/AC/InetCache . // Disable caching since there's no reason to keep local copies of .msixbundle files. @@ -409,7 +440,7 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl( download.get(); deleteFileOnFailure.release(); - return file.Path().c_str(); + return filePath; } [[nodiscard]] HANDLE wsl::windows::common::wslutil::DuplicateHandle(_In_ HANDLE Handle, _In_ std::optional DesiredAccess, _In_ BOOL InheritHandle) diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index f5aa3b1a0..69a40cc35 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -2476,12 +2476,27 @@ void Trim(std::wstring& string) ScopedEnvVariable::ScopedEnvVariable(const std::wstring& Name, const std::wstring& Value) : m_name(Name) { + std::wstring value; + const auto result = wil::GetEnvironmentVariableW(Name.c_str(), value); + if (result != HRESULT_FROM_WIN32(ERROR_ENVVAR_NOT_FOUND)) + { + VERIFY_SUCCEEDED(result); + m_originalValue = std::move(value); + } + VERIFY_IS_TRUE(SetEnvironmentVariable(Name.c_str(), Value.c_str())); } ScopedEnvVariable::~ScopedEnvVariable() { - VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), nullptr)); + if (m_originalValue.has_value()) + { + VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), m_originalValue->c_str())); + } + else + { + VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), nullptr)); + } } UniqueWebServer::UniqueWebServer(LPCWSTR Endpoint, LPCWSTR Content) diff --git a/test/windows/Common.h b/test/windows/Common.h index ad2706ebf..f018ccf77 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -312,13 +312,12 @@ class ScopedEnvVariable ScopedEnvVariable(const std::wstring& Name, const std::wstring& Value); ~ScopedEnvVariable(); - ScopedEnvVariable(const WslConfigChange&) = delete; - ScopedEnvVariable(WslConfigChange&&) = delete; - const ScopedEnvVariable& operator=(ScopedEnvVariable&&) = delete; - const ScopedEnvVariable& operator=(ScopedEnvVariable&) = delete; + NON_COPYABLE(ScopedEnvVariable); + NON_MOVABLE(ScopedEnvVariable); private: std::wstring m_name; + std::optional m_originalValue{std::nullopt}; }; class UniqueWebServer diff --git a/test/windows/UnitTests.cpp b/test/windows/UnitTests.cpp index c8aa0791d..08fbf51cc 100644 --- a/test/windows/UnitTests.cpp +++ b/test/windows/UnitTests.cpp @@ -7387,5 +7387,70 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", } } + TEST_METHOD(DownloadToHiddenSystemTempFolder) + { + // Avoid contaminating the real temp folder. + const auto testTempFolder = std::filesystem::temp_directory_path() / L"wsl-download-test"; + std::filesystem::create_directories(testTempFolder); + auto cleanupTempFolder = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + std::error_code error; + std::filesystem::remove_all(testTempFolder, error); + }); + + const auto originalAttributes = GetFileAttributesW(testTempFolder.c_str()); + VERIFY_IS_TRUE(originalAttributes != INVALID_FILE_ATTRIBUTES); + VERIFY_IS_TRUE(SetFileAttributesW(testTempFolder.c_str(), originalAttributes | FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM)); + + ScopedEnvVariable temp(L"TEMP", testTempFolder.wstring()); + ScopedEnvVariable tmp(L"TMP", testTempFolder.wstring()); + + VERIFY_IS_TRUE(std::filesystem::equivalent(std::filesystem::temp_directory_path(), testTempFolder)); + + constexpr USHORT port = 6666; + const auto endpoint = std::format(L"http://127.0.0.1:{}/", port); + constexpr auto fileName = L"downloaded-file.bin"; + constexpr auto fileContent = L"wsl download test content"; + UniqueWebServer server(endpoint.c_str(), fileContent); + + const auto url = endpoint + fileName; + const auto noProgress = [](uint64_t, uint64_t) {}; + + wsl::shared::retry::RetryWithTimeout( + [&]() { + wil::unique_socket probe{socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)}; + THROW_LAST_ERROR_IF(!probe); + + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_port = htons(port); + address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + THROW_LAST_ERROR_IF(connect(probe.get(), reinterpret_cast(&address), sizeof(address)) == SOCKET_ERROR); + }, + std::chrono::milliseconds(500), + std::chrono::seconds(5)); + + const auto firstPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress); + + auto readFile = [](const std::filesystem::path& Path) { + std::ifstream file(Path, std::ios::binary); + VERIFY_IS_TRUE(file.good()); + return std::string{std::istreambuf_iterator(file), {}}; + }; + + VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).parent_path(), testTempFolder); + VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).filename().wstring(), std::wstring(fileName)); + VERIFY_IS_TRUE(std::filesystem::exists(firstPath)); + VERIFY_ARE_EQUAL(readFile(firstPath), wsl::shared::string::WideToMultiByte(fileContent)); + + const auto secondPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress); + + VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).parent_path(), testTempFolder); + VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).filename().wstring(), std::wstring(L"downloaded-file (2).bin")); + VERIFY_IS_TRUE(std::filesystem::exists(firstPath)); + VERIFY_IS_TRUE(std::filesystem::exists(secondPath)); + VERIFY_ARE_EQUAL(readFile(secondPath), wsl::shared::string::WideToMultiByte(fileContent)); + } + }; // namespace UnitTests } // namespace UnitTests