Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backend/app/models/chunk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import enum
from functools import lru_cache
from app.utils.singleflight_cache import singleflight_cache

import sys
from typing import Optional, Type
from sqlmodel import (
Field,
Expand Down Expand Up @@ -35,7 +34,7 @@ def get_kb_chunk_model(kb: KnowledgeBase) -> Type[SQLModel]:
return get_dynamic_chunk_model(vector_dimension, str(kb.id))


@lru_cache(maxsize=sys.maxsize)
@singleflight_cache
def get_dynamic_chunk_model(
vector_dimension: int,
namespace: Optional[str] = None,
Expand Down
5 changes: 2 additions & 3 deletions backend/app/models/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
from functools import lru_cache
import sys
from app.utils.singleflight_cache import singleflight_cache
from typing import Optional, List, Dict, Type

from sqlmodel import (
Expand Down Expand Up @@ -41,7 +40,7 @@ def get_kb_entity_model(kb: KnowledgeBase) -> Type[SQLModel]:
return get_dynamic_entity_model(vector_dimension, str(kb.id))


@lru_cache(maxsize=sys.maxsize)
@singleflight_cache
def get_dynamic_entity_model(
vector_dimension: int,
namespace: Optional[str] = None,
Expand Down
5 changes: 2 additions & 3 deletions backend/app/models/relationship.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from functools import lru_cache
import sys
from app.utils.singleflight_cache import singleflight_cache
from typing import Optional, Type
from uuid import UUID

Expand Down Expand Up @@ -37,7 +36,7 @@ def get_kb_relationship_model(kb: KnowledgeBase) -> Type[SQLModel]:
return get_dynamic_relationship_model(vector_dimension, str(kb.id), entity_model)


@lru_cache(maxsize=sys.maxsize)
@singleflight_cache
def get_dynamic_relationship_model(
vector_dimension: int,
namespace: Optional[str] = None,
Expand Down
45 changes: 45 additions & 0 deletions backend/app/utils/singleflight_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import threading
from functools import wraps


def singleflight_cache(func):
"""
A thread-safe cache decorator implementing the 'singleflight' pattern.

The singleflight pattern ensures that for any given set of arguments,
concurrent calls to the decorated function will only result in a single
actual execution. Other threads with the same arguments will wait for
the first execution to complete and then receive the same result,
rather than triggering duplicate computations.

This is especially useful for expensive or resource-intensive operations
where you want to avoid redundant work and prevent cache stampede.

Example:
@singleflight_cache
def load_data(key):
# expensive operation
...

# In multiple threads:
load_data('foo') # Only one thread will actually execute the function for 'foo'
"""
_cache = {}
_locks = {}
_locks_lock = threading.Lock()

@wraps(func)
def wrapper(*args, **kwargs):
key = args + tuple(sorted(kwargs.items()))
if key in _cache:
return _cache[key]
with _locks_lock:
lock = _locks.setdefault(key, threading.Lock())
with lock:
if key in _cache:
return _cache[key]
result = func(*args, **kwargs)
_cache[key] = result
return result

return wrapper
34 changes: 34 additions & 0 deletions backend/tests/test_dynamic_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import threading
from app.models.entity import get_dynamic_entity_model
from app.models.relationship import get_dynamic_relationship_model
from app.models.chunk import get_dynamic_chunk_model


def dynamic_model_creation(dim, ns):
entity_model = get_dynamic_entity_model(dim, ns)
relationship_model = get_dynamic_relationship_model(dim, ns, entity_model)
chunk_model = get_dynamic_chunk_model(dim, ns)
return entity_model, relationship_model, chunk_model


def test_concurrent_dynamic_model_creation():
results = [None] * 10
threads = []
for i in range(10):
t = threading.Thread(
target=lambda idx: results.__setitem__(
idx, dynamic_model_creation(128, "test")
),
args=(i,),
)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()

# Ensure each model is created only once across all threads
entity_models, relationship_models, chunk_models = zip(*results)
assert all(m is entity_models[0] for m in entity_models)
assert all(m is relationship_models[0] for m in relationship_models)
assert all(m is chunk_models[0] for m in chunk_models)