|
14 | 14 | from pytest_regressions.file_regression import FileRegressionFixture
|
15 | 15 | from typing_extensions import Annotated
|
16 | 16 |
|
17 |
| -from simple_parsing import ArgumentParser, subgroups |
| 17 | +from simple_parsing import ArgumentParser, parse, subgroups |
| 18 | +from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode |
18 | 19 |
|
19 | 20 | from .test_choice import Color
|
20 | 21 | from .testutils import TestSetup, raises_invalid_choice, raises_missing_required_arg
|
@@ -794,3 +795,139 @@ def test_help(
|
794 | 795 | # assert Config.setup("--model small").model == SmallModel()
|
795 | 796 | # assert Config.setup("--model big").model == BigModel()
|
796 | 797 | # assert Config.setup("--num_layers 123").model == Model(num_layers=123, hidden_dim=32)
|
| 798 | + |
| 799 | + |
| 800 | +@pytest.mark.parametrize("frozen", [True, False]) |
| 801 | +def test_nested_subgroups(frozen: bool): |
| 802 | + """Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160""" |
| 803 | + |
| 804 | + @dataclass(frozen=frozen) |
| 805 | + class FooConfig: |
| 806 | + ... |
| 807 | + |
| 808 | + @dataclass(frozen=frozen) |
| 809 | + class BarConfig: |
| 810 | + foo: FooConfig |
| 811 | + |
| 812 | + @dataclass(frozen=frozen) |
| 813 | + class FooAConfig(FooConfig): |
| 814 | + foo_param_a: float = 0.0 |
| 815 | + |
| 816 | + @dataclass(frozen=frozen) |
| 817 | + class FooBConfig(FooConfig): |
| 818 | + foo_param_b: str = "foo_b" |
| 819 | + |
| 820 | + @dataclass(frozen=frozen) |
| 821 | + class Bar1Config(BarConfig): |
| 822 | + foo: FooConfig = subgroups( |
| 823 | + {"foo_a": FooAConfig, "foo_b": FooBConfig}, |
| 824 | + default_factory=FooAConfig, |
| 825 | + ) |
| 826 | + |
| 827 | + @dataclass(frozen=frozen) |
| 828 | + class Bar2Config(BarConfig): |
| 829 | + foo: FooConfig = subgroups( |
| 830 | + {"foo_a": FooAConfig, "foo_b": FooBConfig}, |
| 831 | + default_factory=FooBConfig, |
| 832 | + ) |
| 833 | + |
| 834 | + @dataclass(frozen=frozen) |
| 835 | + class Config(TestSetup): |
| 836 | + bar: Bar1Config | Bar2Config = subgroups( |
| 837 | + {"bar_1": Bar1Config, "bar_2": Bar2Config}, |
| 838 | + default_factory=Bar2Config, |
| 839 | + ) |
| 840 | + |
| 841 | + assert Config.setup("") == Config(bar=Bar2Config(foo=FooBConfig())) |
| 842 | + assert Config.setup("--bar=bar_1 --foo=foo_a") == Config(bar=Bar1Config(foo=FooAConfig())) |
| 843 | + |
| 844 | + |
| 845 | +@dataclass |
| 846 | +class ModelConfig: |
| 847 | + ... |
| 848 | + |
| 849 | + |
| 850 | +@dataclass |
| 851 | +class DatasetConfig: |
| 852 | + ... |
| 853 | + |
| 854 | + |
| 855 | +@dataclass |
| 856 | +class ModelAConfig(ModelConfig): |
| 857 | + lr: float = 3e-4 |
| 858 | + optimizer: str = "Adam" |
| 859 | + betas: tuple[float, float] = (0.9, 0.999) |
| 860 | + |
| 861 | + |
| 862 | +@dataclass |
| 863 | +class ModelBConfig(ModelConfig): |
| 864 | + lr: float = 1e-3 |
| 865 | + optimizer: str = "SGD" |
| 866 | + momentum: float = 1.234 |
| 867 | + |
| 868 | + |
| 869 | +@dataclass |
| 870 | +class Dataset1Config(DatasetConfig): |
| 871 | + data_dir: str | Path = "data/foo" |
| 872 | + foo: bool = False |
| 873 | + |
| 874 | + |
| 875 | +@dataclass |
| 876 | +class Dataset2Config(DatasetConfig): |
| 877 | + data_dir: str | Path = "data/bar" |
| 878 | + bar: float = 1.2 |
| 879 | + |
| 880 | + |
| 881 | +@dataclass |
| 882 | +class Config(TestSetup): |
| 883 | + |
| 884 | + # Which model to use |
| 885 | + model: ModelConfig = subgroups( |
| 886 | + {"model_a": ModelAConfig, "model_b": ModelBConfig}, |
| 887 | + default_factory=ModelAConfig, |
| 888 | + ) |
| 889 | + |
| 890 | + # Which dataset to use |
| 891 | + dataset: DatasetConfig = subgroups( |
| 892 | + {"dataset_1": Dataset1Config, "dataset_2": Dataset2Config}, |
| 893 | + default_factory=Dataset2Config, |
| 894 | + ) |
| 895 | + |
| 896 | + |
| 897 | +def _parse_config(args: str) -> Config: |
| 898 | + return parse( |
| 899 | + Config, |
| 900 | + args=args, |
| 901 | + argument_generation_mode=ArgumentGenerationMode.NESTED, |
| 902 | + nested_mode=NestedMode.WITHOUT_ROOT, |
| 903 | + ) |
| 904 | + |
| 905 | + |
| 906 | +def test_ordering_of_args_doesnt_matter(): |
| 907 | + """Test to confirm that #160 is fixed:""" |
| 908 | + |
| 909 | + # $ python issue.py --model model_a --model.lr 1e-2 |
| 910 | + assert _parse_config(args="--model model_a --model.lr 1e-2") == Config( |
| 911 | + model=ModelAConfig(lr=0.01, optimizer="Adam", betas=(0.9, 0.999)), |
| 912 | + dataset=Dataset2Config(data_dir="data/bar", bar=1.2), |
| 913 | + ) |
| 914 | + |
| 915 | + # I was expecting this to work given that both model configs have `lr` attribute |
| 916 | + # $ python issue.py --model.lr 1e-2. |
| 917 | + assert _parse_config(args="--model.lr 1e-2") == Config( |
| 918 | + model=ModelAConfig(lr=1e-2, optimizer="Adam", betas=(0.9, 0.999)), |
| 919 | + dataset=Dataset2Config(data_dir="data/bar", bar=1.2), |
| 920 | + ) |
| 921 | + |
| 922 | + # $ python issue.py --model model_a --model.betas 0. 1. |
| 923 | + assert _parse_config(args="--model model_a --model.betas 0. 1.") == Config( |
| 924 | + model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)), |
| 925 | + dataset=Dataset2Config(data_dir="data/bar", bar=1.2), |
| 926 | + ) |
| 927 | + |
| 928 | + # % ModelA being the default, I was expecting this two work |
| 929 | + # $ python issue.py --model.betas 0. 1. |
| 930 | + assert _parse_config(args="--model.betas 0. 1.") == Config( |
| 931 | + model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)), |
| 932 | + dataset=Dataset2Config(data_dir="data/bar", bar=1.2), |
| 933 | + ) |
0 commit comments