@@ -1765,7 +1765,8 @@ def save(
1765
1765
path = self ._default_save_path () + ".pt"
1766
1766
1767
1767
# save the TorchForecastingModel (does not save the PyTorch LightningModule, and Trainer)
1768
- super ().save (path , clean = clean )
1768
+ with open (path , "wb" ) as f_out :
1769
+ torch .save (self if not clean else self ._clean (), f_out )
1769
1770
1770
1771
# save the LightningModule checkpoint (weights only with `clean=True`)
1771
1772
path_ptl_ckpt = path + ".ckpt"
@@ -1802,7 +1803,7 @@ def load(
1802
1803
model_loaded = RNNModel.load(path)
1803
1804
..
1804
1805
1805
- Example for loading an :class:`RNNModel` to GPU:
1806
+ Example for loading an :class:`RNNModel` to GPU that was trained on CPU :
1806
1807
1807
1808
.. highlight:: python
1808
1809
.. code-block:: python
@@ -1812,6 +1813,16 @@ def load(
1812
1813
model_loaded = RNNModel.load(path, pl_trainer_kwargs={"accelerator": "gpu"})
1813
1814
..
1814
1815
1816
+ Example for loading an :class:`RNNModel` to CPU that was saved on GPU:
1817
+
1818
+ .. highlight:: python
1819
+ .. code-block:: python
1820
+
1821
+ from darts.models import RNNModel
1822
+
1823
+ model_loaded = RNNModel.load(path, map_location="cpu", pl_trainer_kwargs={"accelerator": "gpu"})
1824
+ ..
1825
+
1815
1826
Parameters
1816
1827
----------
1817
1828
path
@@ -1825,11 +1836,15 @@ def load(
1825
1836
for more information about the supported kwargs.
1826
1837
**kwargs
1827
1838
Additional kwargs for PyTorch Lightning's :func:`LightningModule.load_from_checkpoint()` method,
1839
+ such as ``map_location`` to load the model onto a different device than the one on which it was saved.
1828
1840
For more information, read the `official documentation <https://pytorch-lightning.readthedocs.io/en/stable/
1829
1841
common/lightning_module.html#load-from-checkpoint>`_.
1830
1842
"""
1831
1843
# load the base TorchForecastingModel (does not contain the actual PyTorch LightningModule)
1832
- model : TorchForecastingModel = ForecastingModel .load (path )
1844
+ with open (path , "rb" ) as fin :
1845
+ model : TorchForecastingModel = torch .load (
1846
+ fin , weights_only = False , map_location = kwargs .get ("map_location" , None )
1847
+ )
1833
1848
1834
1849
# if a checkpoint was saved, we also load the PyTorch LightningModule from checkpoint
1835
1850
path_ptl_ckpt = path + ".ckpt"
@@ -1927,7 +1942,9 @@ def load_from_checkpoint(
1927
1942
f"Could not find base model save file `{ INIT_MODEL_NAME } ` in { model_dir } ." ,
1928
1943
logger ,
1929
1944
)
1930
- model : TorchForecastingModel = ForecastingModel .load (base_model_path )
1945
+ model : TorchForecastingModel = torch .load (
1946
+ base_model_path , weights_only = False , map_location = kwargs .get ("map_location" )
1947
+ )
1931
1948
1932
1949
# load PyTorch LightningModule from checkpoint
1933
1950
# if file_name is None, find the path of the best or most recent checkpoint in savepath
@@ -2093,7 +2110,12 @@ def load_weights_from_checkpoint(
2093
2110
)
2094
2111
2095
2112
# updating model attributes before self._init_model() which create new tfm ckpt
2096
- tfm_save : TorchForecastingModel = ForecastingModel .load (tfm_save_file_path )
2113
+ with open (tfm_save_file_path , "rb" ) as tfm_save_file :
2114
+ tfm_save : TorchForecastingModel = torch .load (
2115
+ tfm_save_file ,
2116
+ weights_only = False ,
2117
+ map_location = kwargs .get ("map_location" , None ),
2118
+ )
2097
2119
2098
2120
# encoders are necessary for direct inference
2099
2121
self .encoders , self .add_encoders = self ._load_encoders (
0 commit comments