Skip to content

Commit 58446ee

Browse files
seayang-nvclaude
andauthored
fix(configurator): add AutoParamType so --flag=auto works for Auto*Param (#353)
<!-- SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. --> <!-- SPDX-License-Identifier: Apache-2.0 --> <!-- Thank you for contributing to Safe Synthesizer! --> # Summary - Introduces AutoParamType, a custom Click parameter type that accepts the "auto" sentinel string or delegates to a wrapped base type (click.INT, click.FLOAT, click.BOOL) for validation. This replaces the previous behavior where Auto*Param union fields (Literal["auto"] | int, etc.) fell through to click.STRING, which gave no numeric validation and displayed an unhelpful TEXT type in --help. - Updates _click_type() to detect Literal["auto"] | <numeric> unions and return AutoParamType(base_type) instead of click.STRING, so --help now shows descriptive types like integer|auto, float|auto, and boolean|auto. - Pure Literal["auto"] | str unions (no numeric component) still correctly resolve to click.STRING. ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make check` or via prek validation. - [x] `make test` passes locally - [x] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) - [ ] GPU CI status check passes -- comment `/sync` on this PR to trigger a run (auto-triggers on ready-for-review) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [ ] New or updated tests for any fix or new behavior - [ ] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #159 --------- Signed-off-by: Sean Yang <seayang@nvidia.com> Signed-off-by: seayang <seayang@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2a138d5 commit 58446ee

3 files changed

Lines changed: 282 additions & 10 deletions

File tree

src/nemo_safe_synthesizer/configurator/pydantic_click_options.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
from pydantic import BaseModel
2828
from pydantic.fields import FieldInfo
2929

30-
__all__ = ["pydantic_options", "parse_overrides"]
30+
from ..config.types import AUTO_STR
31+
32+
__all__ = ["pydantic_options", "parse_overrides", "AutoParamType"]
3133

3234

3335
def parse_overrides(values: dict[str, Any] | None = None, field_sep: str = "__") -> dict[str, Any]:
@@ -128,26 +130,103 @@ def _nullable_model_arg(union_args: tuple) -> type[BaseModel] | None:
128130
]
129131

130132

133+
class AutoParamType(click.ParamType):
134+
"""A Click type that accepts the sentinel ``AUTO_STR`` or a base numeric/bool value.
135+
136+
Used for ``Auto*Param`` fields (``AutoIntParam``, ``AutoFloatParam``,
137+
``AutoBoolParam``) so that ``--flag auto`` and ``--flag 2`` both work.
138+
The ``--help`` display shows ``integer|auto``, ``float|auto``, or
139+
``boolean|auto`` instead of the generic ``TEXT`` label.
140+
141+
Args:
142+
base_type: The underlying Click type (``click.INT``, ``click.FLOAT``,
143+
or ``click.BOOL``) used to parse and validate non-``AUTO_STR`` values.
144+
"""
145+
146+
def __init__(self, base_type: click.ParamType) -> None:
147+
self.base_type = base_type
148+
self.name = f"{base_type.name}|{AUTO_STR}"
149+
150+
def convert(
151+
self,
152+
value: str,
153+
param: click.Parameter | None,
154+
ctx: click.Context | None,
155+
) -> str | int | float | bool:
156+
"""Convert the raw CLI value to ``AUTO_STR`` or the base numeric/bool type.
157+
158+
The ``value`` parameter is typed ``str`` to match the parent
159+
``click.ParamType.convert`` stub signature; in practice Click may also
160+
pass through default values of any type, but the equality check and
161+
delegated ``base_type.convert`` both handle that correctly.
162+
163+
Args:
164+
value: Raw value from the CLI or the option default.
165+
param: The Click parameter object (passed through to the base type).
166+
ctx: The Click context (passed through to the base type).
167+
168+
Returns:
169+
``AUTO_STR`` if ``value`` equals it, otherwise the result of
170+
delegating to ``self.base_type.convert()``.
171+
"""
172+
if value == AUTO_STR:
173+
return AUTO_STR
174+
return self.base_type.convert(value, param, ctx) # type: ignore[return-value]
175+
176+
131177
def _has_string_literal(args: set) -> bool:
132178
"""Check if any member is a ``Literal`` containing a string value."""
133179
return any(get_origin(a) is Literal and any(isinstance(v, str) for v in get_args(a)) for a in args)
134180

135181

182+
def _is_auto_only_literal_union(args: set) -> bool:
183+
"""Check that the union's string-valued ``Literal`` members are exactly ``{AUTO_STR}``.
184+
185+
Returns ``True`` only if every string-valued Literal member contributes the
186+
single value ``AUTO_STR`` -- e.g. ``Literal["auto"] | int``. Returns
187+
``False`` for unions like ``Literal["disabled"] | int`` or
188+
``Literal["auto", "manual"] | int``, where ``AutoParamType`` would
189+
silently reject the non-``"auto"`` sentinels at parse time.
190+
"""
191+
string_values: set[str] = set()
192+
for a in args:
193+
if get_origin(a) is Literal:
194+
for v in get_args(a):
195+
if isinstance(v, str):
196+
string_values.add(v)
197+
return string_values == {AUTO_STR}
198+
199+
136200
def _click_type(annotation: Any) -> click.ParamType:
137201
"""Map a Pydantic field annotation to a Click type.
138202
139203
Unwraps ``Annotated[T, ...]`` and ``T | None`` unions, then returns the
140-
widest Click type that covers any member of the union. String-valued
141-
``Literal`` members (e.g. ``Literal["auto"]``) force ``click.STRING``
142-
so Click won't reject the sentinel before Pydantic validates it.
143-
Falls back to ``click.STRING`` for unrecognized types.
204+
widest Click type that covers any member of the union. ``Auto*Param``
205+
fields (``Literal["auto"] | <numeric|bool>``) get an ``AutoParamType``
206+
wrapping the numeric/bool base so Click can validate non-sentinel values
207+
while still accepting the ``"auto"`` sentinel. Other string-valued
208+
``Literal`` members fall through to ``click.STRING`` so Click won't
209+
reject the sentinel before Pydantic validates it. Falls back to
210+
``click.STRING`` for unrecognized types.
144211
"""
145212
t = annotation
146213
if get_origin(t) is Annotated:
147214
t = get_args(t)[0]
148215
args = set(get_args(t)) if get_origin(t) in (Union, types.UnionType) else {t}
149216
args.discard(type(None))
150217
if _has_string_literal(args):
218+
# Auto*Param: Literal["auto"] | <numeric|bool> -- wrap in AutoParamType so
219+
# --help shows "integer|auto" instead of "TEXT" and Click validates the
220+
# numeric side before handing the value to Pydantic. The detection is
221+
# tightened to only fire when the literal values are exactly {AUTO_STR};
222+
# other sentinels (e.g. Literal["disabled"] | int) fall through to STRING
223+
# so AutoParamType doesn't reject them with a confusing error.
224+
if _is_auto_only_literal_union(args):
225+
for py_type, click_type in _CLICK_TYPE_PRIORITY:
226+
if py_type is str:
227+
continue
228+
if py_type in args:
229+
return AutoParamType(click_type)
151230
return click.STRING
152231
for py_type, click_type in _CLICK_TYPE_PRIORITY:
153232
if py_type in args:

tests/cli/test_run.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import nemo_safe_synthesizer.sdk.library_builder # noqa: F401 - ensure submodule is loaded for mock.patch
1313
from nemo_safe_synthesizer.cli.run import run
1414
from nemo_safe_synthesizer.cli.settings import CLISettings
15+
from nemo_safe_synthesizer.cli.utils import merge_overrides
1516
from nemo_safe_synthesizer.tooling import PreflightRenderContext
1617

1718
# =============================================================================
@@ -661,3 +662,113 @@ def test_generate_with_dataset_registry_calls_common_setup(
661662
call_kwargs = mock_common_setup.call_args.kwargs
662663
settings: CLISettings = call_kwargs["settings"]
663664
assert settings.dataset_registry == "./registry.yaml"
665+
666+
667+
class TestAutoParamCliOverrides:
668+
"""End-to-end tests for ``Auto*Param`` field CLI overrides (issue #159).
669+
670+
These tests drive the real ``run`` Click command and capture the parsed
671+
``synthesis_overrides`` reaching ``common_setup`` to verify each CLI value
672+
is parsed correctly. A separate test checks that the value also lands on
673+
the resolved ``SafeSynthesizerParameters`` object (using fields that pass
674+
through Pydantic validation unchanged -- some ``Auto*Param`` fields have
675+
model validators that resolve ``"auto"`` to a concrete value).
676+
"""
677+
678+
# Note: ``AutoBoolParam`` is defined in ``config.types`` but no field of
679+
# ``SafeSynthesizerParameters`` currently uses it (the only such field,
680+
# ``training.use_unsloth``, was removed when the Unsloth backend was
681+
# dropped). Bool conversion is exercised at the unit level by
682+
# ``tests/configurator/test_pydantic_click_options.py``.
683+
@pytest.mark.parametrize(
684+
"flag,raw_value,nested_path,expected",
685+
[
686+
# AutoIntParam / OptionalAutoInt
687+
("--training__rope_scaling_factor", "auto", ("training", "rope_scaling_factor"), "auto"),
688+
("--training__rope_scaling_factor", "2", ("training", "rope_scaling_factor"), 2),
689+
("--training__num_input_records_to_sample", "auto", ("training", "num_input_records_to_sample"), "auto"),
690+
("--training__num_input_records_to_sample", "100", ("training", "num_input_records_to_sample"), 100),
691+
("--data__max_sequences_per_example", "auto", ("data", "max_sequences_per_example"), "auto"),
692+
("--data__max_sequences_per_example", "5", ("data", "max_sequences_per_example"), 5),
693+
# AutoFloatParam
694+
("--privacy__delta", "auto", ("privacy", "delta"), "auto"),
695+
("--privacy__delta", "0.001", ("privacy", "delta"), 0.001),
696+
],
697+
)
698+
def test_auto_param_override_is_parsed_into_synthesis_overrides(
699+
self,
700+
cli_runner: CliRunner,
701+
dummy_csv: Path,
702+
patched_run_dependencies: dict,
703+
flag: str,
704+
raw_value: str,
705+
nested_path: tuple[str, ...],
706+
expected: object,
707+
):
708+
"""Auto*Param CLI flags accept ``"auto"`` and typed values, and reach ``common_setup``."""
709+
result = cli_runner.invoke(
710+
run,
711+
["--data-source", str(dummy_csv), flag, raw_value],
712+
catch_exceptions=False,
713+
)
714+
715+
assert result.exit_code == 0, result.output
716+
717+
mock_common_setup = patched_run_dependencies["common_setup"]
718+
mock_common_setup.assert_called_once()
719+
settings: CLISettings = mock_common_setup.call_args.kwargs["settings"]
720+
721+
# Click parses the raw CLI string (``"auto"`` stays a string, numbers
722+
# and bools are coerced) and ``parse_overrides`` reshapes the flat
723+
# kwargs into the nested overrides dict before it reaches settings.
724+
node: object = settings.synthesis_overrides
725+
for key in nested_path:
726+
assert isinstance(node, dict) and key in node, (
727+
f"missing {'.'.join(nested_path)} in synthesis_overrides: {settings.synthesis_overrides}"
728+
)
729+
node = node[key]
730+
assert node == expected
731+
assert type(node) is type(expected)
732+
733+
@pytest.mark.parametrize(
734+
"flag,raw_value,nested_path,expected",
735+
[
736+
# rope_scaling_factor, num_input_records_to_sample, and
737+
# privacy.delta pass through Pydantic validation unchanged for both
738+
# 'auto' and explicit values. max_sequences_per_example is excluded
739+
# because its model validator rewrites 'auto' to a concrete default
740+
# (10 with DP disabled, 1 with DP enabled).
741+
("--training__rope_scaling_factor", "auto", ("training", "rope_scaling_factor"), "auto"),
742+
("--training__rope_scaling_factor", "2", ("training", "rope_scaling_factor"), 2),
743+
("--training__num_input_records_to_sample", "auto", ("training", "num_input_records_to_sample"), "auto"),
744+
("--training__num_input_records_to_sample", "100", ("training", "num_input_records_to_sample"), 100),
745+
("--privacy__delta", "auto", ("privacy", "delta"), "auto"),
746+
("--privacy__delta", "0.001", ("privacy", "delta"), 0.001),
747+
],
748+
)
749+
def test_auto_param_override_reaches_params_object(
750+
self,
751+
cli_runner: CliRunner,
752+
dummy_csv: Path,
753+
patched_run_dependencies: dict,
754+
flag: str,
755+
raw_value: str,
756+
nested_path: tuple[str, ...],
757+
expected: object,
758+
):
759+
"""The parsed CLI value also lands on the validated ``SafeSynthesizerParameters`` object."""
760+
result = cli_runner.invoke(
761+
run,
762+
["--data-source", str(dummy_csv), flag, raw_value],
763+
catch_exceptions=False,
764+
)
765+
766+
assert result.exit_code == 0, result.output
767+
settings: CLISettings = patched_run_dependencies["common_setup"].call_args.kwargs["settings"]
768+
769+
params = merge_overrides(None, settings.synthesis_overrides)
770+
resolved: object = params
771+
for key in nested_path:
772+
resolved = getattr(resolved, key)
773+
assert resolved == expected
774+
assert type(resolved) is type(expected)

tests/configurator/test_pydantic_click_options.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from nemo_safe_synthesizer.config import SafeSynthesizerParameters
1515
from nemo_safe_synthesizer.configurator.pydantic_click_options import (
16+
AutoParamType,
1617
FlagParam,
1718
LeafParam,
1819
_click_type,
@@ -140,17 +141,98 @@ def test_click_type_str_or_int_returns_string():
140141
assert _click_type(str | int | None) == click.STRING
141142

142143

143-
def test_click_type_literal_auto_int_returns_string():
144-
"""Literal['auto'] | int must use STRING so Click accepts 'auto'."""
144+
def test_click_type_literal_auto_int_returns_auto_param_type():
145+
"""Literal['auto'] | int must use AutoParamType so Click accepts both 'auto' and integers."""
145146
from typing import Literal
146147

147-
assert _click_type(Literal["auto"] | int) == click.STRING
148+
result = _click_type(Literal["auto"] | int)
149+
assert isinstance(result, AutoParamType)
150+
assert result.base_type is click.INT
148151

149152

150-
def test_click_type_literal_auto_float_returns_string():
153+
def test_click_type_literal_auto_float_returns_auto_param_type():
151154
from typing import Literal
152155

153-
assert _click_type(Literal["auto"] | float) == click.STRING
156+
result = _click_type(Literal["auto"] | float)
157+
assert isinstance(result, AutoParamType)
158+
assert result.base_type is click.FLOAT
159+
160+
161+
def test_click_type_literal_auto_bool_returns_auto_param_type():
162+
from typing import Literal
163+
164+
result = _click_type(Literal["auto"] | bool)
165+
assert isinstance(result, AutoParamType)
166+
assert result.base_type is click.BOOL
167+
168+
169+
def test_click_type_literal_auto_str_returns_string():
170+
"""Literal['auto'] | str has no numeric side -- STRING is correct."""
171+
from typing import Literal
172+
173+
assert _click_type(Literal["auto"] | str) == click.STRING
174+
175+
176+
def test_click_type_literal_non_auto_int_returns_string():
177+
"""Literal['disabled'] | int must NOT use AutoParamType.
178+
179+
AutoParamType only accepts the AUTO_STR sentinel, so a sentinel like
180+
'disabled' would be rejected with a confusing 'not a valid integer' error
181+
even though Pydantic would accept it. Falling through to STRING keeps the
182+
sentinel routable to Pydantic for validation.
183+
"""
184+
from typing import Literal
185+
186+
assert _click_type(Literal["disabled"] | int) == click.STRING
187+
188+
189+
def test_click_type_literal_auto_plus_other_int_returns_string():
190+
"""Literal['auto', 'manual'] | int must NOT use AutoParamType (multiple sentinels)."""
191+
from typing import Literal
192+
193+
assert _click_type(Literal["auto", "manual"] | int) == click.STRING
194+
195+
196+
# ---------------------------------------------------------------------------
197+
# AutoParamType
198+
# ---------------------------------------------------------------------------
199+
200+
201+
def test_auto_param_type_accepts_auto_sentinel():
202+
assert AutoParamType(click.INT).convert("auto", None, None) == "auto"
203+
204+
205+
@pytest.mark.parametrize(
206+
"base_type,value,expected",
207+
[
208+
(click.INT, "2", 2),
209+
(click.INT, "100", 100),
210+
(click.FLOAT, "1.5", 1.5),
211+
(click.FLOAT, "0.001", 0.001),
212+
(click.BOOL, "true", True),
213+
(click.BOOL, "false", False),
214+
],
215+
)
216+
def test_auto_param_type_converts_numeric_values(base_type, value, expected):
217+
assert AutoParamType(base_type).convert(value, None, None) == expected
218+
219+
220+
def test_auto_param_type_rejects_non_auto_string():
221+
runner = CliRunner()
222+
223+
@click.command()
224+
@click.option("--val", type=AutoParamType(click.INT))
225+
def cmd(val):
226+
pass
227+
228+
result = runner.invoke(cmd, ["--val", "notanumber"])
229+
assert result.exit_code != 0
230+
231+
232+
def test_auto_param_type_name():
233+
assert AutoParamType(click.INT).name == "integer|auto"
234+
assert AutoParamType(click.FLOAT).name == "float|auto"
235+
assert AutoParamType(click.BOOL).name == "boolean|auto"
154236

155237

156238
def test_click_type_plain_int_returns_int():

0 commit comments

Comments
 (0)