Skip to content

Commit bb6aafc

Browse files
authored
Add support for custom Option and Argument subclasses (#16)
1 parent 199e825 commit bb6aafc

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

dataclass_click/dataclass_click.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,9 @@ def _patch_click_types(
185185
hint: typing.Type[Any]
186186
_, hint = _strip_optional(type_hints[key])
187187
if "type" not in annotation.kwargs:
188-
stub: click.core.Option | click.core.Argument
189-
if annotation.callable is click.option:
190-
stub = click.core.Option(annotation.args, **annotation.kwargs)
191-
if stub.is_flag:
192-
continue
193-
else:
194-
stub = click.core.Argument(annotation.args, **annotation.kwargs)
188+
stub = _build_stub(annotation)
189+
if hasattr(stub, "is_flag") and stub.is_flag:
190+
continue
195191
annotation.kwargs["type"] = _eval_type(key, hint, stub, complete_type_inferences)
196192

197193

@@ -218,6 +214,23 @@ def _eval_type(
218214
raise TypeError(f"Could not infer ParamType for {key} type {hint!r}. Explicitly annotate type=<type>")
219215

220216

217+
def _build_stub(annotation: _DelayedCall) -> click.core.Option | click.core.Argument:
218+
"""Stub uses click's parser rather than trying to second guess how click will behave."""
219+
if "cls" in annotation.kwargs:
220+
return annotation.kwargs["cls"](
221+
annotation.args,
222+
**{k: v
223+
for k, v in annotation.kwargs.items() if k != "cls"},
224+
)
225+
226+
if annotation.callable is click.option:
227+
return click.core.Option(annotation.args, **annotation.kwargs)
228+
elif annotation.callable is click.argument:
229+
return click.core.Argument(annotation.args, **annotation.kwargs)
230+
231+
raise TypeError(f"Unknown annotation callable {annotation.callable!r}")
232+
233+
221234
def _patch_required(arg_class: typing.Type[Arg], annotations: dict[str, _DelayedCall]) -> None:
222235
"""Default click option to required if typehint is not OPTIONAL
223236
@@ -228,15 +241,15 @@ def _patch_required(arg_class: typing.Type[Arg], annotations: dict[str, _Delayed
228241
:return: None, annotations are updated in place"""
229242
type_hints = typing.get_type_hints(arg_class)
230243
for key, annotation in annotations.items():
231-
hint: typing.Type[Any]
232-
is_optional, hint = _strip_optional(type_hints[key])
244+
is_optional, _ = _strip_optional(type_hints[key])
233245
if not is_optional:
234246
if annotation.callable is click.option:
235247
# If required or default set directly.
236248
if "required" not in annotation.kwargs and "default" not in annotation.kwargs:
237-
# Stub uses click's parser rather than trying to second guess how click will behave
249+
stub = _build_stub(annotation)
250+
if not isinstance(stub, click.core.Option):
251+
raise TypeError(f"Expected click.core.Option or its subclass, got {stub!r}")
238252
# If click would imply is_flag or multiple
239-
stub = click.core.Option(annotation.args, **annotation.kwargs)
240253
if not stub.is_flag and not stub.multiple:
241254
annotation.kwargs["required"] = True
242255

tests/mypy_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def main_d(a: Config, b: Config2):
6565
def main_e(a: Config, b: Config2):
6666
...
6767

68+
6869
@click.command()
6970
@dataclass_click(Config2)
7071
@dataclass_click(Config)

tests/test_end_to_end.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,57 @@ def test_mypy_check_loads():
315315
"""Check that mypy_check.py will parse by python"""
316316
from . import mypy_check
317317
assert mypy_check
318+
319+
320+
def test_cls_kwarg():
321+
322+
class MutexOption(click.Option):
323+
324+
def __init__(self, *args, **kwargs):
325+
self.not_required_if: list[str] = kwargs.pop("not_required_if")
326+
327+
assert self.not_required_if, "'not_required_if' parameter required"
328+
kwargs["help"] = (
329+
f"{kwargs.get('help', '')} Option is mutually exclusive with "
330+
f"{', '.join(self.not_required_if)}.".strip())
331+
super().__init__(*args, **kwargs)
332+
333+
def handle_parse_result(self, ctx, opts, args):
334+
if set(self.not_required_if).intersection(set(opts)):
335+
self.required = False
336+
337+
return super().handle_parse_result(ctx, opts, args)
338+
339+
class MutexArgument(click.Argument):
340+
341+
def __init__(self, *args, **kwargs):
342+
self.not_required_if: list[str] = kwargs.pop("not_required_if")
343+
344+
assert self.not_required_if, "'not_required_if' parameter required"
345+
super().__init__(*args, **kwargs)
346+
347+
def handle_parse_result(self, ctx, opts, args):
348+
if set(self.not_required_if).intersection(set(opts)):
349+
self.required = False
350+
351+
return super().handle_parse_result(ctx, opts, args)
352+
353+
@dataclass
354+
class Config:
355+
username: Annotated[str | None, option(required=True, cls=MutexOption, not_required_if=["token"])]
356+
password: Annotated[str | None, option(required=True, cls=MutexOption, not_required_if=["token"])]
357+
token: Annotated[str | None,
358+
argument(required=True, cls=MutexArgument, not_required_if=["username", "password"])]
359+
360+
@click.command()
361+
@dataclass_click(Config)
362+
def main(*args, **kwargs):
363+
results.append((args, kwargs))
364+
365+
results: list[CallRecord] = []
366+
quick_run(main, "--username", "foo", "--password", "bar")
367+
assert results == [((Config(username="foo", password="bar", token=None), ), {})]
368+
369+
results: list[CallRecord] = []
370+
quick_run(main, "foo")
371+
assert results == [((Config(username=None, password=None, token="foo"), ), {})]

0 commit comments

Comments
 (0)