Skip to content

Commit 8bea071

Browse files
authored
revert back to saving and loading a TFM with torch rather than pickle (#2692)
* revert back to saving and loading a TFM with torch rather than pickle * update changelog
1 parent 4de0868 commit 8bea071

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1818

1919
**Fixed**
2020

21+
- 🔴 / 🟢 Fixed a bug which raised an error when loading torch models that were saved with Darts versions < 0.33.0. This is a breaking change and models saved with version 0.33.0 will not be loadable anymore. [#2692](https://github.com/unit8co/darts/pull/2692) by [Dennis Bader](https://github.com/dennisbader).
22+
2123
**Dependencies**
2224

2325
### For developers of the library:

darts/models/forecasting/torch_forecasting_model.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,8 @@ def save(
17651765
path = self._default_save_path() + ".pt"
17661766

17671767
# save the TorchForecastingModel (does not save the PyTorch LightningModule, and Trainer)
1768-
super().save(path, clean=clean)
1768+
with open(path, "wb") as f_out:
1769+
torch.save(self if not clean else self._clean(), f_out)
17691770

17701771
# save the LightningModule checkpoint (weights only with `clean=True`)
17711772
path_ptl_ckpt = path + ".ckpt"
@@ -1802,7 +1803,7 @@ def load(
18021803
model_loaded = RNNModel.load(path)
18031804
..
18041805
1805-
Example for loading an :class:`RNNModel` to GPU:
1806+
Example for loading an :class:`RNNModel` to GPU that was trained on CPU:
18061807
18071808
.. highlight:: python
18081809
.. code-block:: python
@@ -1812,6 +1813,16 @@ def load(
18121813
model_loaded = RNNModel.load(path, pl_trainer_kwargs={"accelerator": "gpu"})
18131814
..
18141815
1816+
Example for loading an :class:`RNNModel` to CPU that was saved on GPU:
1817+
1818+
.. highlight:: python
1819+
.. code-block:: python
1820+
1821+
from darts.models import RNNModel
1822+
1823+
model_loaded = RNNModel.load(path, map_location="cpu", pl_trainer_kwargs={"accelerator": "gpu"})
1824+
..
1825+
18151826
Parameters
18161827
----------
18171828
path
@@ -1825,11 +1836,15 @@ def load(
18251836
for more information about the supported kwargs.
18261837
**kwargs
18271838
Additional kwargs for PyTorch Lightning's :func:`LightningModule.load_from_checkpoint()` method,
1839+
such as ``map_location`` to load the model onto a different device than the one on which it was saved.
18281840
For more information, read the `official documentation <https://pytorch-lightning.readthedocs.io/en/stable/
18291841
common/lightning_module.html#load-from-checkpoint>`_.
18301842
"""
18311843
# load the base TorchForecastingModel (does not contain the actual PyTorch LightningModule)
1832-
model: TorchForecastingModel = ForecastingModel.load(path)
1844+
with open(path, "rb") as fin:
1845+
model: TorchForecastingModel = torch.load(
1846+
fin, weights_only=False, map_location=kwargs.get("map_location", None)
1847+
)
18331848

18341849
# if a checkpoint was saved, we also load the PyTorch LightningModule from checkpoint
18351850
path_ptl_ckpt = path + ".ckpt"
@@ -1927,7 +1942,9 @@ def load_from_checkpoint(
19271942
f"Could not find base model save file `{INIT_MODEL_NAME}` in {model_dir}.",
19281943
logger,
19291944
)
1930-
model: TorchForecastingModel = ForecastingModel.load(base_model_path)
1945+
model: TorchForecastingModel = torch.load(
1946+
base_model_path, weights_only=False, map_location=kwargs.get("map_location")
1947+
)
19311948

19321949
# load PyTorch LightningModule from checkpoint
19331950
# if file_name is None, find the path of the best or most recent checkpoint in savepath
@@ -2093,7 +2110,12 @@ def load_weights_from_checkpoint(
20932110
)
20942111

20952112
# updating model attributes before self._init_model() which create new tfm ckpt
2096-
tfm_save: TorchForecastingModel = ForecastingModel.load(tfm_save_file_path)
2113+
with open(tfm_save_file_path, "rb") as tfm_save_file:
2114+
tfm_save: TorchForecastingModel = torch.load(
2115+
tfm_save_file,
2116+
weights_only=False,
2117+
map_location=kwargs.get("map_location", None),
2118+
)
20972119

20982120
# encoders are necessary for direct inference
20992121
self.encoders, self.add_encoders = self._load_encoders(

0 commit comments

Comments
 (0)