2020
2121# Only import if running type checking
2222if TYPE_CHECKING :
23+ from odc .loader .types import Band_DType
2324 from pystac import ItemCollection
2425 from xarray import Dataset
2526
@@ -64,7 +65,7 @@ def stac_load(
6465 mask_geopolygon : bool = False ,
6566 stac_query : dict | None = None ,
6667 stac_url : str = "https://planetarycomputer.microsoft.com/api/stac/v1" ,
67- dtype : Any | None = None ,
68+ dtype : Band_DType | None = None ,
6869 ** load_params ,
6970) -> tuple [Dataset , ItemCollection ]:
7071 """Query and load satellite data from a STAC API.
@@ -129,10 +130,13 @@ def stac_load(
129130 modifier = (planetary_computer .sign_inplace if "planetarycomputer" in stac_url else None ),
130131 )
131132
132- # Set dtype; use provided unless `mask_geopolygon` is provided,
133- # in which case use `float32`.
133+ # Use provided dtype if exists, or "float32" if `mask_geopolygon` is provided
134134 dtype = "float32" if mask_geopolygon else dtype
135135
136+ # Add dtype to load parameters if required
137+ if dtype is not None :
138+ load_params ["dtype" ] = dtype
139+
136140 # Set up time for query
137141 time = "/" .join (time ) if time is not None else None
138142
@@ -159,7 +163,6 @@ def stac_load(
159163 geopolygon = geopolygon ,
160164 lon = lon ,
161165 lat = lat ,
162- dtype = dtype ,
163166 ** load_params ,
164167 )
165168
0 commit comments