Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
407 changes: 407 additions & 0 deletions packages/backend/embedding_atlas/cache.py

Large diffs are not rendered by default.

9 changes: 2 additions & 7 deletions packages/backend/embedding_atlas/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import pandas as pd
import uvicorn

from .cache import sha256_hexdigest
from .data_source import DataSource
from .options import make_embedding_atlas_props
from .server import make_server
from .utils import (
Hasher,
apply_logging_config,
load_huggingface_data,
load_pandas_data,
Expand Down Expand Up @@ -488,12 +488,7 @@ def main(
"props": props,
}

hasher = Hasher()
hasher.update(__version__)
hasher.update(inputs)
hasher.update(metadata)
identifier = hasher.hexdigest()

identifier = sha256_hexdigest([__version__, inputs, metadata], scope="DataSource")
dataset = DataSource(identifier, df, metadata)

if static is None:
Expand Down
75 changes: 49 additions & 26 deletions packages/backend/embedding_atlas/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import shutil
import zipfile
from io import BytesIO
from typing import Any

import pandas as pd

from .utils import cache_path, to_parquet_bytes
from .cache import file_cache_get, file_cache_set
from .utils import to_parquet_bytes


def _deep_merge(base: dict, overrides: dict) -> dict:
Expand All @@ -32,20 +34,45 @@ def __init__(
self.identifier = identifier
self.dataset = dataset
self.metadata = metadata
self.cache_path = cache_path("cache", self.identifier)
self._cache_index: set[str] = set(self._cache_index_load())

def _cache_index_key(self):
return [self.identifier, "__index__"]

def _cache_index_load(self) -> list[str]:
index = file_cache_get(self._cache_index_key(), scope="DataSource")
if index is None:
return []
return index

def _cache_index_save(self):
file_cache_set(
self._cache_index_key(), sorted(self._cache_index), scope="DataSource"
)

def _cache_index_add(self, name: str):
if name not in self._cache_index:
self._cache_index.add(name)
# Re-read from disk and merge to avoid losing entries from other processes
persisted = set(self._cache_index_load())
merged = self._cache_index | persisted
file_cache_set(self._cache_index_key(), sorted(merged), scope="DataSource")

def cache_set(self, name: str, data):
path = self.cache_path / name
with open(path, "w") as f:
json.dump(data, f)
file_cache_set([self.identifier, name], data, scope="DataSource")
self._cache_index_add(name)

def cache_get(self, name: str):
path = self.cache_path / name
if path.exists():
with open(path, "r") as f:
return json.load(f)
else:
return None
return file_cache_get([self.identifier, name], scope="DataSource")

def cache_items(self) -> dict[str, Any]:
"""Return all cached entries as a dict of {name: value}."""
result = {}
for name in self._cache_index:
value = self.cache_get(name)
if value is not None:
result[name] = value
return result

def _build_metadata(self, metadata_overrides: dict | None = None) -> dict:
metadata = self.metadata | {
Expand All @@ -68,13 +95,11 @@ def make_archive(self, static_path: str, metadata_overrides: dict | None = None)
for fn in files:
p = os.path.relpath(os.path.join(root, fn), static_path)
zip.write(os.path.join(root, fn), p)
for root, _, files in os.walk(self.cache_path):
for fn in files:
p = os.path.join(
"data/cache",
os.path.relpath(os.path.join(root, fn), str(self.cache_path)),
)
zip.write(os.path.join(root, fn), p)
for name, value in self.cache_items().items():
zip.writestr(
f"data/cache/{name}",
json.dumps(value),
)
return io.getvalue()

def export_to_folder(
Expand Down Expand Up @@ -103,11 +128,9 @@ def export_to_folder(
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)

# Copy cache files
for root, _, files in os.walk(self.cache_path):
for fn in files:
src = os.path.join(root, fn)
rel = os.path.relpath(src, str(self.cache_path))
dst = data_dir / "cache" / rel
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
# Write cache files
cache_dir = data_dir / "cache"
for name, value in self.cache_items().items():
cache_file = cache_dir / name
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(value))
Loading
Loading