-
Notifications
You must be signed in to change notification settings - Fork 10
[limit_n_basins mode] [opt mem] [opt runtime] add mode to allow training and validation on practically any dataset size.
#246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import contextlib | ||
| import functools | ||
| import itertools | ||
| import logging | ||
|
|
@@ -322,22 +323,56 @@ def __init__( | |
| LOGGER.debug('scale data') | ||
| self._dataset = self.scaler.scale(self._dataset) | ||
|
|
||
| if not cfg.lazy_load: | ||
| LOGGER.debug('[eager load] compute dataset') | ||
| LOGGER.debug(f'Dataset size: {self._dataset.nbytes / 1024**2} MB') | ||
| LOGGER.debug('# forecast dataset init complete (%s)', self._period) | ||
|
|
||
| self._dataset_all = self._dataset | ||
| del self._dataset | ||
|
|
||
| def unload_basins(self) -> None: | ||
| with contextlib.suppress(AttributeError): | ||
| del self._dataset | ||
| with contextlib.suppress(AttributeError): | ||
| del self._sample_index | ||
| with contextlib.suppress(AttributeError): | ||
| del self._num_samples | ||
| with contextlib.suppress(AttributeError): | ||
| del self._per_basin_target_stds | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you merge all lines to use a single "with" clause?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Want it to be safe, otherwise a line failing means I considered a loop on value names and |
||
|
|
||
| def load_basins(self, basins: list[str] | None = None) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment |
||
| self._data_cache: dict[str, xr.DataArray] = {} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move to unload_basins |
||
| self.unload_basins() | ||
|
|
||
| if basins is None: | ||
| self._dataset = self._dataset_all | ||
| else: | ||
| LOGGER.debug('[limit %d basins] (%s)', len(basins), self._period) | ||
| self._dataset = self._dataset_all.sel(basin=basins) | ||
|
|
||
| if not self._cfg.lazy_load: | ||
| LOGGER.debug('[eager load] compute dataset (%s)', self._period) | ||
| (self._dataset,) = dask.compute(self._dataset) | ||
| memory.release() | ||
| else: | ||
| LOGGER.debug('[lazy load] not computing dataset') | ||
| LOGGER.debug('[lazy load] not computing dataset (%s)', self._period) | ||
|
|
||
| LOGGER.debug( | ||
| 'Dataset size: %f MB (%s)', | ||
| self._dataset.nbytes / 1024**2, | ||
| self._period, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why printing self._period in all debug lines?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Helped see it faster as a standalone line (contextless). Can remove though. |
||
| ) | ||
|
|
||
| LOGGER.debug('create valid sample mask and indices plan') | ||
| valid_sample_mask, indices = self._create_valid_sample_mask() | ||
| LOGGER.debug('compute indices') | ||
| (indices,) = dask.compute(indices) | ||
| memory.release() | ||
|
|
||
| LOGGER.debug(f'Dataset size: {sizeof(self._dataset) / 1024**2} MB') | ||
| LOGGER.debug(f'Dataset on disk: {self._dataset.nbytes / 1024**2} MB') | ||
| LOGGER.debug(f'Sample index size: {sizeof(indices) / 1024**2} MB') | ||
| LOGGER.debug( | ||
| 'Sample index size: %f MB (%s)', | ||
| sizeof(indices) / 1024**2, | ||
| self._period, | ||
| ) | ||
|
|
||
| # Create sample index lookup table for `__getitem__`. | ||
| LOGGER.debug('create sample index') | ||
|
|
@@ -347,7 +382,7 @@ def __init__( | |
| # TODO (future) :: Find a better way to decide whether to calculate these. At least keep a list of | ||
| # losses that require them somewhere like `training.__init__.py`. Perhaps simply always calculate. | ||
| self._per_basin_target_stds = None | ||
| if cfg.loss.lower() in ['nse']: | ||
| if self._cfg.loss.lower() in ['nse']: | ||
| LOGGER.debug('create per_basin_target_stds') | ||
| self._per_basin_target_stds = self._dataset[ | ||
| self._target_features | ||
|
|
@@ -360,10 +395,6 @@ def __init__( | |
| skipna=True, | ||
| ) | ||
|
|
||
| self._data_cache: dict[str, xr.DataArray] = {} | ||
|
|
||
| LOGGER.debug('forecast dataset init complete (%s)', self._period) | ||
|
|
||
| def __len__(self) -> int: | ||
| return self._num_samples | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from datetime import datetime | ||
| from pathlib import Path | ||
|
|
||
| import more_itertools | ||
| import numpy as np | ||
| import torch | ||
| import torch.optim.lr_scheduler | ||
|
|
@@ -57,6 +58,8 @@ class BaseTrainer(object): | |
|
|
||
| def __init__(self, cfg: Config): | ||
| super(BaseTrainer, self).__init__() | ||
|
|
||
| self.ds: Dataset | None = None | ||
| self.cfg = cfg | ||
| self.model = None | ||
| self.optimizer = None | ||
|
|
@@ -103,10 +106,13 @@ def __init__(self, cfg: Config): | |
| self._set_random_seeds() | ||
| self._set_device() | ||
|
|
||
| def _get_dataset(self, compute_scaler: bool) -> Dataset: | ||
| def _get_dataset( | ||
| self, *, compute_scaler: bool, basins: list[str] | None = None | ||
| ) -> Dataset: | ||
| return get_dataset( | ||
| cfg=self.cfg, | ||
| period='train', | ||
| basins=basins, | ||
| is_train=True, | ||
| compute_scaler=compute_scaler, | ||
| ) | ||
|
|
@@ -179,6 +185,54 @@ def _freeze_model_parts(self): | |
| f'Could not resolve the following module parts for finetuning: {unresolved_modules}' | ||
| ) | ||
|
|
||
| def init_loader( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please split into smaller PRs |
||
| self, | ||
| *, | ||
| max_random_basins: int = 0, | ||
| first_epoc: bool = False, | ||
| ds: Dataset | None = None, | ||
| ) -> Dataset: | ||
| compute_scaler = ( | ||
| (not self.cfg.is_finetuning) | ||
| and max_random_basins < 1 | ||
| and ds is None | ||
| ) | ||
|
|
||
| if ds is None: | ||
| ds = self._get_dataset(compute_scaler=compute_scaler) | ||
| if compute_scaler: # Break early from loading basins next | ||
| return ds | ||
|
|
||
| assert ds is not None | ||
| assert not compute_scaler | ||
|
|
||
| if max_random_basins > 0: | ||
| # Take random sequential basins, it's faster to extract | ||
| start = random.randrange(len(self.basins)) | ||
| indices = range(start, start + max_random_basins) | ||
| basins = np.take(self.basins, indices, mode='wrap').tolist() | ||
| ds.load_basins(basins) | ||
| elif first_epoc: | ||
| ds.load_basins() | ||
|
|
||
| if (not compute_scaler) and len(ds) == 0: | ||
| raise ValueError('Dataset contains no samples.') | ||
| if not compute_scaler: | ||
| self.loader = self._get_data_loader(ds=ds) | ||
|
|
||
| if first_epoc: | ||
| self.experiment_logger = Logger(cfg=self.cfg) | ||
| if self.cfg.log_tensorboard: | ||
| self.experiment_logger.start_tb() | ||
|
|
||
| if self.cfg.is_continue_training: | ||
| # set epoch and iteration step counter to continue from the selected checkpoint | ||
| self.experiment_logger.epoch = self._epoch | ||
| self.experiment_logger.update = len(self.loader) * self._epoch | ||
|
|
||
| LOGGER.debug('init_loader for %d (%s)', max_random_basins, ds._period) | ||
| return ds | ||
|
|
||
| def initialize_training(self): | ||
| """Initialize the training class. | ||
|
|
||
|
|
@@ -187,10 +241,7 @@ def initialize_training(self): | |
| If called in a ``continue_training`` context, this model will also restore the model and optimizer state. | ||
| """ | ||
| # Initialize dataset before the model is loaded. | ||
| ds = self._get_dataset(compute_scaler=(not self.cfg.is_finetuning)) | ||
| if len(ds) == 0: | ||
| raise ValueError('Dataset contains no samples.') | ||
| self.loader = self._get_data_loader(ds=ds) | ||
| self.ds = self.init_loader() # Compute full scaler | ||
|
|
||
| LOGGER.debug('init model') | ||
| self.model = self._get_model().to(self.device) | ||
|
|
@@ -238,15 +289,6 @@ def initialize_training(self): | |
| if self.cfg.is_continue_training: | ||
| self._restore_training_state() | ||
|
|
||
| self.experiment_logger = Logger(cfg=self.cfg) | ||
| if self.cfg.log_tensorboard: | ||
| self.experiment_logger.start_tb() | ||
|
|
||
| if self.cfg.is_continue_training: | ||
| # set epoch and iteration step counter to continue from the selected checkpoint | ||
| self.experiment_logger.epoch = self._epoch | ||
| self.experiment_logger.update = len(self.loader) * self._epoch | ||
|
|
||
| if self.cfg.validate_every is not None: | ||
| if self.cfg.validate_n_random_basins < 1: | ||
| warn_msg = [ | ||
|
|
@@ -262,12 +304,12 @@ def initialize_training(self): | |
| loc=0, scale=self.cfg.target_noise_std | ||
| ) | ||
| target_means = [ | ||
| ds.scaler.scaler.sel(parameter='mean')[feature].item() | ||
| self.ds.scaler.scaler.sel(parameter='mean')[feature].item() | ||
| for feature in self.cfg.target_variables | ||
| ] | ||
| self._target_mean = torch.tensor(target_means).to(self.device) | ||
| target_stds = [ | ||
| ds.scaler.scaler.sel(parameter='std')[feature].item() | ||
| self.ds.scaler.scaler.sel(parameter='std')[feature].item() | ||
| for feature in self.cfg.target_variables | ||
| ] | ||
| self._target_std = torch.tensor(target_stds).to(self.device) | ||
|
|
@@ -319,10 +361,18 @@ def train_and_validate(self): | |
| """ | ||
| lr_scheduler, lr_step = self._create_lr_scheduler() | ||
|
|
||
| for epoch in range(self._epoch + 1, self._epoch + self.cfg.epochs + 1): | ||
| epoc_range = range(self._epoch + 1, self._epoch + self.cfg.epochs + 1) | ||
| for is_first, _, epoch in more_itertools.mark_ends(epoc_range): | ||
| self.ds = self.init_loader( | ||
| max_random_basins=self.cfg.limit_n_basins, | ||
| first_epoc=is_first, | ||
| ds=self.ds, | ||
| ) | ||
| LOGGER.info(f'learning rate is {lr_scheduler.get_last_lr()}') | ||
|
|
||
| self._train_epoch(epoch=epoch) | ||
| if self.cfg.limit_n_basins > 0: | ||
| self.ds.unload_basins() | ||
|
|
||
| avg_losses = self.experiment_logger.summarise() | ||
| lr_step(avg_losses['avg_loss']) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why del here?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Want
self._datasetnot be defined for safety and clarity below. Another option is to rename all the way up dataset to dataset_all or assess if can just keep withoutself.until the end, but that would pollute this PR with refactor, could add a todo.