diff --git a/core/framework/storage/concurrent.py b/core/framework/storage/concurrent.py index 1672822ac9..b164f26791 100644 --- a/core/framework/storage/concurrent.py +++ b/core/framework/storage/concurrent.py @@ -10,10 +10,11 @@ import asyncio import logging import time -from collections import defaultdict +from collections import defaultdict, OrderedDict from dataclasses import dataclass from pathlib import Path from typing import Any +from weakref import WeakValueDictionary from framework.schemas.run import Run, RunStatus, RunSummary from framework.storage.backend import FileStorage @@ -61,6 +62,7 @@ def __init__( cache_ttl: float = 60.0, batch_interval: float = 0.1, max_batch_size: int = 100, + max_locks: int = 1000, ): """ Initialize concurrent storage. @@ -70,6 +72,7 @@ def __init__( cache_ttl: Cache time-to-live in seconds batch_interval: Interval between batch flushes max_batch_size: Maximum items before forcing flush + max_locks: Maximum number of active file locks to track strongly """ self.base_path = Path(base_path) self._base_storage = FileStorage(base_path) @@ -84,9 +87,10 @@ def __init__( self._max_batch_size = max_batch_size self._batch_task: asyncio.Task | None = None - # Locking - self._file_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - self._global_lock = asyncio.Lock() + # Locking - Use WeakValueDictionary to allow unused locks to be GC'd + self._file_locks: WeakValueDictionary = WeakValueDictionary() + self._lru_tracking: OrderedDict = OrderedDict() + self._max_locks = max_locks # State self._running = False @@ -107,7 +111,10 @@ async def stop(self) -> None: self._running = False - # Cancel batch task first to prevent queue competition + # Flush remaining items + await self._flush_pending() + + # Cancel batch task if self._batch_task: self._batch_task.cancel() try: @@ -116,11 +123,40 @@ async def stop(self) -> None: pass self._batch_task = None - # Now flush remaining items (batch task is stopped) - await self._flush_pending() - logger.info("ConcurrentStorage stopped") + async def _get_lock(self, lock_key: str) -> asyncio.Lock: + """Get or create a lock for a given key with safe eviction.""" + # 1. Check if lock exists + lock = self._file_locks.get(lock_key) + + if lock is not None: + # OPTIMIZATION: Only update LRU for "run" locks. + # This prevents high-frequency "index" locks from flushing out + # the actual run locks we want to keep cached. + if lock_key.startswith("run:"): + if lock_key in self._lru_tracking: + self._lru_tracking.move_to_end(lock_key) + return lock + + # 2. Create new lock + lock = asyncio.Lock() + self._file_locks[lock_key] = lock + + # CRITICAL: Only add "run:" locks to the strong-ref LRU tracking. + # Index locks live exclusively in WeakValueDictionary and are GC'd immediately. + if lock_key.startswith("run:"): + # Manage capacity only for run locks + if len(self._lru_tracking) >= self._max_locks: + # Remove oldest tracked lock (strong ref) + # WeakValueDictionary will auto-remove the lock once no longer in use + self._lru_tracking.popitem(last=False) + + # Add strong reference to keep run lock alive + self._lru_tracking[lock_key] = lock + + return lock + # === RUN OPERATIONS (Async, Thread-Safe) === async def save_run(self, run: Run, immediate: bool = False) -> None: @@ -140,12 +176,40 @@ async def save_run(self, run: Run, immediate: bool = False) -> None: self._cache[f"run:{run.id}"] = CacheEntry(run, time.time()) async def _save_run_locked(self, run: Run) -> None: - """Save a run with file locking.""" + """Save a run with file locking, including index locks.""" lock_key = f"run:{run.id}" - async with self._file_locks[lock_key]: - # Run in executor to avoid blocking event loop - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._base_storage.save_run, run) + + # Helper to get lock + async def get_lock(k): + return await self._get_lock(k) + + # Acquire main lock + run_lock = await get_lock(lock_key) + + async with run_lock: + # 2. Acquire index locks + index_lock_keys = [ + f"index:by_goal:{run.goal_id}", + f"index:by_status:{run.status.value}", + ] + for node_id in run.metrics.nodes_executed: + index_lock_keys.append(f"index:by_node:{node_id}") + + # Collect index locks + index_locks = [await get_lock(k) for k in index_lock_keys] + + # Recursive acquisition + async def with_locks(locks, callback): + if not locks: + return await callback() + async with locks[0]: + return await with_locks(locks[1:], callback) + + async def perform_save(): + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._base_storage.save_run, run) + + await with_locks(index_locks, perform_save) async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None: """ @@ -158,23 +222,25 @@ async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None: Returns: Run object or None if not found """ - cache_key = f"run:{run_id}" - - # Check cache - if use_cache and cache_key in self._cache: - entry = self._cache[cache_key] - if not entry.is_expired(self._cache_ttl): - return entry.value - - # Load from storage + if use_cache: + cache_key = f"run:{run_id}" + cached = self._cache.get(cache_key) + if cached and not cached.is_expired(self._cache_ttl): + # CRITICAL: Touch LRU even on cache hit + lock_key = f"run:{run_id}" + if lock_key in self._lru_tracking: + self._lru_tracking.move_to_end(lock_key) + return cached.value + + # CRITICAL: Acquire lock to trigger LRU update lock_key = f"run:{run_id}" - async with self._file_locks[lock_key]: + async with await self._get_lock(lock_key): loop = asyncio.get_event_loop() run = await loop.run_in_executor(None, self._base_storage.load_run, run_id) # Update cache if run: - self._cache[cache_key] = CacheEntry(run, time.time()) + self._cache[f"run:{run_id}"] = CacheEntry(run, time.time()) return run @@ -189,8 +255,10 @@ async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary return entry.value # Load from storage - loop = asyncio.get_event_loop() - summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id) + lock_key = f"summary:{run_id}" + async with await self._get_lock(lock_key): + loop = asyncio.get_event_loop() + summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id) # Update cache if summary: @@ -201,7 +269,7 @@ async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary async def delete_run(self, run_id: str) -> bool: """Delete a run from storage.""" lock_key = f"run:{run_id}" - async with self._file_locks[lock_key]: + async with await self._get_lock(lock_key): loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, self._base_storage.delete_run, run_id) @@ -215,7 +283,7 @@ async def delete_run(self, run_id: str) -> bool: async def get_runs_by_goal(self, goal_id: str) -> list[str]: """Get all run IDs for a goal.""" - async with self._file_locks[f"index:by_goal:{goal_id}"]: + async with await self._get_lock(f"index:by_goal:{goal_id}"): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._base_storage.get_runs_by_goal, goal_id) @@ -223,13 +291,13 @@ async def get_runs_by_status(self, status: str | RunStatus) -> list[str]: """Get all run IDs with a status.""" if isinstance(status, RunStatus): status = status.value - async with self._file_locks[f"index:by_status:{status}"]: + async with await self._get_lock(f"index:by_status:{status}"): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._base_storage.get_runs_by_status, status) async def get_runs_by_node(self, node_id: str) -> list[str]: """Get all run IDs that executed a node.""" - async with self._file_locks[f"index:by_node:{node_id}"]: + async with await self._get_lock(f"index:by_node:{node_id}"): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._base_storage.get_runs_by_node, node_id)