Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions agentlightning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
_C4 = TypeVar("_C4", bound=CliConfigurable)


_DEFAULT_SENTINEL = object()


# Custom type for CLI arguments that can be string or None
def nullable_str(value: str) -> str | None:
"""Converts specific string values (case-insensitive) to None, otherwise returns the string."""
Expand Down Expand Up @@ -174,6 +177,7 @@ def _add_argument_for_parameter(
param_obj: inspect.Parameter,
dest_name: str,
resolved_param_annotation: Any = None,
provided_default: Any = _DEFAULT_SENTINEL,
) -> None:
"""Configures and adds a single CLI argument for an __init__ parameter."""
if resolved_param_annotation is None:
Expand All @@ -189,16 +193,24 @@ def _add_argument_for_parameter(
has_init_default = param_obj.default is not inspect.Parameter.empty
init_default_value = param_obj.default if has_init_default else None

argparse_kwargs = _determine_argparse_type_and_nargs(core_type if is_list else param_type_annotation, is_list)
argparse_kwargs = _determine_argparse_type_and_nargs(
core_type if is_list else param_type_annotation, is_list
)

if has_init_default:
if provided_default is not _DEFAULT_SENTINEL:
argparse_kwargs["default"] = provided_default
elif has_init_default:
argparse_kwargs["default"] = init_default_value
elif is_overall_optional: # Parameter is Optional (e.g. Optional[int]) and no explicit default in __init__
argparse_kwargs["default"] = None # So, if not provided on CLI, it becomes None.

argparse_kwargs["help"] = _build_help_string(cls.__name__, param_name, core_type, is_overall_optional, is_list)

if not has_init_default and not is_overall_optional: # Required if no __init__ default AND not Optional
if (
provided_default is _DEFAULT_SENTINEL
and not has_init_default
and not is_overall_optional
): # Required if no defaults are available AND not Optional
argparse_kwargs["required"] = True
if "default" in argparse_kwargs: # Should not happen if logic is correct
del argparse_kwargs["default"]
Expand All @@ -211,6 +223,7 @@ def _add_arguments_for_class(
parser: argparse.ArgumentParser,
cls: Type[CliConfigurable],
class_arg_configs_maps: Dict[Type[CliConfigurable], Dict[str, str]], # Maps cls to {param_name: dest_name}
provided_defaults: Dict[str, Any] | None = None,
) -> None:
"""Adds all relevant CLI arguments for a given class by processing its __init__ parameters."""
cls_name_lower = cls.__name__.lower()
Expand Down Expand Up @@ -240,7 +253,18 @@ def _add_arguments_for_class(

# Use the resolved hint if available, otherwise fallback to param_obj.annotation (which might be a string)
actual_param_annotation = resolved_hints.get(param_name, param_obj.annotation)
_add_argument_for_parameter(parser, cls, param_name, param_obj, dest_name, actual_param_annotation)
default_override = None
if provided_defaults and param_name in provided_defaults:
default_override = provided_defaults[param_name]
_add_argument_for_parameter(
parser,
cls,
param_name,
param_obj,
dest_name,
actual_param_annotation,
default_override if provided_defaults and param_name in provided_defaults else _DEFAULT_SENTINEL,
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition provided_defaults and param_name in provided_defaults is duplicated from lines 257-258. Consider extracting this logic into a variable to avoid repetition and improve maintainability.

Suggested change
default_override if provided_defaults and param_name in provided_defaults else _DEFAULT_SENTINEL,
default_override if has_default_override else _DEFAULT_SENTINEL,

Copilot uses AI. Check for mistakes.
)


def _create_argument_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -294,18 +318,42 @@ def _instantiate_classes(


@overload
def lightning_cli(cls1: Type[_C1]) -> _C1: ...
def lightning_cli(cls1: Type[_C1], *, defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None) -> _C1: ...
@overload
def lightning_cli(cls1: Type[_C1], cls2: Type[_C2]) -> Tuple[_C1, _C2]: ...
def lightning_cli(
cls1: Type[_C1],
cls2: Type[_C2],
*,
defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None,
) -> Tuple[_C1, _C2]: ...
@overload
def lightning_cli(cls1: Type[_C1], cls2: Type[_C2], cls3: Type[_C3]) -> Tuple[_C1, _C2, _C3]: ...
def lightning_cli(
cls1: Type[_C1],
cls2: Type[_C2],
cls3: Type[_C3],
*,
defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None,
) -> Tuple[_C1, _C2, _C3]: ...
@overload
def lightning_cli(cls1: Type[_C1], cls2: Type[_C2], cls3: Type[_C3], cls4: Type[_C4]) -> Tuple[_C1, _C2, _C3, _C4]: ...
def lightning_cli(
cls1: Type[_C1],
cls2: Type[_C2],
cls3: Type[_C3],
cls4: Type[_C4],
*,
defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None,
) -> Tuple[_C1, _C2, _C3, _C4]: ...
@overload # Fallback for more than 4 or a dynamic number of classes
def lightning_cli(*classes: Type[CliConfigurable]) -> Tuple[CliConfigurable, ...]: ...
def lightning_cli(
*classes: Type[CliConfigurable],
defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None,
) -> Tuple[CliConfigurable, ...]: ...


def lightning_cli(*classes: Type[CliConfigurable]) -> CliConfigurable | Tuple[CliConfigurable, ...]:
def lightning_cli(
*classes: Type[CliConfigurable],
defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None,
) -> CliConfigurable | Tuple[CliConfigurable, ...]:
"""
Parses command-line arguments to configure and instantiate provided CliConfigurable classes.

Expand All @@ -325,7 +373,8 @@ def lightning_cli(*classes: Type[CliConfigurable]) -> CliConfigurable | Tuple[Cl
class_arg_configs_maps: Dict[Type[CliConfigurable], Dict[str, str]] = {}

for cls in classes:
_add_arguments_for_class(parser, cls, class_arg_configs_maps)
defaults_for_cls = defaults.get(cls) if defaults else None
_add_arguments_for_class(parser, cls, class_arg_configs_maps, defaults_for_cls)

parsed_args = parser.parse_args() # Uses sys.argv[1:] by default

Expand Down
22 changes: 20 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,13 @@ def __init__(self, param):


# --- Integration Tests for lightning_cli ---
def run_lightning_cli(classes_to_configure, cli_args_list):
def run_lightning_cli(classes_to_configure, cli_args_list, defaults=None):
"""Helper to run lightning_cli with mocked sys.argv."""
if defaults is None:
defaults = {}
# Prepend a dummy program name to cli_args_list for sys.argv
with mock.patch.object(sys, "argv", ["test_program.py"] + cli_args_list):
result = config.lightning_cli(*classes_to_configure)
result = config.lightning_cli(*classes_to_configure, defaults=defaults)
if not isinstance(result, tuple):
return (result,)
return result
Expand Down Expand Up @@ -595,3 +597,19 @@ def test_lightning_cli_optional_no_default_behavior():
# Provided with a value
(cfg3,) = run_lightning_cli([OptionalNoDefaultConfig], ["--optionalnodefaultconfig.opt-val", "ActualValue"])
assert cfg3.opt_val == "ActualValue"


def test_lightning_cli_programmatic_defaults_override_required():
"""Tests that defaults passed to lightning_cli satisfy required args and can be overridden."""
defaults = {SimpleConfig: {"name": "Provided"}}

# No CLI args, should use provided default for required 'name'
(cfg1,) = run_lightning_cli([SimpleConfig], [], defaults=defaults)
assert cfg1.name == "Provided"
assert cfg1.value == 10

# CLI arg should override provided default
(cfg2,) = run_lightning_cli(
[SimpleConfig], ["--simpleconfig.name", "FromCLI"], defaults=defaults
)
assert cfg2.name == "FromCLI"
Loading