Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/9891.test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add component tests for session lifecycle management
103 changes: 100 additions & 3 deletions tests/component/session/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class SessionSeedData:
kernel_id: uuid.UUID
access_key: str
domain_name: str
user_uuid: uuid.UUID | None = None


@pytest.fixture()
def scheduling_controller_mock() -> AsyncMock:
"""Mock SchedulingController exposed as a fixture for test-level configuration."""
return AsyncMock()


@pytest.fixture()
Expand All @@ -68,6 +75,7 @@ async def session_processors(
background_task_manager: BackgroundTaskManager,
error_monitor: ErrorPluginContext,
appproxy_client_pool: AsyncMock,
scheduling_controller_mock: AsyncMock,
) -> SessionProcessors:
"""Real SessionProcessors with real SessionService and SessionRepository."""
session_repo = SessionRepository(database_engine)
Expand All @@ -79,7 +87,7 @@ async def session_processors(
error_monitor=error_monitor,
idle_checker_host=AsyncMock(),
session_repository=session_repo,
scheduling_controller=AsyncMock(),
scheduling_controller=scheduling_controller_mock,
appproxy_client_pool=appproxy_client_pool,
)
service = SessionService(args)
Expand Down Expand Up @@ -223,10 +231,11 @@ async def session_seed(
kernel_id=kernel_id,
access_key=admin_user_fixture.keypair.access_key,
domain_name=domain_fixture,
user_uuid=admin_user_fixture.user_uuid,
)

async with db_engine.begin() as conn:
await conn.execute(kernels.delete().where(kernels.c.id == kernel_id))
await conn.execute(kernels.delete().where(kernels.c.session_id == session_id))
await conn.execute(
SessionRow.__table__.delete().where(SessionRow.__table__.c.id == session_id)
)
Expand Down Expand Up @@ -317,10 +326,98 @@ async def terminated_session_seed(
kernel_id=kernel_id,
access_key=admin_user_fixture.keypair.access_key,
domain_name=domain_fixture,
user_uuid=admin_user_fixture.user_uuid,
)

async with db_engine.begin() as conn:
await conn.execute(kernels.delete().where(kernels.c.session_id == session_id))
await conn.execute(
SessionRow.__table__.delete().where(SessionRow.__table__.c.id == session_id)
)


@pytest.fixture()
async def user_session_seed(
db_engine: SAEngine,
domain_fixture: str,
group_fixture: uuid.UUID,
regular_user_fixture: UserFixtureData,
scaling_group_fixture: str,
) -> AsyncIterator[SessionSeedData]:
"""Seed a RUNNING session owned by the regular user."""
unique = secrets.token_hex(4)
session_id = SessionId(uuid.uuid4())
session_name = f"test-user-session-{unique}"
kernel_id = uuid.uuid4()
now = datetime.now(tzutc())

status_history: dict[str, Any] = {
SessionStatus.PENDING.name: now.isoformat(),
SessionStatus.RUNNING.name: now.isoformat(),
}

async with db_engine.begin() as conn:
await conn.execute(
sa.insert(SessionRow.__table__).values(
id=session_id,
creation_id=f"cid-{unique}",
name=session_name,
session_type=SessionTypes.INTERACTIVE,
cluster_size=1,
cluster_mode="single-node",
domain_name=domain_fixture,
group_id=group_fixture,
user_uuid=regular_user_fixture.user_uuid,
access_key=regular_user_fixture.keypair.access_key,
scaling_group_name=scaling_group_fixture,
status=SessionStatus.RUNNING,
status_info="",
status_history=status_history,
occupying_slots=ResourceSlot(),
requested_slots=ResourceSlot(),
created_at=now,
)
)
await conn.execute(
sa.insert(kernels).values(
id=kernel_id,
session_id=session_id,
session_creation_id=f"cid-{unique}",
session_name=session_name,
session_type=SessionTypes.INTERACTIVE,
cluster_role="main",
cluster_idx=0,
cluster_hostname="main0",
cluster_mode="single-node",
cluster_size=1,
domain_name=domain_fixture,
group_id=group_fixture,
user_uuid=regular_user_fixture.user_uuid,
access_key=regular_user_fixture.keypair.access_key,
scaling_group=scaling_group_fixture,
status=KernelStatus.RUNNING,
status_info="",
occupied_slots=ResourceSlot(),
requested_slots=ResourceSlot(),
repl_in_port=0,
repl_out_port=0,
stdin_port=0,
stdout_port=0,
created_at=now,
)
)

yield SessionSeedData(
session_id=session_id,
session_name=session_name,
kernel_id=kernel_id,
access_key=regular_user_fixture.keypair.access_key,
domain_name=domain_fixture,
user_uuid=regular_user_fixture.user_uuid,
)

async with db_engine.begin() as conn:
await conn.execute(kernels.delete().where(kernels.c.id == kernel_id))
await conn.execute(kernels.delete().where(kernels.c.session_id == session_id))
await conn.execute(
SessionRow.__table__.delete().where(SessionRow.__table__.c.id == session_id)
)
185 changes: 183 additions & 2 deletions tests/component/session/test_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from unittest.mock import AsyncMock

import pytest

from ai.backend.client.exceptions import BackendAPIError
from ai.backend.client.v2.exceptions import NotFoundError
from ai.backend.client.v2.registry import BackendAIClientRegistry
from ai.backend.common.dto.manager.session.request import (
Expand All @@ -16,6 +19,8 @@
GetStatusHistoryResponse,
MatchSessionsResponse,
)
from ai.backend.manager.data.session.types import SessionStatus
from ai.backend.manager.repositories.scheduler.types.session import MarkTerminatingResult

from .conftest import SessionSeedData

Expand All @@ -31,7 +36,7 @@ async def test_admin_gets_session_info(
# LegacySessionInfo.asdict() uses camelCase keys and does not include
# session id/name — verify that the returned dict contains expected
# fields from the seeded row instead.
assert result.root["status"] == "RUNNING"
assert result.root["status"] == SessionStatus.RUNNING.name
assert result.root["domainName"] == session_seed.domain_name

async def test_get_nonexistent_session_returns_not_found(
Expand All @@ -56,7 +61,7 @@ async def test_admin_renames_session(
# Verify the rename took effect by fetching with the new name
result = await admin_registry.session.get_info(new_name)
assert isinstance(result, GetSessionInfoResponse)
assert result.root["status"] == "RUNNING"
assert result.root["status"] == SessionStatus.RUNNING.name

async def test_rename_nonexistent_session_returns_not_found(
self,
Expand All @@ -68,6 +73,68 @@ async def test_rename_nonexistent_session_returns_not_found(
RenameSessionRequest(session_name="new-name"),
)

async def test_rename_back_and_forth(
self,
admin_registry: BackendAIClientRegistry,
session_seed: SessionSeedData,
) -> None:
"""Verifies that rename is fully reversible — after each rename the new name
resolves successfully while the old name returns NotFoundError.
"""
original_name = session_seed.session_name
new_name = f"{original_name}-lifecycle-test"

await admin_registry.session.rename(
original_name,
RenameSessionRequest(session_name=new_name),
)
result = await admin_registry.session.get_info(new_name)
assert result.root["status"] == SessionStatus.RUNNING.name
with pytest.raises(NotFoundError):
await admin_registry.session.get_info(original_name)

await admin_registry.session.rename(
new_name,
RenameSessionRequest(session_name=original_name),
)
result = await admin_registry.session.get_info(original_name)
assert result.root["status"] == SessionStatus.RUNNING.name
with pytest.raises(NotFoundError):
await admin_registry.session.get_info(new_name)

async def test_rename_to_same_name_is_rejected(
self,
admin_registry: BackendAIClientRegistry,
session_seed: SessionSeedData,
) -> None:
"""The server rejects a no-op rename as an invalid request
rather than silently succeeding.
"""
with pytest.raises(BackendAPIError):
await admin_registry.session.rename(
session_seed.session_name,
RenameSessionRequest(session_name=session_seed.session_name),
)

async def test_user_renames_own_session(
self,
user_registry: BackendAIClientRegistry,
user_session_seed: SessionSeedData,
) -> None:
"""Non-admin users have rename permission on sessions they own,
and the old name is no longer resolvable after rename.
"""
original_name = user_session_seed.session_name
new_name = f"{original_name}-renamed"
await user_registry.session.rename(
original_name,
RenameSessionRequest(session_name=new_name),
)
result = await user_registry.session.get_info(new_name)
assert result.root["status"] == SessionStatus.RUNNING.name
with pytest.raises(NotFoundError):
await user_registry.session.get_info(original_name)


class TestSessionMatchSessions:
async def test_admin_matches_sessions_by_name(
Expand Down Expand Up @@ -139,6 +206,52 @@ async def test_destroy_nonexistent_session_returns_not_found(
DestroySessionRequest(forced=True),
)

async def test_destroy_already_terminated_session_succeeds(
self,
admin_registry: BackendAIClientRegistry,
terminated_session_seed: SessionSeedData,
scheduling_controller_mock: AsyncMock,
) -> None:
"""Forced destroy is idempotent — calling destroy on an already-terminated
session does not raise an error and returns status "terminated".
"""
scheduling_controller_mock.mark_sessions_for_termination.return_value = (
MarkTerminatingResult(
cancelled_sessions=[],
terminating_sessions=[],
force_terminated_sessions=[terminated_session_seed.session_id],
skipped_sessions=[],
)
)
result = await admin_registry.session.destroy(
terminated_session_seed.session_name,
DestroySessionRequest(forced=True),
)
assert isinstance(result, DestroySessionResponse)
assert result.root["stats"]["status"] == "terminated"

async def test_user_destroys_own_session(
self,
user_registry: BackendAIClientRegistry,
user_session_seed: SessionSeedData,
scheduling_controller_mock: AsyncMock,
) -> None:
"""Non-admin users have destroy permission on sessions they own."""
scheduling_controller_mock.mark_sessions_for_termination.return_value = (
MarkTerminatingResult(
cancelled_sessions=[],
terminating_sessions=[],
force_terminated_sessions=[user_session_seed.session_id],
skipped_sessions=[],
)
)
result = await user_registry.session.destroy(
user_session_seed.session_name,
DestroySessionRequest(forced=True),
)
assert isinstance(result, DestroySessionResponse)
assert result.root["stats"]["status"] == "terminated"


class TestSessionGetContainerLogs:
async def test_admin_gets_container_logs_for_terminated_session(
Expand All @@ -160,3 +273,71 @@ async def test_get_container_logs_nonexistent_session(
await admin_registry.session.get_container_logs(
"nonexistent-session-xyz-99999",
)


class TestSessionPermissions:
"""Test role-based access control for session operations."""

async def test_user_gets_own_session_info(
self,
user_registry: BackendAIClientRegistry,
user_session_seed: SessionSeedData,
) -> None:
"""Users can read their own session metadata including
status and domain name through the get_info endpoint.
"""
result = await user_registry.session.get_info(user_session_seed.session_name)
assert result.root["status"] == SessionStatus.RUNNING.name
assert result.root["domainName"] == user_session_seed.domain_name

async def test_user_cannot_access_admin_session(
self,
user_registry: BackendAIClientRegistry,
session_seed: SessionSeedData,
) -> None:
"""Session visibility is scoped by access key — a user keypair
cannot resolve sessions belonging to a different keypair.
"""
with pytest.raises((NotFoundError, BackendAPIError)):
await user_registry.session.get_info(session_seed.session_name)

async def test_admin_cannot_access_user_session_without_ownership(
self,
admin_registry: BackendAIClientRegistry,
user_session_seed: SessionSeedData,
) -> None:
"""The get_info handler resolves scope using the requester's own access
key, so sessions owned by other access keys are not found. This is a
known limitation of the current implementation that may change in
future refactoring.
"""
with pytest.raises(NotFoundError):
await admin_registry.session.get_info(user_session_seed.session_name)

async def test_user_cannot_destroy_admin_session(
self,
user_registry: BackendAIClientRegistry,
session_seed: SessionSeedData,
) -> None:
"""The destroy endpoint enforces ownership — the session is not
resolvable under the user's access key scope.
"""
with pytest.raises((NotFoundError, BackendAPIError)):
await user_registry.session.destroy(
session_seed.session_name,
DestroySessionRequest(forced=True),
)

async def test_user_cannot_rename_admin_session(
self,
user_registry: BackendAIClientRegistry,
session_seed: SessionSeedData,
) -> None:
"""The rename endpoint enforces ownership — the session is not
resolvable under the user's access key scope.
"""
with pytest.raises((NotFoundError, BackendAPIError)):
await user_registry.session.rename(
session_seed.session_name,
RenameSessionRequest(session_name="hacked-name"),
)
Loading
Loading