diff --git a/changes/10916.breaking.md b/changes/10916.breaking.md new file mode 100644 index 00000000000..15aa2541d1e --- /dev/null +++ b/changes/10916.breaking.md @@ -0,0 +1 @@ +Change the REST v1 session API's delegation mechanism from the `owner_access_key` query parameter to an `owner_id` (user UUID) field. The `owner_access_key` parameter is removed from all session endpoints (`GET /session/{name}`, `DELETE /session/{name}`, `POST /session/_/create-from-template`, `POST /session/_/create`, `POST /session/_/create-cluster`, `GET /session/{name}/logs`, `GET /session/{name}/status-history`, etc.). Clients that previously passed `owner_access_key=` to act on behalf of another user must now pass `owner_id=` (and only on the session-creation endpoints; for read/control endpoints the caller always acts as themselves). Session and kernel `access_key` semantics also change: the value is no longer tied to the keypair used to create the session but resolved at read time from the owner's `main_access_key`. diff --git a/changes/11040.breaking.md b/changes/11040.breaking.md new file mode 100644 index 00000000000..970089854d3 --- /dev/null +++ b/changes/11040.breaking.md @@ -0,0 +1 @@ +Drop the `access_key` column from `sessions` and `kernels` tables; the owner's keypair is now resolved from `users.main_access_key` at read time, with `user_uuid` remaining as the canonical owner reference. diff --git a/changes/11051.enhance.md b/changes/11051.enhance.md new file mode 100644 index 00000000000..47e2c0e2812 --- /dev/null +++ b/changes/11051.enhance.md @@ -0,0 +1 @@ +Test updates and remaining ORM/repository touch-ups for the BA-5650 stack: scheduler db_source, keypair and endpoint row cleanup, gql_legacy endpoint/routing, session lifecycle/service tests, sokovan scheduler tests, and dependency-injection tests. diff --git a/changes/BA-5650-E.misc.md b/changes/BA-5650-E.misc.md deleted file mode 100644 index 63b2ee24773..00000000000 --- a/changes/BA-5650-E.misc.md +++ /dev/null @@ -1 +0,0 @@ -Collapse scheduler / predicates / scheduler options signatures to take `owner_id: UUID`. Rename `access_key` field to `main_access_key` on `ScheduledSessionData` / `TerminatingSessionData` / `SweptSessionInfo` / scheduler types. diff --git a/changes/BA-5650-H.breaking.md b/changes/BA-5650-H.breaking.md new file mode 100644 index 00000000000..7386d44c391 --- /dev/null +++ b/changes/BA-5650-H.breaking.md @@ -0,0 +1 @@ +**Breaking**: Remove `owner_access_key` query parameter from REST v1 session endpoints. Delegation is now performed via `owner_id` (user UUID) and only on the session creation endpoints (`/session/_/create-from-template`, `/session/_/create`, `/session/_/create-cluster`). Read/control endpoints always act as the authenticated caller. Clients that previously passed `owner_access_key=` must migrate to `owner_id=`. diff --git a/changes/BA-5650-I.misc.md b/changes/BA-5650-I.misc.md new file mode 100644 index 00000000000..d9effd4a79b --- /dev/null +++ b/changes/BA-5650-I.misc.md @@ -0,0 +1 @@ +Test updates and remaining ORM/repository touch-ups for the BA-5650 stack: scheduler db_source, keypair and endpoint row cleanup, gql_legacy endpoint/routing, session lifecycle/service tests, sokovan scheduler tests, and dependency-injection tests. No external behavior change beyond what earlier slices documented. diff --git a/dev b/dev index 36f2a7e4427..4e41201e45b 100755 --- a/dev +++ b/dev @@ -149,32 +149,14 @@ cmd_restart() { cmd_log() { local svc=$1 - local follow=${2:-} local winname winname=$(_tmux_window_name "$svc") local win win=$(tmux list-windows -t "$TMUX_SESSION" -F "#{window_name}" 2>/dev/null | grep "^${winname}$" | head -1) || true - if [ -z "$win" ]; then - echo "$(_color red "No tmux window found for $svc")" - return 1 - fi - if [ "$follow" = "-f" ]; then - local last_hash="" - trap 'exit 0' INT - while true; do - local output - output=$(tmux capture-pane -t "$TMUX_SESSION:$win" -p -S -50 2>/dev/null) - local cur_hash - cur_hash=$(echo "$output" | md5sum | cut -d' ' -f1) - if [ "$cur_hash" != "$last_hash" ]; then - clear - echo "$output" - last_hash=$cur_hash - fi - sleep 1 - done - else + if [ -n "$win" ]; then tmux capture-pane -t "$TMUX_SESSION:$win" -p -S -50 + else + echo "$(_color red "No tmux window found for $svc")" fi } @@ -188,7 +170,7 @@ Commands: start Start a service stop Stop a service restart Restart a service - log [-f] Show recent log output (-f to follow) + log Show recent log output Services: mgr, agent, storage, web, proxy-coordinator, proxy-worker, all @@ -225,9 +207,9 @@ case "$1" in "cmd_$1" "$2" ;; log) - [ $# -lt 2 ] && { echo "$(_color red "Usage: ./dev log [-f]")"; exit 1; } + [ $# -lt 2 ] && { echo "$(_color red "Usage: ./dev log ")"; exit 1; } _validate_service "$2" - cmd_log "$2" "${3:-}" + cmd_log "$2" ;; *) echo "$(_color red "Unknown command: $1")"; usage; exit 1 ;; esac diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index a9850630c02..a589c2fe35c 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -1836,9 +1836,6 @@ type Routing implements Item { endpoint: String session: UUID status: String - - """Added in 26.4.1.""" - health_status: String traffic_ratio: Float created_at: DateTime error_data: JSONString diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index ba8493a3277..d329fc62708 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -166,16 +166,6 @@ input AddRevisionInput extraMounts: [ExtraVFolderMountInput!] = null } -"""Added in 26.4.1. Options for the add_model_revision mutation.""" -input AddRevisionOptions - @join__type(graph: STRAWBERRY) -{ - """ - When true, automatically activate the newly added revision immediately after creation. - """ - autoActivate: Boolean! = false -} - """Added in 25.19.0. Payload for adding a revision.""" type AddRevisionPayload @join__type(graph: STRAWBERRY) @@ -4683,9 +4673,9 @@ input DeploymentOrderBy enum DeploymentOrderField @join__type(graph: STRAWBERRY) { - NAME @join__enumValue(graph: STRAWBERRY) CREATED_AT @join__enumValue(graph: STRAWBERRY) UPDATED_AT @join__enumValue(graph: STRAWBERRY) + NAME @join__enumValue(graph: STRAWBERRY) } """Added in 25.19.0. Deployment policy configuration.""" @@ -10019,7 +10009,7 @@ type Mutation syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload! @join__field(graph: STRAWBERRY) """Added in 25.16.0. Add model revision.""" - addModelRevision(input: AddRevisionInput!, options: AddRevisionOptions = null): AddRevisionPayload! @join__field(graph: STRAWBERRY) + addModelRevision(input: AddRevisionInput!): AddRevisionPayload! @join__field(graph: STRAWBERRY) """ Added in 26.4.1. Create or update the deployment policy for a given deployment (upsert semantics). If the deployment already has a policy, it is replaced entirely with the new configuration @@ -14933,9 +14923,6 @@ type Routing implements Item endpoint: String session: UUID status: String - - """Added in 26.4.1.""" - health_status: String traffic_ratio: Float created_at: DateTime error_data: JSONString diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 7cbf4b51e3d..ac53b7df518 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -126,14 +126,6 @@ input AddRevisionInput { extraMounts: [ExtraVFolderMountInput!] = null } -"""Added in 26.4.1. Options for the add_model_revision mutation.""" -input AddRevisionOptions { - """ - When true, automatically activate the newly added revision immediately after creation. - """ - autoActivate: Boolean! = false -} - """Added in 25.19.0. Payload for adding a revision.""" type AddRevisionPayload { """Added revision""" @@ -3012,9 +3004,9 @@ input DeploymentOrderBy { } enum DeploymentOrderField { - NAME CREATED_AT UPDATED_AT + NAME } """Added in 25.19.0. Deployment policy configuration.""" @@ -6021,7 +6013,7 @@ type Mutation { syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload! """Added in 25.16.0. Add model revision.""" - addModelRevision(input: AddRevisionInput!, options: AddRevisionOptions = null): AddRevisionPayload! + addModelRevision(input: AddRevisionInput!): AddRevisionPayload! """ Added in 26.4.1. Create or update the deployment policy for a given deployment (upsert semantics). If the deployment already has a policy, it is replaced entirely with the new configuration diff --git a/src/ai/backend/common/clients/prometheus/__init__.py b/src/ai/backend/common/clients/prometheus/__init__.py index b4b0e60f382..7a5d683c578 100644 --- a/src/ai/backend/common/clients/prometheus/__init__.py +++ b/src/ai/backend/common/clients/prometheus/__init__.py @@ -1,11 +1,9 @@ from .client import PrometheusClient -from .preset import LabelMatcher, LabelOperator, MetricPreset +from .preset import MetricPreset from .querier import ContainerMetricQuerier, MetricQuerier from .types import ValueType __all__ = [ - "LabelMatcher", - "LabelOperator", "PrometheusClient", "MetricPreset", "MetricQuerier", diff --git a/src/ai/backend/common/clients/prometheus/preset.py b/src/ai/backend/common/clients/prometheus/preset.py index 298d52c4303..247aa5b6ed1 100644 --- a/src/ai/backend/common/clients/prometheus/preset.py +++ b/src/ai/backend/common/clients/prometheus/preset.py @@ -1,30 +1,5 @@ from collections.abc import Mapping, Set from dataclasses import dataclass, field -from enum import StrEnum -from typing import Self - - -class LabelOperator(StrEnum): - EQUAL = "=" - NOT_EQUAL = "!=" - REGEX = "=~" - NOT_REGEX = "!~" - - -@dataclass(frozen=True) -class LabelMatcher: - """PromQL label matcher with an explicit operator.""" - - value: str - operator: LabelOperator = LabelOperator.EQUAL - - @classmethod - def exact(cls, value: str) -> Self: - return cls(value=value, operator=LabelOperator.EQUAL) - - @classmethod - def regex(cls, value: str) -> Self: - return cls(value=value, operator=LabelOperator.REGEX) def _escape_label_value(value: str) -> str: @@ -40,7 +15,7 @@ class MetricPreset: template: str # Query labels (injected into {labels} placeholder) - labels: Mapping[str, LabelMatcher] = field(default_factory=dict) + labels: Mapping[str, str] = field(default_factory=dict) # Group by labels (injected into {group_by} placeholder) group_by: Set[str] = field(default_factory=frozenset) @@ -50,10 +25,7 @@ class MetricPreset: def render(self) -> str: """Render the PromQL query with all values injected.""" - label_str = ",".join( - f'{key}{value.operator}"{_escape_label_value(value.value)}"' - for key, value in self.labels.items() - ) + label_str = ",".join(f'{k}="{_escape_label_value(v)}"' for k, v in self.labels.items()) return self.template.format( labels=label_str, window=self.window, diff --git a/src/ai/backend/common/clients/prometheus/querier.py b/src/ai/backend/common/clients/prometheus/querier.py index 5b973418e59..f28d6876d71 100644 --- a/src/ai/backend/common/clients/prometheus/querier.py +++ b/src/ai/backend/common/clients/prometheus/querier.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from uuid import UUID -from ai.backend.common.clients.prometheus.preset import LabelMatcher from ai.backend.common.clients.prometheus.types import ValueType @@ -15,7 +14,7 @@ class MetricQuerier(ABC): """ @abstractmethod - def labels(self) -> Mapping[str, LabelMatcher]: + def labels(self) -> Mapping[str, str]: """Return the labels to be used in the Prometheus query.""" ... @@ -35,22 +34,22 @@ class ContainerMetricQuerier(MetricQuerier): user_id: UUID | None = None project_id: UUID | None = None - def labels(self) -> Mapping[str, LabelMatcher]: + def labels(self) -> Mapping[str, str]: """Return the labels for the container metric query.""" - result: dict[str, LabelMatcher] = { - "container_metric_name": LabelMatcher.exact(self.metric_name), - "value_type": LabelMatcher.exact(self.value_type), + result: dict[str, str] = { + "container_metric_name": self.metric_name, + "value_type": self.value_type, } if self.kernel_id is not None: - result["kernel_id"] = LabelMatcher.exact(str(self.kernel_id)) + result["kernel_id"] = str(self.kernel_id) if self.session_id is not None: - result["session_id"] = LabelMatcher.exact(str(self.session_id)) + result["session_id"] = str(self.session_id) if self.agent_id is not None: - result["agent_id"] = LabelMatcher.exact(self.agent_id) + result["agent_id"] = self.agent_id if self.user_id is not None: - result["user_id"] = LabelMatcher.exact(str(self.user_id)) + result["user_id"] = str(self.user_id) if self.project_id is not None: - result["project_id"] = LabelMatcher.exact(str(self.project_id)) + result["project_id"] = str(self.project_id) return result def group_by_labels(self) -> frozenset[str]: diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py index 4d8ae23d5a1..76a6dc9594d 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py @@ -95,14 +95,11 @@ class RouteHealthRecord: route_id: str created_at: int # Unix timestamp when route was created - initial_delay_until: int # Unix timestamp = running_at + initial_delay + initial_delay_until: int # Unix timestamp = created_at + initial_delay health_path: str # extracted from model_definition inference_port: int # extracted from kernel replica_host: str # extracted from kernel - # Timestamp when route entered RUNNING state (set by coordinator) - running_at: int | None = None - # Agent check results agent_healthy: bool = False agent_last_check: int = 0 # Unix timestamp @@ -131,7 +128,7 @@ def is_stale(self, current_time: int, staleness_sec: int = MAX_HEALTH_STALENESS_ def to_valkey_hash(self) -> Mapping[str, str]: """Serialize to Valkey hash fields.""" - data: dict[str, str] = { + return { "route_id": self.route_id, "created_at": str(self.created_at), "initial_delay_until": str(self.initial_delay_until), @@ -143,9 +140,6 @@ def to_valkey_hash(self) -> Mapping[str, str]: "manager_healthy": "1" if self.manager_healthy else "0", "manager_last_check": str(self.manager_last_check), } - if self.running_at is not None: - data["running_at"] = str(self.running_at) - return data @classmethod def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord: @@ -157,7 +151,6 @@ def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord: health_path=data["health_path"], inference_port=int(data["inference_port"]), replica_host=data["replica_host"], - running_at=int(raw) if (raw := data.get("running_at")) and raw != "0" else None, agent_healthy=data.get("agent_healthy", "0") == "1", agent_last_check=int(data.get("agent_last_check", "0")), manager_healthy=data.get("manager_healthy", "0") == "1", @@ -675,52 +668,6 @@ async def update_route_liveness(self, route_id: str, liveness: bool) -> None: async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() - async def mark_route_running_at(self, route_id: str) -> None: - """ - Record the RUNNING transition timestamp for a route. - Called when a route transitions to RUNNING status. - Uses Redis time for consistency with health check comparisons. - - :param route_id: The route ID that entered RUNNING state - """ - key = self._get_route_health_key(route_id) - current_time = str(await self._get_redis_time()) - async with self._client.client() as conn: - await conn.hset(key, {"running_at": current_time}) - await conn.expire(key, ROUTE_HEALTH_TTL_SEC) - - @valkey_schedule_resilience.apply() - async def get_route_running_at_batch(self, route_ids: Sequence[str]) -> dict[str, int | None]: - """ - Batch read running_at field from route health hashes. - Works even on partial hashes (before full RouteHealthRecord is initialized). - - :param route_ids: Route IDs to look up - :return: Mapping of route_id to running_at timestamp (None if not set) - """ - if not route_ids: - return {} - - batch = Batch(is_atomic=False) - for route_id in route_ids: - key = self._get_route_health_key(route_id) - batch.hget(key, "running_at") - - async with self._client.client() as conn: - results = await conn.exec(batch, raise_on_error=False) - if results is None: - return dict.fromkeys(route_ids) - - running_at_map: dict[str, int | None] = {} - for i, route_id in enumerate(route_ids): - raw = results[i] if len(results) > i else None - if raw and raw != b"0": - running_at_map[route_id] = int(raw) - else: - running_at_map[route_id] = None - return running_at_map - @valkey_schedule_resilience.apply() async def refresh_route_health_ttl(self, route_id: str) -> None: """ @@ -881,8 +828,6 @@ async def get_route_health_record(self, route_id: str) -> RouteHealthRecord | No return None data = {k.decode(): v.decode() for k, v in result.items()} - if "route_id" not in data: - return None return RouteHealthRecord.from_valkey_hash(data) @valkey_schedule_resilience.apply() @@ -921,10 +866,6 @@ async def get_route_health_records_batch( continue data = {k.decode(): v.decode() for k, v in raw.items()} - if "route_id" not in data: - # Partial hash (e.g., only running_at set by mark_route_running_at) - records[route_id] = None - continue records[route_id] = RouteHealthRecord.from_valkey_hash(data) return records diff --git a/src/ai/backend/common/dto/manager/auth/request.py b/src/ai/backend/common/dto/manager/auth/request.py index a562bfdbf9d..42fa3b1e851 100644 --- a/src/ai/backend/common/dto/manager/auth/request.py +++ b/src/ai/backend/common/dto/manager/auth/request.py @@ -51,12 +51,10 @@ class AuthorizeRequest(BaseRequestModel): default=None, description="One-time password for TOTP-based two-factor authentication", ) - client_type_id: UUID | None = Field( - default=None, + client_type_id: UUID = Field( description=( "Login client type UUID (must reference an existing login_client_types row). " - "Concurrent session limits are enforced per client type. " - "When omitted, session tracking is not scoped by client type." + "Concurrent session limits are enforced per client type." ), ) force: bool = Field( diff --git a/src/ai/backend/common/dto/manager/deployment/__init__.py b/src/ai/backend/common/dto/manager/deployment/__init__.py index ee7d38bb6c0..b38698b1dd1 100644 --- a/src/ai/backend/common/dto/manager/deployment/__init__.py +++ b/src/ai/backend/common/dto/manager/deployment/__init__.py @@ -5,7 +5,6 @@ from __future__ import annotations from .request import ( - AddRevisionOptions, AddRevisionRequest, BlueGreenConfigInput, ClusterConfigInput, @@ -113,7 +112,6 @@ # Request DTOs - Create/Update requests "CreateDeploymentRequest", "UpsertDeploymentPolicyRequest", - "AddRevisionOptions", "AddRevisionRequest", "UpdateDeploymentRequest", "UpdateRouteTrafficStatusRequest", diff --git a/src/ai/backend/common/dto/manager/deployment/request.py b/src/ai/backend/common/dto/manager/deployment/request.py index b6500bce652..51f6aea1445 100644 --- a/src/ai/backend/common/dto/manager/deployment/request.py +++ b/src/ai/backend/common/dto/manager/deployment/request.py @@ -334,20 +334,7 @@ class DeploymentPolicyPathParam(BaseRequestModel): ) -class AddRevisionOptions(BaseRequestModel): - """Options for the add revision operation.""" - - auto_activate: bool = Field( - default=False, - description="When true, automatically activate the newly added revision immediately after creation.", - ) - - class AddRevisionRequest(BaseRequestModel): """Request to add a new revision to an existing deployment.""" revision: RevisionInput = Field(description="Revision configuration") - options: AddRevisionOptions = Field( - default_factory=AddRevisionOptions, - description="Additional options for the add revision operation.", - ) diff --git a/src/ai/backend/common/dto/manager/v2/deployment/request.py b/src/ai/backend/common/dto/manager/v2/deployment/request.py index 0d50ff8d837..0367cd344d4 100644 --- a/src/ai/backend/common/dto/manager/v2/deployment/request.py +++ b/src/ai/backend/common/dto/manager/v2/deployment/request.py @@ -201,17 +201,8 @@ class CreateRevisionInputDTO(BaseRequestModel): ) -class AddRevisionOptions(BaseRequestModel): - """Options for the add revision operation.""" - - auto_activate: bool = Field( - default=False, - description="When true, automatically activate the newly added revision immediately after creation.", - ) - - class AddRevisionGQLInputDTO(BaseRequestModel): - """Input for adding a revision. Used by both GQL and REST v2 APIs.""" + """Input for adding a revision via GQL (flat structure matching GQL AddRevisionInput).""" name: str | None = Field(default=None, description="Revision name") revision_preset_id: UUID | None = Field( @@ -231,9 +222,9 @@ class AddRevisionGQLInputDTO(BaseRequestModel): extra_mounts: list[ExtraVFolderMountInput] | None = Field( default=None, description="Additional vfolder mounts" ) - options: AddRevisionOptions | None = Field( - default=None, - description="Additional options for the add revision operation.", + auto_activate: bool = Field( + default=False, + description="If true, automatically activate this revision after creation.", ) diff --git a/src/ai/backend/manager/api/adapters/deployment.py b/src/ai/backend/manager/api/adapters/deployment.py index 33ba248efaa..efa9ffdeb2e 100644 --- a/src/ai/backend/manager/api/adapters/deployment.py +++ b/src/ai/backend/manager/api/adapters/deployment.py @@ -27,7 +27,6 @@ from ai.backend.common.dto.manager.v2.deployment.request import ( ActivateRevisionInput, AddRevisionGQLInputDTO, - AddRevisionOptions, AdminSearchDeploymentsInput, AdminSearchRevisionsInput, BulkDeleteAccessTokensInput, @@ -991,7 +990,6 @@ async def upsert_policy( async def add_revision( self, input: AddRevisionGQLInputDTO, - options: AddRevisionOptions, ) -> AddRevisionPayload: """Add a new model revision to a deployment.""" mounts_creator = VFolderMountsCreator( @@ -1029,20 +1027,11 @@ async def add_revision( ), model_definition=input.model_definition, revision_preset_id=input.revision_preset_id, + auto_activate=input.auto_activate, ) action_result = await self._processors.deployment.add_model_revision.wait_for_complete( - AddModelRevisionAction( - model_deployment_id=input.deployment_id, - adder=adder, - ) + AddModelRevisionAction(model_deployment_id=input.deployment_id, adder=adder) ) - if options.auto_activate: - await self._processors.deployment.activate_revision.wait_for_complete( - ActivateRevisionAction( - deployment_id=input.deployment_id, - revision_id=action_result.revision.id, - ) - ) return AddRevisionPayload(revision=self._revision_data_to_dto(action_result.revision)) async def get_revision(self, revision_id: UUID) -> RevisionNode: diff --git a/src/ai/backend/manager/api/adapters/model_card.py b/src/ai/backend/manager/api/adapters/model_card.py index c3caa4898a0..a795cae5e2d 100644 --- a/src/ai/backend/manager/api/adapters/model_card.py +++ b/src/ai/backend/manager/api/adapters/model_card.py @@ -468,6 +468,7 @@ async def deploy( execution=ExecutionSpec(runtime_variant=RuntimeVariant("custom")), model_definition=None, revision_preset_id=input.revision_preset_id, + auto_activate=True, ), policy=policy, ) diff --git a/src/ai/backend/manager/api/adapters/vfolder.py b/src/ai/backend/manager/api/adapters/vfolder.py index 1d57ab629d3..e1640cc11a7 100644 --- a/src/ai/backend/manager/api/adapters/vfolder.py +++ b/src/ai/backend/manager/api/adapters/vfolder.py @@ -477,6 +477,7 @@ async def deploy( execution=ExecutionSpec(runtime_variant=RuntimeVariant("custom")), model_definition=None, revision_preset_id=input.revision_preset_id, + auto_activate=True, ), policy=policy, ) diff --git a/src/ai/backend/manager/api/gql/deployment/resolver/revision.py b/src/ai/backend/manager/api/gql/deployment/resolver/revision.py index 8a1d9a55b1e..5be9d49cced 100644 --- a/src/ai/backend/manager/api/gql/deployment/resolver/revision.py +++ b/src/ai/backend/manager/api/gql/deployment/resolver/revision.py @@ -9,12 +9,7 @@ from strawberry.relay import PageInfo from strawberry.scalars import JSON -from ai.backend.common.dto.manager.v2.deployment.request import ( - AddRevisionOptions as AdapterAddRevisionOptions, -) -from ai.backend.common.dto.manager.v2.deployment.request import ( - AdminSearchRevisionsInput, -) +from ai.backend.common.dto.manager.v2.deployment.request import AdminSearchRevisionsInput from ai.backend.manager.api.gql.base import encode_cursor, resolve_global_id from ai.backend.manager.api.gql.decorators import ( BackendAIGQLMeta, @@ -27,7 +22,6 @@ ActivateRevisionInputGQL, ActivateRevisionPayloadGQL, AddRevisionInput, - AddRevisionOptionsGQL, AddRevisionPayload, ModelRevision, ModelRevisionConnection, @@ -149,15 +143,10 @@ async def inference_runtime_configs(info: Info[StrawberryGQLContext]) -> JSON: @gql_mutation(BackendAIGQLMeta(added_version="25.16.0", description="Add model revision.")) # type: ignore[misc] async def add_model_revision( - input: AddRevisionInput, - info: Info[StrawberryGQLContext], - options: AddRevisionOptionsGQL | None = None, + input: AddRevisionInput, info: Info[StrawberryGQLContext] ) -> AddRevisionPayload: """Add a model revision to a deployment.""" - payload = await info.context.adapters.deployment.add_revision( - input.to_pydantic(), - options=options.to_pydantic() if options else AdapterAddRevisionOptions(), - ) + payload = await info.context.adapters.deployment.add_revision(input.to_pydantic()) return AddRevisionPayload(revision=ModelRevision.from_pydantic(payload.revision)) diff --git a/src/ai/backend/manager/api/gql/deployment/types/access_token.py b/src/ai/backend/manager/api/gql/deployment/types/access_token.py index deae4f15a6e..25834c95e9a 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/access_token.py +++ b/src/ai/backend/manager/api/gql/deployment/types/access_token.py @@ -31,9 +31,6 @@ from ai.backend.common.dto.manager.v2.deployment.response import ( DeleteAccessTokenPayload as DeleteAccessTokenPayloadDTO, ) -from ai.backend.common.dto.manager.v2.deployment.types import ( - AccessTokenOrderField, -) from ai.backend.manager.api.gql.base import DateTimeFilter, OrderDirection, StringFilter from ai.backend.manager.api.gql.decorators import ( BackendAIGQLMeta, @@ -46,6 +43,9 @@ ) from ai.backend.manager.api.gql.pydantic_compat import PydanticNodeMixin from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.types import ( + AccessTokenOrderField, +) @gql_pydantic_input( diff --git a/src/ai/backend/manager/api/gql/deployment/types/auto_scaling.py b/src/ai/backend/manager/api/gql/deployment/types/auto_scaling.py index 02d62d5af0b..454b85ef6cb 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/auto_scaling.py +++ b/src/ai/backend/manager/api/gql/deployment/types/auto_scaling.py @@ -39,9 +39,6 @@ from ai.backend.common.dto.manager.v2.deployment.response import ( UpdateAutoScalingRulePayload as UpdateAutoScalingRulePayloadDTO, ) -from ai.backend.common.dto.manager.v2.deployment.types import ( - AutoScalingRuleOrderField, -) from ai.backend.common.meta import NEXT_RELEASE_VERSION from ai.backend.manager.api.gql.base import DateTimeFilter, OrderDirection from ai.backend.manager.api.gql.decorators import ( @@ -57,6 +54,9 @@ ) from ai.backend.manager.api.gql.pydantic_compat import PydanticNodeMixin from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.types import ( + AutoScalingRuleOrderField, +) @gql_enum( diff --git a/src/ai/backend/manager/api/gql/deployment/types/deployment.py b/src/ai/backend/manager/api/gql/deployment/types/deployment.py index 61cecea2b54..c664d117469 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/deployment.py +++ b/src/ai/backend/manager/api/gql/deployment/types/deployment.py @@ -71,7 +71,6 @@ from ai.backend.common.dto.manager.v2.deployment.types import ( DeploymentMetadataInfoDTO, DeploymentNetworkAccessInfoDTO, - DeploymentOrderField, DeploymentStrategyInfoDTO, ReplicaStateInfo, ) @@ -141,6 +140,7 @@ from ai.backend.manager.data.deployment.types import ( AccessTokenSearchScope, AutoScalingRuleSearchScope, + DeploymentOrderField, ReplicaSearchScope, RevisionSearchScope, ) diff --git a/src/ai/backend/manager/api/gql/deployment/types/replica.py b/src/ai/backend/manager/api/gql/deployment/types/replica.py index a74ece063f7..a64c7bc2733 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/replica.py +++ b/src/ai/backend/manager/api/gql/deployment/types/replica.py @@ -31,9 +31,6 @@ from ai.backend.common.dto.manager.v2.deployment.response import ( ReplicaStatusChangedPayload as ReplicaStatusChangedPayloadDTO, ) -from ai.backend.common.dto.manager.v2.deployment.types import ( - ReplicaOrderField, -) from ai.backend.manager.api.gql.base import ( OrderDirection, to_global_id, @@ -53,6 +50,7 @@ from ai.backend.manager.api.gql.types import StrawberryGQLContext from ai.backend.manager.api.gql_legacy.session import ComputeSessionNode from ai.backend.manager.data.deployment.types import ( + ReplicaOrderField, RouteStatus, RouteTrafficStatus, ) diff --git a/src/ai/backend/manager/api/gql/deployment/types/revision.py b/src/ai/backend/manager/api/gql/deployment/types/revision.py index 95d4b377578..2e0082c2eeb 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/revision.py +++ b/src/ai/backend/manager/api/gql/deployment/types/revision.py @@ -38,9 +38,6 @@ AddRevisionGQLInputDTO, CreateRevisionInputDTO, ) -from ai.backend.common.dto.manager.v2.deployment.request import ( - AddRevisionOptions as AddRevisionOptionsDTO, -) from ai.backend.common.dto.manager.v2.deployment.request import ( ClusterConfigInput as ClusterConfigInputDTO, ) @@ -844,20 +841,6 @@ class CreateRevisionInput(PydanticInputMixin[CreateRevisionInputDTO]): ) -@gql_pydantic_input( - BackendAIGQLMeta( - description="Options for the add_model_revision mutation.", - added_version=NEXT_RELEASE_VERSION, - ), - name="AddRevisionOptions", -) -class AddRevisionOptionsGQL(PydanticInputMixin[AddRevisionOptionsDTO]): - auto_activate: bool = gql_field( - default=False, - description="When true, automatically activate the newly added revision immediately after creation.", - ) - - @gql_pydantic_input( BackendAIGQLMeta(description="", added_version="25.19.0"), ) diff --git a/src/ai/backend/manager/api/gql_legacy/endpoint.py b/src/ai/backend/manager/api/gql_legacy/endpoint.py index 6502953dc0c..dabee691705 100644 --- a/src/ai/backend/manager/api/gql_legacy/endpoint.py +++ b/src/ai/backend/manager/api/gql_legacy/endpoint.py @@ -32,7 +32,7 @@ RuleId, RuntimeVariant, ) -from ai.backend.manager.data.deployment.types import RouteHealthStatus, RouteStatus +from ai.backend.manager.data.deployment.types import RouteStatus from ai.backend.manager.data.model_serving.creator import EndpointAutoScalingRuleCreator from ai.backend.manager.data.model_serving.modifier import ( ExtraMount, @@ -872,15 +872,17 @@ async def resolve_status(self, info: graphene.ResolveInfo) -> str: case _: if not self.routings: return EndpointStatus.DEGRADED - active_status_names = {s.name for s in RouteStatus.active_route_statuses()} - active_routings = [r for r in self.routings if r.status in active_status_names] - if not active_routings: + active_route_status_names = {s.name for s in RouteStatus.active_route_statuses()} + active_routings = [ + r for r in self.routings if r.status in active_route_status_names + ] + healthy_count = sum( + 1 for r in active_routings if r.status == RouteStatus.RUNNING.name + ) + if healthy_count == 0: return EndpointStatus.UNHEALTHY - health_statuses = {r.health_status for r in active_routings} - if health_statuses == {RouteHealthStatus.HEALTHY.name}: + if healthy_count == len(active_routings): return EndpointStatus.HEALTHY - if health_statuses == {RouteHealthStatus.UNHEALTHY.name}: - return EndpointStatus.UNHEALTHY return EndpointStatus.DEGRADED async def resolve_model_vfolder(self, info: graphene.ResolveInfo) -> VirtualFolderNode: diff --git a/src/ai/backend/manager/api/gql_legacy/routing.py b/src/ai/backend/manager/api/gql_legacy/routing.py index a254110096e..8102f29382a 100644 --- a/src/ai/backend/manager/api/gql_legacy/routing.py +++ b/src/ai/backend/manager/api/gql_legacy/routing.py @@ -9,7 +9,6 @@ from graphene.types.datetime import DateTime as GQLDateTime from sqlalchemy.exc import NoResultFound -from ai.backend.common.meta import NEXT_RELEASE_VERSION from ai.backend.manager.data.deployment.types import RouteStatus from ai.backend.manager.errors.service import RoutingNotFound from ai.backend.manager.models.routing import RoutingRow @@ -32,7 +31,6 @@ class Meta: endpoint = graphene.String() session = graphene.UUID() status = graphene.String() - health_status = graphene.String(description=f"Added in {NEXT_RELEASE_VERSION}.") traffic_ratio = graphene.Float() created_at = GQLDateTime() error = InferenceSessionError() @@ -51,7 +49,6 @@ def from_dto(cls, dto: Any) -> Self | None: endpoint=dto.endpoint, session=dto.session, status=dto.status.name, - health_status=dto.health_status.name, traffic_ratio=dto.traffic_ratio, created_at=dto.created_at, error_data=dto.error_data, @@ -69,7 +66,6 @@ async def from_row( endpoint=(endpoint or row.endpoint_row).url, session=row.session, status=row.status.name, - health_status=row.health_status.name, traffic_ratio=row.traffic_ratio, created_at=row.created_at, error_data=row.error_data, diff --git a/src/ai/backend/manager/api/rest/deployment/handler.py b/src/ai/backend/manager/api/rest/deployment/handler.py index 75e31220ad1..b0969b5265d 100644 --- a/src/ai/backend/manager/api/rest/deployment/handler.py +++ b/src/ai/backend/manager/api/rest/deployment/handler.py @@ -255,13 +255,6 @@ async def add_revision( adder=revision_creator, ) ) - if body.parsed.options.auto_activate: - await self._deployment.activate_revision.wait_for_complete( - ActivateRevisionAction( - deployment_id=path.parsed.deployment_id, - revision_id=action_result.revision.id, - ) - ) # Build response resp = AddRevisionResponse( diff --git a/src/ai/backend/manager/api/rest/session/handler.py b/src/ai/backend/manager/api/rest/session/handler.py index 0a6ac7dab65..9bd8bac1f09 100644 --- a/src/ai/backend/manager/api/rest/session/handler.py +++ b/src/ai/backend/manager/api/rest/session/handler.py @@ -275,7 +275,7 @@ async def create_from_template( owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( - "GET_OR_CREATE (ak:{0}/{1}, img:{2}, s:{3})", + "GET_OR_CREATE (ak:{0}, img:{1}, s:{2})", requester_access_key, owner_id if owner_id != request["user"]["uuid"] else "*", params.image, @@ -380,7 +380,7 @@ async def create_from_params( requester_access_key = AccessKey(request["keypair"]["access_key"]) owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( - "GET_OR_CREATE (ak:{0}/{1}, img:{2}, s:{3})", + "GET_OR_CREATE (ak:{0}, img:{1}, s:{2})", requester_access_key, owner_id if owner_id != request["user"]["uuid"] else "*", params.image, @@ -443,7 +443,7 @@ async def create_cluster( requester_access_key = AccessKey(request["keypair"]["access_key"]) owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( - "CREAT_CLUSTER (ak:{0}/{1}, s:{2})", + "CREAT_CLUSTER (ak:{0}, s:{1})", requester_access_key, owner_id if owner_id != request["user"]["uuid"] else "*", params.session_name, diff --git a/src/ai/backend/manager/api/rest/v2/deployment/handler.py b/src/ai/backend/manager/api/rest/v2/deployment/handler.py index cf055e8ef3c..e59e7b41958 100644 --- a/src/ai/backend/manager/api/rest/v2/deployment/handler.py +++ b/src/ai/backend/manager/api/rest/v2/deployment/handler.py @@ -18,7 +18,6 @@ from ai.backend.common.dto.manager.v2.deployment.request import ( ActivateRevisionInput, AddRevisionGQLInputDTO, - AddRevisionOptions, AdminSearchDeploymentsInput, AdminSearchRevisionsInput, BulkDeleteAccessTokensInput, @@ -156,9 +155,7 @@ async def add_revision( body: BodyParam[AddRevisionGQLInputDTO], ) -> APIResponse: """Add a new model revision to a deployment.""" - result = await self._adapter.add_revision( - body.parsed, body.parsed.options or AddRevisionOptions() - ) + result = await self._adapter.add_revision(body.parsed) return APIResponse.build(status_code=HTTPStatus.CREATED, response_model=result) async def get_revision( diff --git a/src/ai/backend/manager/data/deployment/creator.py b/src/ai/backend/manager/data/deployment/creator.py index 8cc03e77938..7b459207144 100644 --- a/src/ai/backend/manager/data/deployment/creator.py +++ b/src/ai/backend/manager/data/deployment/creator.py @@ -44,6 +44,7 @@ class ModelRevisionCreator: model_definition: ModelDefinition | None revision_preset_id: UUID | None = None preset_values: list[PresetValueData] = field(default_factory=list) + auto_activate: bool = False @dataclass diff --git a/src/ai/backend/manager/models/alembic/versions/8c1d2e3f4a5b_drop_session_kernel_access_key.py b/src/ai/backend/manager/models/alembic/versions/8c1d2e3f4a5b_drop_session_kernel_access_key.py new file mode 100644 index 00000000000..141b096ad45 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/8c1d2e3f4a5b_drop_session_kernel_access_key.py @@ -0,0 +1,57 @@ +"""drop sessions/kernels access_key columns + +Part of BA-5653. The ``access_key`` column is removed from the +``sessions`` and ``kernels`` tables. Downstream code now resolves the +owner's ``main_access_key`` from the ``users`` table when needed +(keypair-scoped concurrency tracking, resource policy lookups, agent +RPC payloads). The ``user_uuid`` column is kept on both tables as the +canonical owner reference; only the redundant ``access_key`` snapshot +is dropped. + +Revision ID: 8c1d2e3f4a5b +Revises: 2a531e0c528e +Create Date: 2026-04-14 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8c1d2e3f4a5b" +down_revision = "2a531e0c528e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # The (access_key, sess_id) partial unique index on ``kernels`` references the + # column being dropped — remove it first. + op.drop_index("ix_kernels_unique_sess_token", table_name="kernels") + op.drop_column("kernels", "access_key") + op.drop_column("sessions", "access_key") + + +def downgrade() -> None: + """Recreate the ``access_key`` columns as nullable. + + NOTE: This downgrade is intentionally lossy. Previous values cannot + be restored because the upgrade does not preserve them. Callers that + depended on the old column must resolve ``main_access_key`` via the + ``users`` table instead. + """ + op.add_column( + "sessions", + sa.Column("access_key", sa.String(length=20), nullable=True), + ) + op.add_column( + "kernels", + sa.Column("access_key", sa.String(length=20), nullable=True), + ) + op.create_index( + op.f("ix_kernels_unique_sess_token"), + "kernels", + ["access_key", "sess_id"], + unique=True, + postgresql_where=sa.text("kernels.status != 'TERMINATED' and kernels.role = 'master'"), + ) diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index c1a2a9e8ea5..69372496992 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -1215,7 +1215,7 @@ async def get_pending_timeout_sessions_by_ids( SweptSessionInfo( session_id=row.id, creation_id=row.creation_id, - main_access_key=row.access_key, + main_access_key=row.main_access_key, ) ) @@ -2896,7 +2896,8 @@ async def _get_sessions_by_statuses( KernelRow.status, KernelRow.status_changed, ) - ) + ), + selectinload(SessionRow.user).options(load_only(UserRow.main_access_key)), ) ) result = await db_sess.execute(stmt) @@ -2921,12 +2922,11 @@ async def _get_sessions_by_statuses( ) kernels_data.append(kernel_data) + owner_main_ak = session.user.main_access_key if session.user else None scheduled_session = ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - main_access_key=AccessKey(session.access_key) - if session.access_key - else AccessKey(""), + main_access_key=AccessKey(owner_main_ak) if owner_main_ak else AccessKey(""), reason="triggered-by-scheduler", ) scheduled_sessions.append(scheduled_session) @@ -2953,10 +2953,11 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes KernelRow.architecture, ) ), + selectinload(SessionRow.user).options(load_only(UserRow.main_access_key)), load_only( SessionRow.id, SessionRow.creation_id, - SessionRow.access_key, + SessionRow.user_uuid, SessionRow.session_type, SessionRow.name, ), @@ -2967,13 +2968,12 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes scheduled_sessions: list[ScheduledSessionData] = [] for session in sessions: + owner_main_ak = session.user.main_access_key if session.user else None scheduled_sessions.append( ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - main_access_key=AccessKey(session.access_key) - if session.access_key - else AccessKey(""), + main_access_key=AccessKey(owner_main_ak) if owner_main_ak else AccessKey(""), reason="triggered-by-scheduler", ) ) @@ -4625,7 +4625,7 @@ async def search_sessions_with_kernels_and_user( sessions_for_start: list[SessionDataForStart] = [] for session_id in session_ids: session_info = session_info_map[session_id] - user_info = user_map.get(session_info["user_uuid"]) + user_info = user_map.get(session_info["owner_id"]) if not user_info: log.warning(f"User info not found for session {session_id}") continue diff --git a/src/ai/backend/manager/services/auth/actions/authorize.py b/src/ai/backend/manager/services/auth/actions/authorize.py index 409d997392b..c201a8bd647 100644 --- a/src/ai/backend/manager/services/auth/actions/authorize.py +++ b/src/ai/backend/manager/services/auth/actions/authorize.py @@ -20,7 +20,7 @@ class AuthorizeAction(AuthAction): password: str stoken: str | None otp: str | None - client_type_id: UUID | None + client_type_id: UUID force: bool = False @override diff --git a/src/ai/backend/manager/services/auth/service.py b/src/ai/backend/manager/services/auth/service.py index db2299ff200..98c86a1de67 100644 --- a/src/ai/backend/manager/services/auth/service.py +++ b/src/ai/backend/manager/services/auth/service.py @@ -210,7 +210,7 @@ async def _verify_user( self, action: AuthorizeAction, auth_config: AuthConfig, - login_client_type_id: uuid.UUID | None, + login_client_type_id: uuid.UUID, ) -> tuple[RowMapping, list[ActiveSessionInfo]]: """Step 1: Verify user identity via hook or password.""" params = action.hook_params @@ -329,7 +329,7 @@ async def _create_login_session( keypair_row: Any, live_sessions: list[ActiveSessionInfo], auth_config: AuthConfig, - login_client_type_id: uuid.UUID | None, + login_client_type_id: uuid.UUID, ) -> AuthorizeActionResult: """Step 3: Create login session (DB + Valkey), force-invalidate old sessions if needed. diff --git a/src/ai/backend/manager/services/deployment/service.py b/src/ai/backend/manager/services/deployment/service.py index 1ef4fd86506..6cc01efb896 100644 --- a/src/ai/backend/manager/services/deployment/service.py +++ b/src/ai/backend/manager/services/deployment/service.py @@ -493,16 +493,11 @@ async def create_deployment( # Create initial revision if provided, via the same path as add_model_revision # to ensure preset/merge/resolve logic is applied consistently. if revision is not None: - add_result = await self.add_model_revision( + initial_revision_creator = dataclasses.replace(revision, auto_activate=True) + await self.add_model_revision( AddModelRevisionAction( model_deployment_id=deployment_info.id, - adder=revision, - ) - ) - await self.activate_revision( - ActivateRevisionAction( - deployment_id=deployment_info.id, - revision_id=add_result.revision.id, + adder=initial_revision_creator, ) ) @@ -991,6 +986,15 @@ async def add_model_revision( creator, deployment_id ) + # Auto-activate revision if requested + if action.adder.auto_activate: + await self.activate_revision( + ActivateRevisionAction( + deployment_id=deployment_id, + revision_id=revision_data.id, + ) + ) + return AddModelRevisionActionResult(revision=revision_data) async def get_revision_by_id( diff --git a/src/ai/backend/manager/services/model_serving/services/model_serving.py b/src/ai/backend/manager/services/model_serving/services/model_serving.py index 0eaa9d658fc..a6e8179a72d 100644 --- a/src/ai/backend/manager/services/model_serving/services/model_serving.py +++ b/src/ai/backend/manager/services/model_serving/services/model_serving.py @@ -896,10 +896,8 @@ async def modify_endpoint(self, action: ModifyEndpointAction) -> ModifyEndpointA self._storage_manager, ) spec = cast(EndpointUpdaterSpec, action.updater.spec) - if spec.replica_count_modified() or spec.has_revision_changes(): - # Trigger CHECK_REPLICA for both cases: - # - replica count change: reconcile running session count - # - revision change: rotate sessions to use the newly activated revision + if spec.replica_count_modified(): + # Notify appproxy to update routing info await self._deployment_controller.mark_lifecycle_needed( DeploymentLifecycleType.CHECK_REPLICA, ) diff --git a/src/ai/backend/manager/services/prometheus_query_preset/service.py b/src/ai/backend/manager/services/prometheus_query_preset/service.py index 085a0cf535a..c45e2191eb2 100644 --- a/src/ai/backend/manager/services/prometheus_query_preset/service.py +++ b/src/ai/backend/manager/services/prometheus_query_preset/service.py @@ -1,7 +1,7 @@ import logging from ai.backend.common.clients.prometheus.client import PrometheusClient -from ai.backend.common.clients.prometheus.preset import LabelMatcher, MetricPreset +from ai.backend.common.clients.prometheus.preset import MetricPreset from ai.backend.common.dto.clients.prometheus.response import PrometheusResponse from ai.backend.common.exception import PrometheusQueryPresetInvalidLabel from ai.backend.logging.utils import BraceStyleAdapter @@ -90,12 +90,6 @@ def _validate_labels( f"Allowed: {sorted(preset_data.group_labels)}" ) - def _build_filter_label_matchers( - self, - filter_labels: dict[str, str], - ) -> dict[str, LabelMatcher]: - return {key: LabelMatcher.exact(value) for key, value in filter_labels.items()} - async def execute_preset(self, action: ExecutePresetAction) -> ExecutePresetActionResult: preset_data = await self._repository.get_by_id(action.preset_id) self._validate_labels(action.options, preset_data) @@ -104,7 +98,7 @@ async def execute_preset(self, action: ExecutePresetAction) -> ExecutePresetActi metric_preset = MetricPreset( template=preset_data.query_template, - labels=self._build_filter_label_matchers(action.options.filter_labels), + labels=action.options.filter_labels, group_by=set(action.options.group_labels), window=time_window, ) diff --git a/src/ai/backend/manager/services/user/service.py b/src/ai/backend/manager/services/user/service.py index 1b618f52978..b24a0fa4f0f 100644 --- a/src/ai/backend/manager/services/user/service.py +++ b/src/ai/backend/manager/services/user/service.py @@ -231,10 +231,18 @@ async def purge_user(self, action: PurgeUserAction) -> PurgeUserActionResult: # Handle endpoint ownership delegation if action.delegate_endpoint_ownership.optional_value(): + target_main_ak = await self._user_repository.get_main_access_key_by_id( + action.user_info_ctx.uuid + ) + if target_main_ak is None: + raise UserPurgeFailure( + f"Cannot delegate endpoint ownership: target user " + f"{action.user_info_ctx.uuid} has no main_access_key" + ) await self._user_repository.delegate_endpoint_ownership( user_uuid=user_uuid, target_user_uuid=action.user_info_ctx.uuid, - target_main_access_key=action.user_info_ctx.main_access_key, + target_main_access_key=AccessKey(target_main_ak), ) await self._user_repository.delete_endpoints( user_uuid=user_uuid, @@ -304,10 +312,18 @@ async def _purge_single_user( # Handle endpoint ownership delegation if action.delegate_endpoint_ownership.optional_value(): + target_main_ak = await self._user_repository.get_main_access_key_by_id( + user_info_ctx.uuid + ) + if target_main_ak is None: + raise UserPurgeFailure( + f"Cannot delegate endpoint ownership: target user " + f"{user_info_ctx.uuid} has no main_access_key" + ) await self._user_repository.delegate_endpoint_ownership( user_uuid=user_uuid, target_user_uuid=user_info_ctx.uuid, - target_main_access_key=user_info_ctx.main_access_key, + target_main_access_key=AccessKey(target_main_ak), ) await self._user_repository.delete_endpoints( user_uuid=user_uuid, diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py index 5f527d9de3b..b4159d9b2ac 100644 --- a/src/ai/backend/web/server.py +++ b/src/ai/backend/web/server.py @@ -755,7 +755,6 @@ async def redis_ctx( redis_storage = RedisStorage( valkey_session_client, max_age=config.session.max_age, - secure=config.service.ssl_enabled, ) setup_session(app, redis_storage) try: diff --git a/tests/unit/common/clients/prometheus/test_client.py b/tests/unit/common/clients/prometheus/test_client.py index a3e95ab6ec0..8724e618693 100644 --- a/tests/unit/common/clients/prometheus/test_client.py +++ b/tests/unit/common/clients/prometheus/test_client.py @@ -5,7 +5,6 @@ import pytest from ai.backend.common.clients.prometheus import ( - LabelMatcher, MetricPreset, PrometheusClient, ) @@ -67,10 +66,7 @@ class TestQueryRange: def sample_preset(self) -> MetricPreset: return MetricPreset( template="sum(my_metric{{{labels}}}) by ({group_by})", - labels={ - "container_metric_name": LabelMatcher.exact("mem"), - "value_type": LabelMatcher.exact("current"), - }, + labels={"container_metric_name": "mem", "value_type": "current"}, group_by=frozenset({"value_type"}), window="5m", ) @@ -179,10 +175,7 @@ class TestQueryInstant: def sample_preset(self) -> MetricPreset: return MetricPreset( template="sum(my_metric{{{labels}}}) by ({group_by})", - labels={ - "container_metric_name": LabelMatcher.exact("mem"), - "value_type": LabelMatcher.exact("current"), - }, + labels={"container_metric_name": "mem", "value_type": "current"}, group_by=frozenset({"value_type"}), window="5m", ) diff --git a/tests/unit/common/clients/prometheus/test_preset.py b/tests/unit/common/clients/prometheus/test_preset.py index 0c4404d8337..738ecc2ffb5 100644 --- a/tests/unit/common/clients/prometheus/test_preset.py +++ b/tests/unit/common/clients/prometheus/test_preset.py @@ -2,14 +2,14 @@ import pytest -from ai.backend.common.clients.prometheus import LabelMatcher, MetricPreset +from ai.backend.common.clients.prometheus import MetricPreset @dataclass class RenderTestCase: id: str template: str - labels: dict[str, LabelMatcher] + labels: dict[str, str] group_by: frozenset[str] window: str expected: str @@ -32,7 +32,7 @@ class TestMetricPresetRender: RenderTestCase( id="multiple_group_by_sorted", template="sum(my_metric{{{labels}}}) by ({group_by})", - labels={"job": LabelMatcher.exact("test")}, + labels={"job": "test"}, group_by=frozenset({"value_type", "kernel_id", "session_id"}), window="", expected='sum(my_metric{job="test"}) by (kernel_id,session_id,value_type)', @@ -52,7 +52,7 @@ class TestMetricPresetRender: RenderTestCase( id="with_window", template="sum(rate(my_metric{{{labels}}}[{window}])) by ({group_by})", - labels={"job": LabelMatcher.exact("test")}, + labels={"job": "test"}, group_by=frozenset({"instance"}), window="5m", expected='sum(rate(my_metric{job="test"}[5m])) by (instance)', @@ -60,7 +60,7 @@ class TestMetricPresetRender: RenderTestCase( id="escapes_double_quotes_in_label_value", template="my_metric{{{labels}}}", - labels={"key": LabelMatcher.exact('value with "quotes"')}, + labels={"key": 'value with "quotes"'}, group_by=frozenset(), window="", expected='my_metric{key="value with \\"quotes\\""}', @@ -68,7 +68,7 @@ class TestMetricPresetRender: RenderTestCase( id="escapes_backslash_in_label_value", template="my_metric{{{labels}}}", - labels={"path": LabelMatcher.exact("C:\\Users\\test")}, + labels={"path": "C:\\Users\\test"}, group_by=frozenset(), window="", expected='my_metric{path="C:\\\\Users\\\\test"}', @@ -76,7 +76,7 @@ class TestMetricPresetRender: RenderTestCase( id="escapes_newline_in_label_value", template="my_metric{{{labels}}}", - labels={"msg": LabelMatcher.exact("line1\nline2")}, + labels={"msg": "line1\nline2"}, group_by=frozenset(), window="", expected='my_metric{msg="line1\\nline2"}', @@ -84,19 +84,11 @@ class TestMetricPresetRender: RenderTestCase( id="escapes_mixed_special_chars", template="my_metric{{{labels}}}", - labels={"data": LabelMatcher.exact('path\\to\\"file"\nend')}, + labels={"data": 'path\\to\\"file"\nend'}, group_by=frozenset(), window="", expected='my_metric{data="path\\\\to\\\\\\"file\\"\\nend"}', ), - RenderTestCase( - id="regex_matcher", - template="my_metric{{{labels}}}", - labels={"kernel_id": LabelMatcher.regex("kernel-1|kernel-2")}, - group_by=frozenset(), - window="", - expected='my_metric{kernel_id=~"kernel-1|kernel-2"}', - ), ], ids=lambda c: c.id, ) diff --git a/tests/unit/common/clients/prometheus/test_querier.py b/tests/unit/common/clients/prometheus/test_querier.py index 1c0a98251e0..da2765ef45e 100644 --- a/tests/unit/common/clients/prometheus/test_querier.py +++ b/tests/unit/common/clients/prometheus/test_querier.py @@ -1,6 +1,6 @@ from uuid import UUID -from ai.backend.common.clients.prometheus import ContainerMetricQuerier, LabelMatcher, ValueType +from ai.backend.common.clients.prometheus import ContainerMetricQuerier, ValueType class TestContainerMetricQuerier: @@ -15,8 +15,8 @@ async def test_labels_required_only(self) -> None: result = querier.labels() assert result == { - "container_metric_name": LabelMatcher.exact("cpu_util"), - "value_type": LabelMatcher.exact("current"), + "container_metric_name": "cpu_util", + "value_type": "current", } async def test_labels_all_fields(self) -> None: @@ -38,13 +38,13 @@ async def test_labels_all_fields(self) -> None: result = querier.labels() assert result == { - "container_metric_name": LabelMatcher.exact("net_rx"), - "value_type": LabelMatcher.exact("current"), - "kernel_id": LabelMatcher.exact(str(kernel_id)), - "session_id": LabelMatcher.exact(str(session_id)), - "agent_id": LabelMatcher.exact("agent-001"), - "user_id": LabelMatcher.exact(str(user_id)), - "project_id": LabelMatcher.exact(str(project_id)), + "container_metric_name": "net_rx", + "value_type": "current", + "kernel_id": str(kernel_id), + "session_id": str(session_id), + "agent_id": "agent-001", + "user_id": str(user_id), + "project_id": str(project_id), } async def test_group_by_required_only(self) -> None: diff --git a/tests/unit/manager/api/endpoint/test_types.py b/tests/unit/manager/api/endpoint/test_types.py index a64bcb048ef..0fd9df30daf 100644 --- a/tests/unit/manager/api/endpoint/test_types.py +++ b/tests/unit/manager/api/endpoint/test_types.py @@ -4,7 +4,7 @@ from ai.backend.common.data.endpoint.types import EndpointLifecycle, EndpointStatus from ai.backend.manager.api.gql_legacy.endpoint import Endpoint -from ai.backend.manager.data.deployment.types import RouteHealthStatus, RouteStatus +from ai.backend.manager.data.deployment.types import RouteStatus class TestEndpointType: @@ -21,7 +21,6 @@ async def test_status_unhealthy_when_no_healthy_routes(self) -> None: unhealthy_route = Mock() unhealthy_route.status = RouteStatus.FAILED_TO_START.name - unhealthy_route.health_status = RouteHealthStatus.UNHEALTHY.name mock_endpoint.routings = [unhealthy_route, unhealthy_route] result = await Endpoint.resolve_status(mock_endpoint, info=Mock()) @@ -37,10 +36,8 @@ async def test_status_degraded_when_healthy_and_provisioning_routes_mixed_1(self healthy_route = Mock() healthy_route.status = RouteStatus.RUNNING.name - healthy_route.health_status = RouteHealthStatus.HEALTHY.name provisioning_route = Mock() provisioning_route.status = RouteStatus.PROVISIONING.name - provisioning_route.health_status = RouteHealthStatus.NOT_CHECKED.name mock_endpoint.routings = [healthy_route, provisioning_route] @@ -60,10 +57,8 @@ async def test_status_degraded_when_healthy_and_provisioning_routes_mixed(self) healthy_route = Mock() healthy_route.status = RouteStatus.RUNNING.name - healthy_route.health_status = RouteHealthStatus.HEALTHY.name provisioning_route = Mock() provisioning_route.status = RouteStatus.PROVISIONING.name - provisioning_route.health_status = RouteHealthStatus.NOT_CHECKED.name mock_endpoint.routings = [healthy_route, provisioning_route] @@ -81,7 +76,6 @@ async def test_status_unhealthy_when_all_routes_terminated(self) -> None: terminated_route = Mock() terminated_route.status = RouteStatus.TERMINATED.name - terminated_route.health_status = RouteHealthStatus.UNHEALTHY.name mock_endpoint.routings = [terminated_route, terminated_route] @@ -102,10 +96,8 @@ async def test_status_healthy_when_terminated_routes_are_mixed_with_healthy(self healthy_route = Mock() healthy_route.status = RouteStatus.RUNNING.name - healthy_route.health_status = RouteHealthStatus.HEALTHY.name terminated_route = Mock() terminated_route.status = RouteStatus.TERMINATED.name - terminated_route.health_status = RouteHealthStatus.UNHEALTHY.name mock_endpoint.routings = [healthy_route, healthy_route, terminated_route] diff --git a/tests/unit/manager/api/gql_legacy/test_endpoint_resolve_status.py b/tests/unit/manager/api/gql_legacy/test_endpoint_resolve_status.py index 3da8fb1c6da..d2f27f370b5 100644 --- a/tests/unit/manager/api/gql_legacy/test_endpoint_resolve_status.py +++ b/tests/unit/manager/api/gql_legacy/test_endpoint_resolve_status.py @@ -6,7 +6,7 @@ from ai.backend.common.data.endpoint.types import EndpointLifecycle, EndpointStatus from ai.backend.manager.api.gql_legacy.endpoint import Endpoint -from ai.backend.manager.data.deployment.types import RouteHealthStatus, RouteStatus +from ai.backend.manager.data.deployment.types import RouteStatus class TestEndpointResolveStatus: @@ -33,7 +33,6 @@ async def test_empty_routings_returns_degraded(self, info: MagicMock) -> None: async def test_all_inactive_routings_returns_unhealthy(self, info: MagicMock) -> None: routing = MagicMock() routing.status = RouteStatus.TERMINATED.name - routing.health_status = RouteHealthStatus.UNHEALTHY.name ep = Endpoint() ep.lifecycle_stage = EndpointLifecycle.READY.name ep.routings = [routing] diff --git a/tests/unit/manager/dependencies/agents/test_registry.py b/tests/unit/manager/dependencies/agents/test_registry.py index 7ee1f2de877..cf4223e0022 100644 --- a/tests/unit/manager/dependencies/agents/test_registry.py +++ b/tests/unit/manager/dependencies/agents/test_registry.py @@ -36,6 +36,7 @@ async def test_provide_agent_registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), debug=False, manager_public_key=MagicMock(), manager_secret_key=MagicMock(), diff --git a/tests/unit/manager/repositories/agent/test_repository.py b/tests/unit/manager/repositories/agent/test_repository.py index 5f4291de485..1d90acb4f92 100644 --- a/tests/unit/manager/repositories/agent/test_repository.py +++ b/tests/unit/manager/repositories/agent/test_repository.py @@ -817,7 +817,6 @@ async def agent_with_kernels( domain_name=domain_name, group_id=UUID(group_id_str), user_uuid=uuid4(), - access_key="AKTEST" + uuid4().hex[:12], environ={}, mounts=[], vfolder_mounts=[], @@ -855,7 +854,6 @@ async def agent_with_kernels( domain_name=domain_name, group_id=UUID(group_id_str), user_uuid=uuid4(), - access_key="AKTEST" + uuid4().hex[:12], environ={}, mounts=[], vfolder_mounts=[], diff --git a/tests/unit/manager/repositories/deployment/test_deployment_repository.py b/tests/unit/manager/repositories/deployment/test_deployment_repository.py index 786c343972f..cdf17ff5e51 100644 --- a/tests/unit/manager/repositories/deployment/test_deployment_repository.py +++ b/tests/unit/manager/repositories/deployment/test_deployment_repository.py @@ -420,7 +420,6 @@ async def test_session_id( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_access_key, scaling_group_name=test_scaling_group_name, status=SessionStatus.RUNNING, cluster_mode=ClusterMode.SINGLE_NODE, @@ -467,7 +466,6 @@ async def test_kernel_with_inference_port( kernel = KernelRow( id=kernel_id, session_id=test_session_id, - access_key=test_access_key, agent=test_agent_id, agent_addr="127.0.0.1:2001", scaling_group=test_scaling_group_name, @@ -529,7 +527,6 @@ async def test_kernel_without_inference_port( kernel = KernelRow( id=kernel_id, session_id=test_session_id, - access_key=test_access_key, agent=test_agent_id, agent_addr="127.0.0.1:2001", scaling_group=test_scaling_group_name, @@ -881,7 +878,6 @@ async def test_fetch_multiple_routes( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_access_key, scaling_group_name=test_scaling_group_name, status=SessionStatus.RUNNING, cluster_mode=ClusterMode.SINGLE_NODE, @@ -908,7 +904,6 @@ async def test_fetch_multiple_routes( kernel = KernelRow( id=kernel_id, session_id=session_id, - access_key=test_access_key, agent=test_agent_id, agent_addr="127.0.0.1:2001", scaling_group=test_scaling_group_name, @@ -3815,78 +3810,3 @@ async def test_create_endpoint_succeeds_with_same_name_when_existing_is_destroye assert result.metadata.name == "reusable-endpoint" assert result.metadata.project == test_group.id - - @pytest.fixture - async def coexisting_active_and_destroying_endpoints( - self, - db_with_cleanup: ExtendedAsyncSAEngine, - test_domain: DomainRow, - test_group: GroupRow, - test_scaling_group: ScalingGroupRow, - ) -> tuple[uuid.UUID, uuid.UUID]: - """Seed two endpoints sharing (name, domain, project) — one CREATED and - one already in DESTROYING — bypassing the application-level uniqueness - check to reproduce the corrupt state described in BA-5698. - - Returns (target_id, sibling_id) where target is the active row to be - destroyed and sibling is the DESTROYING row. - """ - duplicate_name = f"dup-destroy-{uuid.uuid4().hex[:8]}" - target_id = uuid.uuid4() - sibling_id = uuid.uuid4() - user_id = uuid.uuid4() - - async with db_with_cleanup.begin_session() as db_sess: - target = EndpointRow( - id=target_id, - name=duplicate_name, - created_user=user_id, - session_owner=user_id, - domain=test_domain.name, - project=test_group.id, - resource_group=test_scaling_group.name, - replicas=1, - desired_replicas=1, - url=None, - open_to_public=False, - lifecycle_stage=EndpointLifecycle.CREATED, - ) - sibling = EndpointRow( - id=sibling_id, - name=duplicate_name, - created_user=user_id, - session_owner=user_id, - domain=test_domain.name, - project=test_group.id, - resource_group=test_scaling_group.name, - replicas=0, - desired_replicas=0, - url=None, - open_to_public=False, - lifecycle_stage=EndpointLifecycle.DESTROYING, - ) - db_sess.add_all([target, sibling]) - await db_sess.commit() - - return target_id, sibling_id - - async def test_destroy_endpoint_with_destroying_sibling_does_not_conflict( - self, - deployment_repository: DeploymentRepository, - db_with_cleanup: ExtendedAsyncSAEngine, - coexisting_active_and_destroying_endpoints: tuple[uuid.UUID, uuid.UUID], - ) -> None: - """Destroying an active endpoint succeeds even when a sibling row with - the same (name, domain, project) is already in DESTROYING.""" - target_id, _ = coexisting_active_and_destroying_endpoints - - succeeded = await deployment_repository.destroy_endpoint(target_id) - assert succeeded is True - - async with db_with_cleanup.begin_session() as db_sess: - target_stage = ( - await db_sess.execute( - sa.select(EndpointRow.lifecycle_stage).where(EndpointRow.id == target_id) - ) - ).scalar_one() - assert target_stage == EndpointLifecycle.DESTROYING diff --git a/tests/unit/manager/repositories/group/test_group_db_source.py b/tests/unit/manager/repositories/group/test_group_db_source.py index 7ba6842e361..ddf20eca633 100644 --- a/tests/unit/manager/repositories/group/test_group_db_source.py +++ b/tests/unit/manager/repositories/group/test_group_db_source.py @@ -327,7 +327,6 @@ async def inactive_endpoint_with_session_and_routing( domain_name=test_domain, group_id=test_group, user_uuid=test_user, - access_key="test-access-key", cluster_mode="single-node", cluster_size=1, occupying_slots=ResourceSlot(), @@ -455,7 +454,6 @@ async def multiple_endpoints_with_sessions( domain_name=test_domain, group_id=test_group, user_uuid=test_user, - access_key=f"test-access-key-{i}", cluster_mode="single-node", cluster_size=1, occupying_slots=ResourceSlot(), diff --git a/tests/unit/manager/repositories/group/test_group_repository.py b/tests/unit/manager/repositories/group/test_group_repository.py index a8fa42e0731..c057b800bf2 100644 --- a/tests/unit/manager/repositories/group/test_group_repository.py +++ b/tests/unit/manager/repositories/group/test_group_repository.py @@ -551,7 +551,6 @@ async def group_with_active_kernel( domain_name=test_domain, group_id=group_id, user_uuid=test_user, - access_key="test-access-key", cluster_mode="single-node", cluster_size=1, occupying_slots=ResourceSlot(), @@ -571,7 +570,6 @@ async def group_with_active_kernel( domain_name=test_domain, group_id=group_id, user_uuid=test_user, - access_key="test-access-key", agent=agent_id, agent_addr="tcp://127.0.0.1:5001", cluster_role="main", @@ -712,7 +710,6 @@ async def group_with_mounted_vfolders( domain_name=test_domain, group_id=group_id, user_uuid=test_user, - access_key="test-access-key", cluster_mode="single-node", cluster_size=1, occupying_slots=ResourceSlot(), @@ -733,7 +730,6 @@ async def group_with_mounted_vfolders( domain_name=test_domain, group_id=group_id, user_uuid=test_user, - access_key="test-access-key", agent=agent_id, agent_addr="tcp://127.0.0.1:5001", cluster_role="main", diff --git a/tests/unit/manager/repositories/resource_preset/test_check_presets.py b/tests/unit/manager/repositories/resource_preset/test_check_presets.py index 1a9d7f8aef0..911a72885b2 100644 --- a/tests/unit/manager/repositories/resource_preset/test_check_presets.py +++ b/tests/unit/manager/repositories/resource_preset/test_check_presets.py @@ -606,7 +606,6 @@ async def test_running_kernels_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, scaling_group_name=test_scaling_group_name, result=SessionResult.UNDEFINED, agent_ids=[], @@ -623,7 +622,6 @@ async def test_running_kernels_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, image="test-image:latest", status=KernelStatus.RUNNING, status_changed=datetime.now(tzutc()), @@ -739,7 +737,6 @@ async def test_terminating_kernels_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, scaling_group_name=test_scaling_group_name, result=SessionResult.UNDEFINED, agent_ids=[], @@ -756,7 +753,6 @@ async def test_terminating_kernels_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, image="test-image:latest", status=KernelStatus.TERMINATING, status_changed=datetime.now(tzutc()), @@ -869,7 +865,6 @@ async def test_pending_kernels_do_not_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, scaling_group_name=test_scaling_group_name, result=SessionResult.UNDEFINED, agent_ids=[], @@ -886,7 +881,6 @@ async def test_pending_kernels_do_not_count_towards_occupied_slots( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, image="test-image:latest", status=KernelStatus.PENDING, status_changed=datetime.now(tzutc()), @@ -1028,7 +1022,6 @@ async def test_ignores_cached_occupied_slots_in_agent_row( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, scaling_group_name=test_scaling_group_name, result=SessionResult.UNDEFINED, agent_ids=[], @@ -1045,7 +1038,6 @@ async def test_ignores_cached_occupied_slots_in_agent_row( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_keypair_access_key, image="test-image:latest", status=KernelStatus.RUNNING, status_changed=datetime.now(tzutc()), diff --git a/tests/unit/manager/repositories/scaling_group/test_resource_info.py b/tests/unit/manager/repositories/scaling_group/test_resource_info.py index 80ccaf6286f..cbb8d2894ff 100644 --- a/tests/unit/manager/repositories/scaling_group/test_resource_info.py +++ b/tests/unit/manager/repositories/scaling_group/test_resource_info.py @@ -48,7 +48,7 @@ def _create_kernel( session_id: SessionId, domain_name: str, group_id: uuid.UUID, - user_uuid: uuid.UUID, + owner_id: uuid.UUID, scaling_group: str, agent_id: str, status: KernelStatus, @@ -62,7 +62,7 @@ def _create_kernel( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + user_uuid=owner_id, scaling_group=scaling_group, agent=agent_id, status=status, @@ -464,7 +464,7 @@ async def scaling_group_with_running_kernels( Returns: Tuple of (scaling_group_name, agent_capacity, list of kernel occupied_slots) """ - user_uuid, domain_name, group_id = test_user_domain_group + owner_id, domain_name, group_id = test_user_domain_group agent_capacity = ResourceSlot({"cpu": Decimal("16"), "mem": Decimal("34359738368")}) kernel_slots = [ @@ -506,7 +506,7 @@ async def scaling_group_with_running_kernels( id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + user_uuid=owner_id, scaling_group_name=base_scaling_group, cluster_size=2, vfolder_mounts={}, @@ -520,7 +520,7 @@ async def scaling_group_with_running_kernels( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.RUNNING, @@ -536,7 +536,7 @@ async def scaling_group_with_running_kernels( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.TERMINATING, @@ -562,7 +562,7 @@ async def scaling_group_with_mixed_kernel_statuses( Returns: Tuple of (scaling_group_name, agent_capacity, expected_used) """ - user_uuid, domain_name, group_id = test_user_domain_group + owner_id, domain_name, group_id = test_user_domain_group agent_capacity = ResourceSlot({"cpu": Decimal("32"), "mem": Decimal("68719476736")}) running_slots = ResourceSlot({"cpu": Decimal("2"), "mem": Decimal("4294967296")}) @@ -606,7 +606,7 @@ async def scaling_group_with_mixed_kernel_statuses( id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + user_uuid=owner_id, scaling_group_name=base_scaling_group, cluster_size=4, vfolder_mounts={}, @@ -620,7 +620,7 @@ async def scaling_group_with_mixed_kernel_statuses( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.RUNNING, @@ -636,7 +636,7 @@ async def scaling_group_with_mixed_kernel_statuses( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.TERMINATING, @@ -652,7 +652,7 @@ async def scaling_group_with_mixed_kernel_statuses( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.TERMINATED, @@ -668,7 +668,7 @@ async def scaling_group_with_mixed_kernel_statuses( session_id=session_id, domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, + owner_id=owner_id, scaling_group=base_scaling_group, agent_id=agent_id, status=KernelStatus.PENDING, diff --git a/tests/unit/manager/repositories/scheduler/test_cancellation_resource_freeing.py b/tests/unit/manager/repositories/scheduler/test_cancellation_resource_freeing.py index 73d8d01fd73..8c48d8c0567 100644 --- a/tests/unit/manager/repositories/scheduler/test_cancellation_resource_freeing.py +++ b/tests/unit/manager/repositories/scheduler/test_cancellation_resource_freeing.py @@ -366,7 +366,6 @@ async def _insert_session( domain_name=domain_name, group_id=group_id, user_uuid=user_uuid, - access_key=access_key, mounts=[], environ={}, vfolder_mounts=[], diff --git a/tests/unit/manager/repositories/scheduler/test_resource_allocation.py b/tests/unit/manager/repositories/scheduler/test_resource_allocation.py index 757b46bfff8..751ed037019 100644 --- a/tests/unit/manager/repositories/scheduler/test_resource_allocation.py +++ b/tests/unit/manager/repositories/scheduler/test_resource_allocation.py @@ -391,7 +391,6 @@ async def _create_kernel_with_pending_allocations( domain_name=domain_name, group_id=group_id, user_uuid=user_uuid, - access_key=access_key, mounts=[], environ={}, vfolder_mounts=[], diff --git a/tests/unit/manager/repositories/scheduler/test_resource_deallocation.py b/tests/unit/manager/repositories/scheduler/test_resource_deallocation.py index f16a22481db..bd573578b33 100644 --- a/tests/unit/manager/repositories/scheduler/test_resource_deallocation.py +++ b/tests/unit/manager/repositories/scheduler/test_resource_deallocation.py @@ -236,7 +236,6 @@ async def test_access_key( db_sess.add( KeyPairRow( user_id=f"test-user-{uuid.uuid4().hex[:8]}@test.com", - access_key=access_key, secret_key=SecretKey(f"SK{uuid.uuid4().hex}"), is_active=True, is_admin=False, @@ -318,8 +317,7 @@ async def _create_session_with_kernel_and_resources( domain_name: str, scaling_group_name: str, group_id: uuid.UUID, - user_uuid: uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, agent_id: str | None, cpu_used: Decimal = Decimal("2"), mem_used: Decimal = Decimal("4096"), @@ -370,8 +368,7 @@ async def _create_session_with_kernel_and_resources( requested_slots=ResourceSlot({"cpu": cpu_used, "mem": mem_used}), domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, - access_key=access_key, + user_uuid=owner_id, mounts=[], environ={}, vfolder_mounts=[], @@ -459,8 +456,7 @@ async def test_force_terminate_frees_resources( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, cpu_used=Decimal("2"), mem_used=Decimal("4096"), @@ -526,8 +522,7 @@ async def test_force_terminate_with_null_agent_still_sets_free_at( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=None, ) @@ -573,8 +568,7 @@ async def test_force_terminate_from_terminating_session_frees_resources( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, cpu_used=Decimal("2"), mem_used=Decimal("4096"), @@ -781,7 +775,6 @@ async def test_access_key( db_sess.add( KeyPairRow( user_id=f"test-user-{uuid.uuid4().hex[:8]}@test.com", - access_key=access_key, secret_key=SecretKey(f"SK{uuid.uuid4().hex}"), is_active=True, is_admin=False, @@ -861,8 +854,7 @@ async def _create_kernel_with_resources( domain_name: str, scaling_group_name: str, group_id: uuid.UUID, - user_uuid: uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, agent_id: str, cpu_used: Decimal = Decimal("2"), mem_used: Decimal = Decimal("4096"), @@ -913,8 +905,7 @@ async def _create_kernel_with_resources( requested_slots=ResourceSlot({"cpu": cpu_used, "mem": mem_used}), domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, - access_key=access_key, + user_uuid=owner_id, mounts=[], environ={}, vfolder_mounts=[], @@ -997,8 +988,7 @@ async def test_bulk_terminate_frees_resources( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) @@ -1211,7 +1201,6 @@ async def test_access_key( db_sess.add( KeyPairRow( user_id=f"test-user-{uuid.uuid4().hex[:8]}@test.com", - access_key=access_key, secret_key=SecretKey(f"SK{uuid.uuid4().hex}"), is_active=True, is_admin=False, @@ -1344,7 +1333,6 @@ async def test_double_terminate_does_not_go_negative( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_access_key, mounts=[], environ={}, vfolder_mounts=[], diff --git a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py index 687bd85c648..157521c45e1 100644 --- a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py +++ b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py @@ -250,7 +250,6 @@ async def test_access_key( async with db_with_cleanup.begin_session() as db_sess: keypair = KeyPairRow( user_id=f"test-user-{uuid.uuid4().hex[:8]}@test.com", - access_key=access_key, secret_key=SecretKey(f"SK{uuid.uuid4().hex}"), is_active=True, is_admin=False, @@ -587,7 +586,6 @@ async def test_access_key( async with db_with_cleanup.begin_session() as db_sess: keypair = KeyPairRow( user_id=f"test-user-{uuid.uuid4().hex[:8]}@test.com", - access_key=access_key, secret_key=SecretKey(f"SK{uuid.uuid4().hex}"), is_active=True, is_admin=False, @@ -662,8 +660,7 @@ async def _create_session_with_kernel( domain_name: str, scaling_group_name: str, group_id: uuid.UUID, - user_uuid: uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, agent_id: str, ) -> SessionId: """Helper to create a session with a kernel in given statuses.""" @@ -710,8 +707,7 @@ async def _create_session_with_kernel( requested_slots=ResourceSlot({"cpu": Decimal("2"), "mem": Decimal("4096")}), domain_name=domain_name, group_id=group_id, - user_uuid=user_uuid, - access_key=access_key, + user_uuid=owner_id, mounts=[], environ={}, vfolder_mounts=[], @@ -746,8 +742,7 @@ async def test_mark_sessions_as_terminating_creates_scheduling_history( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) @@ -787,8 +782,7 @@ async def test_force_terminate_creates_scheduling_history( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) @@ -827,8 +821,7 @@ async def test_force_terminate_from_terminating_creates_scheduling_history( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) @@ -868,8 +861,7 @@ async def test_mark_sessions_as_terminating_captures_correct_from_status( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) scheduled_session_id = await self._create_session_with_kernel( @@ -879,8 +871,7 @@ async def test_mark_sessions_as_terminating_captures_correct_from_status( domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, agent_id=test_agent_id, ) diff --git a/tests/unit/manager/repositories/scheduler/test_termination.py b/tests/unit/manager/repositories/scheduler/test_termination.py index c7abf509546..2feff5b1a39 100644 --- a/tests/unit/manager/repositories/scheduler/test_termination.py +++ b/tests/unit/manager/repositories/scheduler/test_termination.py @@ -381,7 +381,6 @@ async def test_terminating_kernel_id( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_access_key, mounts=[], environ={}, vfolder_mounts=[], @@ -432,7 +431,6 @@ async def test_running_kernel_id( domain_name=test_domain_name, group_id=test_group_id, user_uuid=test_user_uuid, - access_key=test_access_key, mounts=[], environ={}, vfolder_mounts=[], diff --git a/tests/unit/manager/repositories/session/test_session_repository.py b/tests/unit/manager/repositories/session/test_session_repository.py index 7681676ed30..d9842123213 100644 --- a/tests/unit/manager/repositories/session/test_session_repository.py +++ b/tests/unit/manager/repositories/session/test_session_repository.py @@ -222,7 +222,6 @@ async def session_with_kernel(self, db_with_cleanup: ExtendedAsyncSAEngine) -> S domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, tag=None, status=SessionStatus.RUNNING, status_info=None, @@ -254,7 +253,6 @@ async def session_with_kernel(self, db_with_cleanup: ExtendedAsyncSAEngine) -> S domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, cluster_mode=ClusterMode.SINGLE_NODE.value, cluster_size=1, cluster_role="main", @@ -375,8 +373,7 @@ async def test_search_sessions( assert session_data.name == "test-session" assert session_data.domain_name == session_with_kernel.domain_name assert session_data.group_id == session_with_kernel.group_id - assert session_data.user_uuid == session_with_kernel.user_id - assert session_data.access_key == session_with_kernel.access_key + assert session_data.owner_id == session_with_kernel.user_id async def test_search_sessions_empty_result( self, @@ -548,7 +545,6 @@ async def session_with_allocations( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, tag=None, status=SessionStatus.RUNNING, status_info=None, @@ -578,7 +574,6 @@ async def session_with_allocations( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, cluster_mode=ClusterMode.SINGLE_NODE.value, cluster_size=1, cluster_role="main", diff --git a/tests/unit/manager/repositories/session/test_session_search_in_project.py b/tests/unit/manager/repositories/session/test_session_search_in_project.py index bbec45be7af..1ad3bd34054 100644 --- a/tests/unit/manager/repositories/session/test_session_search_in_project.py +++ b/tests/unit/manager/repositories/session/test_session_search_in_project.py @@ -220,7 +220,6 @@ async def test_data( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, tag=None, status=SessionStatus.RUNNING, status_info=None, @@ -250,7 +249,6 @@ async def test_data( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, cluster_mode=ClusterMode.SINGLE_NODE.value, cluster_size=1, cluster_role="main", diff --git a/tests/unit/manager/services/deployment/test_apply_deployment_level_preset.py b/tests/unit/manager/services/deployment/test_apply_deployment_level_preset.py index de437eda095..cb037cb3e6d 100644 --- a/tests/unit/manager/services/deployment/test_apply_deployment_level_preset.py +++ b/tests/unit/manager/services/deployment/test_apply_deployment_level_preset.py @@ -112,6 +112,7 @@ def _make_creator( execution=ExecutionSpec(runtime_variant=RuntimeVariant("custom")), model_definition=None, revision_preset_id=preset_id, + auto_activate=True, ), policy=policy, ) diff --git a/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py b/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py index f9e1c45fa27..b1cf8aa8d4b 100644 --- a/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py +++ b/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py @@ -124,8 +124,6 @@ def mock_deployment_controller(self) -> MagicMock: def mock_deployment_repository(self) -> MagicMock: mock = MagicMock() mock.get_default_architecture_from_scaling_group = AsyncMock(return_value=None) - mock.get_endpoint_info = AsyncMock(return_value=MagicMock(current_revision_id=None)) - mock.fetch_model_definition = AsyncMock(return_value=None) return mock @pytest.fixture @@ -230,9 +228,7 @@ def mock_modify_endpoint(self, mocker: Any, mock_repositories: Any) -> AsyncMock ), ) - def _make_updater_spec( - self, *, replicas: int | None = None, has_revision_changes: bool = False - ) -> MagicMock: + def _make_updater_spec(self, *, replicas: int | None = None) -> MagicMock: spec = MagicMock(spec=EndpointUpdaterSpec) if replicas is not None: spec.replicas = OptionalState.update(replicas) @@ -240,7 +236,6 @@ def _make_updater_spec( else: spec.replicas = OptionalState[int].nop() spec.replica_count_modified.return_value = False - spec.has_revision_changes.return_value = has_revision_changes return spec async def test_replica_count_change_marks_check_replica( @@ -266,34 +261,7 @@ async def test_replica_count_change_marks_check_replica( assert result.success is True assert result.data == mock_endpoint_data - mock_deployment_controller.mark_lifecycle_needed.assert_awaited_once_with( - DeploymentLifecycleType.CHECK_REPLICA - ) - - async def test_revision_change_marks_check_replica( - self, - model_serving_processors: ModelServingProcessors, - mock_modify_endpoint: AsyncMock, - mock_deployment_controller: MagicMock, - endpoint_id: uuid.UUID, - ) -> None: - """Revision-level field change triggers CHECK_REPLICA to auto-activate the new revision.""" - updater_spec = self._make_updater_spec(replicas=None, has_revision_changes=True) - mock_updater = MagicMock() - mock_updater.spec = updater_spec - - mock_endpoint_data = MagicMock() - mock_endpoint_data.id = endpoint_id - mock_modify_endpoint.return_value = MutationResult( - success=True, message="ok", data=mock_endpoint_data - ) - - action = ModifyEndpointAction(endpoint_id=endpoint_id, updater=mock_updater) - result = await model_serving_processors.modify_endpoint.wait_for_complete(action) - - assert result.success is True - assert result.data == mock_endpoint_data - mock_deployment_controller.mark_lifecycle_needed.assert_awaited_once_with( + mock_deployment_controller.mark_lifecycle_needed.assert_called_once_with( DeploymentLifecycleType.CHECK_REPLICA ) diff --git a/tests/unit/manager/services/session/BUILD b/tests/unit/manager/services/session/BUILD index 8dfc7444c13..3b2600c7ce8 100644 --- a/tests/unit/manager/services/session/BUILD +++ b/tests/unit/manager/services/session/BUILD @@ -2,3 +2,8 @@ python_tests( name="tests", sources=["test_*.py"], ) + +python_test_utils( + name="test_utils", + sources=["conftest.py"], +) diff --git a/tests/unit/manager/services/session/conftest.py b/tests/unit/manager/services/session/conftest.py new file mode 100644 index 00000000000..e3978085c36 --- /dev/null +++ b/tests/unit/manager/services/session/conftest.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Iterator +from uuid import uuid4 + +import pytest + +from ai.backend.common.contexts.user import with_user +from ai.backend.common.data.user.types import UserData +from ai.backend.manager.models.user import UserRole + + +@pytest.fixture(autouse=True) +def _user_context(request: pytest.FixtureRequest) -> Iterator[None]: + """Set up a default authenticated user in the request context for service tests. + + Session service methods resolve the requester via current_user(); without this + autouse fixture, tests calling such methods would fail because the auth + middleware (which normally populates the context) is not exercised in unit + tests. + + If the test defines a ``sample_user_id`` fixture, that UUID is used so that + assertions referencing the same value continue to match. + """ + try: + user_id = request.getfixturevalue("sample_user_id") + except pytest.FixtureLookupError: + user_id = uuid4() + with with_user( + UserData( + user_id=user_id, + is_authorized=True, + is_admin=False, + is_superadmin=False, + role=UserRole.USER, + domain_name="default", + ) + ): + yield diff --git a/tests/unit/manager/services/session/test_session_lifecycle_service.py b/tests/unit/manager/services/session/test_session_lifecycle_service.py index 5919d5171a2..58a4030df25 100644 --- a/tests/unit/manager/services/session/test_session_lifecycle_service.py +++ b/tests/unit/manager/services/session/test_session_lifecycle_service.py @@ -168,7 +168,7 @@ async def session_service( session_repository=mock_session_repository, scheduling_controller=mock_scheduling_controller, appproxy_client_pool=mock_appproxy_client_pool, - user_repository=MagicMock(), + user_repository=AsyncMock(), ) return SessionService(args) @@ -210,6 +210,7 @@ def _make_session_data( ) -> SessionData: return SessionData( id=session_id, + owner_id=user_id, creation_id="test-creation-id", name=name, session_type=session_type, @@ -220,8 +221,6 @@ def _make_session_data( agent_ids=["i-ubuntu"], domain_name="default", group_id=group_id, - user_uuid=user_id, - access_key=access_key, images=["cr.backend.ai/stable/python:latest"], tag=None, occupying_slots=ResourceSlot({"cpu": 1, "mem": 1024}), @@ -312,7 +311,6 @@ async def test_commit_success( action = CommitSessionAction( session_name="test-session", - owner_access_key=sample_access_key, filename=None, ) @@ -327,6 +325,7 @@ async def test_commit_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("Session not found") @@ -334,7 +333,6 @@ async def test_commit_session_not_found( action = CommitSessionAction( session_name="nonexistent", - owner_access_key=sample_access_key, filename=None, ) @@ -363,7 +361,6 @@ async def test_commit_custom_filename( action = CommitSessionAction( session_name="test-session", - owner_access_key=sample_access_key, filename="my-snapshot.tar.gz", ) @@ -398,7 +395,6 @@ async def test_success( action = GetCommitStatusAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_commit_status(action) @@ -410,6 +406,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("Session not found") @@ -417,7 +414,6 @@ async def test_session_not_found( action = GetCommitStatusAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -460,7 +456,6 @@ async def test_v1_query_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(1,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode=None, options=None, @@ -505,7 +500,6 @@ async def test_v2_batch_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="batch", options=None, @@ -544,7 +538,6 @@ async def test_v2_complete_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="complete", options={}, @@ -575,7 +568,6 @@ async def test_v2_continue_without_run_id_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="continue", options=None, @@ -606,7 +598,6 @@ async def test_v2_invalid_mode_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="invalid_mode", options=None, @@ -637,7 +628,6 @@ async def test_v2_null_mode_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode=None, options=None, @@ -678,7 +668,6 @@ async def test_null_code_defaults_to_empty_string( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="query", options=None, @@ -712,11 +701,11 @@ def delegated_session_action( self, sample_access_key: AccessKey, sample_user_id: UUID, - delegated_owner_access_key: AccessKey, + delegated_owner_id: UUID, ) -> CreateFromParamsAction: """ CreateFromParamsAction representing an admin (sample_user_id) creating - a session on behalf of another user via owner_access_key. + a session on behalf of another user via owner_id. """ return CreateFromParamsAction( params=CreateFromParamsActionParams( @@ -732,7 +721,7 @@ def delegated_session_action( tag="", priority=0, is_preemptible=True, - owner_access_key=delegated_owner_access_key, + owner_id=delegated_owner_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -784,7 +773,7 @@ async def test_image_resolve_failure_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -830,7 +819,7 @@ async def test_invalid_domain_group_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -888,7 +877,7 @@ async def test_create_distributed_session( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -951,7 +940,7 @@ async def test_reuse_if_exists_returns_existing( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1011,7 +1000,7 @@ async def test_quota_exceeded_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1053,6 +1042,11 @@ async def test_owner_access_key_uses_owner_user_scope( identity leaked into the session row, causing scaling group access checks and container UID/GID lookups to use the wrong user. """ + user_repo_mock = MagicMock() + user_repo_mock.get_user_by_uuid = AsyncMock( + return_value=MagicMock(main_access_key=str(delegated_owner_access_key)) + ) + session_service._user_repository = user_repo_mock new_session_id = str(uuid4()) mock_session_repository.query_userinfo = AsyncMock( return_value=SessionOwnerContext( @@ -1159,7 +1153,7 @@ async def test_create_from_template_success( tag=undefined, priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1205,7 +1199,7 @@ async def test_template_not_found_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1266,7 +1260,7 @@ async def test_create_cluster_success( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1297,7 +1291,7 @@ async def test_template_not_found_raises( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1343,7 +1337,7 @@ async def test_too_many_sessions_converts_to_already_exists( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1381,7 +1375,6 @@ async def test_prefix_matching_returns_sessions( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1401,7 +1394,6 @@ async def test_no_match_returns_empty( action = MatchSessionsAction( id_or_name_prefix="nonexistent", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1419,12 +1411,11 @@ async def test_owner_access_key_filtering( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) await session_service.match_sessions(action) - mock_session_repository.match_sessions.assert_called_once_with("test", sample_access_key) + mock_session_repository.match_sessions.assert_called_once_with("test", sample_user_id) # ==================== GetAbusingReport Tests ==================== @@ -1451,7 +1442,6 @@ async def test_valid_session_returns_report( action = GetAbusingReportAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_abusing_report(action) @@ -1462,6 +1452,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("not found") @@ -1469,7 +1460,6 @@ async def test_session_not_found( action = GetAbusingReportAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1503,7 +1493,6 @@ async def test_system_session_returns_ports( action = GetDirectAccessInfoAction( session_name="system-session", - owner_access_key=sample_access_key, ) result = await session_service.get_direct_access_info(action) @@ -1533,7 +1522,6 @@ async def test_interactive_session_returns_empty_dict( action = GetDirectAccessInfoAction( session_name="interactive-session", - owner_access_key=sample_access_key, ) result = await session_service.get_direct_access_info(action) @@ -1562,7 +1550,6 @@ async def test_agent_row_none_raises_kernel_not_ready( action = GetDirectAccessInfoAction( session_name="system-session", - owner_access_key=sample_access_key, ) with pytest.raises(KernelNotReady): @@ -1597,7 +1584,6 @@ async def test_session_with_dependencies( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) result = await session_service.get_dependency_graph(action) @@ -1626,7 +1612,6 @@ async def test_root_only_session( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) result = await session_service.get_dependency_graph(action) @@ -1637,6 +1622,7 @@ async def test_empty_session_id_raises_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: dep_graph: dict[str, Any] = {"session_id": "", "children": []} mock_session_repository.find_dependency_sessions = AsyncMock(return_value=dep_graph) @@ -1644,7 +1630,6 @@ async def test_empty_session_id_raises_not_found( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1871,7 +1856,6 @@ async def test_shutdown_success( action = ShutdownServiceAction( session_name="test-session", - owner_access_key=sample_access_key, service_name="jupyter", ) result = await session_service.shutdown_service(action) @@ -1884,6 +1868,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("not found") @@ -1891,7 +1876,6 @@ async def test_session_not_found( action = ShutdownServiceAction( session_name="nonexistent", - owner_access_key=sample_access_key, service_name="jupyter", ) diff --git a/tests/unit/manager/services/session/test_session_service.py b/tests/unit/manager/services/session/test_session_service.py index 82827bf834d..df06b052e80 100644 --- a/tests/unit/manager/services/session/test_session_service.py +++ b/tests/unit/manager/services/session/test_session_service.py @@ -304,7 +304,7 @@ async def test_success( assert result.result[0]["id"] == str(sample_session_data.id) assert result.result[0]["name"] == sample_session_data.name assert result.result[0]["status"] == sample_session_data.status.name - mock_session_repository.match_sessions.assert_called_once_with("test", sample_access_key) + mock_session_repository.match_sessions.assert_called_once_with("test", sample_user_id) async def test_no_matches( self, @@ -403,6 +403,7 @@ async def test_success( mock_session_repository: MagicMock, sample_session_id: SessionId, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully getting status history""" status_history: dict[str, Any] = { @@ -430,6 +431,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting status history for non-existent session""" mock_session_repository.get_session_validated = AsyncMock( @@ -449,6 +451,7 @@ async def test_empty_status_history( mock_session_repository: MagicMock, sample_session_id: SessionId, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting empty status history returns empty dict when None""" mock_session = MagicMock() @@ -478,6 +481,7 @@ async def test_success_cancelled( mock_scheduling_controller: MagicMock, sample_session_id: SessionId, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully destroying session (cancelled status)""" mock_session_repository.get_target_session_ids = AsyncMock(return_value=[sample_session_id]) @@ -500,7 +504,7 @@ async def test_success_cancelled( assert result.result == {"stats": {"status": "cancelled"}} mock_session_repository.get_target_session_ids.assert_called_once_with( - "test-session", sample_access_key, recursive=False + "test-session", sample_user_id, recursive=False ) mock_scheduling_controller.mark_sessions_for_termination.assert_called_once() @@ -511,6 +515,7 @@ async def test_success_terminated( mock_scheduling_controller: MagicMock, sample_session_id: SessionId, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully destroying session (terminated status via normal termination)""" mock_session_repository.get_target_session_ids = AsyncMock(return_value=[sample_session_id]) @@ -545,6 +550,7 @@ async def test_force_terminate_directly_terminated( mock_scheduling_controller: MagicMock, sample_session_id: SessionId, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test force-terminate skips TERMINATING and goes directly to TERMINATED""" mock_session_repository.get_target_session_ids = AsyncMock(return_value=[sample_session_id]) @@ -578,6 +584,7 @@ async def test_recursive_destroy( mock_session_repository: MagicMock, mock_scheduling_controller: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test destroying sessions recursively""" session_ids = [SessionId(uuid4()) for _ in range(3)] @@ -600,7 +607,7 @@ async def test_recursive_destroy( result = await session_service.destroy_session(action) mock_session_repository.get_target_session_ids.assert_called_once_with( - "test-session", sample_access_key, recursive=True + "test-session", sample_user_id, recursive=True ) assert result.result == {"stats": {"status": "cancelled"}} @@ -610,6 +617,7 @@ async def test_no_sessions_to_destroy( mock_session_repository: MagicMock, mock_scheduling_controller: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test destroying when no sessions found""" mock_session_repository.get_target_session_ids = AsyncMock(return_value=[]) @@ -646,6 +654,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully completing code""" expected_response = CodeCompletionResp( @@ -680,6 +689,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test completing code when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -747,6 +757,7 @@ async def test_success( mock_running_session: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully getting session info""" mock_session_repository.get_session_validated = AsyncMock(return_value=mock_running_session) @@ -772,6 +783,7 @@ async def test_success_with_no_container_id( mock_session_repository: MagicMock, mock_running_session: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting session info when container_id is None (pre-RUNNING state)""" mock_running_session.main_kernel.container_id = None @@ -791,6 +803,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting session info when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -897,6 +910,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully getting direct access info""" mock_session = MagicMock() @@ -920,6 +934,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting direct access info when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -946,6 +961,7 @@ async def test_success( mock_session_repository: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully renaming session""" mock_session = MagicMock() @@ -962,7 +978,7 @@ async def test_success( assert isinstance(result, RenameSessionActionResult) assert result.session_data == sample_session_data mock_session_repository.update_session_name.assert_called_once_with( - "test-session", "new-session-name", sample_access_key + "test-session", "new-session-name", sample_user_id ) async def test_not_running_session( @@ -971,6 +987,7 @@ async def test_not_running_session( mock_session_repository: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test renaming non-running session raises error""" mock_session = MagicMock() @@ -999,6 +1016,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully restarting session""" mock_session = MagicMock() @@ -1023,6 +1041,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test restarting session when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1050,6 +1069,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully shutting down service""" mock_session = MagicMock() @@ -1074,6 +1094,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test shutting down service when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1102,6 +1123,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully uploading files""" # Create a mock reader @@ -1142,6 +1164,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test uploading files when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1172,6 +1195,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully executing code""" expected_execute_response = { @@ -1214,6 +1238,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test executing code when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1249,6 +1274,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully interrupting session""" mock_session = MagicMock() @@ -1272,6 +1298,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test interrupting session when not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1357,6 +1384,7 @@ async def test_success( mock_agent_registry: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully getting container logs""" # get_logs_from_agent returns the logs directly @@ -1386,6 +1414,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test getting logs when session not found""" mock_session_repository.get_session_validated = AsyncMock( @@ -1674,7 +1703,7 @@ def sample_kernel_info(self) -> KernelInfo: ), user_permission=UserPermission( owner_id=user_id, - main_access_key="TESTKEY", + main_access_key=None, domain_name="default", group_id=group_id, uid=1000, diff --git a/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py b/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py index 959b4f30e34..2377b8d0bc4 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py @@ -64,7 +64,7 @@ def handler( def _make_terminating_session_data(session_id: SessionId) -> TerminatingSessionData: return TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-access-key"), + main_access_key=AccessKey("test-access-key"), creation_id="test-creation-id", status=SessionStatus.TERMINATED, status_info="FORCE_TERMINATED", diff --git a/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py b/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py index 787c969edde..01f0a1740c7 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py @@ -123,7 +123,7 @@ async def test_partial_scheduling_returns_skipped( ScheduledSessionData( session_id=first_session.session_info.identity.id, creation_id=first_session.session_info.identity.creation_id, - access_key=AccessKey(first_session.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), reason="scheduled-successfully", ) ] diff --git a/tests/unit/manager/sokovan/scheduler/launcher/conftest.py b/tests/unit/manager/sokovan/scheduler/launcher/conftest.py index 3ccf846a507..8b12a6933c8 100644 --- a/tests/unit/manager/sokovan/scheduler/launcher/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/launcher/conftest.py @@ -163,7 +163,7 @@ def _create_session_for_pull( return SessionDataForPull( session_id=session_id or SessionId(uuid4()), creation_id=str(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), kernels=kernels, ) @@ -234,10 +234,10 @@ def _create_session_for_start( return SessionDataForStart( session_id=session_id or SessionId(uuid4()), creation_id=str(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), session_type=SessionTypes.INTERACTIVE, name="test-session", - user_uuid=uuid4(), + owner_id=uuid4(), user_email="test@example.com", user_name="testuser", cluster_mode=cluster_mode, diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py index e3425dd9db1..199eaf1c085 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py @@ -126,9 +126,9 @@ async def test_single_user_workloads( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -136,9 +136,9 @@ async def test_single_user_workloads( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -162,9 +162,9 @@ async def test_multiple_users_different_dominant_shares( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # 30% dominant share + main_access_key=AccessKey("user2"), # 30% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -172,9 +172,9 @@ async def test_multiple_users_different_dominant_shares( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), # 5% dominant share (lowest) + main_access_key=AccessKey("user3"), # 5% dominant share (lowest) requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -182,9 +182,9 @@ async def test_multiple_users_different_dominant_shares( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), # 20% dominant share + main_access_key=AccessKey("user1"), # 20% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -198,9 +198,9 @@ async def test_multiple_users_different_dominant_shares( # Should be ordered by dominant share (ascending): user3 (5%), user1 (20%), user2 (30%) assert len(result) == 3 - assert result[0].access_key == AccessKey("user3") - assert result[1].access_key == AccessKey("user1") - assert result[2].access_key == AccessKey("user2") + assert result[0].main_access_key == AccessKey("user3") + assert result[1].main_access_key == AccessKey("user1") + assert result[2].main_access_key == AccessKey("user2") async def test_multiple_users_same_dominant_share( self, scaling_group: str, sequencer: DRFSequencer, empty_system_snapshot: SystemSnapshot @@ -209,9 +209,9 @@ async def test_multiple_users_same_dominant_share( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -219,9 +219,9 @@ async def test_multiple_users_same_dominant_share( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -229,9 +229,9 @@ async def test_multiple_users_same_dominant_share( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -256,9 +256,9 @@ async def test_new_user_gets_priority( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # 30% dominant share + main_access_key=AccessKey("user2"), # 30% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -266,9 +266,9 @@ async def test_new_user_gets_priority( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("new_user"), # 0% dominant share (new user) + main_access_key=AccessKey("new_user"), # 0% dominant share (new user) requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -282,8 +282,8 @@ async def test_new_user_gets_priority( # New user with 0% dominant share should get priority assert len(result) == 2 - assert result[0].access_key == AccessKey("new_user") - assert result[1].access_key == AccessKey("user2") + assert result[0].main_access_key == AccessKey("new_user") + assert result[1].main_access_key == AccessKey("user2") async def test_dominant_share_calculation_with_zero_capacity( self, scaling_group: str, sequencer: DRFSequencer @@ -331,9 +331,9 @@ async def test_dominant_share_calculation_with_zero_capacity( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py index 0815f5e161b..2a9c48add60 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py @@ -71,9 +71,9 @@ async def test_preserves_order( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -81,9 +81,9 @@ async def test_preserves_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -91,9 +91,9 @@ async def test_preserves_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("30"), mem=Decimal("30")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -161,9 +161,9 @@ async def test_ignores_system_snapshot( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # User with more allocation + main_access_key=AccessKey("user2"), # User with more allocation requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -171,9 +171,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), # User with less allocation + main_access_key=AccessKey("user1"), # User with less allocation requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py index 41394684266..240915fb253 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py @@ -71,9 +71,9 @@ async def test_reverses_order( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -81,9 +81,9 @@ async def test_reverses_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -91,9 +91,9 @@ async def test_reverses_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("30"), mem=Decimal("30")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -115,9 +115,9 @@ async def test_single_workload( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -183,9 +183,9 @@ async def test_ignores_system_snapshot( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -193,9 +193,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -203,9 +203,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), # New user + main_access_key=AccessKey("user3"), # New user requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py b/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py index ad2e6e75ae9..b257d9d937a 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py @@ -66,9 +66,9 @@ def _create_scheduling_data_with_strategy( # Create one pending session session = PendingSessionData( id=SessionId(uuid.uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal("1"), "mem": Decimal("1024")}), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group_name="test-sg", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py index 230ece49647..46acad45823 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py @@ -35,9 +35,9 @@ def sftp_validator(self) -> ConcurrencyValidator: def workload(self) -> SessionWorkload: return SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -48,9 +48,9 @@ def workload(self) -> SessionWorkload: def sftp_workload(self) -> SessionWorkload: return SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py index 4884f49da5c..63c010a524a 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py @@ -31,9 +31,9 @@ def validator(self) -> DependenciesValidator: def test_passes_when_no_dependencies(self, validator: DependenciesValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -60,9 +60,9 @@ def test_passes_when_dependencies_satisfied(self, validator: DependenciesValidat dep_id = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -99,9 +99,9 @@ def test_fails_when_dependencies_not_satisfied(self, validator: DependenciesVali dep_id = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -143,9 +143,9 @@ def test_fails_when_multiple_dependencies_not_satisfied( dep_id2 = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py index a494f3519e6..a5d9d95cac1 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py @@ -35,9 +35,9 @@ def test_passes_when_under_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -73,9 +73,9 @@ def test_fails_when_exceeds_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -112,9 +112,9 @@ def test_passes_when_no_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -153,9 +153,9 @@ def test_passes_when_no_current_occupancy( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py index 1edd0b6540e..c2bb588536f 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py @@ -31,9 +31,9 @@ def validator(self) -> KeypairResourceLimitValidator: def test_passes_when_under_limit(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -83,9 +83,9 @@ def test_passes_when_under_limit(self, validator: KeypairResourceLimitValidator) def test_fails_when_exceeds_limit(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -136,9 +136,9 @@ def test_fails_when_exceeds_limit(self, validator: KeypairResourceLimitValidator def test_passes_when_no_policy(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -181,9 +181,9 @@ def test_passes_when_no_current_occupancy( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py index 6699e73697b..700c850c89f 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py @@ -31,9 +31,9 @@ def validator(self) -> PendingSessionResourceLimitValidator: def test_passes_when_under_limit(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -82,9 +82,9 @@ def test_passes_when_under_limit(self, validator: PendingSessionResourceLimitVal def test_passes_when_no_limit(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -131,9 +131,9 @@ def test_passes_when_no_limit(self, validator: PendingSessionResourceLimitValida def test_passes_when_no_policy(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -173,9 +173,9 @@ def test_handles_multiple_pending_sessions( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py index 01de7ba1b13..03aff338cfd 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py @@ -37,7 +37,7 @@ def test_passes_when_under_limit( resource_occupancy=ResourceOccupancySnapshot( by_keypair={}, by_user={ - workload.user_uuid: [ + workload.owner_id: [ SlotQuantity("cpu", Decimal("3")), SlotQuantity("mem", Decimal("3")), ] @@ -49,7 +49,7 @@ def test_passes_when_under_limit( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) @@ -77,7 +77,7 @@ def test_fails_when_exceeds_limit( resource_occupancy=ResourceOccupancySnapshot( by_keypair={}, by_user={ - workload.user_uuid: [ + workload.owner_id: [ SlotQuantity("cpu", Decimal("8")), SlotQuantity("mem", Decimal("8")), ] @@ -89,7 +89,7 @@ def test_fails_when_exceeds_limit( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) @@ -151,7 +151,7 @@ def test_passes_when_no_current_occupancy( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) diff --git a/tests/unit/manager/sokovan/scheduler/test_scheduler.py b/tests/unit/manager/sokovan/scheduler/test_scheduler.py index f35e5bf9797..ddc293080b7 100644 --- a/tests/unit/manager/sokovan/scheduler/test_scheduler.py +++ b/tests/unit/manager/sokovan/scheduler/test_scheduler.py @@ -65,9 +65,9 @@ def create_session_workload( return SessionWorkload( session_id=session_id, - access_key=access_key, + main_access_key=access_key, requested_slots=requested_slots, - user_uuid=user_uuid, + owner_id=user_uuid, group_id=group_id, domain_name=domain_name, scaling_group=scaling_group,