-
Notifications
You must be signed in to change notification settings - Fork 40
Description
Motivation
I would like to use Optuna w/ the PytorchLightningPruningCallback in a code-base with a pre-2.0 version of Pytorch Lightning. As it stands, I need to vendor the callback to support using the pytorch_lightning package instead of lightning.pytorch.
The callback imports lightning.pytorch which is the 2.0+ approved method for doing so. Using the pytorch-lightning package (which has feature parity but is the backwards-compatible method) does not work. There is no fundamental reason, as far as Optuna is concerned, that the pruning callback couldn't try both packages.
Suggestion
Attempt to import lightning.pytorch. If that is not successful, try to import pytorch_lightning. If that doesn't work, give up.
Current code:
with optuna._imports.try_import() as _imports:
import lightning.pytorch as pl
from lightning.pytorch import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
if not _imports.is_successful():
Callback = object # type: ignore[assignment, misc] # NOQA[F811]
LightningModule = object # type: ignore[assignment, misc] # NOQA[F811]
Trainer = object # type: ignore[assignment, misc] # NOQA[F811]Suggested Code:
with optuna._imports.try_import() as _imports:
try:
import lightning.pytorch as pl
from lightning.pytorch import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
except ImportError:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
if not _imports.is_successful():
Callback = object # type: ignore[assignment, misc] # NOQA[F811]
LightningModule = object # type: ignore[assignment, misc] # NOQA[F811]
Trainer = object # type: ignore[assignment, misc] # NOQA[F811]Additional context (optional)
optuna-integration version: 3.6.0.
The diff suggests there were no material changes to the PytorchLightning compatibility between v3.6.0 and v4.0.0.
#137 previously suggested changing from lightning.pytorch to pytorch_lightning, which was rejected. I think supporting both packages is reasonable, however.