Skip to content
47 changes: 39 additions & 8 deletions src/windows/common/wslutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,45 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
Filename = Url.substr(lastSlash + 1);
}

const auto downloadFolder =
winrt::Windows::Storage::StorageFolder::GetFolderFromPathAsync(std::filesystem::temp_directory_path().wstring()).get();

const auto file =
downloadFolder.CreateFileAsync(Filename, winrt::Windows::Storage::CreationCollisionOption::GenerateUniqueName).get();
auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { file.DeleteAsync().get(); });
// GetFolderFromPathAsync won't work if the folder is hidden or system.
auto downloadFolderPath = std::filesystem::temp_directory_path();
auto filenameStem = std::filesystem::path(Filename).stem().wstring();
auto filenameExtension = std::filesystem::path(Filename).extension().wstring();
Comment thread
chemwolf6922 marked this conversation as resolved.
std::wstring filePath{};
winrt::Windows::Storage::Streams::IRandomAccessStream outputStream{};
for (int suffix = 1; outputStream == nullptr; suffix++)
{
if (suffix == 1)
{
filePath = (downloadFolderPath / Filename).wstring();
}
else
{
filePath = (downloadFolderPath / std::format(L"{} ({}){}", filenameStem, suffix, filenameExtension)).wstring();
}
try
{
outputStream = winrt::Windows::Storage::Streams::FileRandomAccessStream::OpenAsync(
filePath,
winrt::Windows::Storage::FileAccessMode::ReadWrite,
winrt::Windows::Storage::StorageOpenOptions::None,
winrt::Windows::Storage::Streams::FileOpenDisposition::CreateNew)
.get();
}
catch (...)
{
if (wil::ResultFromCaughtException() != HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS))
{
throw;
}
}
}

const auto outputStream = file.OpenAsync(winrt::Windows::Storage::FileAccessMode::ReadWrite).get().GetOutputStreamAt(0);
auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
outputStream.Close();
std::error_code ec;
std::filesystem::remove(filePath, ec);
});

// By default downloaded files are cached in %appdata%/local/packages/{package-family}/AC/InetCache .
// Disable caching since there's no reason to keep local copies of .msixbundle files.
Expand Down Expand Up @@ -409,7 +440,7 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
download.get();
deleteFileOnFailure.release();

return file.Path().c_str();
return filePath;
}

[[nodiscard]] HANDLE wsl::windows::common::wslutil::DuplicateHandle(_In_ HANDLE Handle, _In_ std::optional<DWORD> DesiredAccess, _In_ BOOL InheritHandle)
Expand Down
17 changes: 16 additions & 1 deletion test/windows/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2476,12 +2476,27 @@ void Trim(std::wstring& string)

ScopedEnvVariable::ScopedEnvVariable(const std::wstring& Name, const std::wstring& Value) : m_name(Name)
{
std::wstring value;
const auto result = wil::GetEnvironmentVariableW(Name.c_str(), value);
if (result != HRESULT_FROM_WIN32(ERROR_ENVVAR_NOT_FOUND))
{
VERIFY_SUCCEEDED(result);
m_originalValue = std::move(value);
}

VERIFY_IS_TRUE(SetEnvironmentVariable(Name.c_str(), Value.c_str()));
}

ScopedEnvVariable::~ScopedEnvVariable()
{
VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), nullptr));
if (m_originalValue.has_value())
{
VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), m_originalValue->c_str()));
}
else
{
VERIFY_IS_TRUE(SetEnvironmentVariable(m_name.c_str(), nullptr));
}
}

UniqueWebServer::UniqueWebServer(LPCWSTR Endpoint, LPCWSTR Content)
Expand Down
7 changes: 3 additions & 4 deletions test/windows/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,12 @@ class ScopedEnvVariable
ScopedEnvVariable(const std::wstring& Name, const std::wstring& Value);
~ScopedEnvVariable();

ScopedEnvVariable(const WslConfigChange&) = delete;
ScopedEnvVariable(WslConfigChange&&) = delete;
const ScopedEnvVariable& operator=(ScopedEnvVariable&&) = delete;
const ScopedEnvVariable& operator=(ScopedEnvVariable&) = delete;
NON_COPYABLE(ScopedEnvVariable);
NON_MOVABLE(ScopedEnvVariable);

private:
std::wstring m_name;
std::optional<std::wstring> m_originalValue{std::nullopt};
};

class UniqueWebServer
Expand Down
65 changes: 65 additions & 0 deletions test/windows/UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7387,5 +7387,70 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
}
}

TEST_METHOD(DownloadToHiddenSystemTempFolder)
{
// Avoid contaminating the real temp folder.
const auto testTempFolder = std::filesystem::temp_directory_path() / L"wsl-download-test";
std::filesystem::create_directories(testTempFolder);
auto cleanupTempFolder = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
std::error_code error;
std::filesystem::remove_all(testTempFolder, error);
});

const auto originalAttributes = GetFileAttributesW(testTempFolder.c_str());
VERIFY_IS_TRUE(originalAttributes != INVALID_FILE_ATTRIBUTES);
VERIFY_IS_TRUE(SetFileAttributesW(testTempFolder.c_str(), originalAttributes | FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM));

ScopedEnvVariable temp(L"TEMP", testTempFolder.wstring());
ScopedEnvVariable tmp(L"TMP", testTempFolder.wstring());

VERIFY_IS_TRUE(std::filesystem::equivalent(std::filesystem::temp_directory_path(), testTempFolder));

constexpr USHORT port = 6666;
const auto endpoint = std::format(L"http://127.0.0.1:{}/", port);
constexpr auto fileName = L"downloaded-file.bin";
constexpr auto fileContent = L"wsl download test content";
UniqueWebServer server(endpoint.c_str(), fileContent);

const auto url = endpoint + fileName;
const auto noProgress = [](uint64_t, uint64_t) {};

wsl::shared::retry::RetryWithTimeout<void>(
[&]() {
wil::unique_socket probe{socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)};
THROW_LAST_ERROR_IF(!probe);

sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_port = htons(port);
address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);

THROW_LAST_ERROR_IF(connect(probe.get(), reinterpret_cast<const sockaddr*>(&address), sizeof(address)) == SOCKET_ERROR);
},
std::chrono::milliseconds(500),
std::chrono::seconds(5));

const auto firstPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress);

auto readFile = [](const std::filesystem::path& Path) {
std::ifstream file(Path, std::ios::binary);
VERIFY_IS_TRUE(file.good());
return std::string{std::istreambuf_iterator<char>(file), {}};
};

VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).parent_path(), testTempFolder);
VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).filename().wstring(), std::wstring(fileName));
VERIFY_IS_TRUE(std::filesystem::exists(firstPath));
VERIFY_ARE_EQUAL(readFile(firstPath), wsl::shared::string::WideToMultiByte(fileContent));

const auto secondPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress);

VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).parent_path(), testTempFolder);
VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).filename().wstring(), std::wstring(L"downloaded-file (2).bin"));
VERIFY_IS_TRUE(std::filesystem::exists(firstPath));
VERIFY_IS_TRUE(std::filesystem::exists(secondPath));
VERIFY_ARE_EQUAL(readFile(secondPath), wsl::shared::string::WideToMultiByte(fileContent));
}

}; // namespace UnitTests
} // namespace UnitTests