diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index bf62fc285d..856a6edce7 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -51,7 +51,7 @@ jobs: if [ "${{ matrix.os }}" == "macos-13" ]; then source $HOME/.local/bin/env fi - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt - name: "Cache python environment" uses: actions/cache@v4 @@ -67,7 +67,7 @@ jobs: - name: "Install Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt - name: "Install libomp (for LightGBM)" run: | @@ -99,7 +99,7 @@ jobs: - name: "Compile Dependency Versions" run: | curl -LsSf https://astral.sh/uv/install.sh | sh - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt requirements/optional.txt > requirements-latest.txt # only restore cache but do not upload - name: "Restore cached python environment" @@ -120,7 +120,7 @@ jobs: - name: "Install Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt - name: "Install libomp (for LightGBM)" run: | @@ -152,7 +152,7 @@ jobs: - name: "Compile Dependency Versions" run: | curl -LsSf https://astral.sh/uv/install.sh | sh - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt requirements/optional.txt > requirements-latest.txt # only restore cache but do not upload - name: "Restore cached python environment" @@ -169,7 +169,7 @@ jobs: - name: "Install Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt - name: "Install libomp (for LightGBM)" run: | diff --git a/.github/workflows/merge.yml b/.github/workflows/merge.yml index b74cd0a26f..b57ab115f4 100644 --- a/.github/workflows/merge.yml +++ b/.github/workflows/merge.yml @@ -54,7 +54,7 @@ jobs: elif [ "${{ matrix.flavour }}" == "torch" ]; then pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/dev.txt elif [ "${{ matrix.flavour }}" == "all" ]; then - pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/dev.txt + pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/optional.txt -r requirements/dev.txt fi - name: "Install libomp (for LightGBM)" @@ -94,7 +94,7 @@ jobs: - name: "Compile Dependency Versions" run: | curl -LsSf https://astral.sh/uv/install.sh | sh - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt # only restore cache but do not upload - name: "Restore cached python environment" @@ -111,7 +111,7 @@ jobs: - name: "Install Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt - name: "Install libomp (for LightGBM)" run: | @@ -141,7 +141,7 @@ jobs: - name: "Compile Dependency Versions" run: | curl -LsSf https://astral.sh/uv/install.sh | sh - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt # only restore cache but do not upload - name: "Restore cached python environment" @@ -162,7 +162,7 @@ jobs: - name: "Install Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt - name: "Install libomp (for LightGBM)" run: | diff --git a/.github/workflows/update-cache.yml b/.github/workflows/update-cache.yml index 3243bc4dc2..85d03c6f8c 100644 --- a/.github/workflows/update-cache.yml +++ b/.github/workflows/update-cache.yml @@ -31,7 +31,7 @@ jobs: if [ "${{ matrix.os }}" == "macos-13" ]; then source $HOME/.local/bin/env fi - uv pip compile requirements/dev-all.txt > requirements-latest.txt + uv pip compile requirements/dev-all.txt -r requirements/optional.txt -o requirements-latest.txt - name: "Cache python environment" uses: actions/cache@v4 @@ -47,4 +47,4 @@ jobs: - name: "Install Latest Dependencies" run: | # install latest dependencies (potentially updating cached dependencies) - pip install -U -r requirements/dev-all.txt + pip install -U -r requirements/dev-all.txt -r requirements/optional.txt diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index 2f53e12af5..a75661be35 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -726,7 +726,7 @@ def __init__( If ``False``, only attends to previous time steps in the decoder. If ``True`` attends to previous, current, and future time steps. Defaults to ``False``. feed_forward - A feedforward network is a fully-connected layer with an activation. TFT Can be one of the glu variant's + A feedforward network is a fully-connected layer with an activation. Can be one of the glu variant's FeedForward Network (FFN)[2]. The glu variant's FeedForward Network are a series of FFNs designed to work better with Transformer based models. Defaults to ``"GatedResidualNetwork"``. ["GLU", "Bilinear", "ReGLU", "GEGLU", "SwiGLU", "ReLU", "GELU"] or the TFT original FeedForward Network ["GatedResidualNetwork"]. diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index b949efb914..c1e4b80dee 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -646,6 +646,12 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): logger=logger, ) + @abstractmethod + def _update_covariates_use(self): + """Based on the Forecasting class and the training_sample attribute, update the + uses_[past/future/static]_covariates attributes.""" + pass + def to_onnx(self, path: Optional[str] = None, **kwargs): """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's :func:`torch.onnx.export` method (`official documentation `_. """ + # TODO: LSTM model should be exported with a batch size of 1 + # TODO: predictions with TFT and TCN models is incorrect, might be caused by helper function to process inputs if not self._fit_called: raise_log( ValueError("`fit()` needs to be called before `to_onnx()`."), logger @@ -2133,6 +2141,9 @@ def load_weights_from_checkpoint( self.model.load_state_dict(ckpt["state_dict"], strict=strict) # update the fit_called attribute to allow for direct inference self._fit_called = True + # based on the shape of train_sample, figure out which covariates are used by the model + # (usually set in the Darts model prior to fitting it) + self._update_covariates_use() def load_weights( self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs @@ -2683,6 +2694,13 @@ def extreme_lags( None, ) + def _update_covariates_use(self): + """The model is expected to rely on the `PastCovariatesTrainingDataset`""" + _, past_covs, static_covs, _ = self.train_sample + self._uses_past_covariates = past_covs is not None + self._uses_future_covariates = False + self._uses_static_covariates = static_covs is not None + class FutureCovariatesTorchModel(TorchForecastingModel, ABC): supports_past_covariates = False @@ -2776,6 +2794,13 @@ def extreme_lags( None, ) + def _update_covariates_use(self): + """The model is expected to rely on the `FutureCovariatesTrainingDataset`""" + _, future_covs, static_covs, _ = self.train_sample + self._uses_past_covariates = False + self._uses_future_covariates = future_covs is not None + self._uses_static_covariates = static_covs is not None + class DualCovariatesTorchModel(TorchForecastingModel, ABC): supports_past_covariates = False @@ -2870,6 +2895,15 @@ def extreme_lags( None, ) + def _update_covariates_use(self): + """The model is expected to rely on the `DualCovariatesTrainingDataset`""" + _, historic_future_covs, future_covs, static_covs, _ = self.train_sample + self._uses_past_covariates = False + self._uses_future_covariates = ( + historic_future_covs is not None or future_covs is not None + ) + self._uses_static_covariates = static_covs is not None + class MixedCovariatesTorchModel(TorchForecastingModel, ABC): def _build_train_dataset( @@ -2964,6 +2998,17 @@ def extreme_lags( None, ) + def _update_covariates_use(self): + """The model is expected to rely on the `MixedCovariatesTrainingDataset`""" + _, past_covs, historic_future_covs, future_covs, static_covs, _ = ( + self.train_sample + ) + self._uses_past_covariates = past_covs is not None + self._uses_future_covariates = ( + historic_future_covs is not None or future_covs is not None + ) + self._uses_static_covariates = static_covs is not None + class SplitCovariatesTorchModel(TorchForecastingModel, ABC): def _build_train_dataset( @@ -3058,3 +3103,12 @@ def extreme_lags( self.output_chunk_shift, None, ) + + def _update_covariates_use(self): + """The model is expected to rely on the `SplitCovariatesTrainingDataset`""" + _, past_covs, historic_future_covs, future_covs, static_covs, _ = ( + self.train_sample + ) + self._uses_past_covariates = past_covs is not None + self._uses_future_covariates = future_covs is not None + self._uses_static_covariates = static_covs is not None diff --git a/darts/tests/conftest.py b/darts/tests/conftest.py index 90bf29e20b..48ba0c5e25 100644 --- a/darts/tests/conftest.py +++ b/darts/tests/conftest.py @@ -17,6 +17,31 @@ logger.warning("Torch not installed - Some tests will be skipped.") TORCH_AVAILABLE = False +try: + import onnx # noqa: F401 + import onnxruntime # noqa: F401 + + ONNX_AVAILABLE = True +except ImportError: + logger.warning("Onnx not installed - Some tests will be skipped.") + ONNX_AVAILABLE = False + +try: + import optuna # noqa: F401 + + OPTUNA_AVAILABLE = True +except ImportError: + logger.warning("Optuna not installed - Some tests will be skipped.") + OPTUNA_AVAILABLE = False + +try: + import ray # noqa: F401 + + RAY_AVAILABLE = True +except ImportError: + logger.warning("Ray not installed - Some tests will be skipped.") + RAY_AVAILABLE = False + tfm_kwargs = { "pl_trainer_kwargs": { "accelerator": "cpu", @@ -25,6 +50,15 @@ } } +tfm_kwargs_dev = { + "pl_trainer_kwargs": { + "accelerator": "cpu", + "enable_progress_bar": False, + "enable_model_summary": False, + "fast_dev_run": True, + } +} + @pytest.fixture(scope="session", autouse=True) def set_up_tests(request): diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 4a52eb4128..37997f9d0a 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -13,7 +13,7 @@ from darts.dataprocessing.encoders import SequentialEncoder from darts.dataprocessing.transformers import BoxCox, Scaler from darts.metrics import mape -from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs +from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs, tfm_kwargs_dev if not TORCH_AVAILABLE: pytest.skip( @@ -429,6 +429,63 @@ def create_model(**kwargs): model_new = create_model(**kwargs_) model_new.load_weights(model_path_manual) + @pytest.mark.parametrize( + "params", + itertools.product( + [DLinearModel, NBEATSModel, RNNModel], # model_cls + [True, False], # past_covs + [True, False], # future_covs + [True, False], # static covs + ), + ) + def test_save_and_load_weights_covs_usage_attributes(self, tmpdir_fn, params): + """ + Verify that save/load correctly preserve the use_[past/future/static]_covariates attribute. + """ + model_cls, use_pc, use_fc, use_sc = params + model = model_cls( + input_chunk_length=4, + output_chunk_length=1, + n_epochs=1, + **tfm_kwargs_dev, + ) + # skip test if the combination of covariates is not supported by the model + if ( + (use_pc and not model.supports_past_covariates) + or (use_fc and not model.supports_future_covariates) + or (use_sc and not model.supports_static_covariates) + ): + return + + model.fit( + series=self.series + if not use_sc + else self.series.with_static_covariates(pd.Series([12], ["loc"])), + past_covariates=self.series + 10 if use_pc else None, + future_covariates=self.series - 5 if use_fc else None, + ) + # save and load the model + filename_ckpt = f"{model.model_name}.pt" + model.save(filename_ckpt) + model_loaded = model_cls( + input_chunk_length=4, + output_chunk_length=1, + **tfm_kwargs_dev, + ) + model_loaded.load_weights(filename_ckpt) + + assert model.uses_past_covariates == model_loaded.uses_past_covariates == use_pc + assert ( + model.uses_future_covariates + == model_loaded.uses_future_covariates + == use_fc + ) + assert ( + model.uses_static_covariates + == model_loaded.uses_static_covariates + == use_sc + ) + def test_save_and_load_weights_w_encoders(self, tmpdir_fn): """ Verify that save/load does not break encoders. diff --git a/darts/tests/optional_deps/test_onnx.py b/darts/tests/optional_deps/test_onnx.py new file mode 100644 index 0000000000..1d31d27d20 --- /dev/null +++ b/darts/tests/optional_deps/test_onnx.py @@ -0,0 +1,204 @@ +from itertools import product +from typing import Optional + +import numpy as np +import pandas as pd +import pytest + +import darts.utils.timeseries_generation as tg +from darts import TimeSeries +from darts.tests.conftest import ONNX_AVAILABLE, TORCH_AVAILABLE, tfm_kwargs_dev +from darts.utils.onnx_utils import prepare_onnx_inputs + +if not (TORCH_AVAILABLE and ONNX_AVAILABLE): + pytest.skip( + f"Torch or Onnx not available. {__name__} tests will be skipped.", + allow_module_level=True, + ) +import onnx +import onnxruntime as ort + +from darts.models import ( + BlockRNNModel, + NHiTSModel, + TiDEModel, +) + +# TODO: check how RINorm can be handled with respect to ONNX +torch_model_cls = [ + BlockRNNModel, + NHiTSModel, + TiDEModel, +] + + +class TestOnnx: + ts_tg = tg.linear_timeseries(start_value=0, end_value=100, length=30).astype( + "float32" + ) + ts_tg_with_static = ts_tg.with_static_covariates( + pd.Series(data=[12], index=["loc"]) + ) + ts_pc = tg.constant_timeseries(value=123.4, length=300).astype("float32") + ts_fc = tg.sine_timeseries(length=32).astype("float32") + + @pytest.mark.parametrize("model_cls", torch_model_cls) + def test_onnx_save_load(self, tmpdir_fn, model_cls): + model = model_cls( + input_chunk_length=4, output_chunk_length=2, n_epochs=1, **tfm_kwargs_dev + ) + onnx_filename = f"test_onnx_{model.model_name}.onnx" + + # exporting without fitting the model fails + with pytest.raises(ValueError): + model.to_onnx("dummy_name.onnx") + + model.fit( + series=self.ts_tg_with_static + if model.supports_static_covariates + else self.ts_tg, + past_covariates=self.ts_pc if model.supports_past_covariates else None, + future_covariates=self.ts_fc if model.supports_future_covariates else None, + ) + # native inference + pred = model.predict(2) + + # model export + model.to_onnx(onnx_filename) + + # onnx model verification + onnx_model = onnx.load(onnx_filename) + onnx.checker.check_model(onnx_model) + + # onnx model loading and inference + onnx_pred = self._helper_onnx_inference( + model=model, + onnx_filename=onnx_filename, + series=self.ts_tg_with_static, + past_covariates=self.ts_pc, + future_covariates=self.ts_fc, + )[0][0] + + # check that the predictions are similar + assert pred.shape == onnx_pred.shape, "forecasts don't have the same shape." + np.testing.assert_array_almost_equal(onnx_pred, pred.all_values(), decimal=4) + + @pytest.mark.parametrize( + "params", + product( + torch_model_cls, + [True, False], # clean + ), + ) + def test_onnx_from_ckpt(self, tmpdir_fn, params): + """Check that creating the onnx export from a model directly loaded from a checkpoint work as expected""" + model_cls, clean = params + model = model_cls( + input_chunk_length=4, output_chunk_length=2, n_epochs=1, **tfm_kwargs_dev + ) + onnx_filename = f"test_onnx_{model.model_name}.onnx" + onnx_filename2 = f"test_onnx_{model.model_name}_weights.onnx" + ckpt_filename = f"test_ckpt_{model.model_name}.pt" + + model.fit( + series=self.ts_tg_with_static + if model.supports_static_covariates + else self.ts_tg, + past_covariates=self.ts_pc if model.supports_past_covariates else None, + future_covariates=self.ts_fc if model.supports_future_covariates else None, + ) + model.save(ckpt_filename, clean=clean) + + # load the entire checkpoint + model_loaded = model_cls.load(ckpt_filename) + pred = model_loaded.predict( + n=2, + series=self.ts_tg_with_static + if model_loaded.uses_static_covariates + else self.ts_tg, + past_covariates=self.ts_pc if model_loaded.uses_past_covariates else None, + future_covariates=self.ts_fc + if model_loaded.uses_future_covariates + else None, + ) + + # export the loaded model + model_loaded.to_onnx(onnx_filename) + + # onnx model loading and inference + onnx_pred = self._helper_onnx_inference( + model=model_loaded, + onnx_filename=onnx_filename, + series=self.ts_tg_with_static, + past_covariates=self.ts_pc, + future_covariates=self.ts_fc, + )[0][0] + + # check that the predictions are similar + assert pred.shape == onnx_pred.shape, "forecasts don't have the same shape." + np.testing.assert_array_almost_equal(onnx_pred, pred.all_values(), decimal=4) + + # load only the weights + model_weights = model_cls( + input_chunk_length=4, output_chunk_length=2, n_epochs=1, **tfm_kwargs_dev + ) + model_weights.load_weights(ckpt_filename) + pred_weights = model_weights.predict( + n=2, + series=self.ts_tg_with_static + if model_weights.uses_static_covariates + else self.ts_tg, + past_covariates=self.ts_pc if model_weights.uses_past_covariates else None, + future_covariates=self.ts_fc + if model_weights.uses_future_covariates + else None, + ) + + # export the loaded model + model_weights.to_onnx(onnx_filename2) + + # onnx model loading and inference + onnx_pred_weights = self._helper_onnx_inference( + model=model_weights, + onnx_filename=onnx_filename2, + series=self.ts_tg_with_static, + past_covariates=self.ts_pc, + future_covariates=self.ts_fc, + )[0][0] + + assert pred_weights.shape == onnx_pred_weights.shape, ( + "forecasts don't have the same shape." + ) + np.testing.assert_array_almost_equal( + onnx_pred_weights, pred_weights.all_values(), decimal=4 + ) + + def _helper_onnx_inference( + self, + model, + onnx_filename: str, + series: TimeSeries, + past_covariates: Optional[TimeSeries], + future_covariates: Optional[TimeSeries], + ): + """Darts model is only used to detect which covariates are supported by the weights.""" + ort_session = ort.InferenceSession(onnx_filename) + + # extract the input arrays from the series + past_feats, future_feats, static_feats = prepare_onnx_inputs( + model=model, + series=series, + past_covariates=past_covariates, + future_covariates=future_covariates, + ) + + # extract only the features expected by the model + ort_inputs = {} + for name, arr in zip( + ["x_past", "x_future", "x_static"], [past_feats, future_feats, static_feats] + ): + if name in [inp.name for inp in list(ort_session.get_inputs())]: + ort_inputs[name] = arr + + # output has shape (batch, output_chunk_length, n components, 1 or n likelihood params) + return ort_session.run(None, ort_inputs) diff --git a/darts/tests/optional_deps/test_optuna.py b/darts/tests/optional_deps/test_optuna.py new file mode 100644 index 0000000000..b49bdec7f2 --- /dev/null +++ b/darts/tests/optional_deps/test_optuna.py @@ -0,0 +1,182 @@ +import os +from itertools import product + +import numpy as np +import pytest +from sklearn.preprocessing import MaxAbsScaler + +from darts.dataprocessing.transformers import Scaler +from darts.datasets import AirPassengersDataset +from darts.metrics import smape +from darts.models import LinearRegressionModel +from darts.tests.conftest import OPTUNA_AVAILABLE, TORCH_AVAILABLE, tfm_kwargs + +if not OPTUNA_AVAILABLE: + pytest.skip( + f"Optuna not available. {__name__} tests will be skipped.", + allow_module_level=True, + ) + +import optuna + +if TORCH_AVAILABLE: + import torch + from pytorch_lightning.callbacks import Callback, EarlyStopping + + # hacky workaround found in https://github.com/Lightning-AI/pytorch-lightning/issues/17485 + # to avoid import of both lightning and pytorch_lightning + class PatchedPruningCallback( + optuna.integration.PyTorchLightningPruningCallback, Callback + ): + pass + + from darts.models import TCNModel + from darts.utils.likelihood_models import GaussianLikelihood + + +class TestOptuna: + series = AirPassengersDataset().load().astype(np.float32) + + val_length = 36 + train, val = series.split_after(val_length) + + # scale + scaler = Scaler(MaxAbsScaler()) + train = scaler.fit_transform(train) + val = scaler.transform(val) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") + def test_optuna_torch_model(self, tmpdir_fn): + """Check that optuna works as expected with a torch-based model""" + + # define objective function + def objective(trial): + # select input and output chunk lengths + in_len = trial.suggest_int("in_len", 4, 8) + out_len = trial.suggest_int("out_len", 1, 3) + + # Other hyperparameters + kernel_size = trial.suggest_int("kernel_size", 2, 3) + num_filters = trial.suggest_int("num_filters", 1, 2) + lr = trial.suggest_float("lr", 5e-5, 1e-3, log=True) + include_year = trial.suggest_categorical("year", [False, True]) + + # throughout training we'll monitor the validation loss for both pruning and early stopping + pruner = PatchedPruningCallback(trial, monitor="val_loss") + early_stopper = EarlyStopping( + "val_loss", min_delta=0.001, patience=3, verbose=True + ) + + # optionally also add the (scaled) year value as a past covariate + if include_year: + encoders = { + "datetime_attribute": {"past": ["year"]}, + "transformer": Scaler(), + } + else: + encoders = None + + # reproducibility + torch.manual_seed(42) + + # build the TCN model + model = TCNModel( + input_chunk_length=in_len, + output_chunk_length=out_len, + batch_size=8, + n_epochs=2, + nr_epochs_val_period=1, + kernel_size=kernel_size, + num_filters=num_filters, + optimizer_kwargs={"lr": lr}, + add_encoders=encoders, + likelihood=GaussianLikelihood(), + pl_trainer_kwargs={ + **tfm_kwargs["pl_trainer_kwargs"], + "callbacks": [pruner, early_stopper], + }, + model_name="tcn_model", + force_reset=True, + save_checkpoints=True, + work_dir=os.getcwd(), + ) + + # when validating during training, we can use a slightly longer validation + # set which also contains the first input_chunk_length time steps + model_val_set = self.scaler.transform( + self.series[-(self.val_length + in_len) :] + ) + + # train the model + model.fit( + series=self.train, + val_series=model_val_set, + ) + + # reload best model over course of training + model = TCNModel.load_from_checkpoint( + model_name="tcn_model", work_dir=os.getcwd() + ) + + # Evaluate how good it is on the validation set, using sMAPE + preds = model.predict(series=self.train, n=self.val_length) + smapes = smape(self.val, preds, n_jobs=-1) + smape_val = np.mean(smapes) + + return smape_val if smape_val != np.nan else float("inf") + + # optimize hyperparameters by minimizing the sMAPE on the validation set + study = optuna.create_study(direction="minimize") + study.optimize(objective, n_trials=3) + + @pytest.mark.parametrize( + "params", + product( + [True, False], # multi_models + [1, 3], # ocl + ), + ) + def test_optuna_regression_model(self, params): + """Check that optuna works as expected with a regression model""" + + multi_models, ocl = params + + # define objective function + def objective(trial): + # select input and encoder usage + target_lags = trial.suggest_int("lags", 1, 12) + include_year = trial.suggest_categorical("year", [False, True]) + + # optionally also add the (scaled) year value as a past covariate + if include_year: + encoders = { + "datetime_attribute": {"past": ["year"]}, + "transformer": Scaler(), + } + past_lags = trial.suggest_int("lags_past_covariates", 1, 12) + else: + encoders = None + past_lags = None + + # build the model + model = LinearRegressionModel( + lags=target_lags, + lags_past_covariates=past_lags, + output_chunk_length=ocl, + multi_models=multi_models, + add_encoders=encoders, + ) + model.fit( + series=self.train, + ) + + # Evaluate how good it is on the validation set, using sMAPE + preds = model.predict(series=self.train, n=self.val_length) + smapes = smape(self.val, preds, n_jobs=-1) + smape_val = np.mean(smapes) + + return smape_val if smape_val != np.nan else float("inf") + + # optimize hyperparameters by minimizing the sMAPE on the validation set + study = optuna.create_study(direction="minimize") + study.optimize(objective, n_trials=3) diff --git a/darts/tests/optional_deps/test_ray.py b/darts/tests/optional_deps/test_ray.py new file mode 100644 index 0000000000..daf2eaa056 --- /dev/null +++ b/darts/tests/optional_deps/test_ray.py @@ -0,0 +1,116 @@ +import numpy as np +import pytest +from sklearn.preprocessing import MaxAbsScaler + +from darts.dataprocessing.transformers import Scaler +from darts.datasets import AirPassengersDataset +from darts.tests.conftest import RAY_AVAILABLE, TORCH_AVAILABLE, tfm_kwargs + +if not RAY_AVAILABLE: + pytest.skip( + f"Ray not available. {__name__} tests will be skipped.", + allow_module_level=True, + ) + +from ray import tune +from ray.tune.tuner import Tuner + +if TORCH_AVAILABLE: + from pytorch_lightning.callbacks import Callback, EarlyStopping + from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback + from torchmetrics import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MetricCollection, + ) + + from darts.models import NBEATSModel + + +class TestRay: + series = AirPassengersDataset().load().astype(np.float32) + + val_length = 36 + train, val = series.split_after(val_length) + + # scale + scaler = Scaler(MaxAbsScaler()) + train = scaler.fit_transform(train) + val = scaler.transform(val) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") + def test_ray_torch_model(self, tmpdir_fn): + """Check that ray works as expected with a torch-based model""" + + def train_model(model_args, callbacks, train, val): + torch_metrics = MetricCollection([ + MeanAbsolutePercentageError(), + MeanAbsoluteError(), + ]) + + # Create the model using model_args from Ray Tune + model = NBEATSModel( + input_chunk_length=4, + output_chunk_length=3, + n_epochs=2, + torch_metrics=torch_metrics, + pl_trainer_kwargs={ + **tfm_kwargs["pl_trainer_kwargs"], + "callbacks": callbacks, + }, + **model_args, + ) + + model.fit( + series=train, + val_series=val, + ) + + # Early stop callback + my_stopper = EarlyStopping( + monitor="val_MeanAbsolutePercentageError", + patience=5, + min_delta=0.05, + mode="min", + ) + + # set up ray tune callback + class TuneReportCallback(TuneReportCheckpointCallback, Callback): + pass + + tune_callback = TuneReportCallback( + { + "loss": "val_loss", + "MAPE": "val_MeanAbsolutePercentageError", + }, + on="validation_end", + ) + + # Define the trainable function that will be tuned by Ray Tune + train_fn_with_parameters = tune.with_parameters( + train_model, + callbacks=[tune_callback, my_stopper], + train=self.train, + val=self.val, + ) + + # define the hyperparameter space + param_space = { + "batch_size": tune.choice([8, 16]), + "num_blocks": tune.choice([1, 2]), + "num_stacks": tune.choice([2, 4]), + } + + # the number of combinations to try + num_samples = 2 + + # Create the Tuner object and run the hyperparameter search + tuner = Tuner( + trainable=train_fn_with_parameters, + param_space=param_space, + tune_config=tune.TuneConfig( + metric="MAPE", mode="min", num_samples=num_samples + ), + run_config=tune.RunConfig(name="tune_darts"), + ) + tuner.fit() diff --git a/darts/utils/onnx_utils.py b/darts/utils/onnx_utils.py new file mode 100644 index 0000000000..39d3b67c33 --- /dev/null +++ b/darts/utils/onnx_utils.py @@ -0,0 +1,52 @@ +from typing import Optional + +import numpy as np + +from darts import TimeSeries + + +def prepare_onnx_inputs( + model, + series: TimeSeries, + past_covariates: Optional[TimeSeries] = None, + future_covariates: Optional[TimeSeries] = None, +) -> tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: + """Helper function to slice and concatenate the input features. + + In order to remove the dependency on the `model` argument, it can be decomposed into + the following arguments (and simplified depending on the characteristics of the model used): + - model_icl + - model_ocl + - model_uses_past_covs + - model_uses_future_covs + - model_uses_static_covs + """ + past_feats, future_feats, static_feats = None, None, None + # get input & output windows + past_start = series.end_time() - (model.input_chunk_length - 1) * series.freq + past_end = series.end_time() + future_start = past_end + 1 * series.freq + future_end = past_end + model.output_chunk_length * series.freq + # extract all historic and future features from target, past and future covariates + past_feats = series[past_start:past_end].values() + if past_covariates and model.uses_past_covariates: + # extract past covariates + past_feats = np.concatenate( + [past_feats, past_covariates[past_start:past_end].values()], axis=1 + ) + if future_covariates and model.uses_future_covariates: + # extract past part of future covariates + past_feats = np.concatenate( + [past_feats, future_covariates[past_start:past_end].values()], axis=1 + ) + # extract future part of future covariates + future_feats = future_covariates[future_start:future_end].values() + # add batch dimension -> (batch, n time steps, n components) + past_feats = np.expand_dims(past_feats, axis=0).astype(series.dtype) + future_feats = np.expand_dims(future_feats, axis=0).astype(series.dtype) + # extract static covariates + if series.has_static_covariates and model.uses_static_covariates: + static_feats = np.expand_dims(series.static_covariates_values(), axis=0).astype( + series.dtype + ) + return past_feats, future_feats, static_feats diff --git a/docs/userguide/hyperparameter_optimization.md b/docs/userguide/hyperparameter_optimization.md index 5097532424..8094b15698 100644 --- a/docs/userguide/hyperparameter_optimization.md +++ b/docs/userguide/hyperparameter_optimization.md @@ -20,7 +20,7 @@ import numpy as np import optuna import torch from optuna.integration import PyTorchLightningPruningCallback -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import Callback, EarlyStopping from sklearn.preprocessing import MaxAbsScaler from darts.dataprocessing.transformers import Scaler @@ -41,6 +41,11 @@ scaler = Scaler(MaxAbsScaler()) train = scaler.fit_transform(train) val = scaler.transform(val) +# workaround found in https://github.com/Lightning-AI/pytorch-lightning/issues/17485 +# to avoid import of both lightning and pytorch_lightning +class PatchedPruningCallback(optuna.integration.PyTorchLightningPruningCallback, Callback): + pass + # define objective function def objective(trial): # select input and output chunk lengths @@ -57,7 +62,7 @@ def objective(trial): include_year = trial.suggest_categorical("year", [False, True]) # throughout training we'll monitor the validation loss for both pruning and early stopping - pruner = PyTorchLightningPruningCallback(trial, monitor="val_loss") + pruner = PatchedPruningCallback(trial, monitor="val_loss") early_stopper = EarlyStopping("val_loss", min_delta=0.001, patience=3, verbose=True) callbacks = [pruner, early_stopper] @@ -112,7 +117,6 @@ def objective(trial): model.fit( series=train, val_series=model_val_set, - num_loader_workers=num_workers, ) # reload best model over course of training diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index 928612e7f5..4aed55cf55 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -372,49 +372,7 @@ import onnx import onnxruntime as ort import numpy as np from darts import TimeSeries - -def prepare_onnx_inputs( - model, - series: TimeSeries, - past_covariates : Optional[TimeSeries] = None, - future_covariates : Optional[TimeSeries] = None, -) -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: - """Helper function to slice and concatenate the input features""" - past_feats, future_feats, static_feats = None, None, None - # get input & output windows - past_start = series.end_time() - (model.input_chunk_length - 1) * series.freq - past_end = series.end_time() - future_start = past_end + 1 * series.freq - future_end = past_end + model.output_chunk_length * series.freq - # extract all historic and future features from target, past and future covariates - past_feats = series[past_start:past_end].values() - if past_covariates and model.uses_past_covariates: - # extract past covariates - past_feats = np.concatenate( - [ - past_feats, - past_covariates[past_start:past_end].values() - ], - axis=1 - ) - if future_covariates and model.uses_future_covariates: - # extract past part of future covariates - past_feats = np.concatenate( - [ - past_feats, - future_covariates[past_start:past_end].values() - ], - axis=1 - ) - # extract future part of future covariates - future_feats = future_covariates[future_start:future_end].values() - # add batch dimension -> (batch, n time steps, n components) - past_feats = np.expand_dims(past_feats, axis=0).astype(series.dtype) - future_feats = np.expand_dims(future_feats, axis=0).astype(series.dtype) - # extract static covariates - if series.has_static_covariates and model.uses_static_covariates: - static_feats = np.expand_dims(series.static_covariates_values(), axis=0).astype(series.dtype) - return past_feats, future_feats, static_feats +from darts.utils.onnx_utils.py import prepare_onnx_inputs onnx_model = onnx.load(onnx_filename) onnx.checker.check_model(onnx_model) diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 0000000000..b5ba05815a --- /dev/null +++ b/requirements/optional.txt @@ -0,0 +1,5 @@ +onnx +onnxruntime +optuna +optuna-integration[pytorch_lightning] +ray