Skip to content

Commit c5ecd3c

Browse files
authored
Merge pull request #38 from flyingcircusio/PL-135530-save-state-race-condition
Pl 135530 save state race condition
2 parents 2396d3a + 24807ea commit c5ecd3c

7 files changed

Lines changed: 243 additions & 54 deletions

File tree

src/skvaider/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ async def lifespan(
9393

9494
loop.set_exception_handler(global_exception_handler)
9595

96+
skvaider.auth.start_verify_pool()
97+
9698
backends: list[skvaider.proxy.backends.Backend] = []
9799
for backend_config in config.backend:
98100
if backend_config.type == "skvaider":
@@ -135,6 +137,7 @@ async def lifespan(
135137
if aramaki:
136138
aramaki.stop()
137139
pool.close()
140+
skvaider.auth.stop_verify_pool()
138141

139142

140143
def app_factory(config: Config, lifespan: Any) -> FastAPI:

src/skvaider/auth.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import asyncio
12
import base64
23
import binascii
4+
import concurrent.futures
35
import json
6+
import multiprocessing
47
import time
58
from json import JSONDecodeError
6-
from typing import Annotated, cast
9+
from typing import Annotated, Any, cast
710

811
import svcs
912
from argon2 import PasswordHasher
@@ -16,6 +19,41 @@
1619

1720
hasher = PasswordHasher()
1821

22+
VERIFY_POOL_WORKERS = 10
23+
24+
# argon2 verification is CPU-bound and holds the GIL, so running it on the
25+
# event loop (or in a thread) stalls every other request on the worker. Offload
26+
# it to a process pool, which has independent GILs. "spawn" avoids forking a
27+
# process that already runs the asyncio loop and aiosqlite threads.
28+
verify_pool: concurrent.futures.ProcessPoolExecutor | None = None
29+
30+
31+
def start_verify_pool() -> None:
32+
global verify_pool
33+
verify_pool = concurrent.futures.ProcessPoolExecutor(
34+
max_workers=VERIFY_POOL_WORKERS,
35+
mp_context=multiprocessing.get_context("spawn"),
36+
)
37+
38+
39+
def stop_verify_pool() -> None:
40+
global verify_pool
41+
if verify_pool is None:
42+
return
43+
verify_pool.shutdown(wait=False, cancel_futures=True)
44+
verify_pool = None
45+
46+
47+
def verify_password(secret_hash: str, secret: str) -> bool:
48+
# Runs in a pool worker process. Catch everything (not just mismatch) and
49+
# report failure, so an unexpected library error can't leak into the caller
50+
# as anything other than a rejected token.
51+
try:
52+
hasher.verify(secret_hash, secret)
53+
return True
54+
except Exception:
55+
return False
56+
1957

2058
class AuthTokens(aramaki.Collection):
2159
collection = "fc.directory.ai.token"
@@ -36,19 +74,20 @@ class Cache:
3674
TTL = 300
3775

3876
def __init__(self):
39-
self.cache: dict[str, float] = {}
77+
self.cache: dict[str, tuple[float, Any]] = {}
4078

41-
def __contains__(self, key: str):
79+
def __getitem__(self, key: str):
4280
now = time.time()
4381
if key not in self.cache:
44-
return False
45-
if self.cache[key] < now:
82+
raise KeyError(key)
83+
expiry, payload = self.cache[key]
84+
if expiry < now:
4685
del self.cache[key]
47-
return False
48-
return True
86+
raise KeyError(key)
87+
return payload
4988

50-
def add(self, key: str):
51-
self.cache[key] = time.time() + self.TTL
89+
def add(self, key: str, payload: Any):
90+
self.cache[key] = (time.time() + self.TTL, payload)
5291

5392

5493
cache = Cache() # XXX turn into service
@@ -59,16 +98,20 @@ async def verify_token(
5998
credentials: Annotated[HTTPAuthorizationCredentials, Depends(_bearer_auth)],
6099
services: svcs.fastapi.DepContainer,
61100
) -> None:
62-
if credentials.credentials in cache:
63-
return
101+
try:
102+
token_id = cache[credentials.credentials]
103+
except KeyError:
104+
pass
105+
else:
106+
request.state.token_id = token_id
64107

65108
try:
66109
admin_tokens = services.get(AdminTokens)
67110
except svcs.exceptions.ServiceNotFoundError:
68111
pass
69112
else:
70113
if credentials.credentials in admin_tokens.tokens:
71-
request.state.token_id = "admin-token"
114+
request.state.token_id = "<admin-token>"
72115
return
73116

74117
# XXX There's a lot of type issues going on here, because the mechanics of passing through
@@ -92,14 +135,20 @@ async def verify_token(
92135
)
93136
if not db_token:
94137
raise HTTPException(401, detail="Bad authentication")
95-
try:
96-
assert isinstance(db_token["secret_hash"], str)
97-
hasher.verify(db_token["secret_hash"], client_token["secret"])
98-
request.state.token_id = client_token["id"]
99-
cache.add(credentials.credentials)
100-
# We could specify explicit exceptions here but go the safe route and just catch all in case the lib addes one
101-
except Exception:
138+
assert isinstance(db_token["secret_hash"], str)
139+
loop = asyncio.get_running_loop()
140+
# verify_pool is None outside a running app (e.g. tests); run_in_executor
141+
# then falls back to the loop's default thread executor.
142+
valid = await loop.run_in_executor(
143+
verify_pool,
144+
verify_password,
145+
db_token["secret_hash"],
146+
client_token["secret"],
147+
)
148+
if not valid:
102149
raise HTTPException(401, detail="Bad authentication")
150+
request.state.token_id = client_token["id"]
151+
cache.add(credentials.credentials, client_token["id"])
103152

104153

105154
async def verify_admin_token(
@@ -114,4 +163,4 @@ async def verify_admin_token(
114163
raise HTTPException(401, detail="No admin tokens configured")
115164
if credentials.credentials not in admin_tokens.tokens:
116165
raise HTTPException(401, detail="Bad authentication")
117-
request.state.token_id = "admin-token"
166+
request.state.token_id = "<admin-token>"

src/skvaider/debug.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any
77

8+
import aiofiles
89
from fastapi import Request
910
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1011

@@ -142,7 +143,6 @@ def data(self) -> bytes:
142143

143144

144145
class DebugRecorder:
145-
# XXX async file io!
146146
temp_file: Path | None
147147
triggers: list[str]
148148
time_start: float
@@ -211,11 +211,13 @@ def trigger_flag(self):
211211

212212
async def write_request(self, stem: str, data: dict[str, Any]) -> None:
213213
path = self.directory / f"{stem}.request"
214-
path.write_text(_format_request(data))
214+
async with aiofiles.open(path, mode="w") as f:
215+
await f.write(_format_request(data))
215216

216217
async def write_response(self, stem: str, data: dict[str, Any]) -> None:
217218
path = self.directory / f"{stem}.response"
218-
path.write_text(_format_response(data))
219+
async with aiofiles.open(path, mode="w") as f:
220+
await f.write(_format_response(data))
219221

220222
async def record(self) -> None:
221223
if 400 <= self.status_code < 500:

src/skvaider/inference/manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,20 @@ def manifest(self, value: set[str]) -> None:
133133
self.manifest_changed.set()
134134

135135
def update_manifest(self, model_ids: set[str], serial: Serial) -> None:
136-
if serial <= self.manifest_serial:
136+
if serial < self.manifest_serial:
137137
log.info(
138138
"ignoring manifest with stale serial",
139139
serial=serial,
140140
current=self.manifest_serial,
141141
)
142142
return
143+
if serial == self.manifest_serial:
144+
log.info(
145+
"ignoring manifest with current serial",
146+
serial=serial,
147+
current=self.manifest_serial,
148+
)
149+
return
143150
self.manifest_serial = serial
144151
self.manifest = model_ids
145152

src/skvaider/proxy/pool.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
self.model_configs = {m.id: m for m in model_configs}
162162

163163
self.model_management_lock = asyncio.Lock()
164+
self.save_state_lock = asyncio.Lock()
164165

165166
for model_id in self.model_configs:
166167
self.semaphores[model_id] = ModelSemaphore(model_id, self)
@@ -206,33 +207,36 @@ async def save_state(self) -> None:
206207
if not self.state_file:
207208
return
208209

209-
records = {
210-
backend.url: BackendStateRecord(
211-
url=backend.url,
212-
healthy=backend.healthy,
213-
memory=backend.memory,
214-
map_up=backend.map_up.state,
215-
map_up_last_change=backend.map_up.last_change,
216-
map_in=backend.map_in.state,
217-
map_in_last_change=backend.map_in.last_change,
218-
models={
219-
model_id: ModelStateRecord(
220-
id=model_id,
221-
memory_usage=model.memory_usage,
222-
)
223-
for model_id, model in backend.models.items()
224-
},
225-
)
226-
for backend in self.backends
227-
}
228-
state = ClusterState(backends=records)
229-
data = state.model_dump_json(indent=2)
230-
231-
tmp = self.state_file.with_suffix(".tmp")
232-
async with aiofiles.open(tmp, mode="w") as f:
233-
await f.write(data)
234-
await f.flush()
235-
tmp.rename(self.state_file)
210+
# Serialize concurrent saves: multiple backend health monitors call
211+
# this, and a shared temp file would race on the rename otherwise.
212+
async with self.save_state_lock:
213+
records = {
214+
backend.url: BackendStateRecord(
215+
url=backend.url,
216+
healthy=backend.healthy,
217+
memory=backend.memory,
218+
map_up=backend.map_up.state,
219+
map_up_last_change=backend.map_up.last_change,
220+
map_in=backend.map_in.state,
221+
map_in_last_change=backend.map_in.last_change,
222+
models={
223+
model_id: ModelStateRecord(
224+
id=model_id,
225+
memory_usage=model.memory_usage,
226+
)
227+
for model_id, model in backend.models.items()
228+
},
229+
)
230+
for backend in self.backends
231+
}
232+
state = ClusterState(backends=records)
233+
data = state.model_dump_json(indent=2)
234+
235+
tmp = self.state_file.with_suffix(".tmp")
236+
async with aiofiles.open(tmp, mode="w") as f:
237+
await f.write(data)
238+
await f.flush()
239+
tmp.rename(self.state_file)
236240

237241
def placement_map(self) -> ModelMap:
238242
return placement.placement_map(

src/skvaider/proxy/tests/test_pool.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import asyncio
2+
import contextlib
23
import datetime
34
from pathlib import Path
4-
from typing import Callable
5+
from typing import Any, AsyncGenerator, Callable
56
from unittest.mock import patch
67

8+
import aiofiles
9+
710
from skvaider.config import ModelInstanceConfig, parse_size
811
from skvaider.conftest import registered_model_factory
912
from skvaider.manifest import Serial
1013
from skvaider.proxy.backends import DummyBackend
1114
from skvaider.proxy.models import AIModel
1215
from skvaider.utils import TaskManager
1316

14-
from ..pool import Pool
17+
from ..pool import ClusterState, Pool
1518

1619

1720
async def test_maps_only_includes_desired_models(
@@ -520,3 +523,57 @@ async def test_save_state_writes_state_file(
520523
assert pool.state_file.exists()
521524
content = pool.state_file.read_text()
522525
assert "m1" in content
526+
527+
528+
async def test_save_state_serializes_concurrent_calls(
529+
dummy_backend: DummyBackend,
530+
tmp_path: Path,
531+
):
532+
"""Concurrent save_state() calls must not race on the shared temp file.
533+
534+
Multiple backend health monitors call save_state() concurrently. Without
535+
serialization they share one temp file and the second rename() fails with
536+
FileNotFoundError once the first has moved the temp file into place.
537+
"""
538+
dummy_backend.healthy = True
539+
dummy_backend.memory = {"ram": {"free": 924, "total": 1024}}
540+
registered_model_factory("m1", dummy_backend, ram=100)
541+
542+
pool = Pool(
543+
[
544+
ModelInstanceConfig(
545+
id="m1", instances=1, memory={"ram": 100}, task="chat"
546+
)
547+
],
548+
[dummy_backend],
549+
data_dir=tmp_path,
550+
)
551+
552+
active = 0
553+
max_active = 0
554+
real_open: Callable[..., Any] = aiofiles.open
555+
556+
@contextlib.asynccontextmanager
557+
async def tracking_open(
558+
*args: Any, **kwargs: Any
559+
) -> AsyncGenerator[Any, None]:
560+
nonlocal active, max_active
561+
async with real_open(*args, **kwargs) as f:
562+
active += 1
563+
max_active = max(max_active, active)
564+
# Force a scheduling point inside the critical section so any
565+
# unserialized concurrency becomes observable.
566+
await asyncio.sleep(0)
567+
try:
568+
yield f
569+
finally:
570+
active -= 1
571+
572+
with patch("skvaider.proxy.pool.aiofiles.open", tracking_open):
573+
await asyncio.gather(*(pool.save_state() for _ in range(10)))
574+
575+
assert max_active == 1
576+
577+
assert pool.state_file is not None
578+
# The final file is complete, valid JSON (no torn write left behind).
579+
ClusterState.model_validate_json(pool.state_file.read_text())

0 commit comments

Comments
 (0)