[limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size.#246
Conversation
…tion on practically any dataset size. `train`: Using e.g. `limit_n_basins: 100` in the config results in materializing data into memory of only up to 100 basins during training. Data is freed once an epoc is complete. `validation`: During validation, support is partial at this time, to simplify this PR to focus on training, since that is the main memory bottleneck and validation isn't usually done over e.g. 46 yrs of data. When not disabled (unspecified or zero), validation data is loaded only during validation phases, and only up to `validate_n_random_basins` if specified. `test`/`infer`: These modes load all data at this time. --- `limit_n_basins` allows to train on datasets stretching over e.g. 16k basins over 46yrs using ~6GB of memory while training. NOTE: Memory spikes of multimet::__init__ and compute()s are handled in separate. --- A complementary mode is `lazy_load` which lowers memory usage even further at significant expense to runtime though.
limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size.
exclude basins logic needs to run when dataset's dataset object is available (after `load_basins` is called).
| with contextlib.suppress(AttributeError): | ||
| del self._num_samples | ||
| with contextlib.suppress(AttributeError): | ||
| del self._per_basin_target_stds |
There was a problem hiding this comment.
Can you merge all lines to use a single "with" clause?
There was a problem hiding this comment.
Want it to be safe, otherwise a line failing means dels stop at that line.
I considered a loop on value names and __dict__.pop(name, None) but something like that is harder to read and doesn't link symbols and limits search. This way, self.<field name> is findable.
| LOGGER.debug('# forecast dataset init complete (%s)', self._period) | ||
|
|
||
| self._dataset_all = self._dataset | ||
| del self._dataset |
There was a problem hiding this comment.
Want self._dataset not be defined for safety and clarity below. Another option is to rename all the way up dataset to dataset_all or assess if can just keep without self. until the end, but that would pollute this PR with refactor, could add a todo.
| LOGGER.debug( | ||
| 'Dataset size: %f MB (%s)', | ||
| self._dataset.nbytes / 1024**2, | ||
| self._period, |
There was a problem hiding this comment.
Why printing self._period in all debug lines?
There was a problem hiding this comment.
Helped see it faster as a standalone line (contextless). Can remove though.
omrishefi
left a comment
There was a problem hiding this comment.
Let's review the changes together
| del self._per_basin_target_stds | ||
|
|
||
| def load_basins(self, basins: list[str] | None = None) -> None: | ||
| self._data_cache: dict[str, xr.DataArray] = {} |
| with contextlib.suppress(AttributeError): | ||
| del self._per_basin_target_stds | ||
|
|
||
| def load_basins(self, basins: list[str] | None = None) -> None: |
| self.basins = [e for e in self.basins if e not in exclude_basins] | ||
| if cfg.limit_n_basins < 1: | ||
| self.dataset.load_basins() | ||
| self.basins = self._calc_and_apply_excluded_basins(self.basins) |
There was a problem hiding this comment.
It seems like a change of logic. Let's not pass self.basins to the function.
| f'Could not resolve the following module parts for finetuning: {unresolved_modules}' | ||
| ) | ||
|
|
||
| def init_loader( |
There was a problem hiding this comment.
Please split into smaller PRs
train:Using e.g.
limit_n_basins: 100in the config results in materializing data into memory of only up to 100 basins during training. Data is freed once an epoc is complete.NOTE: This results in shorter epocs, but users adjust
save_weights_everyandmax_updates_per_epochaccordingly to their selection.validation:During validation, support is partial at this time, to simplify this PR to focus on training, since that is the main memory bottleneck and validation isn't usually done over e.g. 46 yrs of data. When enabled (i.e. specified or positive), validation data is loaded only during validation phases, and only up to
validate_n_random_basinsif specified.test/infer:These modes load all data at this time.
limit_n_basinsallows to train on datasets stretching over e.g. 16k basins over 46yrs using ~6GB or so of memory while training.NOTE: Memory spikes of multimet's init and compute()s are handled in separate.
A complementary mode is
lazy_loadwhich lowers memory usage even further (e.g. 2-3GB) however at significant expense to runtime.