Skip to content

Commit 8c80ae5

Browse files
authored
Add tests to confirm that #160 and #161 are fixed (#234)
* Add test to confirm that #160 is fixed Fixes #160 * Add test to assert that #161 is fixed Fixes #161 --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent dfe0347 commit 8c80ae5

11 files changed

+157
-19
lines changed

examples/subgroups/subgroups_example.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
from simple_parsing import ArgumentParser, subgroups
7+
from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode
78

89

910
@dataclass
@@ -48,17 +49,19 @@ class Config:
4849
# Which model to use
4950
model: ModelConfig = subgroups(
5051
{"model_a": ModelAConfig, "model_b": ModelBConfig},
51-
default=ModelAConfig(),
52+
default_factory=ModelAConfig,
5253
)
5354

5455
# Which dataset to use
5556
dataset: DatasetConfig = subgroups(
5657
{"dataset_1": Dataset1Config, "dataset_2": Dataset2Config},
57-
default=Dataset2Config(),
58+
default_factory=Dataset2Config,
5859
)
5960

6061

61-
parser = ArgumentParser()
62+
parser = ArgumentParser(
63+
argument_generation_mode=ArgumentGenerationMode.NESTED, nested_mode=NestedMode.WITHOUT_ROOT
64+
)
6265
parser.add_arguments(Config, dest="config")
6366
args = parser.parse_args()
6467

simple_parsing/parsing.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import defaultdict
1515
from logging import getLogger
1616
from pathlib import Path
17-
from typing import Any, Callable, Sequence, Type, TypeVar, overload
17+
from typing import Any, Callable, Sequence, Type, overload
1818

1919
from simple_parsing.helpers.subgroups import SubgroupKey
2020
from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType
@@ -971,14 +971,12 @@ def _fill_constructor_arguments_with_fields(
971971
return leftover_args, constructor_arguments
972972

973973

974-
T = TypeVar("T")
975-
976-
974+
# TODO: Change the order of arguments to put `args` as the second argument.
977975
def parse(
978-
config_class: type[Dataclass],
976+
config_class: type[DataclassT],
979977
config_path: Path | str | None = None,
980978
args: str | Sequence[str] | None = None,
981-
default: Dataclass | None = None,
979+
default: DataclassT | None = None,
982980
dest: str = "config",
983981
*,
984982
prefix: str = "",
@@ -989,7 +987,7 @@ def parse(
989987
formatter_class: type[HelpFormatter] = SimpleHelpFormatter,
990988
add_config_path_arg: bool | None = None,
991989
**kwargs,
992-
) -> Dataclass:
990+
) -> DataclassT:
993991
"""Parse the given dataclass from the command-line.
994992
995993
See the `ArgumentParser` constructor for more details on the arguments (they are the same here

test/test_subgroups.py

+138-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from pytest_regressions.file_regression import FileRegressionFixture
1515
from typing_extensions import Annotated
1616

17-
from simple_parsing import ArgumentParser, subgroups
17+
from simple_parsing import ArgumentParser, parse, subgroups
18+
from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode
1819

1920
from .test_choice import Color
2021
from .testutils import TestSetup, raises_invalid_choice, raises_missing_required_arg
@@ -794,3 +795,139 @@ def test_help(
794795
# assert Config.setup("--model small").model == SmallModel()
795796
# assert Config.setup("--model big").model == BigModel()
796797
# assert Config.setup("--num_layers 123").model == Model(num_layers=123, hidden_dim=32)
798+
799+
800+
@pytest.mark.parametrize("frozen", [True, False])
801+
def test_nested_subgroups(frozen: bool):
802+
"""Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160"""
803+
804+
@dataclass(frozen=frozen)
805+
class FooConfig:
806+
...
807+
808+
@dataclass(frozen=frozen)
809+
class BarConfig:
810+
foo: FooConfig
811+
812+
@dataclass(frozen=frozen)
813+
class FooAConfig(FooConfig):
814+
foo_param_a: float = 0.0
815+
816+
@dataclass(frozen=frozen)
817+
class FooBConfig(FooConfig):
818+
foo_param_b: str = "foo_b"
819+
820+
@dataclass(frozen=frozen)
821+
class Bar1Config(BarConfig):
822+
foo: FooConfig = subgroups(
823+
{"foo_a": FooAConfig, "foo_b": FooBConfig},
824+
default_factory=FooAConfig,
825+
)
826+
827+
@dataclass(frozen=frozen)
828+
class Bar2Config(BarConfig):
829+
foo: FooConfig = subgroups(
830+
{"foo_a": FooAConfig, "foo_b": FooBConfig},
831+
default_factory=FooBConfig,
832+
)
833+
834+
@dataclass(frozen=frozen)
835+
class Config(TestSetup):
836+
bar: Bar1Config | Bar2Config = subgroups(
837+
{"bar_1": Bar1Config, "bar_2": Bar2Config},
838+
default_factory=Bar2Config,
839+
)
840+
841+
assert Config.setup("") == Config(bar=Bar2Config(foo=FooBConfig()))
842+
assert Config.setup("--bar=bar_1 --foo=foo_a") == Config(bar=Bar1Config(foo=FooAConfig()))
843+
844+
845+
@dataclass
846+
class ModelConfig:
847+
...
848+
849+
850+
@dataclass
851+
class DatasetConfig:
852+
...
853+
854+
855+
@dataclass
856+
class ModelAConfig(ModelConfig):
857+
lr: float = 3e-4
858+
optimizer: str = "Adam"
859+
betas: tuple[float, float] = (0.9, 0.999)
860+
861+
862+
@dataclass
863+
class ModelBConfig(ModelConfig):
864+
lr: float = 1e-3
865+
optimizer: str = "SGD"
866+
momentum: float = 1.234
867+
868+
869+
@dataclass
870+
class Dataset1Config(DatasetConfig):
871+
data_dir: str | Path = "data/foo"
872+
foo: bool = False
873+
874+
875+
@dataclass
876+
class Dataset2Config(DatasetConfig):
877+
data_dir: str | Path = "data/bar"
878+
bar: float = 1.2
879+
880+
881+
@dataclass
882+
class Config(TestSetup):
883+
884+
# Which model to use
885+
model: ModelConfig = subgroups(
886+
{"model_a": ModelAConfig, "model_b": ModelBConfig},
887+
default_factory=ModelAConfig,
888+
)
889+
890+
# Which dataset to use
891+
dataset: DatasetConfig = subgroups(
892+
{"dataset_1": Dataset1Config, "dataset_2": Dataset2Config},
893+
default_factory=Dataset2Config,
894+
)
895+
896+
897+
def _parse_config(args: str) -> Config:
898+
return parse(
899+
Config,
900+
args=args,
901+
argument_generation_mode=ArgumentGenerationMode.NESTED,
902+
nested_mode=NestedMode.WITHOUT_ROOT,
903+
)
904+
905+
906+
def test_ordering_of_args_doesnt_matter():
907+
"""Test to confirm that #160 is fixed:"""
908+
909+
# $ python issue.py --model model_a --model.lr 1e-2
910+
assert _parse_config(args="--model model_a --model.lr 1e-2") == Config(
911+
model=ModelAConfig(lr=0.01, optimizer="Adam", betas=(0.9, 0.999)),
912+
dataset=Dataset2Config(data_dir="data/bar", bar=1.2),
913+
)
914+
915+
# I was expecting this to work given that both model configs have `lr` attribute
916+
# $ python issue.py --model.lr 1e-2.
917+
assert _parse_config(args="--model.lr 1e-2") == Config(
918+
model=ModelAConfig(lr=1e-2, optimizer="Adam", betas=(0.9, 0.999)),
919+
dataset=Dataset2Config(data_dir="data/bar", bar=1.2),
920+
)
921+
922+
# $ python issue.py --model model_a --model.betas 0. 1.
923+
assert _parse_config(args="--model model_a --model.betas 0. 1.") == Config(
924+
model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)),
925+
dataset=Dataset2Config(data_dir="data/bar", bar=1.2),
926+
)
927+
928+
# % ModelA being the default, I was expecting this two work
929+
# $ python issue.py --model.betas 0. 1.
930+
assert _parse_config(args="--model.betas 0. 1.") == Config(
931+
model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)),
932+
dataset=Dataset2Config(data_dir="data/bar", bar=1.2),
933+
)

test/test_subgroups/test_help[Config---help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[Config---model=model_a --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[Config---model=model_b --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---help].md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:724)
1+
# Regression file for [this test](test/test_subgroups.py:725)
22

33
Given Source code:
44

0 commit comments

Comments
 (0)