Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions googlehydrology/datasetzoo/multimet.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ def _extract_dataset(
self,
data: xr.Dataset,
features: list[str],
indexers: dict[Hashable, int | range],
indexers: dict[Hashable, int | range | slice],
) -> dict[str, np.ndarray | np.float32]:
def extract(feature_name):
def extract(feature_name: str):
key = f"{id(data)}{feature_name}"
feature = self._data_cache.get(key)
if feature is None:
Expand Down Expand Up @@ -572,11 +572,9 @@ def _load_hindcast_features(self) -> list[xr.Dataset]:
if "lead_time" in product_ds:
# The same product may be used both for forecast and hindcast features. For hindcast, we load it with the
# full lead_time similar to forecast, and filter the minimal lead_time values during sampling.
lead_times = [
pd.Timedelta(days=i)
for i in range(MULTIMET_MINIMUM_LEAD_TIME, self.lead_time + 1)
]
product_ds = product_ds.sel(basin=self._basins, lead_time=lead_times)
product_ds = product_ds.sel(
basin=self._basins, lead_time=self._lead_time_slice()
)
else:
product_ds = product_ds.sel(basin=self._basins)

Expand All @@ -602,12 +600,6 @@ def _load_forecast_features(self) -> list[xr.Dataset]:
# Initialize storage for product/band dataframes that will eventually be concatenated.
product_dss = []

# Lead time array.
lead_times = [
pd.Timedelta(days=i)
for i in range(MULTIMET_MINIMUM_LEAD_TIME, self.lead_time + 1)
]

# Load data for the selected products, bands, and basins.
for product, bands in product_bands.items():
product_path = self._dynamics_data_path / product / "timeseries.zarr"
Expand All @@ -619,7 +611,9 @@ def _load_forecast_features(self) -> list[xr.Dataset]:
f"Lead times do not exist for forecast product ({product})."
)

product_ds = product_ds.sel(basin=self._basins, lead_time=lead_times)[bands]
product_ds = product_ds.sel(
basin=self._basins, lead_time=self._lead_time_slice()
)[bands]
product_dss.append(product_ds)

return product_dss
Expand Down Expand Up @@ -650,6 +644,13 @@ def _load_static_features(self) -> xr.Dataset:
features=self._static_features,
)

def _lead_time_slice(self) -> slice:
# https://pandas.pydata.org/pandas-docs/stable/user_guide/advanced.html#endpoints-are-inclusive
return slice(
pd.Timedelta(days=MULTIMET_MINIMUM_LEAD_TIME),
pd.Timedelta(days=self.lead_time),
)

@staticmethod
def collate_fn(
samples: list[
Expand Down Expand Up @@ -681,7 +682,7 @@ def collate_fn(


def _extract_dataarray(
data: xr.DataArray, indexers: dict[Hashable, int | range]
data: xr.DataArray, indexers: dict[Hashable, int | range | slice]
) -> np.ndarray | np.float32:
"""Returns the values in array according to dims given by indexers.

Expand Down