Skip to content

Commit cda8f81

Browse files
Feat/onnx support (#2620)
* feat: wrapping around pl.to_onnx to export models to ONNX, still require testing * feat: cleaned implementation of the to_onnx method * fix: generation of input name, shape of input_batch for PastCov torch module * feat: adding example of onnx usage in userguide * update changelog * fix: revert some changes * fix: export to onnx for RNNModel * feat: added a comment about RNNModel for onnx inference * update changelog * fix: address review comments * update changelog * update torch user guide * update to_onnx --------- Co-authored-by: dennisbader <[email protected]>
1 parent 3aa5c97 commit cda8f81

File tree

5 files changed

+195
-12
lines changed

5 files changed

+195
-12
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Added ONNX support for torch-based models with method `TorchForecastingModel.to_onnx()`. Check out [this example](https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#exporting-model-to-onnx-format-for-inference) from the user guide on how to export and load a model for inference. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)
1415
- Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon)
1516

1617
**Fixed**

darts/models/forecasting/pl_forecasting_module.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
When subclassing this class, please make sure to add the following methods with the given signatures:
9494
- :func:`PLForecastingModule.__init__()`
9595
- :func:`PLForecastingModule.forward()`
96+
- :func:`PLForecastingModule._process_input_batch()`
9697
- :func:`PLForecastingModule._produce_train_output()`
9798
- :func:`PLForecastingModule._get_batch_prediction()`
9899
@@ -632,17 +633,48 @@ def _produce_train_output(self, input_batch: tuple):
632633
input_batch
633634
``(past_target, past_covariates, static_covariates)``
634635
"""
635-
past_target, past_covariates, static_covariates = input_batch
636+
return self(self._process_input_batch(input_batch))
637+
638+
def _process_input_batch(
639+
self, input_batch: tuple
640+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
641+
"""
642+
Converts output of PastCovariatesDataset (training dataset) into an input/past- and
643+
output/future chunk.
644+
645+
Parameters
646+
----------
647+
input_batch
648+
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
649+
650+
Returns
651+
-------
652+
tuple
653+
``(x_past, x_static)`` the input/past and output/future chunks.
654+
"""
655+
# because of future past covariates, the batch shape is different during training and prediction
656+
if len(input_batch) == 3:
657+
(
658+
past_target,
659+
past_covariates,
660+
static_covariates,
661+
) = input_batch
662+
else:
663+
(
664+
past_target,
665+
past_covariates,
666+
future_past_covariates,
667+
static_covariates,
668+
) = input_batch
636669
# Currently all our PastCovariates models require past target and covariates concatenated
637-
inpt = (
670+
return (
638671
(
639672
torch.cat([past_target, past_covariates], dim=2)
640673
if past_covariates is not None
641674
else past_target
642675
),
643676
static_covariates,
644677
)
645-
return self(inpt)
646678

647679
def _get_batch_prediction(
648680
self, n: int, input_batch: tuple, roll_size: int
@@ -674,12 +706,9 @@ def _get_batch_prediction(
674706
past_covariates.shape[dim_component] if past_covariates is not None else 0
675707
)
676708

677-
input_past = torch.cat(
678-
[ds for ds in [past_target, past_covariates] if ds is not None],
679-
dim=dim_component,
680-
)
709+
input_past, input_static = self._process_input_batch(input_batch)
681710

682-
out = self._produce_predict_output(x=(input_past, static_covariates))[
711+
out = self._produce_predict_output(x=(input_past, input_static))[
683712
:, self.first_prediction_index :, :
684713
]
685714

@@ -796,7 +825,7 @@ def _process_input_batch(
796825
future_covariates,
797826
static_covariates,
798827
) = input_batch
799-
dim_variable = 2
828+
dim_comp = 2
800829

801830
x_past = torch.cat(
802831
[
@@ -808,7 +837,7 @@ def _process_input_batch(
808837
]
809838
if tensor is not None
810839
],
811-
dim=dim_variable,
840+
dim=dim_comp,
812841
)
813842
return x_past, future_covariates, static_covariates
814843

darts/models/forecasting/rnn_model.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def forward(
104104
pass
105105

106106
def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
107+
# only return the forecast, not the hidden state
108+
return self(self._process_input_batch(input_batch))[0]
109+
110+
def _process_input_batch(
111+
self, input_batch: tuple
112+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
107113
(
108114
past_target,
109115
historic_future_covariates,
@@ -112,15 +118,14 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
112118
) = input_batch
113119
# For the RNN we concatenate the past_target with the future_covariates
114120
# (they have the same length because we enforce a Shift dataset for RNNs)
115-
model_input = (
121+
return (
116122
(
117123
torch.cat([past_target, future_covariates], dim=2)
118124
if future_covariates is not None
119125
else past_target
120126
),
121127
static_covariates,
122128
)
123-
return self(model_input)[0]
124129

125130
def _produce_predict_output(
126131
self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None

darts/models/forecasting/torch_forecasting_model.py

+60
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,66 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
646646
logger=logger,
647647
)
648648

649+
def to_onnx(self, path: Optional[str] = None, **kwargs):
650+
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
651+
:func:`torch.onnx.export` method (`official documentation <https://lightning.ai/docs/pytorch/
652+
stable/common/lightning_module.html#to-onnx>`_).
653+
654+
Note: requires `onnx` library (optional dependency) to be installed.
655+
656+
Example for exporting a :class:`DLinearModel`:
657+
658+
.. highlight:: python
659+
.. code-block:: python
660+
661+
from darts.datasets import AirPassengersDataset
662+
from darts.models import DLinearModel
663+
664+
series = AirPassengersDataset().load()
665+
model = DLinearModel(input_chunk_length=4, output_chunk_length=1)
666+
model.fit(series, epochs=1)
667+
model.to_onnx("my_model.onnx")
668+
..
669+
670+
Parameters
671+
----------
672+
path
673+
Path under which to save the model at its current state. If no path is specified, the model
674+
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``.
675+
**kwargs
676+
Additional kwargs for PyTorch's :func:`torch.onnx.export` method (except parameters ``file_path``,
677+
``input_sample``, ``input_name``). For more information, read the `official documentation
678+
<https://pytorch.org/docs/master/onnx.html#torch.onnx.export>`_.
679+
"""
680+
if not self._fit_called:
681+
raise_log(
682+
ValueError("`fit()` needs to be called before `to_onnx()`."), logger
683+
)
684+
685+
if path is None:
686+
path = self._default_save_path() + ".onnx"
687+
688+
# last dimension in train_sample_shape is the expected target
689+
mock_batch = tuple(
690+
torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None
691+
for shape in self.model.train_sample_shape[:-1]
692+
)
693+
input_sample = self.model._process_input_batch(mock_batch)
694+
695+
# torch models necessarily use historic target values as features in current implementation
696+
input_names = ["x_past"]
697+
if self.uses_future_covariates:
698+
input_names.append("x_future")
699+
if self.uses_static_covariates:
700+
input_names.append("x_static")
701+
702+
self.model.to_onnx(
703+
file_path=path,
704+
input_sample=(input_sample,),
705+
input_names=input_names,
706+
**kwargs,
707+
)
708+
649709
@random_method
650710
def fit(
651711
self,

docs/userguide/torch_forecasting_models.md

+88
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ We assume that you already know about covariates in Darts. If you're new to the
2222
- [Manual saving / loading](#manual-saving--loading)
2323
- [Train & save on GPU, load on CPU](#trainingsaving-on-gpu-and-loading-on-cpu)
2424
- [Load pre-trained model for fine-tuning](#re-training-or-fine-tuning-a-pre-trained-model)
25+
- [Exporting model to ONNX format for inference](#exporting-model-to-ONNX-format-for-inference)
2526
- [Callbacks](#callbacks)
2627
- [Early Stopping](#example-with-early-stopping)
2728
- [Custom Callback](#example-of-custom-callback-to-store-losses)
@@ -350,6 +351,93 @@ model_finetune = SomeTorchForecastingModel(..., # use identical parameters & va
350351
model_finetune.load_weights("/your/path/to/save/model.pt")
351352
```
352353

354+
#### Exporting model to ONNX format for inference
355+
356+
It is also possible to export the model weights to the ONNX format to run inference in a lightweight environment. The example below works for any `TorchForecastingModel` except `RNNModel` and for optional usage of past, future and / or static covariates. Note that all series and covariates must extend far enough into the past (`input_chunk_length)` and future (`output_chunk_length`) relative to the end of the target `series`. It will not be possible to forecast a horizon `n > output_chunk_length` without implementing the auto-regression logic.
357+
358+
```python
359+
model = SomeTorchForecastingModel(...)
360+
model.fit(...)
361+
362+
# make sure to have `onnx` and `onnxruntime` installed
363+
onnx_filename = "example_onnx.onnx"
364+
model.to_onnx(onnx_filename, export_params=True)
365+
```
366+
367+
Now, to load the model and predict steps after the end of the series:
368+
369+
```python
370+
from typing import Optional
371+
import onnx
372+
import onnxruntime as ort
373+
import numpy as np
374+
from darts import TimeSeries
375+
376+
def prepare_onnx_inputs(
377+
model,
378+
series: TimeSeries,
379+
past_covariates : Optional[TimeSeries] = None,
380+
future_covariates : Optional[TimeSeries] = None,
381+
) -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
382+
"""Helper function to slice and concatenate the input features"""
383+
past_feats, future_feats, static_feats = None, None, None
384+
# get input & output windows
385+
past_start = series.end_time() - (model.input_chunk_length - 1) * series.freq
386+
past_end = series.end_time()
387+
future_start = past_end + 1 * series.freq
388+
future_end = past_end + model.output_chunk_length * series.freq
389+
# extract all historic and future features from target, past and future covariates
390+
past_feats = series[past_start:past_end].values()
391+
if past_covariates and model.uses_past_covariates:
392+
# extract past covariates
393+
past_feats = np.concatenate(
394+
[
395+
past_feats,
396+
past_covariates[past_start:past_end].values()
397+
],
398+
axis=1
399+
)
400+
if future_covariates and model.uses_future_covariates:
401+
# extract past part of future covariates
402+
past_feats = np.concatenate(
403+
[
404+
past_feats,
405+
future_covariates[past_start:past_end].values()
406+
],
407+
axis=1
408+
)
409+
# extract future part of future covariates
410+
future_feats = future_covariates[future_start:future_end].values()
411+
# add batch dimension -> (batch, n time steps, n components)
412+
past_feats = np.expand_dims(past_feats, axis=0).astype(series.dtype)
413+
future_feats = np.expand_dims(future_feats, axis=0).astype(series.dtype)
414+
# extract static covariates
415+
if series.has_static_covariates and model.uses_static_covariates:
416+
static_feats = np.expand_dims(series.static_covariates_values(), axis=0).astype(series.dtype)
417+
return past_feats, future_feats, static_feats
418+
419+
onnx_model = onnx.load(onnx_filename)
420+
onnx.checker.check_model(onnx_model)
421+
ort_session = ort.InferenceSession(onnx_filename)
422+
423+
# use helper function to extract the features from the series
424+
past_feats, future_feats, static_feats = prepare_onnx_inputs(
425+
model=model,
426+
series=series,
427+
past_covariates=ts_past,
428+
future_covariates=ts_future,
429+
)
430+
431+
# extract only the features expected by the model
432+
ort_inputs = {}
433+
for name, arr in zip(['x_past', 'x_future', 'x_static'], [past_feats, future_feats, static_feats]):
434+
if name in [inp.name for inp in list(ort_session.get_inputs())]:
435+
ort_inputs[name] = arr
436+
437+
# output has shape (batch, output_chunk_length, n components, 1 or n likelihood params)
438+
ort_out = ort_session.run(None, ort_inputs)
439+
```
440+
353441
### Callbacks
354442

355443
Callbacks are a powerful way to monitor or control the behavior of the model during the training process. Some examples:

0 commit comments

Comments
 (0)