Skip to content

LightningCLI instantiator receives values applied by instantiation links to set in hparams #20777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.9.0
omegaconf >=2.2.3, <2.4.0
hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.28.0, <=4.40.0
jsonargparse[signatures] >=4.39.0, <4.40.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests that fail is because the newer version of jsonargparse is not installed.

rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin"
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `save_hyperparameters` not working correctly with `LightningCLI` when there are parsing links applied on instantiation ([#20777](https://github.com/Lightning-AI/pytorch-lightning/pull/20777))


---
Expand Down
40 changes: 37 additions & 3 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def __init__(
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
load_from_checkpoint_support: bool = True,
) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
called / instantiated using a parsed configuration file and / or command line args.
Expand Down Expand Up @@ -360,6 +361,11 @@ def __init__(
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_configure_optimizers: Whether to automatically add default optimizer and lr_scheduler arguments.
load_from_checkpoint_support: Whether ``save_hyperparameters`` should save the original parsed
hyperparameters (instead of what ``__init__`` receives), such that it is possible for
``load_from_checkpoint`` to correctly instantiate classes even when using complex nesting and
dependency injection.

"""
self.save_config_callback = save_config_callback
Expand Down Expand Up @@ -389,7 +395,8 @@ def __init__(

self._set_seed()

self._add_instantiators()
if load_from_checkpoint_support:
self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()
self.after_instantiate_classes()
Expand Down Expand Up @@ -537,11 +544,14 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

def _add_instantiators(self) -> None:
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
Expand Down Expand Up @@ -792,12 +802,27 @@ def _get_module_type(value: Union[Callable, type]) -> type:
return value


def _set_dict_nested(data: dict, key: str, value: Any) -> None:
keys = key.split(".")
for k in keys[:-1]:
assert k in data, f"Expected key {key} to be in data"
data = data[k]
data[keys[-1]] = value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
def __call__(
self,
class_type: type[ModuleType],
*args: Any,
applied_instantiation_links: dict,
**kwargs: Any,
) -> ModuleType:
self.cli._dump_config()
hparams = self.cli.config_dump.get(self.key, {})
if "class_path" in hparams:
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
Expand All @@ -808,6 +833,15 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M
**hparams.get("init_args", {}),
**hparams.get("dict_kwargs", {}),
}
# get instantiation link target values from kwargs
for key, value in applied_instantiation_links.items():
if not key.startswith(f"{self.key}."):
continue
key = key[len(f"{self.key}.") :]
if key.startswith("init_args."):
key = key[len("init_args.") :]
_set_dict_nested(hparams, key, value)

with _given_hyperparameters_context(
hparams=hparams,
instantiator="lightning.pytorch.cli.instantiate_module",
Expand Down
79 changes: 74 additions & 5 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.batch_size = batch_size

Expand All @@ -561,35 +562,103 @@ def __init__(self, batch_size: int = 8):
self.num_classes = 5 # only available after instantiation


def test_lightning_cli_link_arguments():
def test_lightning_cli_link_arguments(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")

cli_args = ["--data.batch_size=12"]
cli_args = ["--data.batch_size=12", "--trainer.max_epochs=1"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)

assert cli.model.batch_size == 12
assert cli.model.num_classes == 5

class MyLightningCLI(LightningCLI):
cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 12, "num_classes": 5}

class MyLightningCLI2(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")

cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
cli_args[0] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
cli = MyLightningCLI2(
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
)

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5

cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 8, "num_classes": 5}


class CustomAdam(torch.optim.Adam):
def __init__(self, params, num_classes: Optional[int] = None, **kwargs):
super().__init__(params, **kwargs)


class DeepLinkTargetModel(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
return {"optimizer": optimizer}


def test_lightning_cli_link_arguments_subcommands_nested_target(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments(
"data.num_classes",
"model.init_args.optimizer.init_args.num_classes",
apply_on="instantiate",
)

cli_args = [
"fit",
"--data.batch_size=12",
"--trainer.max_epochs=1",
"--model=tests_pytorch.test_cli.DeepLinkTargetModel",
"--model.optimizer=tests_pytorch.test_cli.CustomAdam",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
DeepLinkTargetModel,
BoringDataModuleBatchSizeAndClasses,
subclass_mode_model=True,
auto_configure_optimizers=False,
)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

assert hparams["optimizer"]["class_path"] == "tests_pytorch.test_cli.CustomAdam"
assert hparams["optimizer"]["init_args"]["num_classes"] == 5


class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
Expand Down
Loading