-
Notifications
You must be signed in to change notification settings - Fork 10
Fix multimet dataloader to support dynamic variable product names with underscores #238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88bbff0
604aa4b
f4f38e9
2c9426d
445ec6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,22 @@ | |
| ] | ||
| MULTIMET_MINIMUM_LEAD_TIME = 1 | ||
|
|
||
| # Caravan Multimet products that are available in the GCS zarr store | ||
| KNOWN_GCS_PRODUCTS = { | ||
| "CHIRPS", | ||
| "CHIRPS_GEFS", | ||
| "CPC", | ||
| "ERA5_LAND", | ||
| "GRAPHCAST", | ||
| "HRES", | ||
| "IMERG" | ||
| } | ||
|
|
||
| # Aliases for multimet product names with inconsistent naming conventions | ||
| PRODUCT_ALIASES = { | ||
| "ERA5LAND": "ERA5_LAND", | ||
| "CHIRPSGEFS": "CHIRPS_GEFS" | ||
|
|
||
| class MultimetDataLoader(torch.utils.data.DataLoader): | ||
| """Custom DataLoader that handles lazy data loading. | ||
|
|
||
|
|
@@ -165,6 +181,8 @@ def __init__( | |
| raise ValueError('hindcast_inputs must be supplied.') | ||
| self._forecast_features = flatten_feature_list(cfg.forecast_inputs) | ||
| self._hindcast_features = flatten_feature_list(cfg.hindcast_inputs) | ||
| self._hindcast_inputs = cfg.hindcast_inputs | ||
| self._forecast_inputs = cfg.forecast_inputs | ||
| self._union_mapping = cfg.union_mapping | ||
|
|
||
| # Feature data paths by type. This allows the option to load some data from cloud and some locally. | ||
|
|
@@ -339,6 +357,15 @@ def __init__( | |
| LOGGER.debug(f'Dataset on disk: {self._dataset.nbytes / 1024**2} MB') | ||
| LOGGER.debug(f'Sample index size: {sizeof(indices) / 1024**2} MB') | ||
|
|
||
| # TODO(future) :: Move above to the scalar compute block | ||
| LOGGER.debug('scaler check zero scale') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could I ask you to pull and merge the most recent changes from |
||
| self.scaler.check_zero_scale() | ||
| LOGGER.debug('scaler save') | ||
|
|
||
| # Don't save the scaler if we are finetuning | ||
| if not cfg.is_finetuning: | ||
| self.scaler.save() | ||
|
|
||
| # Create sample index lookup table for `__getitem__`. | ||
| LOGGER.debug('create sample index') | ||
| self._create_sample_index(valid_sample_mask, indices) | ||
|
|
@@ -699,8 +726,8 @@ def _load_hindcast_as_zarr(self) -> list[xr.Dataset]: | |
| ) | ||
|
|
||
| # Separate products and bands for each product from feature names. | ||
| product_bands = _get_products_and_bands_from_feature_strings( | ||
| features=features | ||
| product_bands = _get_products_and_bands_from_feature_dict( | ||
| self._hindcast_inputs | ||
| ) | ||
|
|
||
| # Initialize storage for product/band dataframes that will eventually be concatenated. | ||
|
|
@@ -711,7 +738,16 @@ def _load_hindcast_as_zarr(self) -> list[xr.Dataset]: | |
| product_path = ( | ||
| self._dynamics_data_path / product / 'timeseries.zarr' | ||
| ) | ||
| LOGGER.info(f"Loading hindcast product '{product}' with bands {bands}") | ||
| product_ds = _open_zarr(product_path) | ||
|
|
||
| missing = set(bands) - set(product_ds.data_vars) | ||
|
|
||
| if missing: | ||
| raise ValueError( | ||
| f"Requested features {missing} not found in product '{product}'. " | ||
| f"Available variables: {list(product_ds.data_vars)}" | ||
| ) | ||
|
|
||
| if 'lead_time' in product_ds: | ||
| # The same product may be used both for forecast and hindcast features. For hindcast, we load it with the | ||
|
|
@@ -870,8 +906,8 @@ def _load_forecast_as_zarr(self) -> list[xr.Dataset]: | |
| Dataset containing the loaded features with dimensions (date, lead_time, basin). | ||
| """ | ||
| # Separate products and bands for each product from feature names. | ||
| product_bands = _get_products_and_bands_from_feature_strings( | ||
| features=self._forecast_features | ||
| product_bands = _get_products_and_bands_from_feature_dict( | ||
| self._forecast_inputs | ||
| ) | ||
|
|
||
| # Initialize storage for product/band dataframes that will eventually be concatenated. | ||
|
|
@@ -882,8 +918,17 @@ def _load_forecast_as_zarr(self) -> list[xr.Dataset]: | |
| product_path = ( | ||
| self._dynamics_data_path / product / 'timeseries.zarr' | ||
| ) | ||
| LOGGER.info(f"Loading forecast product '{product}' with bands {bands}") | ||
| product_ds = _open_zarr(product_path) | ||
|
|
||
| missing = set(bands) - set(product_ds.data_vars) | ||
|
|
||
| if missing: | ||
| raise ValueError( | ||
| f"Requested features {missing} not found in product '{product}'. " | ||
| f"Available variables: {list(product_ds.data_vars)}" | ||
| ) | ||
|
|
||
| # If this is a forecast product, extract only leadtime 0 for hindcasts. | ||
| if 'lead_time' not in product_ds: | ||
| raise ValueError( | ||
|
|
@@ -1039,32 +1084,44 @@ def _open_zarr(path: Path) -> xr.Dataset: | |
| return xr.open_zarr(store=path, chunks='auto', decode_timedelta=True) | ||
|
|
||
|
|
||
| def _get_products_and_bands_from_feature_strings( | ||
| features: Iterable[str], | ||
| ) -> dict[str, list[str]]: | ||
| def _get_products_and_bands_from_feature_dict(feature_dict): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The config input properties are not always dicts. They can be either dicts or lists, depending on whether you want to use a model with vs. without feature groups. This is why the cfg.hindcast_inputs and cfg.forecast_inputs are flattened in lines 182 and 183. However, this caused me to notice that the typehints in |
||
|
|
||
| """ | ||
| Processes feature strings to create a dictionary of product to band(s). | ||
| Processes feature dictionary to create a mapping of product to band(s). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| features : list[str] | ||
| A list features in the format `<product>_<band>. This is the format for feature | ||
| names in the Multimet dataset. | ||
| feature_dict : dict[str, list[str]] | ||
| Dictionary where keys are product names from the config and values | ||
| are lists of features belonging to that product. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, list[str]] | ||
| Keys are product names and values are a list of features for that product. Features | ||
| remain in the format <product>_<band>. | ||
| Keys are normalized product names and values are lists of features | ||
| for that product. | ||
| """ | ||
|
|
||
| product_bands = {} | ||
| for feature in features: | ||
| product = feature.split('_')[0].upper() | ||
| if product == 'ERA5LAND': | ||
| product = 'ERA5_LAND' | ||
| product_bands.setdefault(product, []).append(feature) | ||
| for raw_key, features in feature_dict.items(): | ||
| # Normalize: Uppercase, replace hyphens with underscores | ||
| norm_key = raw_key.upper().replace("-", "_") | ||
|
|
||
| # Try alias mapping first | ||
| product = PRODUCT_ALIASES.get(norm_key, norm_key) | ||
|
|
||
| # If it's a known product, use canonical name | ||
| # (e.g., user says 'era5land', but dataset is 'ERA5_LAND') | ||
| for known in KNOWN_GCS_PRODUCTS: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this loop doing something after the mapping call in line 1111? |
||
| if product.replace("_", "") == known.replace("_", ""): | ||
| product = known | ||
| break | ||
|
|
||
| product_bands[product] = features | ||
|
|
||
| return product_bands | ||
|
|
||
|
|
||
| class SampleIndexer: | ||
| """Reorg columns to rows. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the basic strategy here of generalizing the removal of underscores in variable names, but only to known products. Is it necessary to keep a list of known products? It looks like most of these are not used for anything functional, and instead the PRODUCT_ALIASES is doing the heavy lifting. Is there a need to keep this known products list?