Skip to content

Commit 26ec008

Browse files
joshlklebrice
andauthored
Add option to specify config path argument name (#334)
* Add option to specify config path argument name * Reformat using pre-commit * Address comments * Update test_conf_path.py rename arg * Update mdformat pre-commit hook Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]> Co-authored-by: Fabrice Normandin <[email protected]>
1 parent 4388893 commit 26ec008

File tree

4 files changed

+81
-11
lines changed

4 files changed

+81
-11
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666

6767
# md formatting
6868
- repo: https://github.com/executablebooks/mdformat
69-
rev: 0.7.16
69+
rev: 0.7.21
7070
hooks:
7171
- id: mdformat
7272
args: ["--number"]

examples/config_files/one_config.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ class TrainConfig:
1313

1414

1515
def main(args=None) -> None:
16-
cfg = simple_parsing.parse(config_class=TrainConfig, args=args, add_config_path_arg=True)
16+
cfg = simple_parsing.parse(
17+
config_class=TrainConfig, args=args, add_config_path_arg="config-file"
18+
)
1719
print(f"Training {cfg.exp_name} with {cfg.workers} workers...")
1820

1921

@@ -28,7 +30,7 @@ def main(args=None) -> None:
2830
"""
2931

3032
# NOTE: When running as in the readme:
31-
main("--config_path one_config.yaml --exp_name my_first_exp")
33+
main("--config-file one_config.yaml --exp_name my_first_exp")
3234
expected += """\
3335
Training my_first_exp with 42 workers...
3436
"""

simple_parsing/parsing.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,15 @@ class ArgumentParser(argparse.ArgumentParser):
9797
`argparse.MetavarTypeHelpFormatter` and
9898
`argparse.RawDescriptionHelpFormatter` classes.
9999
100-
- add_config_path_arg : bool, optional
101-
When set to `True`, adds a `--config_path` argument, of type Path, which is used to parse
100+
- add_config_path_arg : bool, str, optional
101+
When set to `True`, adds a `--config_path` argument, of type Path, which is used to parse.
102+
If set to a string then this is the name of the config_path argument.
103+
104+
- config_path: str, optional
105+
The values read from this file will overwrite the default values from the dataclass definitions.
106+
When `add_config_path_arg` is also set the defaults are first updated using `config_path`, and then
107+
updated with the contents of the `--config_path` file(s). By setting this value it will be default set
108+
`add_config_path_arg` to True.
102109
"""
103110

104111
def __init__(
@@ -111,7 +118,7 @@ def __init__(
111118
argument_generation_mode=ArgumentGenerationMode.FLAT,
112119
nested_mode: NestedMode = NestedMode.DEFAULT,
113120
formatter_class: type[HelpFormatter] = SimpleHelpFormatter,
114-
add_config_path_arg: bool | None = None,
121+
add_config_path_arg: bool | str | None = None,
115122
config_path: Path | str | Sequence[Path | str] | None = None,
116123
add_dest_to_option_strings: bool | None = None,
117124
**kwargs,
@@ -298,6 +305,11 @@ def parse_known_args(
298305
self.set_defaults(config_file)
299306

300307
if self.add_config_path_arg:
308+
config_path_name = (
309+
self.add_config_path_arg
310+
if isinstance(self.add_config_path_arg, str)
311+
else "config_path"
312+
)
301313
temp_parser = ArgumentParser(
302314
add_config_path_arg=False,
303315
add_help=False,
@@ -306,14 +318,14 @@ def parse_known_args(
306318
nested_mode=FieldWrapper.nested_mode,
307319
)
308320
temp_parser.add_argument(
309-
"--config_path",
321+
f"--{config_path_name}",
310322
type=Path,
311323
nargs="*",
312324
default=self.config_path,
313325
help="Path to a config file containing default values to use.",
314326
)
315327
args_with_config_path, args = temp_parser.parse_known_args(args)
316-
config_path = args_with_config_path.config_path
328+
config_path = getattr(args_with_config_path, config_path_name.replace("-", "_"))
317329

318330
if config_path is not None:
319331
config_paths = config_path if isinstance(config_path, list) else [config_path]
@@ -323,7 +335,7 @@ def parse_known_args(
323335
# Adding it here just so it shows up in the help message. The default will be set in
324336
# the help string.
325337
self.add_argument(
326-
"--config_path",
338+
f"--{config_path_name}",
327339
type=Path,
328340
default=config_path,
329341
help="Path to a config file containing default values to use.",
@@ -1000,7 +1012,7 @@ def parse(
10001012
add_option_string_dash_variants: DashVariant = DashVariant.AUTO,
10011013
argument_generation_mode=ArgumentGenerationMode.FLAT,
10021014
formatter_class: type[HelpFormatter] = SimpleHelpFormatter,
1003-
add_config_path_arg: bool | None = None,
1015+
add_config_path_arg: bool | str | None = None,
10041016
**kwargs,
10051017
) -> DataclassT:
10061018
"""Parse the given dataclass from the command-line.
@@ -1010,10 +1022,12 @@ def parse(
10101022
10111023
If `config_path` is passed, loads the values from that file and uses them as defaults.
10121024
"""
1025+
if dest == add_config_path_arg:
1026+
raise ValueError("`add_config_path_arg` cannot be the same as `dest`.")
1027+
10131028
parser = ArgumentParser(
10141029
nested_mode=nested_mode,
10151030
add_help=True,
1016-
# add_config_path_arg=None,
10171031
config_path=config_path,
10181032
conflict_resolution=conflict_resolution,
10191033
add_option_string_dash_variants=add_option_string_dash_variants,

test/test_conf_path.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Tests for config-path option."""
2+
3+
import json
4+
from dataclasses import dataclass
5+
from pathlib import Path
6+
7+
import pytest
8+
9+
from simple_parsing.parsing import ArgumentParser, parse
10+
11+
12+
@dataclass
13+
class BarConf:
14+
foo: str
15+
16+
17+
@pytest.mark.parametrize(
18+
"conf_arg_name", ["config-file", "config_file", "foo.bar.baz?", "bob bob bob"]
19+
)
20+
def test_config_path_arg(tmp_path: Path, conf_arg_name: str):
21+
"""Test config_path with valid strings."""
22+
# Create config file
23+
conf_path = tmp_path / "foo.yml"
24+
with conf_path.open("w") as f:
25+
json.dump({"foo": "bee"}, f)
26+
27+
# with pytest.raises(ValueError):
28+
parser = ArgumentParser(BarConf, add_config_path_arg=conf_arg_name)
29+
args = parser.parse_args([f"--{conf_arg_name}", str(conf_path)])
30+
print(args)
31+
32+
33+
@pytest.mark.parametrize(
34+
"conf_arg_name",
35+
[
36+
"-------",
37+
],
38+
)
39+
def test_pass_invalid_value_to_add_config_path_arg(tmp_path: Path, conf_arg_name: str):
40+
"""Test config_path with invalid strings."""
41+
# Create config file
42+
conf_path = tmp_path / "foo.yml"
43+
with conf_path.open("w") as f:
44+
json.dump({"foo": "bee"}, f)
45+
46+
parser = ArgumentParser(BarConf, add_config_path_arg=conf_arg_name)
47+
with pytest.raises(ValueError):
48+
parser.parse_args([f"--{conf_arg_name}", str(conf_path)])
49+
50+
51+
def test_config_path_same_as_dst_error():
52+
"""Raise an error if add_config_path_arg and dest are the equal."""
53+
with pytest.raises(ValueError, match="`add_config_path_arg` cannot be the same as `dest`."):
54+
parse(BarConf, dest="boo", add_config_path_arg="boo")

0 commit comments

Comments
 (0)