Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 125 additions & 58 deletions src/stac_fastapi/geoparquet/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import json
import logging
import urllib.parse
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, TypedDict
from typing import Any, TypedDict, cast

import obstore.store
import pystac.utils
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Request, Response
from rustac import DuckdbClient
from stac_fastapi.api.app import StacApi
from starlette.background import BackgroundTask

from .client import Client
from .models import (
Expand All @@ -20,9 +23,51 @@
)
from .settings import Settings

logger = logging.getLogger(__name__)

GEOPARQUET_MEDIA_TYPE = "application/vnd.apache.parquet"


async def load_collections(settings: Settings) -> list[dict[str, Any]]:
if settings.stac_fastapi_collections_href:
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
href = settings.stac_fastapi_collections_href
else:
href = "file://" + str(
Path(settings.stac_fastapi_collections_href).absolute()
)
prefix, file_name = href.rsplit("/", 1)
store = obstore.store.from_url(prefix)
result = store.get(file_name)
collections = cast(list[dict[str, Any]], json.loads(bytes(result.bytes())))
else:
collections = []
return collections


def _parse_collections(
collections: list[dict[str, Any]], settings: Settings
) -> tuple[dict[str, dict[str, Any]], dict[str, str]]:
"""Parse a raw collections list into (collection_dict, hrefs)."""
collection_dict: dict[str, dict[str, Any]] = {}
hrefs: dict[str, str] = {}
for collection in collections:
collection_id = collection["id"]
if collection_id in collection_dict:
raise ValueError(f"two collections with the same id: {collection_id}")
collection_dict[collection_id] = collection
for asset in collection["assets"].values():
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
if collection_id in hrefs:
raise ValueError(f"two hrefs for one collection: {collection_id}")
hrefs[collection_id] = pystac.utils.make_absolute_href(
asset["href"],
settings.stac_fastapi_collections_href,
start_is_dir=False,
)
return collection_dict, hrefs


class State(TypedDict):
"""Application state."""

Expand All @@ -32,44 +77,72 @@ class State(TypedDict):
It's just an in-memory DuckDB connection with the spatial extension enabled.
"""

collections: dict[str, dict[str, Any]]
"""A mapping of collection id to collection."""

hrefs: dict[str, str]
"""A mapping of collection id to geoparquet href."""
def make_collections_middleware(
settings: Settings,
) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]:
"""Return a TTL-based hot-reload middleware for collections.

On every request the current ``app.state.collections`` / ``app.state.hrefs``
are injected into ``request.state`` so that the rest of the stack is
unaffected. After the response is sent, a background task re-reads
``collections.json`` from object storage and updates ``app.state`` when the
configured TTL has elapsed.
"""

async def _refresh(app: FastAPI) -> None:
try:
raw = await load_collections(settings)
collection_dict, hrefs = _parse_collections(raw, settings)
except Exception:
logger.exception("Failed to reload collections; keeping stale state")
return
app.state.collections = collection_dict
app.state.hrefs = hrefs
app.state.collections_last_updated = datetime.now()
logger.debug(
"Collections reloaded; %d collection(s) active", len(collection_dict)
)

async def middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
request.state.collections = request.app.state.collections
request.state.hrefs = request.app.state.hrefs

background: BackgroundTask | None = None
last_updated: datetime | None = getattr(
request.app.state, "collections_last_updated", None
)
ttl = settings.stac_fastapi_collections_reload_seconds
if last_updated is None or datetime.now() > last_updated + timedelta(
seconds=ttl
):
request.app.state.collections_last_updated = datetime.now()
background = BackgroundTask(_refresh, request.app)

response = await call_next(request)
if background is not None:
response.background = background
return response

return middleware


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
client = app.extra["duckdb_client"]
client: DuckdbClient = app.extra["duckdb_client"]
settings: Settings = app.extra["settings"]
collections = app.extra["collections"]
collection_dict = dict()
hrefs = dict()
for collection in collections:
if collection["id"] in collection_dict:
raise HTTPException(
500, f"two collections with the same id: {collection['id']}"
)
else:
collection_dict[collection["id"]] = collection
for key, asset in collection["assets"].items():
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
if collection["id"] in hrefs:
raise HTTPException(
500, f"two hrefs for one collection: {collection['id']}"
)
else:
hrefs[collection["id"]] = pystac.utils.make_absolute_href(
asset["href"],
settings.stac_fastapi_collections_href,
start_is_dir=False,
)
yield {
"client": client,
"collections": collection_dict,
"hrefs": hrefs,
}

# Perform an initial blocking load so the first request is never served
# with an empty catalog.
raw = await load_collections(settings)
collection_dict, hrefs = _parse_collections(raw, settings)
app.state.collections = collection_dict
app.state.hrefs = hrefs
app.state.collections_last_updated = datetime.now()

yield {"client": client}


def create(
Expand All @@ -85,20 +158,10 @@ def create(
stac_fastapi_description="A stac-fastapi server backend by stac-geoparquet",
)

if settings.stac_fastapi_collections_href:
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
href = settings.stac_fastapi_collections_href
else:
href = "file://" + str(
Path(settings.stac_fastapi_collections_href).absolute()
)
prefix, file_name = href.rsplit("/", 1)
store = obstore.store.from_url(prefix)
result = store.get(file_name)
collections = json.loads(bytes(result.bytes()))
else:
collections = []

# Collections from stac_fastapi_collections_href are loaded in the lifespan
# and kept fresh by the hot-reload middleware.
# Collections from stac_fastapi_geoparquet_href are static (loaded once here).
collections = []
if settings.stac_fastapi_geoparquet_href:
collections.extend(
collections_from_geoparquet_href(
Expand All @@ -107,18 +170,22 @@ def create(
)
)

app = FastAPI(
lifespan=lifespan,
openapi_url=settings.openapi_url,
docs_url=settings.docs_url,
redoc_url=settings.docs_url,
settings=settings,
collections=collections,
duckdb_client=duckdb_client,
)
# Add hot-reload middleware
app.middleware("http")(make_collections_middleware(settings))

api = StacApi(
settings=settings,
client=Client(),
app=FastAPI(
lifespan=lifespan,
openapi_url=settings.openapi_url,
docs_url=settings.docs_url,
redoc_url=settings.docs_url,
settings=settings,
collections=collections,
duckdb_client=duckdb_client,
),
app=app,
search_get_request_model=GetSearchRequestModel,
search_post_request_model=PostSearchRequestModel,
items_get_request_model=ItemsGetRequestModel,
Expand Down
3 changes: 3 additions & 0 deletions src/stac_fastapi/geoparquet/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
class Settings(ApiSettings):
"""stac-fastapi-geoparquet settings"""

stac_fastapi_collections_reload_seconds: int = 60
"""Interval in seconds to reload collections.json (default: 60)."""

stac_fastapi_collections_href: str | None = None
"""The href of a file containing JSON list of collections.

Expand Down
54 changes: 54 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import AsyncMock, patch

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -32,3 +34,55 @@ def test_create_from_parquet_file() -> None:
with TestClient(api.app) as client:
response = client.get("/search")
assert response.status_code == 200


def test_collections_reload_on_ttl_expiry() -> None:
settings = Settings(
stac_fastapi_collections_href=str(COLLECTIONS_PATH),
stac_fastapi_collections_reload_seconds=60,
)
api = stac_fastapi.geoparquet.api.create(settings=settings)

with TestClient(api.app) as client:
# Sanity check: initial collections are populated.
response = client.get("/collections")
assert len(response.json()["collections"]) > 0

# Expire the TTL so the next request schedules a background refresh.
api.app.state.collections_last_updated = datetime.now() - timedelta(seconds=120)

# Patch load_collections to return an empty list for the reload.
with patch(
"stac_fastapi.geoparquet.api.load_collections",
new_callable=AsyncMock,
return_value=[],
):
# TestClient awaits BackgroundTask before returning, so the refresh
# has already updated app.state by the time this call returns.
client.get("/collections")

# The next request should see the reloaded (empty) state.
response = client.get("/collections")
assert response.json()["collections"] == []


def test_collections_no_reload_within_ttl() -> None:
settings = Settings(
stac_fastapi_collections_href=str(COLLECTIONS_PATH),
stac_fastapi_collections_reload_seconds=3600,
)
api = stac_fastapi.geoparquet.api.create(settings=settings)

with TestClient(api.app) as client:
initial_count = len(client.get("/collections").json()["collections"])

with patch(
"stac_fastapi.geoparquet.api.load_collections",
new_callable=AsyncMock,
return_value=[],
) as mock_load:
client.get("/collections")
mock_load.assert_not_called()

# Collections should be unchanged.
assert len(client.get("/collections").json()["collections"]) == initial_count