Skip to content

Fix: Respect required=False in add_lightning_class_args when subclass_mode=False #20856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def add_lightning_class_args(
Args:
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
subclass_mode: Whether to allow any subclass of the given class.
required: Whether the argument group is required.

Returns:
Expand All @@ -145,15 +145,26 @@ def add_lightning_class_args(
):
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)

# NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key
if not subclass_mode and not required:
config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key
config = getattr(self, "config", {})
if not any(k.startswith(config_path) for k in config):
# Skip adding class arguments
return []

if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)

return self.add_class_arguments(
lightning_class,
nested_key,
fail_untyped=False,
instantiate=not issubclass(lightning_class, Trainer),
sub_configs=True,
)

raise MisconfigurationException(
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
"Trainer, LightningModule, LightningDataModule, or Callback."
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,3 +1789,31 @@ def test_lightning_cli_with_args_given(args):
def test_lightning_cli_args_and_sys_argv_warning():
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "):
LightningCLI(TestModel, run=False, args=["--model.foo=789"])


def test_add_class_args_required_false_skips_addition(tmp_path):
from lightning.pytorch import callbacks, cli

class FooCheckpoint(callbacks.ModelCheckpoint):
def __init__(self, dirpath, *args, **kwargs):
super().__init__(dirpath, *args, **kwargs)

class SimpleModel:
def __init__(self):
pass

class SimpleDataModule:
def __init__(self):
pass

class FooCLI(cli.LightningCLI):
def __init__(self):
super().__init__(
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
)

def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)

# Expectation: No error raised even though FooCheckpoint requires `dirpath`
FooCLI()
Loading