Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/test optional dep (onnx, ray, optuna) #2702

Merged
merged 21 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/develop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
if [ "${{ matrix.os }}" == "macos-13" ]; then
source $HOME/.local/bin/env
fi
uv pip compile requirements/dev-all.txt > requirements-latest.txt
uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt

- name: "Cache python environment"
uses: actions/cache@v4
Expand All @@ -67,7 +67,7 @@ jobs:
- name: "Install Dependencies"
run: |
# install latest dependencies (potentially updating cached dependencies)
pip install -U -r requirements/dev-all.txt
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt

- name: "Install libomp (for LightGBM)"
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
elif [ "${{ matrix.flavour }}" == "torch" ]; then
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/dev.txt
elif [ "${{ matrix.flavour }}" == "all" ]; then
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/dev.txt
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/optional.txt -r requirements/dev.txt
fi

- name: "Install libomp (for LightGBM)"
Expand Down Expand Up @@ -94,7 +94,7 @@ jobs:
- name: "Compile Dependency Versions"
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
uv pip compile requirements/dev-all.txt > requirements-latest.txt
uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt

# only restore cache but do not upload
- name: "Restore cached python environment"
Expand Down Expand Up @@ -141,7 +141,7 @@ jobs:
- name: "Compile Dependency Versions"
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
uv pip compile requirements/dev-all.txt > requirements-latest.txt
uv pip compile requirements/dev-all.txt -o requirements-latest.txt

# only restore cache but do not upload
- name: "Restore cached python environment"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/update-cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
if [ "${{ matrix.os }}" == "macos-13" ]; then
source $HOME/.local/bin/env
fi
uv pip compile requirements/dev-all.txt > requirements-latest.txt
uv pip compile requirements/dev-all.txt -r requirements/optional.txt -o requirements-latest.txt

- name: "Cache python environment"
uses: actions/cache@v4
Expand All @@ -47,4 +47,4 @@ jobs:
- name: "Install Latest Dependencies"
run: |
# install latest dependencies (potentially updating cached dependencies)
pip install -U -r requirements/dev-all.txt
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
2 changes: 1 addition & 1 deletion darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def __init__(
If ``False``, only attends to previous time steps in the decoder. If ``True`` attends to previous,
current, and future time steps. Defaults to ``False``.
feed_forward
A feedforward network is a fully-connected layer with an activation. TFT Can be one of the glu variant's
A feedforward network is a fully-connected layer with an activation. Can be one of the glu variant's
FeedForward Network (FFN)[2]. The glu variant's FeedForward Network are a series of FFNs designed to work
better with Transformer based models. Defaults to ``"GatedResidualNetwork"``. ["GLU", "Bilinear", "ReGLU",
"GEGLU", "SwiGLU", "ReLU", "GELU"] or the TFT original FeedForward Network ["GatedResidualNetwork"].
Expand Down
56 changes: 55 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,12 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
logger=logger,
)

@abstractmethod
def _update_covariates_use(self):
"""Based on the Forecasting class and the training_sample attribute, update the
uses_[past/future/static]_covariates attributes."""
pass

def to_onnx(self, path: Optional[str] = None, **kwargs):
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
:func:`torch.onnx.export` method (`official documentation <https://lightning.ai/docs/pytorch/
Expand Down Expand Up @@ -677,6 +683,8 @@ def to_onnx(self, path: Optional[str] = None, **kwargs):
``input_sample``, ``input_name``). For more information, read the `official documentation
<https://pytorch.org/docs/master/onnx.html#torch.onnx.export>`_.
"""
# TODO: LSTM model should be exported with a batch size of 1
# TODO: predictions with TFT and TCN models is incorrect, might be caused by helper function to process inputs
if not self._fit_called:
raise_log(
ValueError("`fit()` needs to be called before `to_onnx()`."), logger
Expand Down Expand Up @@ -1774,7 +1782,7 @@ def save(
self.trainer.save_checkpoint(path_ptl_ckpt, weights_only=clean)

# TODO: keep track of PyTorch Lightning to see if they implement model checkpoint saving
# without having to call fit/predict/validate/test before
# without having to call fit/predict/validate/test before
# try to recover original automatic PL checkpoint
elif self.load_ckpt_path:
if os.path.exists(self.load_ckpt_path):
Expand Down Expand Up @@ -2133,6 +2141,9 @@ def load_weights_from_checkpoint(
self.model.load_state_dict(ckpt["state_dict"], strict=strict)
# update the fit_called attribute to allow for direct inference
self._fit_called = True
# based on the shape of train_sample, figure out with covariates are used by the model
# (usually set in the Darts model prior to fitting it)
self._update_covariates_use()

def load_weights(
self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs
Expand Down Expand Up @@ -2683,6 +2694,13 @@ def extreme_lags(
None,
)

def _update_covariates_use(self):
"""The model is expected to rely on the `PastCovariatesTrainingDataset`"""
_, past_covs, static_covs, _ = self.train_sample
self._uses_past_covariates = past_covs is not None
self._uses_future_covariates = False
self._uses_static_covariates = static_covs is not None


class FutureCovariatesTorchModel(TorchForecastingModel, ABC):
supports_past_covariates = False
Expand Down Expand Up @@ -2776,6 +2794,13 @@ def extreme_lags(
None,
)

def _update_covariates_use(self):
"""The model is expected to rely on the `FutureCovariatesTrainingDataset`"""
_, future_covs, static_covs, _ = self.train_sample
self._uses_past_covariates = False
self._uses_future_covariates = future_covs is not None
self._uses_static_covariates = static_covs is not None


class DualCovariatesTorchModel(TorchForecastingModel, ABC):
supports_past_covariates = False
Expand Down Expand Up @@ -2870,6 +2895,15 @@ def extreme_lags(
None,
)

def _update_covariates_use(self):
"""The model is expected to rely on the `DualCovariatesTrainingDataset`"""
_, historic_future_covs, future_covs, static_covs, _ = self.train_sample
self._uses_past_covariates = False
self._uses_future_covariates = (
historic_future_covs is not None or future_covs is not None
)
self._uses_static_covariates = static_covs is not None


class MixedCovariatesTorchModel(TorchForecastingModel, ABC):
def _build_train_dataset(
Expand Down Expand Up @@ -2964,6 +2998,17 @@ def extreme_lags(
None,
)

def _update_covariates_use(self):
"""The model is expected to rely on the `MixedCovariatesTrainingDataset`"""
_, past_covs, historic_future_covs, future_covs, static_covs, _ = (
self.train_sample
)
self._uses_past_covariates = past_covs is not None
self._uses_future_covariates = (
historic_future_covs is not None or future_covs is not None
)
self._uses_static_covariates = static_covs is not None


class SplitCovariatesTorchModel(TorchForecastingModel, ABC):
def _build_train_dataset(
Expand Down Expand Up @@ -3058,3 +3103,12 @@ def extreme_lags(
self.output_chunk_shift,
None,
)

def _update_covariates_use(self):
"""The model is expected to rely on the `SplitCovariatesTrainingDataset`"""
_, past_covs, historic_future_covs, future_covs, static_covs, _ = (
self.train_sample
)
self._uses_past_covariates = past_covs is not None
self._uses_future_covariates = future_covs is not None
self._uses_static_covariates = static_covs is not None
25 changes: 25 additions & 0 deletions darts/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@
logger.warning("Torch not installed - Some tests will be skipped.")
TORCH_AVAILABLE = False

try:
import onnx # noqa: F401
import onnxruntime # noqa: F401

ONNX_AVAILABLE = True
except ImportError:
logger.warning("Onnx not installed - Some tests will be skipped.")
ONNX_AVAILABLE = False

try:
import optuna # noqa: F401

OPTUNA_AVAILABLE = True
except ImportError:
logger.warning("Optuna not installed - Some tests will be skipped.")
OPTUNA_AVAILABLE = False

try:
import ray # noqa: F401

RAY_AVAILABLE = True
except ImportError:
logger.warning("Ray not installed - Some tests will be skipped.")
RAY_AVAILABLE = False

tfm_kwargs = {
"pl_trainer_kwargs": {
"accelerator": "cpu",
Expand Down
57 changes: 57 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,63 @@ def create_model(**kwargs):
model_new = create_model(**kwargs_)
model_new.load_weights(model_path_manual)

@pytest.mark.parametrize(
"params",
itertools.product(
[DLinearModel, NBEATSModel, RNNModel], # model_cls
[True, False], # past_covs
[True, False], # future_covs
[True, False], # static covs
),
)
def test_save_and_load_weights_covs_usage_attributes(self, tmpdir_fn, params):
"""
Verify that save/load correctly preserve the use_[past/future/static]_covariates attribute.
"""
model_cls, use_pc, use_fc, use_sc = params
model = model_cls(
input_chunk_length=4,
output_chunk_length=1,
n_epochs=1,
**tfm_kwargs,
)
# skip test if the combination of covariates is not supported by the model
if (
(use_pc and not model.supports_past_covariates)
or (use_fc and not model.supports_future_covariates)
or (use_sc and not model.supports_static_covariates)
):
return

model.fit(
series=self.series
if not use_sc
else self.series.with_static_covariates(pd.Series([12], ["loc"])),
past_covariates=self.series + 10 if use_pc else None,
future_covariates=self.series - 5 if use_fc else None,
)
# save and load the model
filename_ckpt = f"{model.model_name}.pt"
model.save(filename_ckpt)
model_loaded = model_cls(
input_chunk_length=4,
output_chunk_length=1,
**tfm_kwargs,
)
model_loaded.load_weights(filename_ckpt)

assert model.uses_past_covariates == model_loaded.uses_past_covariates == use_pc
assert (
model.uses_future_covariates
== model_loaded.uses_future_covariates
== use_fc
)
assert (
model.uses_static_covariates
== model_loaded.uses_static_covariates
== use_sc
)

def test_save_and_load_weights_w_encoders(self, tmpdir_fn):
"""
Verify that save/load does not break encoders.
Expand Down
Loading
Loading