Skip to content

[limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size.#246

Open
amitmarkel wants to merge 3 commits intomainfrom
amarkel-limit_n_basins__train
Open

[limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size.#246
amitmarkel wants to merge 3 commits intomainfrom
amarkel-limit_n_basins__train

Conversation

@amitmarkel
Copy link
Collaborator

@amitmarkel amitmarkel commented Feb 18, 2026

train:
Using e.g. limit_n_basins: 100 in the config results in materializing data into memory of only up to 100 basins during training. Data is freed once an epoc is complete.

NOTE: This results in shorter epocs, but users adjust save_weights_every and max_updates_per_epoch accordingly to their selection.

validation:
During validation, support is partial at this time, to simplify this PR to focus on training, since that is the main memory bottleneck and validation isn't usually done over e.g. 46 yrs of data. When enabled (i.e. specified or positive), validation data is loaded only during validation phases, and only up to validate_n_random_basins if specified.

test/infer:
These modes load all data at this time.


limit_n_basins allows to train on datasets stretching over e.g. 16k basins over 46yrs using ~6GB or so of memory while training.

NOTE: Memory spikes of multimet's init and compute()s are handled in separate.


A complementary mode is lazy_load which lowers memory usage even further (e.g. 2-3GB) however at significant expense to runtime.

…tion on practically any dataset size.

`train`:
Using e.g. `limit_n_basins: 100` in the config results in materializing data into memory of only up to 100 basins during training. Data is freed once an epoc is complete.

`validation`:
During validation, support is partial at this time, to simplify this PR to focus on training, since that is the main memory bottleneck and validation isn't usually done over e.g. 46 yrs of data.
When not disabled (unspecified or zero), validation data is loaded only during validation phases, and only up to `validate_n_random_basins` if specified.

`test`/`infer`:
These modes load all data at this time.

---

`limit_n_basins` allows to train on datasets stretching over e.g. 16k basins over 46yrs using ~6GB of memory while training.

NOTE: Memory spikes of multimet::__init__ and compute()s are handled in separate.

---

A complementary mode is `lazy_load` which lowers memory usage even further at significant expense to runtime though.
@amitmarkel amitmarkel requested a review from omrishefi February 18, 2026 14:16
@amitmarkel amitmarkel changed the title [limit_n_basins mode] [opt mem] add mode to allow training and validation on practically any dataset size. [limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size. Feb 18, 2026
@amitmarkel amitmarkel changed the title [limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size. [limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size. Feb 18, 2026
omrishefi
omrishefi previously approved these changes Feb 18, 2026
exclude basins logic needs to run when dataset's dataset object is available (after `load_basins` is called).
with contextlib.suppress(AttributeError):
del self._num_samples
with contextlib.suppress(AttributeError):
del self._per_basin_target_stds
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge all lines to use a single "with" clause?

Copy link
Collaborator Author

@amitmarkel amitmarkel Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want it to be safe, otherwise a line failing means dels stop at that line.

I considered a loop on value names and __dict__.pop(name, None) but something like that is harder to read and doesn't link symbols and limits search. This way, self.<field name> is findable.

LOGGER.debug('# forecast dataset init complete (%s)', self._period)

self._dataset_all = self._dataset
del self._dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why del here?

Copy link
Collaborator Author

@amitmarkel amitmarkel Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want self._dataset not be defined for safety and clarity below. Another option is to rename all the way up dataset to dataset_all or assess if can just keep without self. until the end, but that would pollute this PR with refactor, could add a todo.

LOGGER.debug(
'Dataset size: %f MB (%s)',
self._dataset.nbytes / 1024**2,
self._period,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why printing self._period in all debug lines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helped see it faster as a standalone line (contextless). Can remove though.

Copy link
Collaborator

@omrishefi omrishefi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's review the changes together

@amitmarkel amitmarkel requested a review from omrishefi February 24, 2026 09:21
del self._per_basin_target_stds

def load_basins(self, basins: list[str] | None = None) -> None:
self._data_cache: dict[str, xr.DataArray] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to unload_basins

with contextlib.suppress(AttributeError):
del self._per_basin_target_stds

def load_basins(self, basins: list[str] | None = None) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment

self.basins = [e for e in self.basins if e not in exclude_basins]
if cfg.limit_n_basins < 1:
self.dataset.load_basins()
self.basins = self._calc_and_apply_excluded_basins(self.basins)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like a change of logic. Let's not pass self.basins to the function.

f'Could not resolve the following module parts for finetuning: {unresolved_modules}'
)

def init_loader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please split into smaller PRs

Copy link
Collaborator

@omrishefi omrishefi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants