Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 59 additions & 8 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,16 @@ def __init__(

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
self.setup_parser(run, main_kwargs, subparser_kwargs)

ckpt_path_present = self._check_ckpt_path(args)
if ckpt_path_present:
self._relax_model_requirements()

self.parse_arguments(self.parser, args)
self._parse_ckpt_path()
if ckpt_path_present:
self._enforce_model_requirements()

self._parse_ckpt_path(self.parser, args)

self.subcommand = self.config["subcommand"] if run else None

Expand All @@ -420,6 +428,37 @@ def __init__(
if self.subcommand is not None:
self._run_subcommand(self.subcommand)

def _check_ckpt_path(self, args: ArgsType) -> bool:
"""Check if --ckpt_path is present in arguments."""
argv = sys.argv[1:] if args is None else args

if not isinstance(argv, list):
return False

return any(arg.startswith("--ckpt_path") for arg in argv)

def _relax_model_requirements(self) -> None:
self._removed_requirements: dict[Any, list[str]] = {}
subcommands = self.parser._subcommands_action

if subcommands is None:
return

for subparser in subcommands._name_parser_map.values():
self._removed_requirements[subparser] = []

if "model" in subparser.required_args:
subparser.required_args.remove("model")
self._removed_requirements[subparser].append("model")

def _enforce_model_requirements(self) -> None:
for subparser, removed_args in self._removed_requirements.items():
for arg_name in removed_args:
if arg_name not in subparser.required_args:
subparser.required_args.add(arg_name)

del self._removed_requirements

def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
subcommand_names = self.subcommands().keys()
main_kwargs = {k: v for k, v in parser_kwargs.items() if k not in subcommand_names}
Expand Down Expand Up @@ -560,9 +599,19 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
def _parse_ckpt_path(self, parser: LightningArgumentParser, args: ArgsType) -> None:
"""Parses the checkpoint path, loads hyperparameters, and injects them as new defaults.

If `ckpt_path` is provided, this method:
1. Loads hyperparameters from the checkpoint file.
2. Sets them as new default values for the specific subcommand parser.
3. Re-runs argument parsing.

This ensures the correct priority order:
__init__ defaults < ckpt hparams < cfg file < CLI args

"""
if not self.config.get("subcommand") or parser._subcommands_action is None:
return
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
if ckpt_path and Path(ckpt_path).is_file():
Expand All @@ -576,12 +625,14 @@ def _parse_ckpt_path(self) -> None:
"class_path": hparams.pop("_class_path"),
"dict_kwargs": hparams,
}
hparams = {self.config.subcommand: {"model": hparams}}
hparams = {"model": hparams}
try:
self.config = self.parser.parse_object(hparams, self.config)
except SystemExit:
subparser = parser._subcommands_action._name_parser_map[self.config.subcommand]
subparser.set_defaults(hparams)
except KeyError as ex:
sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
raise
parser.error(str(ex), ex)
self.parse_arguments(parser, args)
Comment thread
Berezin-Leonid marked this conversation as resolved.

def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
Expand Down
163 changes: 163 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,24 @@ def add_arguments_to_parser(self, parser):
assert cli.model.extra is True
assert cli.model.layer.out_features == 4

# check that empty ckpt raising error parsing
garbage_ckpt_path = Path(cli.trainer.log_dir) / "garbage.ckpt"
torch.save(
{
"state_dict": {},
"hyper_parameters": {"useless_param": 42, "broken_conf": True},
},
garbage_ckpt_path,
)

cli_args = ["predict", f"--ckpt_path={garbage_ckpt_path}"]

err = StringIO()
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit):
CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)
output = err.getvalue()
assert 'error: Parser key "model"' in output


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
Expand Down Expand Up @@ -588,6 +606,151 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai
assert isinstance(cli.model.submodule2, BoringModel)


class DemoModel(BoringModel):
def __init__(
self,
num_classes: int = 10,
learning_rate: float = 0.01,
dropout: float = 0.1,
backbone_hidden_dim: int = 128,
):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.learning_rate = learning_rate
self.dropout = dropout
self.backbone_hidden_dim = backbone_hidden_dim


def test_lightning_cli_args_override_checkpoint_hparams(cleandir):
"""
Check priority: ckpt hparams < CLI Args

Scenario:
1. Save checkpoint with specific `dropout`, `backbone_hidden_dim`
2. Load checkpoint, but explicitly override 'learning_rate` and `backbone_hidden_dim`
3. Verify that `num_classes` and `dropout` is restored from ckpt,
but `learning_rate` and `backbone_hidden_dim` is update from the CLI arg.
"""

# --- Phase 1: Create a base checkpoint
orig_hidden_dim = 256
orig_dropout = 0.5

save_args = [
"fit",
f"--model.dropout={orig_dropout}",
f"--model.backbone_hidden_dim={orig_hidden_dim}",
"--trainer.devices=1",
"--trainer.max_steps=1",
"--trainer.limit_train_batches=1",
"--trainer.limit_val_batches=1",
"--trainer.default_root_dir=./",
]

with mock.patch("sys.argv", ["any.py"] + save_args):
cli = LightningCLI(DemoModel)

checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt")))

# --- Phase 2: Predict with CLI overrides ---
new_lr = 0.123
new_hidden_dim = 512
override_args = [
"predict",
"--trainer.devices=1",
f"--model.learning_rate={new_lr}",
f"--model.backbone_hidden_dim={new_hidden_dim}",
f"--ckpt_path={checkpoint_path}",
]

with mock.patch("sys.argv", ["any.py"] + override_args):
new_cli = LightningCLI(DemoModel)

# --- Phase 3: Assertions ---
assert new_cli.model.learning_rate == new_lr, (
f"CLI override failed! Expected LR {new_lr}, got {new_cli.model.learning_rate}"
)

assert new_cli.model.dropout == orig_dropout, (
f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}"
)
assert new_cli.model.backbone_hidden_dim == new_hidden_dim, (
f"CLI override failed! Expected dim {new_hidden_dim}, got {new_cli.model.backbone_hidden_dim}"
)


def test_lightning_cli_config_priority_over_checkpoint_hparams(cleandir):
"""
Test the full priority hierarchy:
ckpt hparams < Config < CLI Args

Scenario:
1. Save checkpoint with specific `num_classes`, `learning_rate` and `dropout`
2. Load checkpoint, but explicitly override:
num_classes by: config, cli
learning_rate: config
3. Verify that:
num_classes from: CLI Args
learning_rate: Config
dropout: dropout

"""
orig_classes = 60_000
orig_lr = 1e-4
orig_dropout = 0.01
save_args = [
"fit",
f"--model.num_classes={orig_classes}",
f"--model.learning_rate={orig_lr}",
f"--model.dropout={orig_dropout}",
"--trainer.devices=1",
"--trainer.max_steps=1",
"--trainer.limit_train_batches=1",
"--trainer.limit_val_batches=1",
"--trainer.default_root_dir=./",
]

with mock.patch("sys.argv", ["any.py"] + save_args):
cli = LightningCLI(DemoModel)

cfg_lr = 2e-5
config = f"""
model:
num_classes: 1000
learning_rate: {cfg_lr}
"""

config_path = Path("config.yaml")
config_path.write_text(config)

checkpoint_path = str(next(Path(cli.trainer.default_root_dir).rglob("*.ckpt")))

cli_classes = 1024
cli_args = [
"predict",
f"--config={config_path}",
f"--model.num_classes={cli_classes}",
"--trainer.devices=1",
f"--ckpt_path={checkpoint_path}",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
new_cli = LightningCLI(DemoModel)

assert new_cli.model.num_classes == cli_classes, (
f"CLI priority failed! Expected num_classes {cli_classes}, got {new_cli.model.num_classes}"
)
assert new_cli.model.learning_rate == cfg_lr, (
f"Config override failed! Expected LR {cfg_lr}, got {new_cli.model.learning_rate}"
)
assert new_cli.model.dropout == orig_dropout, (
f"Checkpoint restoration failed! Expected dropout {orig_dropout}, got {new_cli.model.dropout}"
)
assert new_cli.model.backbone_hidden_dim == 128, (
f"Checkpoint restoration failed! Expected dim {128}, got {new_cli.model.backbone_hidden_dim}"
)


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE))
def test_lightning_cli_torch_modules(cleandir):
class TestModule(BoringModel):
Expand Down
Loading