Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6354.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Session type validation is now properly enforced when creating sessions within scaling groups
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
MountNameValidationRule,
ScalingGroupAccessRule,
ServicePortRule,
SessionTypeRule,
SessionValidator,
)

Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(self, args: SchedulingControllerArgs) -> None:
validator_rules = [
ContainerLimitRule(),
ScalingGroupAccessRule(),
SessionTypeRule(),
ServicePortRule(),
ClusterValidationRule(),
MountNameValidationRule(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResourceLimitRule,
ScalingGroupAccessRule,
ServicePortRule,
SessionTypeRule,
)
from .validator import SessionValidator

Expand All @@ -16,6 +17,7 @@
"SessionValidatorRule",
"ContainerLimitRule",
"ScalingGroupAccessRule",
"SessionTypeRule",
"ServicePortRule",
"ResourceLimitRule",
"ClusterValidationRule",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Validator rules for session creation."""

from typing import Mapping
from typing import Mapping, override

from ai.backend.common.exception import BackendAIError
from ai.backend.common.service_ports import parse_service_ports
Expand All @@ -20,9 +20,11 @@
class ContainerLimitRule(SessionValidatorRule):
"""Validates cluster size against resource policy limits."""

@override
def name(self) -> str:
return "container_limit"

@override
def validate(
self,
spec: SessionCreationSpec,
Expand All @@ -39,9 +41,11 @@ def validate(
class ScalingGroupAccessRule(SessionValidatorRule):
"""Validates that the scaling group is accessible."""

@override
def name(self) -> str:
return "scaling_group_access"

@override
def validate(
self,
spec: SessionCreationSpec,
Expand All @@ -66,12 +70,44 @@ def validate(
raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")


class SessionTypeRule(SessionValidatorRule):
"""Validates session type compatibility with scaling group."""

@override
def name(self) -> str:
return "session_type"

@override
def validate(
self,
spec: SessionCreationSpec,
context: SessionCreationContext,
allowed_groups: list[AllowedScalingGroup],
) -> None:
if spec.scaling_group is None:
# Should have been resolved already
return

for sg in allowed_groups:
if sg.name == spec.scaling_group:
allowed_session_types = sg.scheduler_opts.allowed_session_types
if spec.session_type not in allowed_session_types:
raise InvalidAPIParameters(
f"Session type {spec.session_type} is not allowed in scaling group {sg.name}"
)
return

raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")

Comment on lines +91 to +101
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the spec's scaling group does not exist in allow list, an error should be raised.


Comment on lines +101 to +102
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation doesn't return after finding a matching scaling group. If the session type is allowed, the loop continues unnecessarily and the method doesn't explicitly return. Add an explicit return after line 93, or add a break after line 93 and a return after the loop completes successfully.

Suggested change
return

Copilot uses AI. Check for mistakes.
class ServicePortRule(SessionValidatorRule):
"""Validates preopen ports against service ports."""

@override
def name(self) -> str:
return "service_port"

@override
def validate(
self,
spec: SessionCreationSpec,
Expand Down Expand Up @@ -138,12 +174,14 @@ def validate(
class ResourceLimitRule(SessionValidatorRule):
"""Validates requested resources against image limits."""

@override
def name(self) -> str:
return "resource_limit"

def __init__(self, known_slot_types: Mapping[SlotName, SlotTypes] | None = None):
self._known_slot_types = known_slot_types

@override
def validate(
self,
spec: SessionCreationSpec,
Expand Down
152 changes: 150 additions & 2 deletions tests/manager/sokovan/scheduling_controller/validators/test_rules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
"""Tests for validation rules."""

import uuid
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Optional
from unittest.mock import MagicMock

import pytest

from ai.backend.common.types import SessionTypes
import yarl

from ai.backend.common.types import (
AccessKey,
ClusterMode,
KernelEnqueueingConfig,
SessionId,
SessionTypes,
)
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.errors.kernel import QuotaExceeded
from ai.backend.manager.models import NetworkRow
from ai.backend.manager.models.scaling_group import ScalingGroupOpts
from ai.backend.manager.repositories.scheduler.types.session_creation import (
AllowedScalingGroup,
Expand All @@ -22,7 +34,9 @@
ContainerLimitRule,
ScalingGroupAccessRule,
ServicePortRule,
SessionTypeRule,
)
from ai.backend.manager.types import UserScope


@pytest.fixture
Expand Down Expand Up @@ -55,6 +69,67 @@ def basic_context():
)


@pytest.fixture
def session_spec_factory() -> Callable[..., SessionCreationSpec]:
def create_spec(
session_creation_id: str = "test-001",
session_name: str = "test-session",
access_key: AccessKey = AccessKey("test-key"),
user_scope: UserScope = UserScope(
domain_name="default",
group_id=uuid.uuid4(),
user_uuid=uuid.uuid4(),
user_role="user",
),
session_type: SessionTypes = SessionTypes.INTERACTIVE,
cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE,
cluster_size: int = 1,
priority: int = 10,
resource_policy: dict | None = None,
kernel_specs: list[KernelEnqueueingConfig] | None = None,
creation_spec: dict | None = None,
scaling_group: Optional[str] = None,
session_tag: Optional[str] = None,
starts_at: Optional[datetime] = None,
batch_timeout: Optional[timedelta] = None,
dependency_sessions: Optional[list[SessionId]] = None,
callback_url: Optional[yarl.URL] = None,
route_id: Optional[uuid.UUID] = None,
sudo_session_enabled: bool = False,
network: Optional[NetworkRow] = None,
designated_agent_list: Optional[list[str]] = None,
internal_data: Optional[dict] = None,
public_sgroup_only: bool = True,
) -> SessionCreationSpec:
return SessionCreationSpec(
session_creation_id=session_creation_id,
session_name=session_name,
access_key=access_key,
user_scope=user_scope,
session_type=session_type,
cluster_mode=cluster_mode,
cluster_size=cluster_size,
priority=priority,
resource_policy=resource_policy or {},
kernel_specs=kernel_specs or [],
creation_spec=creation_spec or {},
scaling_group=scaling_group,
session_tag=session_tag,
starts_at=starts_at,
batch_timeout=batch_timeout,
dependency_sessions=dependency_sessions,
callback_url=callback_url,
route_id=route_id,
sudo_session_enabled=sudo_session_enabled,
network=network,
designated_agent_list=designated_agent_list,
internal_data=internal_data,
public_sgroup_only=public_sgroup_only,
)

return create_spec


class TestContainerLimitRule:
"""Test cases for ContainerLimitRule."""

Expand Down Expand Up @@ -211,6 +286,79 @@ def test_inaccessible_sgroup(self, basic_context):
assert "not accessible" in str(exc_info.value)


class TestSessionTypeRule:
"""Test cases for SessionTypeRule."""

def test_allowed_session_type(
self,
basic_context: SessionCreationContext,
session_spec_factory: Callable[..., SessionCreationSpec],
) -> None:
"""Test session type that is allowed in scaling group."""
rule = SessionTypeRule()

allowed_groups = [
AllowedScalingGroup(
name="test-sg",
is_private=False,
scheduler_opts=ScalingGroupOpts(
allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.BATCH]
),
)
]

spec = session_spec_factory(
session_type=SessionTypes.INTERACTIVE,
scaling_group="test-sg",
)

# Should not raise
rule.validate(spec, basic_context, allowed_groups)

def test_disallowed_session_type(
self,
basic_context: SessionCreationContext,
session_spec_factory: Callable[..., SessionCreationSpec],
) -> None:
"""Test session type that is not allowed in scaling group."""
rule = SessionTypeRule()

allowed_groups = [
AllowedScalingGroup(
name="batch-only-sg",
is_private=False,
scheduler_opts=ScalingGroupOpts(allowed_session_types=[SessionTypes.BATCH]),
)
]

spec = session_spec_factory(
session_type=SessionTypes.INTERACTIVE,
scaling_group="batch-only-sg",
)

with pytest.raises(InvalidAPIParameters) as exc_info:
rule.validate(spec, basic_context, allowed_groups)
assert "not allowed in scaling group" in str(exc_info.value)

def test_empty_allowed_groups(
self,
basic_context: SessionCreationContext,
session_spec_factory: Callable[..., SessionCreationSpec],
) -> None:
"""Test with empty allowed groups list."""
rule = SessionTypeRule()

spec = session_spec_factory(
session_type=SessionTypes.INTERACTIVE,
scaling_group="any-sg",
)

# Should raise - no allowed groups available
with pytest.raises(InvalidAPIParameters) as exc_info:
rule.validate(spec, basic_context, [])
assert "not accessible" in str(exc_info.value)


class TestServicePortRule:
"""Test cases for ServicePortRule."""

Expand Down
Loading