Skip to content

Commit 498f3e2

Browse files
authored
Add hot reload collections (#42)
* Add hot reload collections * Fix linting * fix: remove globals and implement TTL middleware * chore: add tests
1 parent ea95b83 commit 498f3e2

3 files changed

Lines changed: 182 additions & 58 deletions

File tree

src/stac_fastapi/geoparquet/api.py

Lines changed: 125 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import json
2+
import logging
23
import urllib.parse
3-
from collections.abc import AsyncIterator
4+
from collections.abc import AsyncIterator, Awaitable, Callable
45
from contextlib import asynccontextmanager
6+
from datetime import datetime, timedelta
57
from pathlib import Path
6-
from typing import Any, TypedDict
8+
from typing import Any, TypedDict, cast
79

810
import obstore.store
911
import pystac.utils
10-
from fastapi import FastAPI, HTTPException
12+
from fastapi import FastAPI, Request, Response
1113
from rustac import DuckdbClient
1214
from stac_fastapi.api.app import StacApi
15+
from starlette.background import BackgroundTask
1316

1417
from .client import Client
1518
from .models import (
@@ -20,9 +23,51 @@
2023
)
2124
from .settings import Settings
2225

26+
logger = logging.getLogger(__name__)
27+
2328
GEOPARQUET_MEDIA_TYPE = "application/vnd.apache.parquet"
2429

2530

31+
async def load_collections(settings: Settings) -> list[dict[str, Any]]:
32+
if settings.stac_fastapi_collections_href:
33+
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
34+
href = settings.stac_fastapi_collections_href
35+
else:
36+
href = "file://" + str(
37+
Path(settings.stac_fastapi_collections_href).absolute()
38+
)
39+
prefix, file_name = href.rsplit("/", 1)
40+
store = obstore.store.from_url(prefix)
41+
result = store.get(file_name)
42+
collections = cast(list[dict[str, Any]], json.loads(bytes(result.bytes())))
43+
else:
44+
collections = []
45+
return collections
46+
47+
48+
def _parse_collections(
49+
collections: list[dict[str, Any]], settings: Settings
50+
) -> tuple[dict[str, dict[str, Any]], dict[str, str]]:
51+
"""Parse a raw collections list into (collection_dict, hrefs)."""
52+
collection_dict: dict[str, dict[str, Any]] = {}
53+
hrefs: dict[str, str] = {}
54+
for collection in collections:
55+
collection_id = collection["id"]
56+
if collection_id in collection_dict:
57+
raise ValueError(f"two collections with the same id: {collection_id}")
58+
collection_dict[collection_id] = collection
59+
for asset in collection["assets"].values():
60+
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
61+
if collection_id in hrefs:
62+
raise ValueError(f"two hrefs for one collection: {collection_id}")
63+
hrefs[collection_id] = pystac.utils.make_absolute_href(
64+
asset["href"],
65+
settings.stac_fastapi_collections_href,
66+
start_is_dir=False,
67+
)
68+
return collection_dict, hrefs
69+
70+
2671
class State(TypedDict):
2772
"""Application state."""
2873

@@ -32,44 +77,72 @@ class State(TypedDict):
3277
It's just an in-memory DuckDB connection with the spatial extension enabled.
3378
"""
3479

35-
collections: dict[str, dict[str, Any]]
36-
"""A mapping of collection id to collection."""
3780

38-
hrefs: dict[str, str]
39-
"""A mapping of collection id to geoparquet href."""
81+
def make_collections_middleware(
82+
settings: Settings,
83+
) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]:
84+
"""Return a TTL-based hot-reload middleware for collections.
85+
86+
On every request the current ``app.state.collections`` / ``app.state.hrefs``
87+
are injected into ``request.state`` so that the rest of the stack is
88+
unaffected. After the response is sent, a background task re-reads
89+
``collections.json`` from object storage and updates ``app.state`` when the
90+
configured TTL has elapsed.
91+
"""
92+
93+
async def _refresh(app: FastAPI) -> None:
94+
try:
95+
raw = await load_collections(settings)
96+
collection_dict, hrefs = _parse_collections(raw, settings)
97+
except Exception:
98+
logger.exception("Failed to reload collections; keeping stale state")
99+
return
100+
app.state.collections = collection_dict
101+
app.state.hrefs = hrefs
102+
app.state.collections_last_updated = datetime.now()
103+
logger.debug(
104+
"Collections reloaded; %d collection(s) active", len(collection_dict)
105+
)
106+
107+
async def middleware(
108+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
109+
) -> Response:
110+
request.state.collections = request.app.state.collections
111+
request.state.hrefs = request.app.state.hrefs
112+
113+
background: BackgroundTask | None = None
114+
last_updated: datetime | None = getattr(
115+
request.app.state, "collections_last_updated", None
116+
)
117+
ttl = settings.stac_fastapi_collections_reload_seconds
118+
if last_updated is None or datetime.now() > last_updated + timedelta(
119+
seconds=ttl
120+
):
121+
request.app.state.collections_last_updated = datetime.now()
122+
background = BackgroundTask(_refresh, request.app)
123+
124+
response = await call_next(request)
125+
if background is not None:
126+
response.background = background
127+
return response
128+
129+
return middleware
40130

41131

42132
@asynccontextmanager
43133
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
44-
client = app.extra["duckdb_client"]
134+
client: DuckdbClient = app.extra["duckdb_client"]
45135
settings: Settings = app.extra["settings"]
46-
collections = app.extra["collections"]
47-
collection_dict = dict()
48-
hrefs = dict()
49-
for collection in collections:
50-
if collection["id"] in collection_dict:
51-
raise HTTPException(
52-
500, f"two collections with the same id: {collection['id']}"
53-
)
54-
else:
55-
collection_dict[collection["id"]] = collection
56-
for key, asset in collection["assets"].items():
57-
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
58-
if collection["id"] in hrefs:
59-
raise HTTPException(
60-
500, f"two hrefs for one collection: {collection['id']}"
61-
)
62-
else:
63-
hrefs[collection["id"]] = pystac.utils.make_absolute_href(
64-
asset["href"],
65-
settings.stac_fastapi_collections_href,
66-
start_is_dir=False,
67-
)
68-
yield {
69-
"client": client,
70-
"collections": collection_dict,
71-
"hrefs": hrefs,
72-
}
136+
137+
# Perform an initial blocking load so the first request is never served
138+
# with an empty catalog.
139+
raw = await load_collections(settings)
140+
collection_dict, hrefs = _parse_collections(raw, settings)
141+
app.state.collections = collection_dict
142+
app.state.hrefs = hrefs
143+
app.state.collections_last_updated = datetime.now()
144+
145+
yield {"client": client}
73146

74147

75148
def create(
@@ -85,20 +158,10 @@ def create(
85158
stac_fastapi_description="A stac-fastapi server backend by stac-geoparquet",
86159
)
87160

88-
if settings.stac_fastapi_collections_href:
89-
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
90-
href = settings.stac_fastapi_collections_href
91-
else:
92-
href = "file://" + str(
93-
Path(settings.stac_fastapi_collections_href).absolute()
94-
)
95-
prefix, file_name = href.rsplit("/", 1)
96-
store = obstore.store.from_url(prefix)
97-
result = store.get(file_name)
98-
collections = json.loads(bytes(result.bytes()))
99-
else:
100-
collections = []
101-
161+
# Collections from stac_fastapi_collections_href are loaded in the lifespan
162+
# and kept fresh by the hot-reload middleware.
163+
# Collections from stac_fastapi_geoparquet_href are static (loaded once here).
164+
collections = []
102165
if settings.stac_fastapi_geoparquet_href:
103166
collections.extend(
104167
collections_from_geoparquet_href(
@@ -107,18 +170,22 @@ def create(
107170
)
108171
)
109172

173+
app = FastAPI(
174+
lifespan=lifespan,
175+
openapi_url=settings.openapi_url,
176+
docs_url=settings.docs_url,
177+
redoc_url=settings.docs_url,
178+
settings=settings,
179+
collections=collections,
180+
duckdb_client=duckdb_client,
181+
)
182+
# Add hot-reload middleware
183+
app.middleware("http")(make_collections_middleware(settings))
184+
110185
api = StacApi(
111186
settings=settings,
112187
client=Client(),
113-
app=FastAPI(
114-
lifespan=lifespan,
115-
openapi_url=settings.openapi_url,
116-
docs_url=settings.docs_url,
117-
redoc_url=settings.docs_url,
118-
settings=settings,
119-
collections=collections,
120-
duckdb_client=duckdb_client,
121-
),
188+
app=app,
122189
search_get_request_model=GetSearchRequestModel,
123190
search_post_request_model=PostSearchRequestModel,
124191
items_get_request_model=ItemsGetRequestModel,

src/stac_fastapi/geoparquet/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
class Settings(ApiSettings):
55
"""stac-fastapi-geoparquet settings"""
66

7+
stac_fastapi_collections_reload_seconds: int = 60
8+
"""Interval in seconds to reload collections.json (default: 60)."""
9+
710
stac_fastapi_collections_href: str | None = None
811
"""The href of a file containing JSON list of collections.
912

tests/test_api.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from datetime import datetime, timedelta
12
from pathlib import Path
3+
from unittest.mock import AsyncMock, patch
24

35
import pytest
46
from fastapi.testclient import TestClient
@@ -32,3 +34,55 @@ def test_create_from_parquet_file() -> None:
3234
with TestClient(api.app) as client:
3335
response = client.get("/search")
3436
assert response.status_code == 200
37+
38+
39+
def test_collections_reload_on_ttl_expiry() -> None:
40+
settings = Settings(
41+
stac_fastapi_collections_href=str(COLLECTIONS_PATH),
42+
stac_fastapi_collections_reload_seconds=60,
43+
)
44+
api = stac_fastapi.geoparquet.api.create(settings=settings)
45+
46+
with TestClient(api.app) as client:
47+
# Sanity check: initial collections are populated.
48+
response = client.get("/collections")
49+
assert len(response.json()["collections"]) > 0
50+
51+
# Expire the TTL so the next request schedules a background refresh.
52+
api.app.state.collections_last_updated = datetime.now() - timedelta(seconds=120)
53+
54+
# Patch load_collections to return an empty list for the reload.
55+
with patch(
56+
"stac_fastapi.geoparquet.api.load_collections",
57+
new_callable=AsyncMock,
58+
return_value=[],
59+
):
60+
# TestClient awaits BackgroundTask before returning, so the refresh
61+
# has already updated app.state by the time this call returns.
62+
client.get("/collections")
63+
64+
# The next request should see the reloaded (empty) state.
65+
response = client.get("/collections")
66+
assert response.json()["collections"] == []
67+
68+
69+
def test_collections_no_reload_within_ttl() -> None:
70+
settings = Settings(
71+
stac_fastapi_collections_href=str(COLLECTIONS_PATH),
72+
stac_fastapi_collections_reload_seconds=3600,
73+
)
74+
api = stac_fastapi.geoparquet.api.create(settings=settings)
75+
76+
with TestClient(api.app) as client:
77+
initial_count = len(client.get("/collections").json()["collections"])
78+
79+
with patch(
80+
"stac_fastapi.geoparquet.api.load_collections",
81+
new_callable=AsyncMock,
82+
return_value=[],
83+
) as mock_load:
84+
client.get("/collections")
85+
mock_load.assert_not_called()
86+
87+
# Collections should be unchanged.
88+
assert len(client.get("/collections").json()["collections"]) == initial_count

0 commit comments

Comments
 (0)