Skip to content

Commit 5ae90d9

Browse files
committed
First pass at prototyping async getmap
1 parent b8d19cc commit 5ae90d9

3 files changed

Lines changed: 36 additions & 24 deletions

File tree

xpublish_wms/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter:
3333

3434
@router.get("", include_in_schema=False)
3535
@router.get("/")
36-
def wms_root(
36+
async def wms_root(
3737
request: Request,
3838
dataset: xr.Dataset = Depends(deps.dataset),
3939
cache: cachey.Cache = Depends(deps.cache),
4040
):
41-
return wms_handler(request, dataset, cache)
41+
return await wms_handler(request, dataset, cache)
4242

4343
return router

xpublish_wms/wms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
logger = logging.getLogger("uvicorn")
2222

2323

24-
def wms_handler(
24+
async def wms_handler(
2525
request: Request,
2626
dataset: xr.Dataset = Depends(get_dataset),
2727
cache: cachey.Cache = Depends(get_cache),
@@ -35,7 +35,7 @@ def wms_handler(
3535
return get_capabilities(dataset, request, query_params)
3636
elif method == "getmap":
3737
getmap_service = GetMap(cache=cache)
38-
return getmap_service.get_map(dataset, query_params)
38+
return await getmap_service.get_map(dataset, query_params)
3939
elif method == "getfeatureinfo" or method == "gettimeseries":
4040
return get_feature_info(dataset, query_params)
4141
elif method == "getverticalprofile":

xpublish_wms/wms/get_map.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import io
23
import logging
34
import time
@@ -32,7 +33,7 @@ class GetMap:
3233
DEFAULT_STYLE: str = "raster/default"
3334
DEFAULT_PALETTE: str = "turbo"
3435

35-
BBOX_BUFFER = 0.18
36+
BBOX_BUFFER = 30_000 # meters
3637

3738
cache: cachey.Cache
3839

@@ -58,7 +59,7 @@ class GetMap:
5859
def __init__(self, cache: cachey.Cache):
5960
self.cache = cache
6061

61-
def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
62+
async def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
6263
"""
6364
Return the WMS map for the dataset and given parameters
6465
"""
@@ -76,13 +77,13 @@ def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
7677
# The grid type for now. This can be revisited if we choose to interpolate or
7778
# use the contoured renderer for regular grid datasets
7879
image_buffer = io.BytesIO()
79-
render_result = self.render(ds, da, image_buffer, False)
80+
render_result = await self.render(ds, da, image_buffer, False)
8081
if render_result:
8182
image_buffer.seek(0)
8283

8384
return StreamingResponse(image_buffer, media_type="image/png")
8485

85-
def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
86+
async def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
8687
"""
8788
Return the range of values for the dataset and given parameters
8889
"""
@@ -109,7 +110,7 @@ def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
109110
if entire_layer:
110111
return {"min": float(da.min()), "max": float(da.max())}
111112
else:
112-
return self.render(ds, da, None, minmax_only=True)
113+
return await self.render(ds, da, None, minmax_only=True)
113114

114115
def ensure_query_types(self, ds: xr.Dataset, query: dict):
115116
"""
@@ -255,7 +256,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray:
255256

256257
return da
257258

258-
def render(
259+
async def render(
259260
self,
260261
ds: xr.Dataset,
261262
da: xr.DataArray,
@@ -279,29 +280,40 @@ def render(
279280
if minmax_only:
280281
logger.warning("Falling back to default minmax")
281282
return {"min": float(da.min()), "max": float(da.max())}
283+
284+
try:
285+
da = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)
286+
except Exception as e:
287+
logger.error(f"Error filtering data within bbox: {e}")
288+
logger.warning("Falling back to full layer")
282289

283-
logger.debug(f"Projection time: {time.time() - projection_start}")
290+
print(f"Projection time: {time.time() - projection_start}")
284291

285292
start_dask = time.time()
286293

287-
da = da.persist()
288-
if x is not None and y is not None:
289-
x = x.persist()
290-
y = y.persist()
291-
else:
292-
da["x"] = da.x.persist()
293-
da["y"] = da.y.persist()
294+
da = await asyncio.to_thread(da.compute)
295+
296+
# da = da.persist()
297+
# if x is not None and y is not None:
298+
# x = x.persist()
299+
# y = y.persist()
300+
# else:
301+
# da["x"] = da.x.persist()
302+
# da["y"] = da.y.persist()
303+
304+
print(da.x[1].values -da.x[0].values)
305+
print(da.y[1].values - da.y[0].values)
294306

295-
logger.debug(f"dask compute: {time.time() - start_dask}")
307+
print(f"dask compute: {time.time() - start_dask}")
296308

297309
if minmax_only:
298-
da = da.persist()
299-
data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)
310+
# da = da.persist()
311+
# data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)
300312

301313
try:
302314
return {
303-
"min": float(np.nanmin(data_sel)),
304-
"max": float(np.nanmax(data_sel)),
315+
"min": float(np.nanmin(da)),
316+
"max": float(np.nanmax(da)),
305317
}
306318
except Exception as e:
307319
logger.error(
@@ -354,7 +366,7 @@ def render(
354366
how="linear",
355367
span=(vmin, vmax),
356368
)
357-
logger.debug(f"Shade time: {time.time() - start_shade}")
369+
print(f"Shade time: {time.time() - start_shade}")
358370

359371
im = shaded.to_pil()
360372
im.save(buffer, format="PNG")

0 commit comments

Comments
 (0)