Skip to content

Commit 9cc876e

Browse files
committed
refactor(BA-5650-I): test and remaining ORM updates
Final slice of the BA-5650 stack split. Contains: - Remaining ORM touch-ups: models/endpoint/row.py, models/keypair/row.py, repositories/scheduler/{repository,db_source}.py, api/adapters/vfolder.py, api/gql_legacy/endpoint.py, api/gql_legacy/routing.py. - Test updates that depend on the action/service/DTO renames already landed in earlier slices: adapter/session, scheduler repositories, session lifecycle/service, sokovan scheduler suite, compute_sessions handler, dependency injection tests. - Autouse ``_user_context`` fixture under tests/unit/manager/services/session/conftest.py so service tests work without the auth middleware.
1 parent 6180d2d commit 9cc876e

82 files changed

Lines changed: 436 additions & 807 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

changes/10916.breaking.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Change the REST v1 session API's delegation mechanism from the `owner_access_key` query parameter to an `owner_id` (user UUID) field. The `owner_access_key` parameter is removed from all session endpoints (`GET /session/{name}`, `DELETE /session/{name}`, `POST /session/_/create-from-template`, `POST /session/_/create`, `POST /session/_/create-cluster`, `GET /session/{name}/logs`, `GET /session/{name}/status-history`, etc.). Clients that previously passed `owner_access_key=<keypair>` to act on behalf of another user must now pass `owner_id=<user uuid>` (and only on the session-creation endpoints; for read/control endpoints the caller always acts as themselves). Session and kernel `access_key` semantics also change: the value is no longer tied to the keypair used to create the session but resolved at read time from the owner's `main_access_key`.

changes/BA-5650-I.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Test updates and remaining ORM/repository touch-ups for the BA-5650 stack: scheduler db_source, keypair and endpoint row cleanup, gql_legacy endpoint/routing, session lifecycle/service tests, sokovan scheduler tests, and dependency-injection tests. No external behavior change beyond what earlier slices documented.

dev

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,32 +149,14 @@ cmd_restart() {
149149

150150
cmd_log() {
151151
local svc=$1
152-
local follow=${2:-}
153152
local winname
154153
winname=$(_tmux_window_name "$svc")
155154
local win
156155
win=$(tmux list-windows -t "$TMUX_SESSION" -F "#{window_name}" 2>/dev/null | grep "^${winname}$" | head -1) || true
157-
if [ -z "$win" ]; then
158-
echo "$(_color red "No tmux window found for $svc")"
159-
return 1
160-
fi
161-
if [ "$follow" = "-f" ]; then
162-
local last_hash=""
163-
trap 'exit 0' INT
164-
while true; do
165-
local output
166-
output=$(tmux capture-pane -t "$TMUX_SESSION:$win" -p -S -50 2>/dev/null)
167-
local cur_hash
168-
cur_hash=$(echo "$output" | md5sum | cut -d' ' -f1)
169-
if [ "$cur_hash" != "$last_hash" ]; then
170-
clear
171-
echo "$output"
172-
last_hash=$cur_hash
173-
fi
174-
sleep 1
175-
done
176-
else
156+
if [ -n "$win" ]; then
177157
tmux capture-pane -t "$TMUX_SESSION:$win" -p -S -50
158+
else
159+
echo "$(_color red "No tmux window found for $svc")"
178160
fi
179161
}
180162

@@ -188,7 +170,7 @@ Commands:
188170
start <service|all> Start a service
189171
stop <service|all> Stop a service
190172
restart <service|all> Restart a service
191-
log <service> [-f] Show recent log output (-f to follow)
173+
log <service> Show recent log output
192174
193175
Services:
194176
mgr, agent, storage, web, proxy-coordinator, proxy-worker, all
@@ -225,9 +207,9 @@ case "$1" in
225207
"cmd_$1" "$2"
226208
;;
227209
log)
228-
[ $# -lt 2 ] && { echo "$(_color red "Usage: ./dev log <service> [-f]")"; exit 1; }
210+
[ $# -lt 2 ] && { echo "$(_color red "Usage: ./dev log <service>")"; exit 1; }
229211
_validate_service "$2"
230-
cmd_log "$2" "${3:-}"
212+
cmd_log "$2"
231213
;;
232214
*) echo "$(_color red "Unknown command: $1")"; usage; exit 1 ;;
233215
esac

docs/manager/graphql-reference/schema.graphql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,9 +1836,6 @@ type Routing implements Item {
18361836
endpoint: String
18371837
session: UUID
18381838
status: String
1839-
1840-
"""Added in 26.4.1."""
1841-
health_status: String
18421839
traffic_ratio: Float
18431840
created_at: DateTime
18441841
error_data: JSONString

docs/manager/graphql-reference/supergraph.graphql

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,6 @@ input AddRevisionInput
166166
extraMounts: [ExtraVFolderMountInput!] = null
167167
}
168168

169-
"""Added in 26.4.1. Options for the add_model_revision mutation."""
170-
input AddRevisionOptions
171-
@join__type(graph: STRAWBERRY)
172-
{
173-
"""
174-
When true, automatically activate the newly added revision immediately after creation.
175-
"""
176-
autoActivate: Boolean! = false
177-
}
178-
179169
"""Added in 25.19.0. Payload for adding a revision."""
180170
type AddRevisionPayload
181171
@join__type(graph: STRAWBERRY)
@@ -10019,7 +10009,7 @@ type Mutation
1001910009
syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload! @join__field(graph: STRAWBERRY)
1002010010

1002110011
"""Added in 25.16.0. Add model revision."""
10022-
addModelRevision(input: AddRevisionInput!, options: AddRevisionOptions = null): AddRevisionPayload! @join__field(graph: STRAWBERRY)
10012+
addModelRevision(input: AddRevisionInput!): AddRevisionPayload! @join__field(graph: STRAWBERRY)
1002310013

1002410014
"""
1002510015
Added in 26.4.1. Create or update the deployment policy for a given deployment (upsert semantics). If the deployment already has a policy, it is replaced entirely with the new configuration
@@ -14933,9 +14923,6 @@ type Routing implements Item
1493314923
endpoint: String
1493414924
session: UUID
1493514925
status: String
14936-
14937-
"""Added in 26.4.1."""
14938-
health_status: String
1493914926
traffic_ratio: Float
1494014927
created_at: DateTime
1494114928
error_data: JSONString

docs/manager/graphql-reference/v2-schema.graphql

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,6 @@ input AddRevisionInput {
126126
extraMounts: [ExtraVFolderMountInput!] = null
127127
}
128128

129-
"""Added in 26.4.1. Options for the add_model_revision mutation."""
130-
input AddRevisionOptions {
131-
"""
132-
When true, automatically activate the newly added revision immediately after creation.
133-
"""
134-
autoActivate: Boolean! = false
135-
}
136-
137129
"""Added in 25.19.0. Payload for adding a revision."""
138130
type AddRevisionPayload {
139131
"""Added revision"""
@@ -6021,7 +6013,7 @@ type Mutation {
60216013
syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload!
60226014

60236015
"""Added in 25.16.0. Add model revision."""
6024-
addModelRevision(input: AddRevisionInput!, options: AddRevisionOptions = null): AddRevisionPayload!
6016+
addModelRevision(input: AddRevisionInput!): AddRevisionPayload!
60256017

60266018
"""
60276019
Added in 26.4.1. Create or update the deployment policy for a given deployment (upsert semantics). If the deployment already has a policy, it is replaced entirely with the new configuration

src/ai/backend/common/clients/prometheus/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from .client import PrometheusClient
2-
from .preset import LabelMatcher, LabelOperator, MetricPreset
2+
from .preset import MetricPreset
33
from .querier import ContainerMetricQuerier, MetricQuerier
44
from .types import ValueType
55

66
__all__ = [
7-
"LabelMatcher",
8-
"LabelOperator",
97
"PrometheusClient",
108
"MetricPreset",
119
"MetricQuerier",

src/ai/backend/common/clients/prometheus/preset.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,5 @@
11
from collections.abc import Mapping, Set
22
from dataclasses import dataclass, field
3-
from enum import StrEnum
4-
from typing import Self
5-
6-
7-
class LabelOperator(StrEnum):
8-
EQUAL = "="
9-
NOT_EQUAL = "!="
10-
REGEX = "=~"
11-
NOT_REGEX = "!~"
12-
13-
14-
@dataclass(frozen=True)
15-
class LabelMatcher:
16-
"""PromQL label matcher with an explicit operator."""
17-
18-
value: str
19-
operator: LabelOperator = LabelOperator.EQUAL
20-
21-
@classmethod
22-
def exact(cls, value: str) -> Self:
23-
return cls(value=value, operator=LabelOperator.EQUAL)
24-
25-
@classmethod
26-
def regex(cls, value: str) -> Self:
27-
return cls(value=value, operator=LabelOperator.REGEX)
283

294

305
def _escape_label_value(value: str) -> str:
@@ -40,7 +15,7 @@ class MetricPreset:
4015
template: str
4116

4217
# Query labels (injected into {labels} placeholder)
43-
labels: Mapping[str, LabelMatcher] = field(default_factory=dict)
18+
labels: Mapping[str, str] = field(default_factory=dict)
4419

4520
# Group by labels (injected into {group_by} placeholder)
4621
group_by: Set[str] = field(default_factory=frozenset)
@@ -50,10 +25,7 @@ class MetricPreset:
5025

5126
def render(self) -> str:
5227
"""Render the PromQL query with all values injected."""
53-
label_str = ",".join(
54-
f'{key}{value.operator}"{_escape_label_value(value.value)}"'
55-
for key, value in self.labels.items()
56-
)
28+
label_str = ",".join(f'{k}="{_escape_label_value(v)}"' for k, v in self.labels.items())
5729
return self.template.format(
5830
labels=label_str,
5931
window=self.window,

src/ai/backend/common/clients/prometheus/querier.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass
44
from uuid import UUID
55

6-
from ai.backend.common.clients.prometheus.preset import LabelMatcher
76
from ai.backend.common.clients.prometheus.types import ValueType
87

98

@@ -15,7 +14,7 @@ class MetricQuerier(ABC):
1514
"""
1615

1716
@abstractmethod
18-
def labels(self) -> Mapping[str, LabelMatcher]:
17+
def labels(self) -> Mapping[str, str]:
1918
"""Return the labels to be used in the Prometheus query."""
2019
...
2120

@@ -35,22 +34,22 @@ class ContainerMetricQuerier(MetricQuerier):
3534
user_id: UUID | None = None
3635
project_id: UUID | None = None
3736

38-
def labels(self) -> Mapping[str, LabelMatcher]:
37+
def labels(self) -> Mapping[str, str]:
3938
"""Return the labels for the container metric query."""
40-
result: dict[str, LabelMatcher] = {
41-
"container_metric_name": LabelMatcher.exact(self.metric_name),
42-
"value_type": LabelMatcher.exact(self.value_type),
39+
result: dict[str, str] = {
40+
"container_metric_name": self.metric_name,
41+
"value_type": self.value_type,
4342
}
4443
if self.kernel_id is not None:
45-
result["kernel_id"] = LabelMatcher.exact(str(self.kernel_id))
44+
result["kernel_id"] = str(self.kernel_id)
4645
if self.session_id is not None:
47-
result["session_id"] = LabelMatcher.exact(str(self.session_id))
46+
result["session_id"] = str(self.session_id)
4847
if self.agent_id is not None:
49-
result["agent_id"] = LabelMatcher.exact(self.agent_id)
48+
result["agent_id"] = self.agent_id
5049
if self.user_id is not None:
51-
result["user_id"] = LabelMatcher.exact(str(self.user_id))
50+
result["user_id"] = str(self.user_id)
5251
if self.project_id is not None:
53-
result["project_id"] = LabelMatcher.exact(str(self.project_id))
52+
result["project_id"] = str(self.project_id)
5453
return result
5554

5655
def group_by_labels(self) -> frozenset[str]:

src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,11 @@ class RouteHealthRecord:
9595

9696
route_id: str
9797
created_at: int # Unix timestamp when route was created
98-
initial_delay_until: int # Unix timestamp = running_at + initial_delay
98+
initial_delay_until: int # Unix timestamp = created_at + initial_delay
9999
health_path: str # extracted from model_definition
100100
inference_port: int # extracted from kernel
101101
replica_host: str # extracted from kernel
102102

103-
# Timestamp when route entered RUNNING state (set by coordinator)
104-
running_at: int | None = None
105-
106103
# Agent check results
107104
agent_healthy: bool = False
108105
agent_last_check: int = 0 # Unix timestamp
@@ -131,7 +128,7 @@ def is_stale(self, current_time: int, staleness_sec: int = MAX_HEALTH_STALENESS_
131128

132129
def to_valkey_hash(self) -> Mapping[str, str]:
133130
"""Serialize to Valkey hash fields."""
134-
data: dict[str, str] = {
131+
return {
135132
"route_id": self.route_id,
136133
"created_at": str(self.created_at),
137134
"initial_delay_until": str(self.initial_delay_until),
@@ -143,9 +140,6 @@ def to_valkey_hash(self) -> Mapping[str, str]:
143140
"manager_healthy": "1" if self.manager_healthy else "0",
144141
"manager_last_check": str(self.manager_last_check),
145142
}
146-
if self.running_at is not None:
147-
data["running_at"] = str(self.running_at)
148-
return data
149143

150144
@classmethod
151145
def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord:
@@ -157,7 +151,6 @@ def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord:
157151
health_path=data["health_path"],
158152
inference_port=int(data["inference_port"]),
159153
replica_host=data["replica_host"],
160-
running_at=int(raw) if (raw := data.get("running_at")) and raw != "0" else None,
161154
agent_healthy=data.get("agent_healthy", "0") == "1",
162155
agent_last_check=int(data.get("agent_last_check", "0")),
163156
manager_healthy=data.get("manager_healthy", "0") == "1",
@@ -675,52 +668,6 @@ async def update_route_liveness(self, route_id: str, liveness: bool) -> None:
675668
async with self._client.client() as conn:
676669
await conn.exec(batch, raise_on_error=True)
677670

678-
@valkey_schedule_resilience.apply()
679-
async def mark_route_running_at(self, route_id: str) -> None:
680-
"""
681-
Record the RUNNING transition timestamp for a route.
682-
Called when a route transitions to RUNNING status.
683-
Uses Redis time for consistency with health check comparisons.
684-
685-
:param route_id: The route ID that entered RUNNING state
686-
"""
687-
key = self._get_route_health_key(route_id)
688-
current_time = str(await self._get_redis_time())
689-
async with self._client.client() as conn:
690-
await conn.hset(key, {"running_at": current_time})
691-
await conn.expire(key, ROUTE_HEALTH_TTL_SEC)
692-
693-
@valkey_schedule_resilience.apply()
694-
async def get_route_running_at_batch(self, route_ids: Sequence[str]) -> dict[str, int | None]:
695-
"""
696-
Batch read running_at field from route health hashes.
697-
Works even on partial hashes (before full RouteHealthRecord is initialized).
698-
699-
:param route_ids: Route IDs to look up
700-
:return: Mapping of route_id to running_at timestamp (None if not set)
701-
"""
702-
if not route_ids:
703-
return {}
704-
705-
batch = Batch(is_atomic=False)
706-
for route_id in route_ids:
707-
key = self._get_route_health_key(route_id)
708-
batch.hget(key, "running_at")
709-
710-
async with self._client.client() as conn:
711-
results = await conn.exec(batch, raise_on_error=False)
712-
if results is None:
713-
return dict.fromkeys(route_ids)
714-
715-
running_at_map: dict[str, int | None] = {}
716-
for i, route_id in enumerate(route_ids):
717-
raw = results[i] if len(results) > i else None
718-
if raw and raw != b"0":
719-
running_at_map[route_id] = int(raw)
720-
else:
721-
running_at_map[route_id] = None
722-
return running_at_map
723-
724671
@valkey_schedule_resilience.apply()
725672
async def refresh_route_health_ttl(self, route_id: str) -> None:
726673
"""
@@ -881,8 +828,6 @@ async def get_route_health_record(self, route_id: str) -> RouteHealthRecord | No
881828
return None
882829

883830
data = {k.decode(): v.decode() for k, v in result.items()}
884-
if "route_id" not in data:
885-
return None
886831
return RouteHealthRecord.from_valkey_hash(data)
887832

888833
@valkey_schedule_resilience.apply()
@@ -921,10 +866,6 @@ async def get_route_health_records_batch(
921866
continue
922867

923868
data = {k.decode(): v.decode() for k, v in raw.items()}
924-
if "route_id" not in data:
925-
# Partial hash (e.g., only running_at set by mark_route_running_at)
926-
records[route_id] = None
927-
continue
928869
records[route_id] = RouteHealthRecord.from_valkey_hash(data)
929870

930871
return records

0 commit comments

Comments
 (0)