28
28
},
29
29
"d" : 4 ,
30
30
"f" : 8 ,
31
+ "g" : "foo" ,
32
+ "h" : "${g}/bar" ,
31
33
}
32
34
33
35
@@ -50,7 +52,9 @@ def test_get_component_from_path(self):
50
52
):
51
53
_ = _get_component_from_path ("torchtune.models.dummy" )
52
54
53
- @mock .patch ("torchtune.config._parse.OmegaConf.load" , return_value = _CONFIG )
55
+ @mock .patch (
56
+ "torchtune.config._parse.OmegaConf.load" , return_value = OmegaConf .create (_CONFIG )
57
+ )
54
58
def test_merge_yaml_and_cli_args (self , mock_load ):
55
59
parser = TuneRecipeArgumentParser ("test parser" )
56
60
yaml_args , cli_args = parser .parse_known_args (
@@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
63
67
"d=6" , # Test overriding a flat param
64
68
"e=7" , # Test adding a new param
65
69
"~f" , # Test removing a param
70
+ "g=bazz" , # Test interpolation happens after override
66
71
]
67
72
)
68
73
conf = _merge_yaml_and_cli_args (yaml_args , cli_args )
@@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
75
80
assert conf .d == 6 , f"d == { conf .d } , not 6 as set in overrides."
76
81
assert conf .e == 7 , f"e == { conf .e } , not 7 as set in overrides."
77
82
assert "f" not in conf , f"f == { conf .f } , not removed as set in overrides."
83
+ assert conf .h == "bazz/bar" , f"h == { conf .h } , not bazz/bar as set in overrides."
78
84
mock_load .assert_called_once ()
79
85
80
86
yaml_args , cli_args = parser .parse_known_args (
@@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):
185
191
186
192
# Test removing non-existent param fails
187
193
cfg = copy .deepcopy (_CONFIG )
188
- with pytest .raises (KeyError , match = "'g '" ):
189
- _remove_key_by_dotpath (cfg , "g " )
194
+ with pytest .raises (KeyError , match = "'i '" ):
195
+ _remove_key_by_dotpath (cfg , "i " )
0 commit comments