Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 199 additions & 3 deletions sarc/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import logging
import os
import re
from collections.abc import Iterable
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime, time, timedelta
from enum import Enum
from functools import partial, wraps
from pathlib import Path
from typing import IO, Any, Callable, ClassVar, Literal, Protocol, overload
from zipfile import ZIP_LZMA, ZipFile

from .config import config

Expand Down Expand Up @@ -64,6 +66,199 @@ class CacheException(Exception):
pass


def ensure_utc(d: datetime) -> datetime:
assert d.tzinfo is not None
return d.astimezone(UTC)


class CacheEntry:
"""Describe a single cache entry at a point in time.

The cache entry contains multiple key-value pairs."""

_zf: ZipFile

def __init__(self, zf: ZipFile):
self._zf = zf

def add_value(self, key: str, value: bytes) -> None:
"""Add a key-value pair to the cache entry"""
self._zf.writestr(key, value)

def get_value(self, key: str) -> bytes:
"""Get the value for an existing key in this cache entry"""
return self._zf.read(key)

def get_keys(self) -> list[str]:
"""Get the list of keys in the order they were added."""
return self._zf.namelist()

def close(self) -> None:
"""Close the cache entry. MUST be called for new entries."""
self._zf.close()


class Cache:
"""A simple file-based cache that stores data organized by date.

This cache stores binary data in a hierarchical directory structure based on
the date when the data was cached. Files are organized as:
cache_root/subdirectory/YYYY/MM/DD/HH:MM:SS

Attributes:
subdirectory: The subdirectory name within the cache root where data
will be stored.
"""

subdirectory: str

def __init__(self, subdirectory: str):
self.subdirectory = subdirectory

@property
def cache_dir(self) -> Path:
"""Get the cache directory path for this cache instance.

Creates the directory if it doesn't exist.

Returns:
Path: The absolute path to the cache directory.
"""
root = config().cache
assert root is not None
res = root / self.subdirectory
res.mkdir(parents=True, exist_ok=True)
return res

def _dir_from_date(self, cdir: Path, d: datetime) -> Path:
"""Get the directory path for a specific date within the cache.

Args:
cdir: The base cache directory path.
d: The datetime for which to get the directory path.

Returns:
Path: The path to the date-specific directory.
"""
return cdir / f"{d.year:04}" / f"{d.month:02}" / f"{d.day:02}"

def create_entry(self, at_time: datetime) -> CacheEntry:
"""Create a writable CacheEntry for the specified time.

You MUST call close() on the resulting entry when you are finished
adding data to it, otherwise the entry could get corrupted."""
cdir = self.cache_dir

at_time = ensure_utc(at_time)

output_file = self._dir_from_date(cdir, at_time) / at_time.time().isoformat(
"seconds"
)
output_file.parent.mkdir(parents=True, exist_ok=True)
zf = ZipFile(
output_file,
mode="x",
compression=ZIP_LZMA,
)
return CacheEntry(zf)

def save(self, key: str, at_time: datetime, value: bytes) -> None:
"""Save binary data to the cache for a specific key and timestamp.

Only use this method if you want to save a single value for a given timestamp.

Args:
key: The cache key identifier.
at_time: The datetime when this data was generated, must be in UTC.
value: The binary data to store in the cache.

Example:
>>> cache = Cache()
>>> cache.save("data", datetime.now(), b"binary data")
"""
ce = self.create_entry(at_time)
ce.add_value(key, value)
ce.close()

def _paths_from(self, from_time: datetime) -> Iterable[Path]:
"""Returns paths starting from a specific datetime.

Returns an iterator over all cached entries that were created at or
after the specified time. Searches through the date hierarchy starting
from the given date and continuing forward through all subsequent dates.

Args:
from_time: The earliest datetime to include in results.

Yields:
Path: The path for each matching cache entry.
"""
cdir = self.cache_dir
from_time = ensure_utc(from_time)

first_dir = self._dir_from_date(cdir, from_time)

if first_dir.exists():
from_time_nodays = from_time.time()
for file in filter(
lambda fname: time.fromisoformat(fname.parts[-1]) >= from_time_nodays,
sorted(first_dir.iterdir()),
):
yield file

from_time = from_time.replace(hour=0, minute=0, second=0, microsecond=0)
from_time += timedelta(days=1)

first_year_done = False
first_month_done = False

for year_dir in sorted(
filter(lambda y: int(y.parts[-1]) >= from_time.year, cdir.iterdir())
):
for month_dir in sorted(
filter(
lambda m: first_year_done or int(m.parts[-1]) >= from_time.month,
year_dir.iterdir(),
)
):
for day_dir in sorted(
filter(
lambda d: first_month_done or int(d.parts[-1]) >= from_time.day,
month_dir.iterdir(),
)
):
for file in sorted(day_dir.iterdir()):
yield file
first_month_done = True
first_year_done = True

def read_from(self, from_time: datetime) -> Iterable[CacheEntry]:
"""Read all cached entries starting from a specific datetime.

Returns an iterator over all cached entries that were created at or after
the specified time. Unlike `read_from()`, this method returns entries for
all keys, not just a specific key. The cache files are searched through
the date hierarchy starting from the given date and continuing forward
through all subsequent dates.

Args:
from_time: The earliest datetime to include in results. Must be UTC.

Yields:
tuple[str, bytes]: A tuple containing:
- The cache key
- The binary data from the cache entry

Example:
>>> cache = Cache("my_data")
>>> start_time = datetime(2024, 1, 15, 10, 0, 0)
>>> for key, data in cache.read_from_all(start_time):
... print(f"Key: {key}, Data size: {len(data)} bytes")
"""
for file in self._paths_from(from_time):
yield CacheEntry(ZipFile(file, mode="r"))


@dataclass
class CachedResult[T]:
"""Represents a result computed at some time."""
Expand Down Expand Up @@ -109,7 +304,7 @@ def _cache_policy_from_env() -> CachePolicy:


@dataclass(kw_only=True)
class Cache[T]:
class OldCache[T]:
formatter: type[FormatterProto[T]] = JSONFormatter[T]
cache_root: Path | None
subdirectory: str
Expand Down Expand Up @@ -178,6 +373,7 @@ def _read_for_key(
possible = [c for c in candidates if c.name <= maximum]
for candidate in possible:
if valid is True:
# The specific value doesn't matter, it's ignored later
candidate_time = datetime.now(UTC)
else:
m = re.match(
Expand Down Expand Up @@ -221,7 +417,7 @@ def _read_for_key(


@dataclass(kw_only=True)
class CachedFunction[**P, R](Cache[R]): # pylint: disable=too-many-instance-attributes
class CachedFunction[**P, R](OldCache[R]): # pylint: disable=too-many-instance-attributes
fn: Callable[P, R]
key: Callable[P, str | None]
validity: timedelta | Callable[P, timedelta] | Literal[True] = True
Expand Down
4 changes: 3 additions & 1 deletion sarc/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from .acquire import Acquire
from .db import Db
from .fetch import Fetch
from .health import Health
from .parse import Parse

colors = SimpleNamespace(
grey="\033[38;21m",
Expand Down Expand Up @@ -51,7 +53,7 @@ def format(self, record: logging.LogRecord) -> str:
@dataclass
class CLI:
command: Union[Acquire, Db, Health] = subparsers(
{"acquire": Acquire, "db": Db, "health": Health}
{"acquire": Acquire, "db": Db, "health": Health, "fetch": Fetch, "parse": Parse}
)

color: bool = False
Expand Down
2 changes: 0 additions & 2 deletions sarc/cli/acquire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from sarc.cli.acquire.prometheus import AcquirePrometheus
from sarc.cli.acquire.slurmconfig import AcquireSlurmConfig
from sarc.cli.acquire.storages import AcquireStorages
from sarc.cli.acquire.users import AcquireUsers


@dataclass
Expand All @@ -18,7 +17,6 @@ class Acquire:
"allocations": AcquireAllocations,
"jobs": AcquireJobs,
"storages": AcquireStorages,
"users": AcquireUsers,
"slurmconfig": AcquireSlurmConfig,
"prometheus": AcquirePrometheus,
}
Expand Down
41 changes: 0 additions & 41 deletions sarc/cli/acquire/users.py

This file was deleted.

18 changes: 18 additions & 0 deletions sarc/cli/fetch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import Union

from simple_parsing import subparsers

from .users import FetchUsers


@dataclass
class Fetch:
command: Union[FetchUsers] = subparsers(
{
"users": FetchUsers,
}
)

def execute(self) -> int:
return self.command.execute()
21 changes: 21 additions & 0 deletions sarc/cli/fetch/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from simple_parsing import field

from sarc.config import config
from sarc.core.scraping.users import fetch_users


@dataclass
class FetchUsers:
force: bool = field(
action="store_true",
help="Force recalculating the data rather than use the cache",
)

def execute(self) -> int:
users_cfg = config("scraping").users
assert users_cfg is not None

fetch_users(list(users_cfg.scrapers.items()))
return 0
18 changes: 18 additions & 0 deletions sarc/cli/parse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import Union

from simple_parsing import subparsers

from .users import ParseUsers


@dataclass
class Parse:
command: Union[ParseUsers] = subparsers(
{
"users": ParseUsers,
}
)

def execute(self) -> int:
return self.command.execute()
19 changes: 19 additions & 0 deletions sarc/cli/parse/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass
from datetime import datetime

from simple_parsing import field

from sarc.core.scraping.users import parse_users
from sarc.users.db import get_user_collection


@dataclass
class ParseUsers:
from_: datetime = field(help="Start parsing the cache from the specified date")

def execute(self) -> int:
coll = get_user_collection()
for um in parse_users(from_=self.from_):
coll.update_user(um)

return 0
Loading