Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions googlehydrology/datasetzoo/multimet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@
]
MULTIMET_MINIMUM_LEAD_TIME = 1

# Caravan Multimet products that are available in the GCS zarr store
KNOWN_GCS_PRODUCTS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the basic strategy here of generalizing the removal of underscores in variable names, but only to known products. Is it necessary to keep a list of known products? It looks like most of these are not used for anything functional, and instead the PRODUCT_ALIASES is doing the heavy lifting. Is there a need to keep this known products list?

"CHIRPS",
"CHIRPS_GEFS",
"CPC",
"ERA5_LAND",
"GRAPHCAST",
"HRES",
"IMERG"
}

# Aliases for multimet product names with inconsistent naming conventions
PRODUCT_ALIASES = {
"ERA5LAND": "ERA5_LAND",
"CHIRPSGEFS": "CHIRPS_GEFS"

class MultimetDataLoader(torch.utils.data.DataLoader):
"""Custom DataLoader that handles lazy data loading.

Expand Down Expand Up @@ -165,6 +181,8 @@ def __init__(
raise ValueError('hindcast_inputs must be supplied.')
self._forecast_features = flatten_feature_list(cfg.forecast_inputs)
self._hindcast_features = flatten_feature_list(cfg.hindcast_inputs)
self._hindcast_inputs = cfg.hindcast_inputs
self._forecast_inputs = cfg.forecast_inputs
self._union_mapping = cfg.union_mapping

# Feature data paths by type. This allows the option to load some data from cloud and some locally.
Expand Down Expand Up @@ -339,6 +357,15 @@ def __init__(
LOGGER.debug(f'Dataset on disk: {self._dataset.nbytes / 1024**2} MB')
LOGGER.debug(f'Sample index size: {sizeof(indices) / 1024**2} MB')

# TODO(future) :: Move above to the scalar compute block
LOGGER.debug('scaler check zero scale')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I ask you to pull and merge the most recent changes from main? This replication of scaler saving was removed in a recent PR.

self.scaler.check_zero_scale()
LOGGER.debug('scaler save')

# Don't save the scaler if we are finetuning
if not cfg.is_finetuning:
self.scaler.save()

# Create sample index lookup table for `__getitem__`.
LOGGER.debug('create sample index')
self._create_sample_index(valid_sample_mask, indices)
Expand Down Expand Up @@ -699,8 +726,8 @@ def _load_hindcast_as_zarr(self) -> list[xr.Dataset]:
)

# Separate products and bands for each product from feature names.
product_bands = _get_products_and_bands_from_feature_strings(
features=features
product_bands = _get_products_and_bands_from_feature_dict(
self._hindcast_inputs
)

# Initialize storage for product/band dataframes that will eventually be concatenated.
Expand All @@ -711,7 +738,16 @@ def _load_hindcast_as_zarr(self) -> list[xr.Dataset]:
product_path = (
self._dynamics_data_path / product / 'timeseries.zarr'
)
LOGGER.info(f"Loading hindcast product '{product}' with bands {bands}")
product_ds = _open_zarr(product_path)

missing = set(bands) - set(product_ds.data_vars)

if missing:
raise ValueError(
f"Requested features {missing} not found in product '{product}'. "
f"Available variables: {list(product_ds.data_vars)}"
)

if 'lead_time' in product_ds:
# The same product may be used both for forecast and hindcast features. For hindcast, we load it with the
Expand Down Expand Up @@ -870,8 +906,8 @@ def _load_forecast_as_zarr(self) -> list[xr.Dataset]:
Dataset containing the loaded features with dimensions (date, lead_time, basin).
"""
# Separate products and bands for each product from feature names.
product_bands = _get_products_and_bands_from_feature_strings(
features=self._forecast_features
product_bands = _get_products_and_bands_from_feature_dict(
self._forecast_inputs
)

# Initialize storage for product/band dataframes that will eventually be concatenated.
Expand All @@ -882,8 +918,17 @@ def _load_forecast_as_zarr(self) -> list[xr.Dataset]:
product_path = (
self._dynamics_data_path / product / 'timeseries.zarr'
)
LOGGER.info(f"Loading forecast product '{product}' with bands {bands}")
product_ds = _open_zarr(product_path)

missing = set(bands) - set(product_ds.data_vars)

if missing:
raise ValueError(
f"Requested features {missing} not found in product '{product}'. "
f"Available variables: {list(product_ds.data_vars)}"
)

# If this is a forecast product, extract only leadtime 0 for hindcasts.
if 'lead_time' not in product_ds:
raise ValueError(
Expand Down Expand Up @@ -1039,32 +1084,44 @@ def _open_zarr(path: Path) -> xr.Dataset:
return xr.open_zarr(store=path, chunks='auto', decode_timedelta=True)


def _get_products_and_bands_from_feature_strings(
features: Iterable[str],
) -> dict[str, list[str]]:
def _get_products_and_bands_from_feature_dict(feature_dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config input properties are not always dicts. They can be either dicts or lists, depending on whether you want to use a model with vs. without feature groups. This is why the cfg.hindcast_inputs and cfg.forecast_inputs are flattened in lines 182 and 183.

However, this caused me to notice that the typehints in ~/googlehydrology/utils/config.py are wrong -- the typehints only show lists, not dicts, and it should be a Union. I'm fixing this presently.


"""
Processes feature strings to create a dictionary of product to band(s).
Processes feature dictionary to create a mapping of product to band(s).

Parameters
----------
features : list[str]
A list features in the format `<product>_<band>. This is the format for feature
names in the Multimet dataset.
feature_dict : dict[str, list[str]]
Dictionary where keys are product names from the config and values
are lists of features belonging to that product.

Returns
-------
dict[str, list[str]]
Keys are product names and values are a list of features for that product. Features
remain in the format <product>_<band>.
Keys are normalized product names and values are lists of features
for that product.
"""

product_bands = {}
for feature in features:
product = feature.split('_')[0].upper()
if product == 'ERA5LAND':
product = 'ERA5_LAND'
product_bands.setdefault(product, []).append(feature)
for raw_key, features in feature_dict.items():
# Normalize: Uppercase, replace hyphens with underscores
norm_key = raw_key.upper().replace("-", "_")

# Try alias mapping first
product = PRODUCT_ALIASES.get(norm_key, norm_key)

# If it's a known product, use canonical name
# (e.g., user says 'era5land', but dataset is 'ERA5_LAND')
for known in KNOWN_GCS_PRODUCTS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this loop doing something after the mapping call in line 1111?

if product.replace("_", "") == known.replace("_", ""):
product = known
break

product_bands[product] = features

return product_bands


class SampleIndexer:
"""Reorg columns to rows.

Expand Down