@@ -161,8 +161,23 @@ def __init__(
161
161
Number of epochs to wait before evaluating the validation loss (if a validation
162
162
``TimeSeries`` is passed to the :func:`fit()` method).
163
163
torch_device_str
164
- Optionally, a string indicating the torch device to use. (default: "cuda:0" if a GPU
165
- is available, otherwise "cpu")
164
+ Optionally, a string indicating the torch device to use. By default, ``torch_device_str`` is ``None``
165
+ which will run on CPU. Set it to ``"cuda"`` to use all available GPUs or ``"cuda:i"`` to only use
166
+ GPU ``i`` (``i`` must be an integer). For example "cuda:0" will use the first GPU only.
167
+
168
+ .. deprecated:: v0.17.0
169
+ ``torch_device_str`` has been deprecated in v0.17.0 and will be removed in a future version.
170
+ Instead, specify this with keys ``"accelerator", "gpus", "auto_select_gpus"`` in your
171
+ ``pl_trainer_kwargs`` dict. Some examples for setting the devices inside the ``pl_trainer_kwargs``
172
+ dict:
173
+
174
+ - ``{"accelerator": "cpu"}`` for CPU,
175
+ - ``{"accelerator": "gpu", "gpus": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
176
+ - ``{"accelerator": "gpu", "gpus": -1, "auto_select_gpus": True}`` to use all available GPUS.
177
+
178
+ For more info, see here:
179
+ https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and
180
+ https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#select-gpu-devices
166
181
force_reset
167
182
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
168
183
be discarded).
@@ -336,7 +351,9 @@ def __init__(
336
351
self .pl_module_params : Optional [Dict ] = None
337
352
338
353
@staticmethod
339
- def _extract_torch_devices (torch_device_str ) -> Tuple [str , Optional [list ], bool ]:
354
+ def _extract_torch_devices (
355
+ torch_device_str ,
356
+ ) -> Tuple [str , Optional [Union [list , int ]], bool ]:
340
357
"""This method handles the deprecated `torch_device_str` and should be removed in a future Darts version.
341
358
342
359
Returns
@@ -346,7 +363,7 @@ def _extract_torch_devices(torch_device_str) -> Tuple[str, Optional[list], bool]
346
363
"""
347
364
348
365
if torch_device_str is None :
349
- return "auto " , None , False
366
+ return "cpu " , None , False
350
367
351
368
device_warning = (
352
369
"`torch_device_str` is deprecated and will be removed in a coming Darts version. For full support "
@@ -372,13 +389,13 @@ def _extract_torch_devices(torch_device_str) -> Tuple[str, Optional[list], bool]
372
389
373
390
gpus = None
374
391
auto_select_gpus = False
375
- accelerator = device_split [0 ]
376
- if len (device_split ) == 2 and accelerator == "cuda" :
392
+ accelerator = "gpu" if device_split [0 ] == "cuda" else device_split [0 ]
393
+
394
+ if len (device_split ) == 2 and accelerator == "gpu" :
377
395
gpus = device_split [1 ]
378
396
gpus = [int (gpus )]
379
397
elif len (device_split ) == 1 :
380
- if accelerator == "cuda" :
381
- accelerator = "gpu"
398
+ if accelerator == "gpu" :
382
399
gpus = - 1
383
400
auto_select_gpus = True
384
401
else :
@@ -389,9 +406,29 @@ def _extract_torch_devices(torch_device_str) -> Tuple[str, Optional[list], bool]
389
406
)
390
407
return accelerator , gpus , auto_select_gpus
391
408
392
- @staticmethod
393
- def _extract_torch_model_params (** kwargs ):
409
+ @classmethod
410
+ def _validate_model_params (cls , ** kwargs ):
411
+ """validate that parameters used at model creation are part of :class:`TorchForecastingModel`,
412
+ :class:`PLForecastingModule` or cls __init__ methods.
413
+ """
414
+ valid_kwargs = (
415
+ set (inspect .signature (TorchForecastingModel .__init__ ).parameters .keys ())
416
+ | set (inspect .signature (PLForecastingModule .__init__ ).parameters .keys ())
417
+ | set (inspect .signature (cls .__init__ ).parameters .keys ())
418
+ )
419
+
420
+ invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs ]
421
+
422
+ raise_if (
423
+ len (invalid_kwargs ) > 0 ,
424
+ f"Invalid model creation parameters. Model `{ cls .__name__ } ` has no args/kwargs `{ invalid_kwargs } `" ,
425
+ logger = logger ,
426
+ )
427
+
428
+ @classmethod
429
+ def _extract_torch_model_params (cls , ** kwargs ):
394
430
"""extract params from model creation to set up TorchForecastingModels"""
431
+ cls ._validate_model_params (** kwargs )
395
432
get_params = list (
396
433
inspect .signature (TorchForecastingModel .__init__ ).parameters .keys ()
397
434
)
@@ -619,6 +656,13 @@ def fit(
619
656
override Darts' default trainer.
620
657
verbose
621
658
Optionally, whether to print progress.
659
+
660
+ .. deprecated:: v0.17.0
661
+ ``verbose`` has been deprecated in v0.17.0 and will be removed in a future version.
662
+ Instead, control verbosity with PyTorch Lightning Trainer parameters ``enable_progress_bar``,
663
+ ``progress_bar_refresh_rate`` and ``enable_model_summary`` in the ``pl_trainer_kwargs`` dict
664
+ at model creation. See for example here:
665
+ https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#enable-progress-bar
622
666
epochs
623
667
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
624
668
was provided to the model constructor.
@@ -764,6 +808,13 @@ def fit_from_dataset(
764
808
override Darts' default trainer.
765
809
verbose
766
810
Optionally, whether to print progress.
811
+
812
+ .. deprecated:: v0.17.0
813
+ ``verbose`` has been deprecated in v0.17.0 and will be removed in a future version.
814
+ Instead, control verbosity with PyTorch Lightning Trainer parameters ``enable_progress_bar``,
815
+ ``progress_bar_refresh_rate`` and ``enable_model_summary`` in the ``pl_trainer_kwargs`` dict
816
+ at model creation. See for example here:
817
+ https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#enable-progress-bar
767
818
epochs
768
819
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
769
820
was provided to the model constructor.
@@ -965,6 +1016,13 @@ def predict(
965
1016
Size of batches during prediction. Defaults to the models' training ``batch_size`` value.
966
1017
verbose
967
1018
Optionally, whether to print progress.
1019
+
1020
+ .. deprecated:: v0.17.0
1021
+ ``verbose`` has been deprecated in v0.17.0 and will be removed in a future version.
1022
+ Instead, control verbosity with PyTorch Lightning Trainer parameters ``enable_progress_bar``,
1023
+ ``progress_bar_refresh_rate`` and ``enable_model_summary`` in the ``pl_trainer_kwargs`` dict
1024
+ at model creation. See for example here:
1025
+ https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#enable-progress-bar
968
1026
n_jobs
969
1027
The number of jobs to run in parallel. ``-1`` means using all processors. Defaults to ``1``.
970
1028
roll_size
@@ -1084,7 +1142,14 @@ def predict_from_dataset(
1084
1142
batch_size
1085
1143
Size of batches during prediction. Defaults to the models ``batch_size`` value.
1086
1144
verbose
1087
- Shows the progress bar for batch predicition. Off by default.
1145
+ Optionally, whether to print progress.
1146
+
1147
+ .. deprecated:: v0.17.0
1148
+ ``verbose`` has been deprecated in v0.17.0 and will be removed in a future version.
1149
+ Instead, control verbosity with PyTorch Lightning Trainer parameters ``enable_progress_bar``,
1150
+ ``progress_bar_refresh_rate`` and ``enable_model_summary`` in the ``pl_trainer_kwargs`` dict
1151
+ at model creation. See for example here:
1152
+ https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#enable-progress-bar
1088
1153
n_jobs
1089
1154
The number of jobs to run in parallel. ``-1`` means using all processors. Defaults to ``1``.
1090
1155
roll_size
0 commit comments