Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/usage/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Data settings
- ``nan_handling_method``: ``masked_mean``, ``input_replacing``, or ``attention``. Strategy for handling missing input data.
- ``nan_handling_pos_encoding_size``: Size of positional encoding for NaN handling methods.
- ``lazy_load``: Whether to access data lazily rather than load all in-memory. Each batch is loaded dynamically. Default: `False`.
- ``limit_n_basins``: How many basins at most to load at a given time. A value of `0` (default) turns this setting off. Currently affects `train`. During `validation`, frees up memory when validation is done and loads only needed basins when starting. `evaluate` and `infer` load all data at this time even if this setting is set.

Finetune settings
-----------------
Expand Down
53 changes: 42 additions & 11 deletions googlehydrology/datasetzoo/multimet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import itertools
import logging
Expand Down Expand Up @@ -322,22 +323,56 @@ def __init__(
LOGGER.debug('scale data')
self._dataset = self.scaler.scale(self._dataset)

if not cfg.lazy_load:
LOGGER.debug('[eager load] compute dataset')
LOGGER.debug(f'Dataset size: {self._dataset.nbytes / 1024**2} MB')
LOGGER.debug('# forecast dataset init complete (%s)', self._period)

self._dataset_all = self._dataset
del self._dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why del here?

Copy link
Collaborator Author

@amitmarkel amitmarkel Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want self._dataset not be defined for safety and clarity below. Another option is to rename all the way up dataset to dataset_all or assess if can just keep without self. until the end, but that would pollute this PR with refactor, could add a todo.


def unload_basins(self) -> None:
with contextlib.suppress(AttributeError):
del self._dataset
with contextlib.suppress(AttributeError):
del self._sample_index
with contextlib.suppress(AttributeError):
del self._num_samples
with contextlib.suppress(AttributeError):
del self._per_basin_target_stds
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge all lines to use a single "with" clause?

Copy link
Collaborator Author

@amitmarkel amitmarkel Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want it to be safe, otherwise a line failing means dels stop at that line.

I considered a loop on value names and __dict__.pop(name, None) but something like that is harder to read and doesn't link symbols and limits search. This way, self.<field name> is findable.


def load_basins(self, basins: list[str] | None = None) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment

self._data_cache: dict[str, xr.DataArray] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to unload_basins

self.unload_basins()

if basins is None:
self._dataset = self._dataset_all
else:
LOGGER.debug('[limit %d basins] (%s)', len(basins), self._period)
self._dataset = self._dataset_all.sel(basin=basins)

if not self._cfg.lazy_load:
LOGGER.debug('[eager load] compute dataset (%s)', self._period)
(self._dataset,) = dask.compute(self._dataset)
memory.release()
else:
LOGGER.debug('[lazy load] not computing dataset')
LOGGER.debug('[lazy load] not computing dataset (%s)', self._period)

LOGGER.debug(
'Dataset size: %f MB (%s)',
self._dataset.nbytes / 1024**2,
self._period,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why printing self._period in all debug lines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helped see it faster as a standalone line (contextless). Can remove though.

)

LOGGER.debug('create valid sample mask and indices plan')
valid_sample_mask, indices = self._create_valid_sample_mask()
LOGGER.debug('compute indices')
(indices,) = dask.compute(indices)
memory.release()

LOGGER.debug(f'Dataset size: {sizeof(self._dataset) / 1024**2} MB')
LOGGER.debug(f'Dataset on disk: {self._dataset.nbytes / 1024**2} MB')
LOGGER.debug(f'Sample index size: {sizeof(indices) / 1024**2} MB')
LOGGER.debug(
'Sample index size: %f MB (%s)',
sizeof(indices) / 1024**2,
self._period,
)

# Create sample index lookup table for `__getitem__`.
LOGGER.debug('create sample index')
Expand All @@ -347,7 +382,7 @@ def __init__(
# TODO (future) :: Find a better way to decide whether to calculate these. At least keep a list of
# losses that require them somewhere like `training.__init__.py`. Perhaps simply always calculate.
self._per_basin_target_stds = None
if cfg.loss.lower() in ['nse']:
if self._cfg.loss.lower() in ['nse']:
LOGGER.debug('create per_basin_target_stds')
self._per_basin_target_stds = self._dataset[
self._target_features
Expand All @@ -360,10 +395,6 @@ def __init__(
skipna=True,
)

self._data_cache: dict[str, xr.DataArray] = {}

LOGGER.debug('forecast dataset init complete (%s)', self._period)

def __len__(self) -> int:
return self._num_samples

Expand Down
11 changes: 10 additions & 1 deletion googlehydrology/evaluation/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(
exclude_basins = set(self._calc_exclude_basins()) # Needs self.dataset
self.basins = [e for e in self.basins if e not in exclude_basins]

if cfg.limit_n_basins < 1:
self.dataset.load_basins()

def _set_device(self):
if self.cfg.device is not None:
if self.cfg.device.startswith('cuda'):
Expand Down Expand Up @@ -213,16 +216,19 @@ def evaluate(
basins = self.basins
if (
self.period == 'validation'
and len(basins) > self.cfg.validate_n_random_basins
and 0 < self.cfg.validate_n_random_basins < len(basins)
):
basins = random.sample(basins, k=self.cfg.validate_n_random_basins)
if self.cfg.limit_n_basins > 0:
self.dataset.load_basins(basins)

# force model to train-mode when doing mc-dropout evaluation
if self.cfg.mc_dropout:
model.train()
else:
model.eval()

# TODO(future) :: batch runs also by `limit_n_basins` windows.
batch_sampler = BasinBatchSampler(
sample_index=self.dataset._sample_index,
batch_size=self.cfg.batch_size,
Expand Down Expand Up @@ -491,6 +497,9 @@ def evaluate(
median = np.nanmedian(metric)
LOGGER.info('%s %s median=%f', freq, name, median)

if self.cfg.limit_n_basins > 0:
self.dataset.unload_basins()

def _calc_exclude_basins(self) -> Iterator[str]:
if not self.cfg.tester_skip_obs_all_nan:
return
Expand Down
86 changes: 68 additions & 18 deletions googlehydrology/training/basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datetime import datetime
from pathlib import Path

import more_itertools
import numpy as np
import torch
import torch.optim.lr_scheduler
Expand Down Expand Up @@ -57,6 +58,8 @@ class BaseTrainer(object):

def __init__(self, cfg: Config):
super(BaseTrainer, self).__init__()

self.ds: Dataset | None = None
self.cfg = cfg
self.model = None
self.optimizer = None
Expand Down Expand Up @@ -103,10 +106,13 @@ def __init__(self, cfg: Config):
self._set_random_seeds()
self._set_device()

def _get_dataset(self, compute_scaler: bool) -> Dataset:
def _get_dataset(
self, *, compute_scaler: bool, basins: list[str] | None = None
) -> Dataset:
return get_dataset(
cfg=self.cfg,
period='train',
basins=basins,
is_train=True,
compute_scaler=compute_scaler,
)
Expand Down Expand Up @@ -179,6 +185,54 @@ def _freeze_model_parts(self):
f'Could not resolve the following module parts for finetuning: {unresolved_modules}'
)

def init_loader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please split into smaller PRs

self,
*,
max_random_basins: int = 0,
first_epoc: bool = False,
ds: Dataset | None = None,
) -> Dataset:
compute_scaler = (
(not self.cfg.is_finetuning)
and max_random_basins < 1
and ds is None
)

if ds is None:
ds = self._get_dataset(compute_scaler=compute_scaler)
if compute_scaler: # Break early from loading basins next
return ds

assert ds is not None
assert not compute_scaler

if max_random_basins > 0:
# Take random sequential basins, it's faster to extract
start = random.randrange(len(self.basins))
indices = range(start, start + max_random_basins)
basins = np.take(self.basins, indices, mode='wrap').tolist()
ds.load_basins(basins)
elif first_epoc:
ds.load_basins()

if (not compute_scaler) and len(ds) == 0:
raise ValueError('Dataset contains no samples.')
if not compute_scaler:
self.loader = self._get_data_loader(ds=ds)

if first_epoc:
self.experiment_logger = Logger(cfg=self.cfg)
if self.cfg.log_tensorboard:
self.experiment_logger.start_tb()

if self.cfg.is_continue_training:
# set epoch and iteration step counter to continue from the selected checkpoint
self.experiment_logger.epoch = self._epoch
self.experiment_logger.update = len(self.loader) * self._epoch

LOGGER.debug('init_loader for %d (%s)', max_random_basins, ds._period)
return ds

def initialize_training(self):
"""Initialize the training class.

Expand All @@ -187,10 +241,7 @@ def initialize_training(self):
If called in a ``continue_training`` context, this model will also restore the model and optimizer state.
"""
# Initialize dataset before the model is loaded.
ds = self._get_dataset(compute_scaler=(not self.cfg.is_finetuning))
if len(ds) == 0:
raise ValueError('Dataset contains no samples.')
self.loader = self._get_data_loader(ds=ds)
self.ds = self.init_loader() # Compute full scaler

LOGGER.debug('init model')
self.model = self._get_model().to(self.device)
Expand Down Expand Up @@ -238,15 +289,6 @@ def initialize_training(self):
if self.cfg.is_continue_training:
self._restore_training_state()

self.experiment_logger = Logger(cfg=self.cfg)
if self.cfg.log_tensorboard:
self.experiment_logger.start_tb()

if self.cfg.is_continue_training:
# set epoch and iteration step counter to continue from the selected checkpoint
self.experiment_logger.epoch = self._epoch
self.experiment_logger.update = len(self.loader) * self._epoch

if self.cfg.validate_every is not None:
if self.cfg.validate_n_random_basins < 1:
warn_msg = [
Expand All @@ -262,12 +304,12 @@ def initialize_training(self):
loc=0, scale=self.cfg.target_noise_std
)
target_means = [
ds.scaler.scaler.sel(parameter='mean')[feature].item()
self.ds.scaler.scaler.sel(parameter='mean')[feature].item()
for feature in self.cfg.target_variables
]
self._target_mean = torch.tensor(target_means).to(self.device)
target_stds = [
ds.scaler.scaler.sel(parameter='std')[feature].item()
self.ds.scaler.scaler.sel(parameter='std')[feature].item()
for feature in self.cfg.target_variables
]
self._target_std = torch.tensor(target_stds).to(self.device)
Expand Down Expand Up @@ -319,10 +361,18 @@ def train_and_validate(self):
"""
lr_scheduler, lr_step = self._create_lr_scheduler()

for epoch in range(self._epoch + 1, self._epoch + self.cfg.epochs + 1):
epoc_range = range(self._epoch + 1, self._epoch + self.cfg.epochs + 1)
for is_first, _, epoch in more_itertools.mark_ends(epoc_range):
self.ds = self.init_loader(
max_random_basins=self.cfg.limit_n_basins,
first_epoc=is_first,
ds=self.ds,
)
LOGGER.info(f'learning rate is {lr_scheduler.get_last_lr()}')

self._train_epoch(epoch=epoch)
if self.cfg.limit_n_basins > 0:
self.ds.unload_basins()

avg_losses = self.experiment_logger.summarise()
lr_step(avg_losses['avg_loss'])

Expand Down
8 changes: 8 additions & 0 deletions googlehydrology/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ def cache(self) -> Cache:
data = self._cfg.get('cache', {})
return pydantic.TypeAdapter(Cache).validate_python(data)

@property
def limit_n_basins(self) -> int:
return int(self._cfg.get('limit_n_basins', 0))

@limit_n_basins.setter
def limit_n_basins(self, value: int):
self._cfg['limit_n_basins'] = value

@property
def lazy_load(self) -> bool:
return self._cfg.get('lazy_load', False)
Expand Down
Loading