Skip to content

Commit 8d46cb2

Browse files
authored
Fix CLI args type conversion (#3)
1 parent 301c163 commit 8d46cb2

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

src/toml_fmt_common/__init__.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
import os
77
import sys
88
from abc import ABC, abstractmethod
9-
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError, Namespace
9+
from argparse import (
10+
ArgumentDefaultsHelpFormatter,
11+
ArgumentParser,
12+
ArgumentTypeError,
13+
Namespace,
14+
_ArgumentGroup, # noqa: PLC2701
15+
)
1016
from collections import deque
1117
from copy import deepcopy
1218
from dataclasses import dataclass
@@ -16,13 +22,15 @@
1622
from typing import TYPE_CHECKING, Any, Generic, TypeVar
1723

1824
if TYPE_CHECKING:
19-
from collections.abc import Iterable, Sequence
25+
from collections.abc import Callable, Iterable, Mapping, Sequence
2026

2127
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
2228
import tomllib
2329
else: # pragma: <3.11 cover
2430
import tomli as tomllib
2531

32+
ArgumentGroup = _ArgumentGroup
33+
2634

2735
class FmtNamespace(Namespace):
2836
"""Options for pyproject-fmt tool."""
@@ -63,7 +71,7 @@ def filename(self) -> str:
6371
raise NotImplementedError
6472

6573
@abstractmethod
66-
def add_format_flags(self, parser: ArgumentParser) -> None:
74+
def add_format_flags(self, parser: ArgumentGroup) -> None:
6775
"""
6876
Add any additional flags to configure the formatter.
6977
@@ -126,7 +134,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
126134
:param args: CLI arguments
127135
:return: the parsed options
128136
"""
129-
parser = _build_cli(info)
137+
parser, type_conversion = _build_cli(info)
130138
parser.parse_args(namespace=info.opt, args=args)
131139
res = []
132140
for pyproject_toml in info.opt.inputs:
@@ -144,7 +152,9 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
144152
if isinstance(config, dict):
145153
for key in set(vars(override_opt).keys()) - {"inputs", "stdout", "check", "no_print_diff"}:
146154
if key in config:
147-
setattr(override_opt, key, config[key])
155+
raw = config[key]
156+
converted = type_conversion[key](raw) if key in type_conversion else raw
157+
setattr(override_opt, key, converted)
148158
res.append(
149159
_Config(
150160
toml_filename=pyproject_toml,
@@ -159,7 +169,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
159169
return res
160170

161171

162-
def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser:
172+
def _build_cli(of: TOMLFormatter[T]) -> tuple[ArgumentParser, Mapping[str, Callable[[Any], Any]]]:
163173
parser = ArgumentParser(
164174
formatter_class=ArgumentDefaultsHelpFormatter,
165175
prog=of.prog,
@@ -200,15 +210,16 @@ def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser:
200210
help="number of spaces to use for indentation",
201211
metavar="count",
202212
)
203-
of.add_format_flags(format_group) # type: ignore[arg-type]
213+
of.add_format_flags(format_group)
214+
type_conversion = {a.dest: a.type for a in format_group._actions if a.type and a.dest} # noqa: SLF001
204215
msg = "pyproject.toml file(s) to format, use '-' to read from stdin"
205216
parser.add_argument(
206217
"inputs",
207218
nargs="+",
208219
type=partial(_toml_path_creator, of.filename),
209220
help=msg,
210221
)
211-
return parser
222+
return parser, type_conversion
212223

213224

214225
def _toml_path_creator(filename: str, argument: str) -> Path | None:
@@ -289,6 +300,7 @@ def _color_diff(diff: Iterable[str]) -> Iterable[str]:
289300

290301

291302
__all__ = [
303+
"ArgumentGroup",
292304
"FmtNamespace",
293305
"TOMLFormatter",
294306
"run",

tests/test_app.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66

77
import pytest
88

9-
from toml_fmt_common import GREEN, RED, RESET, FmtNamespace, TOMLFormatter, run
9+
from toml_fmt_common import GREEN, RED, RESET, ArgumentGroup, FmtNamespace, TOMLFormatter, run
1010

1111
if TYPE_CHECKING:
12-
from argparse import ArgumentParser
1312
from pathlib import Path
1413

1514
from pytest_mock import MockerFixture
1615

1716

1817
class DumpNamespace(FmtNamespace):
1918
extra: str
19+
tuple_magic: tuple[str, ...]
2020

2121

2222
class Dumb(TOMLFormatter[DumpNamespace]):
@@ -35,11 +35,18 @@ def filename(self) -> str:
3535
def override_cli_from_section(self) -> tuple[str, ...]:
3636
return "start", "sub"
3737

38-
def add_format_flags(self, parser: ArgumentParser) -> None: # noqa: PLR6301
38+
def add_format_flags(self, parser: ArgumentGroup) -> None: # noqa: PLR6301
3939
parser.add_argument("extra", help="this is something extra")
40+
parser.add_argument("-t", "--tuple-magic", default=(), type=lambda t: tuple(t.split(".")))
4041

4142
def format(self, text: str, opt: DumpNamespace) -> str: # noqa: PLR6301
42-
return text if os.environ.get("NO_FMT") else f"{text}\nextras = {opt.extra!r}"
43+
if os.environ.get("NO_FMT"):
44+
return text
45+
return "\n".join([
46+
text,
47+
f"extras = {opt.extra!r}",
48+
*([f"magic = {','.join(opt.tuple_magic)!r}"] if opt.tuple_magic else []),
49+
])
4350

4451

4552
def test_dumb_help(capsys: pytest.CaptureFixture[str]) -> None:
@@ -77,6 +84,31 @@ def test_dumb_format_with_override(capsys: pytest.CaptureFixture[str], tmp_path:
7784
]
7885

7986

87+
def test_dumb_format_with_override_custom_type(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
88+
dumb = tmp_path / "dumb.toml"
89+
dumb.write_text("[start.sub]\ntuple_magic = '1.2.3'")
90+
91+
exit_code = run(Dumb(), ["E", str(dumb)])
92+
assert exit_code == 1
93+
94+
assert dumb.read_text() == "[start.sub]\ntuple_magic = '1.2.3'\nextras = 'E'\nmagic = '1,2,3'"
95+
96+
out, err = capsys.readouterr()
97+
assert not err
98+
assert out.splitlines() == [
99+
f"{RED}--- {dumb}",
100+
f"{RESET}",
101+
f"{GREEN}+++ {dumb}",
102+
f"{RESET}",
103+
"@@ -1,2 +1,4 @@",
104+
"",
105+
" [start.sub]",
106+
" tuple_magic = '1.2.3'",
107+
f"{GREEN}+extras = 'E'{RESET}",
108+
f"{GREEN}+magic = '1,2,3'{RESET}",
109+
]
110+
111+
80112
def test_dumb_format_no_print_diff(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
81113
dumb = tmp_path / "dumb.toml"
82114
dumb.write_text("[start.sub]\nextra = 'B'")

0 commit comments

Comments
 (0)