Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ jobs:
embedding-atlas --help

python -c "import embedding_atlas"
python -c "from embedding_atlas.projection import compute_text_projection, compute_image_projection, compute_vector_projection"
python -c "from embedding_atlas.projection import compute_projection"

pip install jupyterlab anywidget
python -c "from embedding_atlas.widget import EmbeddingAtlasWidget"
Expand Down
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
{
"python.analysis.typeCheckingMode": "standard",
"python.defaultInterpreterPath": "packages/backend/.venv/bin/python",
"python-envs.pythonProjects": [
{
"path": "packages/backend",
"envManager": "ms-python.python:venv"
}
],
"files.trimTrailingWhitespace": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
Expand Down
112 changes: 112 additions & 0 deletions packages/backend/embedding_atlas/async_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio
import random
from typing import Awaitable, Callable, TypeVar

from tqdm.auto import tqdm

from .utils import logger


class _BackoffState:
def __init__(self, base_delay: float, max_delay: float):
self.base_delay = base_delay
self.max_delay = max_delay
self.current_delay = 0.0
self.consecutive_errors = 0

def on_error(self):
self.consecutive_errors += 1
self.current_delay = min(
self.max_delay, self.base_delay * (2 ** (self.consecutive_errors - 1))
)

def on_success(self):
self.consecutive_errors = 0
self.current_delay = 0.0


T = TypeVar("T")
R = TypeVar("R")


async def async_map(
inputs: list[T],
func: Callable[[T], Awaitable[R]],
*,
concurrency: int = 4,
max_retry: int = 0,
retry_base_delay: float = 1.0,
retry_max_delay: float = 30.0,
description: str = "Task",
fallback: R | None = None,
) -> list[R]:
"""
Map the inputs by an async function, return a future that resolves to the mapped array (in correct order).

Args:
inputs: List of items to process
func: Async function to apply to each item
concurrency: Maximum number of concurrent calls
max_retry: Maximum number of retry attempts on failure (0 means no retries)
retry_base_delay: Base delay in seconds for exponential backoff (default 1.0)
retry_max_delay: Maximum delay in seconds for backoff cap (default 30.0)
description: Description in the progress bar
fallback: When an error happens, fill the given result. If None, raise the error.
When fallback is None and an error occurs, stops processing new tasks immediately.
"""
count = len(inputs)
results: list[R | None] = [None] * count
semaphore = asyncio.Semaphore(concurrency)
backoff = _BackoffState(retry_base_delay, retry_max_delay)
# Event to signal that processing should stop (used when fallback is None and an error occurs)
stop_event = asyncio.Event()
# Store the first error encountered when fallback is None
first_error: list[Exception | None] = [None]

pbar = tqdm(total=count, desc=description)

async def process(index: int, item: T) -> None:
async with semaphore:
last_error: Exception | None = None
for attempt in range(max_retry + 1):
# Check if we should stop before each retry attempt
if stop_event.is_set():
return

try:
# All tasks respect the shared backoff
if backoff.current_delay > 0:
delay = random.uniform(0, backoff.current_delay)
logger.warning(
f"Backoff: waiting {delay:.1f}s before attempt {attempt + 1} for item {index}"
)
await asyncio.sleep(delay)
results[index] = await func(item)
backoff.on_success()
pbar.update(1)
return
except Exception as e:
logger.error(e)
backoff.on_error()
last_error = e
if attempt < max_retry:
continue
if last_error is not None:
if fallback is None:
# Signal other tasks to stop and store the error
if first_error[0] is None:
first_error[0] = last_error
stop_event.set()
else:
results[index] = fallback
pbar.update(1)

await asyncio.gather(*(process(i, item) for i, item in enumerate(inputs)))

pbar.close()

# If we stopped due to an error, raise it
if first_error[0] is not None:
raise first_error[0]

return results # type: ignore[return-value]
61 changes: 61 additions & 0 deletions packages/backend/embedding_atlas/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,67 @@ def file_cache_value(
return value


async def async_file_cache_value(
key: Any,
value_func: Callable[[], Any],
*,
scope: str | None = None,
cache_root: str | Path | None = None,
serializer: Callable[[Any, IO[bytes]], None] | None = None,
deserializer: Callable[[IO[bytes]], Any] | None = None,
callback: Callable[[Path], None] | None = None,
):
"""Async version of ``file_cache_value``.

Identical behaviour but *value_func* is awaited on a cache miss.
"""
cache_root = _resolve_cache_root(cache_root)
if serializer is None:
serializer = default_serializer
if deserializer is None:
deserializer = default_deserializer

cache_key, encryption_key = _derive_cache_key_and_encryption_key(
key, scope, cache_root
)

cache_path = cache_root / cache_key[:2] / cache_key

if cache_path.exists():
try:
with open(cache_path, "rb") as file:
data = _decrypt_data(file.read(), key=encryption_key)

result = deserializer(BytesIO(data))

if callback is not None:
callback(cache_path)

return result
except Exception:
logger.debug("Cache read failed for key %s", cache_key, exc_info=True)

value = await value_func()

try:
random_suffix = secrets.token_hex(8)
cache_path_tmp = cache_root / cache_key[:2] / f"{cache_key}.tmp-{random_suffix}"
cache_path.parent.mkdir(parents=True, exist_ok=True)

buffer = BytesIO()
serializer(value, buffer)
encrypted_data = _encrypt_data(buffer.getvalue(), key=encryption_key)

with open(cache_path_tmp, "wb") as file:
file.write(encrypted_data)

cache_path_tmp.rename(cache_path)
except Exception:
logger.debug("Cache write failed for key %s", cache_key, exc_info=True)

return value


def _resolve_cache_root(cache_root: str | Path | None = None) -> Path:
if cache_root is None:
return (user_cache_path("embedding_atlas") / "cache").resolve()
Expand Down
Loading
Loading