Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xpublish_wms/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def dataset_router(self, deps: Dependencies) -> APIRouter:

@router.get("", include_in_schema=False)
@router.get("/")
def wms_root(
async def wms_root(
request: Request,
wms_query: Annotated[WMSQuery, Query()],
dataset: xr.Dataset = Depends(deps.dataset),
Expand All @@ -54,7 +54,7 @@ def wms_root(
del query_params[query_key]

# TODO: Make threshold configurable
return wms_handler(
return await wms_handler(
request,
wms_query.root,
extra_query_params,
Expand Down
18 changes: 12 additions & 6 deletions xpublish_wms/wms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
OGC WMS router for datasets with CF convention metadata
"""

import asyncio
from typing import Union

import cachey
Expand All @@ -25,7 +26,7 @@
from .get_metadata import get_metadata


def wms_handler(
async def wms_handler(
request: Request,
query: Union[
WMSGetCapabilitiesQuery,
Expand All @@ -43,9 +44,9 @@ def wms_handler(

match query:
case WMSGetCapabilitiesQuery():
return get_capabilities(dataset, request, query)
return await asyncio.to_thread(get_capabilities, dataset, request, query)
Comment on lines 45 to +47
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand the context in which these other functions get called - is it important to make these properly async too?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that it's only functions which actually trigger loading of additional chunk data which need to be fully async - if all they do is process metadata that should already be in memory then they don't need to be async.

Copy link
Copy Markdown
Collaborator

@mpiannucci mpiannucci May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once getmap is good, we need to make all the other methods also load data async instead of using to thread.

They all load data just a little differently than getmap

case WMSGetMetadataQuery():
return get_metadata(
return await get_metadata(
dataset,
cache,
query,
Expand All @@ -57,11 +58,16 @@ def wms_handler(
cache=cache,
array_render_threshold_bytes=array_get_map_render_threshold_bytes,
)
return getmap_service.get_map(dataset, query, extra_query_params)
return await getmap_service.get_map(dataset, query, extra_query_params)
case WMSGetFeatureInfoQuery():
return get_feature_info(dataset, query, extra_query_params)
return await asyncio.to_thread(
get_feature_info,
dataset,
query,
extra_query_params,
)
case WMSGetLegendInfoQuery():
return get_legend_info(dataset, query)
return await asyncio.to_thread(get_legend_info, dataset, query)
case _:
raise HTTPException(
status_code=404,
Expand Down
22 changes: 11 additions & 11 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class GetMap:
DEFAULT_STYLE: str = "raster/default"
DEFAULT_PALETTE: str = "turbo"

BBOX_BUFFER = 0.18

cache: cachey.Cache
array_render_threshold_bytes: int

Expand Down Expand Up @@ -63,7 +61,7 @@ def __init__(
self.cache = cache
self.array_render_threshold_bytes = array_render_threshold_bytes

def get_map(
async def get_map(
self,
ds: xr.Dataset,
query: WMSGetMapQuery,
Expand Down Expand Up @@ -125,7 +123,7 @@ def get_map(
# use the contoured renderer for regular grid datasets
image_buffer = io.BytesIO()
try:
render_result = self.render(ds, da, image_buffer, False)
render_result = await self.render(ds, da, image_buffer, False)
except HTTPException as e:
raise e
except Exception as e:
Expand All @@ -140,7 +138,7 @@ def get_map(

return StreamingResponse(image_buffer, media_type="image/png")

def get_minmax(
async def get_minmax(
self,
ds: xr.Dataset,
query: WMSGetMapQuery,
Expand All @@ -164,7 +162,7 @@ def get_minmax(
if entire_layer:
return {"min": float(da.min()), "max": float(da.max())}
else:
return self.render(ds, da, None, minmax_only=True)
return await self.render(ds, da, None, minmax_only=True)

def ensure_query_types(
self,
Expand Down Expand Up @@ -320,7 +318,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray:

return da

def render(
async def render(
self,
ds: xr.Dataset,
da: xr.DataArray,
Expand Down Expand Up @@ -378,7 +376,8 @@ def render(
# if filter_by_bbox was successful, preload data for projection
if filter_success:
filter_load_time = time.time()
da = da.load()
# TODO requires https://github.com/pydata/xarray/pull/10327
da = await da.load_async()
logger.debug(
f"WMS GetMap load filtered data: {time.time() - filter_load_time}",
)
Expand Down Expand Up @@ -414,9 +413,10 @@ def render(
)
logger.debug(f"WMS GetMap loading DataArray size: {da_size:.2f} bytes")

start_dask = time.time()
da = da.load()
logger.debug(f"WMS GetMap load full data: {time.time() - start_dask}")
start_load = time.time()
# TODO requires https://github.com/pydata/xarray/pull/10327
da = await da.load_async()
logger.debug(f"WMS GetMap load full data: {time.time() - start_load}")

if da.size == 0:
logger.warning("No data to render")
Expand Down
15 changes: 8 additions & 7 deletions xpublish_wms/wms/get_metadata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime as dt

import cachey
Expand All @@ -13,7 +14,7 @@
from .get_map import GetMap


def get_metadata(
async def get_metadata(
ds: xr.Dataset,
cache: cachey.Cache,
query: WMSGetMetadataQuery,
Expand Down Expand Up @@ -42,14 +43,14 @@ def get_metadata(
)

if metadata_type == "menu":
payload = get_menu(ds)
payload = await asyncio.to_thread(get_menu, ds)
elif metadata_type == "layerdetails":
payload = get_layer_details(ds, layer_name)
payload = await asyncio.to_thread(get_layer_details, ds, layer_name)
elif metadata_type == "timesteps":
da = ds[layer_name]
payload = get_timesteps(da, query)
payload = await asyncio.to_thread(get_timesteps, da, query)
elif metadata_type == "minmax":
payload = get_minmax(
payload = await get_minmax(
ds,
cache,
query,
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_timesteps(da: xr.DataArray, query: WMSGetMetadataQuery) -> dict:
}


def get_minmax(
async def get_minmax(
ds: xr.Dataset,
cache: cachey.Cache,
query: WMSGetMetadataQuery,
Expand Down Expand Up @@ -129,7 +130,7 @@ def get_minmax(
cache=cache,
array_render_threshold_bytes=array_get_map_render_threshold_bytes,
)
return getmap.get_minmax(ds, getmap_query, query_params, entire_layer)
return await getmap.get_minmax(ds, getmap_query, query_params, entire_layer)


def get_layer_details(ds: xr.Dataset, layer_name: str) -> dict:
Expand Down
Loading