Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/windows/WslcSDK/wslcsdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Module Name:
#include "ProgressCallback.h"
#include "TerminationCallback.h"
#include "wslutil.h"
#include "WSLCSessionDefaults.h"

using namespace std::string_view_literals;
using namespace wsl::windows::common::wslutil;
Expand Down Expand Up @@ -442,6 +443,27 @@ try
}
CATCH_RETURN();

STDAPI WslcGetCliSession(_Out_ WslcSession* session, _Outptr_opt_result_z_ PWSTR* errorMessage)
try
{
RETURN_HR_IF_NULL(E_POINTER, session);
*session = nullptr;
ErrorInfoWrapper errorInfoWrapper{errorMessage};

wil::com_ptr<IWSLCSessionManager> sessionManager = CreateSessionManager();

auto result = std::make_unique<WslcSessionImpl>();
auto defaultSessionName = wsl::windows::common::WSLCSessionDefaults::GetDefaultSessionName();
if (SUCCEEDED(errorInfoWrapper.CaptureResult(sessionManager->OpenSessionByName(defaultSessionName, &result->session))))
{
wsl::windows::common::security::ConfigureForCOMImpersonation(result->session.get());
*session = reinterpret_cast<WslcSession>(result.release());
}

return errorInfoWrapper;
}
CATCH_RETURN();

STDAPI WslcTerminateSession(_In_ WslcSession session)
try
{
Expand Down
1 change: 1 addition & 0 deletions src/windows/WslcSDK/wslcsdk.def
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ WslcInitContainerSettings
WslcInitProcessSettings

WslcCreateSession
WslcGetCliSession
WslcCreateContainer

WslcReleaseSession
Expand Down
2 changes: 2 additions & 0 deletions src/windows/WslcSDK/wslcsdk.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ STDAPI WslcInitSessionSettings(_In_ PCWSTR name, _In_ PCWSTR storagePath, _Out_

STDAPI WslcCreateSession(_In_ WslcSessionSettings* sessionSettings, _Out_ WslcSession* session, _Outptr_opt_result_z_ PWSTR* errorMessage);

STDAPI WslcGetCliSession(_Out_ WslcSession* session, _Outptr_opt_result_z_ PWSTR* errorMessage);

// OPTIONAL SESSION SETTINGS
STDAPI WslcSetSessionSettingsCpuCount(_In_ WslcSessionSettings* sessionSettings, _In_ uint32_t cpuCount);
STDAPI WslcSetSessionSettingsMemory(_In_ WslcSessionSettings* sessionSettings, _In_ uint32_t memoryMb);
Expand Down
1 change: 1 addition & 0 deletions src/windows/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ set(HEADERS
WslCoreMessageQueue.h
WslCoreNetworkEndpointSettings.h
WslCoreNetworkingSupport.h
WSLCSessionDefaults.h
WslInstall.h
WslSecurity.h
WslTelemetry.h
Expand Down
44 changes: 44 additions & 0 deletions src/windows/common/WSLCSessionDefaults.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

/*++

Copyright (c) Microsoft. All rights reserved.

Module Name:

WSLCSessionDefaults.h

Abstract:

This file contains default WSLC session name helpers.

--*/

#pragma once

#include <string>
#include "WslSecurity.h"
#include "stringshared.h"

namespace wsl::windows::common {

class WSLCSessionDefaults
{
public:
// These are elevation-aware static methods that will return the correct
// session name or validate against the correct session name based on the
// elevation of the process.
static const wchar_t* GetDefaultSessionName()
{
return wsl::windows::common::security::IsElevatedOrAbove() ? defaultAdminSessionName : defaultSessionName;
}

static bool IsDefaultSessionName(const std::wstring& sessionName)
{
return wsl::shared::string::IsEqual(sessionName, GetDefaultSessionName());
}

static constexpr const wchar_t defaultSessionName[] = L"wslc-cli";
static constexpr const wchar_t defaultAdminSessionName[] = L"wslc-cli-admin";
};

} // namespace wsl::windows::common
6 changes: 6 additions & 0 deletions src/windows/common/WslSecurity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ bool wsl::windows::common::security::IsTokenLocalSystem(_In_opt_ HANDLE token)
return member ? true : false;
}

bool wsl::windows::common::security::IsElevatedOrAbove()
{
auto token = wil::open_current_access_token(TOKEN_QUERY);
return GetUserBasicIntegrityLevel(token.get()) >= SECURITY_MANDATORY_HIGH_RID;
}

wsl::windows::common::security::unique_revert_to_self wsl::windows::common::security::RpcImpersonateCaller(_In_ RPC_BINDING_HANDLE handle)
{
THROW_IF_WIN32_ERROR(static_cast<DWORD>(RpcImpersonateClient(handle)));
Expand Down
5 changes: 5 additions & 0 deletions src/windows/common/WslSecurity.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ bool IsTokenElevated(_In_ HANDLE token);
/// </summary>
bool IsTokenLocalSystem(_In_opt_ HANDLE token);

/// <summary>
/// Returns true if the current context is elevated or above (e.g. local system).
/// </summary>
bool IsElevatedOrAbove();

/// <summary>
/// Impersonates the RPC caller
/// </summary>
Expand Down
33 changes: 7 additions & 26 deletions src/windows/wslc/services/SessionModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,13 @@ Module Name:
#include <precomp.h>
#include "SessionModel.h"
#include "UserSettings.h"
#include "WSLCSessionDefaults.h"

namespace wsl::windows::wslc::models {

const wchar_t* SessionOptions::GetDefaultSessionName()
{
return IsElevated() ? s_defaultAdminSessionName : s_defaultSessionName;
}

bool SessionOptions::IsDefaultSessionName(const std::wstring& sessionName)
{
// Only returns true for the default session name that matches current elevation.
return wsl::shared::string::IsEqual(sessionName, GetDefaultSessionName());
}

SessionOptions::SessionOptions()
{
m_sessionSettings.DisplayName = GetDefaultSessionName();
m_sessionSettings.DisplayName = wsl::windows::common::WSLCSessionDefaults::GetDefaultSessionName();
m_sessionSettings.StoragePath = GetStoragePath().c_str();
m_sessionSettings.CpuCount = settings::User().Get<settings::Setting::SessionCpuCount>();
m_sessionSettings.MemoryMb = settings::User().Get<settings::Setting::SessionMemoryMb>();
Expand All @@ -49,17 +39,6 @@ SessionOptions::SessionOptions()
}
}

bool SessionOptions::IsElevated()
{
auto token = wil::open_current_access_token(TOKEN_QUERY);

// IsTokenElevated checks if the integrity level is exactly HIGH.
// We must also check for local system because it is above HIGH.
// However, IsTokenLocalSystem() does not work correctly and fails.
// TODO: Add proper handling for system user callers.
return wsl::windows::common::security::IsTokenElevated(token.get());
}

const std::filesystem::path& SessionOptions::GetStoragePath()
{
static const std::filesystem::path basePath = []() {
Expand All @@ -68,10 +47,12 @@ const std::filesystem::path& SessionOptions::GetStoragePath()
: settings::User().Get<settings::Setting::SessionStoragePath>().c_str();
}();

static const std::filesystem::path storagePathNonAdmin = basePath / std::wstring{s_defaultSessionName};
static const std::filesystem::path storagePathAdmin = basePath / std::wstring{s_defaultAdminSessionName};
static const std::filesystem::path storagePathNonAdmin =
basePath / std::wstring{wsl::windows::common::WSLCSessionDefaults::defaultSessionName};
static const std::filesystem::path storagePathAdmin =
basePath / std::wstring{wsl::windows::common::WSLCSessionDefaults::defaultAdminSessionName};

return IsElevated() ? storagePathAdmin : storagePathNonAdmin;
return wsl::windows::common::security::IsElevatedOrAbove ? storagePathAdmin : storagePathNonAdmin;
}

} // namespace wsl::windows::wslc::models
10 changes: 0 additions & 10 deletions src/windows/wslc/services/SessionModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ struct Session
class SessionOptions
{
public:
// These are elevation-aware static methods that will return the correct
// session name or validate against the correct session name based on the
// elevation of the process.
static const wchar_t* GetDefaultSessionName();
static bool IsDefaultSessionName(const std::wstring& sessionName);

SessionOptions();

static const std::filesystem::path& GetStoragePath();
Expand All @@ -55,13 +49,9 @@ class SessionOptions
}

private:
static constexpr const wchar_t s_defaultSessionName[] = L"wslc-cli";
static constexpr const wchar_t s_defaultAdminSessionName[] = L"wslc-cli-admin";
static constexpr const wchar_t s_defaultStorageSubPath[] = L"wslc\\sessions";
static constexpr uint32_t s_defaultBootTimeoutMs = 30 * 1000;

static bool IsElevated();

WSLCSessionSettings m_sessionSettings{};
};

Expand Down
7 changes: 4 additions & 3 deletions src/windows/wslc/tasks/SessionTasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Module Name:
#include "SessionTasks.h"
#include "TableOutput.h"
#include "Task.h"
#include "WSLCSessionDefaults.h"

using namespace wsl::shared;
using namespace wsl::shared::string;
Expand All @@ -38,7 +39,7 @@ void AttachToSession(CLIExecutionContext& context)
}
else
{
sessionId = SessionOptions::GetDefaultSessionName();
sessionId = wsl::windows::common::WSLCSessionDefaults::GetDefaultSessionName();
}

context.ExitCode = SessionService::Attach(sessionId);
Expand All @@ -54,7 +55,7 @@ void CreateSession(CLIExecutionContext& context)
// a non-admin session will fail to create but succeed to open, preventing
// accidental creation of a non-admin session with admin permissions.
const auto& sessionName = context.Args.Get<ArgType::Session>();
if (!SessionOptions::IsDefaultSessionName(sessionName))
if (!wsl::windows::common::WSLCSessionDefaults::IsDefaultSessionName(sessionName))
{
context.Data.Add<Data::Session>(SessionService::OpenSession(sessionName));
return;
Expand Down Expand Up @@ -100,7 +101,7 @@ void TerminateSession(CLIExecutionContext& context)
}
else
{
sessionId = SessionOptions::GetDefaultSessionName();
sessionId = wsl::windows::common::WSLCSessionDefaults::GetDefaultSessionName();
}

context.ExitCode = SessionService::TerminateSession(sessionId);
Expand Down
34 changes: 34 additions & 0 deletions test/windows/WslcSdkTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Module Name:
#include "Common.h"
#include "wslcsdk.h"
#include "wslc_schema.h"
#include "e2e/WSLCExecutor.h"
#include <optional>

extern std::wstring g_testDataPath;
Expand All @@ -40,6 +41,16 @@ void CloseSession(WslcSession session)

using UniqueSession = wil::unique_any<WslcSession, decltype(CloseSession), CloseSession>;

void ReleaseSession(WslcSession session)
{
if (session)
{
WslcReleaseSession(session);
}
}

using UniqueSessionRef = wil::unique_any<WslcSession, decltype(ReleaseSession), ReleaseSession>;

void CloseContainer(WslcContainer container)
{
if (container)
Expand Down Expand Up @@ -263,6 +274,29 @@ class WslcSdkTests
VERIFY_ARE_EQUAL(WslcCreateSession(nullptr, &session2, nullptr), E_POINTER);
}

WSLC_TEST_METHOD(GetCliSession)
{
// Null output pointer must fail.
VERIFY_ARE_EQUAL(WslcGetCliSession(nullptr, nullptr), E_POINTER);

// Ensure no CLI session is running.
WSLCE2ETests::RunWslc(L"session terminate");

// WslcGetCliSession must return ERROR_NOT_FOUND when no CLI session exists.
UniqueSessionRef notFoundSession;
VERIFY_ARE_EQUAL(WslcGetCliSession(&notFoundSession, nullptr), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
VERIFY_IS_NULL(notFoundSession.get());

// Start the CLI session by running a wslc command.
auto result = WSLCE2ETests::RunWslc(L"container list");
VERIFY_ARE_EQUAL(result.ExitCode.value(), (DWORD)0);

// Now WslcGetCliSession should find the running CLI session.
UniqueSessionRef foundSession;
VERIFY_SUCCEEDED(WslcGetCliSession(&foundSession, nullptr));
VERIFY_IS_NOT_NULL(foundSession.get());
}

WSLC_TEST_METHOD(TerminationCallbackViaTerminate)
{
std::promise<WslcSessionTerminationReason> promise;
Expand Down
20 changes: 5 additions & 15 deletions test/windows/wslc/e2e/WSLCE2EContainerCreateTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,10 +1028,8 @@ class WSLCE2EContainerCreateTests
result.Verify({.Stdout = L"unable to find group badgid: no matching entries in group file\r\n", .ExitCode = 126});
}

TEST_METHOD(WSLCE2E_Container_Create_Tmpfs)
WSLC_TEST_METHOD(WSLCE2E_Container_Create_Tmpfs)
{
WSL2_TEST_ONLY();

auto result = RunWslc(std::format(
L"container create --name {} --tmpfs /wslc-tmpfs {} sh -c \"echo -n 'tmpfs_test' > /wslc-tmpfs/data && cat "
L"/wslc-tmpfs/data\"",
Expand All @@ -1043,10 +1041,8 @@ class WSLCE2EContainerCreateTests
result.Verify({.Stdout = L"tmpfs_test", .Stderr = L"", .ExitCode = 0});
}

TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_With_Options)
WSLC_TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_With_Options)
{
WSL2_TEST_ONLY();

auto result = RunWslc(std::format(
L"container create --name {} --tmpfs /wslc-tmpfs:size=64k {} sh -c \"mount | grep -q ' on /wslc-tmpfs type tmpfs ' "
L"&& echo mounted\"",
Expand All @@ -1058,10 +1054,8 @@ class WSLCE2EContainerCreateTests
result.Verify({.Stdout = L"mounted\n", .Stderr = L"", .ExitCode = 0});
}

TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_Multiple_With_Options)
WSLC_TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_Multiple_With_Options)
{
WSL2_TEST_ONLY();

auto result = RunWslc(std::format(
L"container create --name {} --tmpfs /wslc-tmpfs1:size=64k --tmpfs /wslc-tmpfs2:size=128k {} sh -c \"mount | grep -q "
L"' on /wslc-tmpfs1 type tmpfs ' && mount | grep -q ' on /wslc-tmpfs2 type tmpfs ' && echo mounted\"",
Expand All @@ -1073,19 +1067,15 @@ class WSLCE2EContainerCreateTests
result.Verify({.Stdout = L"mounted\n", .Stderr = L"", .ExitCode = 0});
}

TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_RelativePath_Fails)
WSLC_TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_RelativePath_Fails)
{
WSL2_TEST_ONLY();

auto result =
RunWslc(std::format(L"container create --name {} --tmpfs wslc-tmpfs {}", WslcContainerName, DebianImage.NameAndTag()));
result.Verify({.Stderr = L"invalid mount path: 'wslc-tmpfs' mount path must be absolute\r\nError code: E_FAIL\r\n", .ExitCode = 1});
}

TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_EmptyDestination_Fails)
WSLC_TEST_METHOD(WSLCE2E_Container_Create_Tmpfs_EmptyDestination_Fails)
{
WSL2_TEST_ONLY();

auto result =
RunWslc(std::format(L"container create --name {} --tmpfs :size=64k {}", WslcContainerName, DebianImage.NameAndTag()));
result.Verify({.Stderr = L"invalid mount path: '' mount path must be absolute\r\nError code: E_FAIL\r\n", .ExitCode = 1});
Expand Down
Loading
Loading