Skip to content

Commit e79ab8b

Browse files
authored
[EZ] Fix config bug where interpolation happens too early (#2236)
1 parent dadba25 commit e79ab8b

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

tests/torchtune/config/test_config_utils.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
},
2929
"d": 4,
3030
"f": 8,
31+
"g": "foo",
32+
"h": "${g}/bar",
3133
}
3234

3335

@@ -50,7 +52,9 @@ def test_get_component_from_path(self):
5052
):
5153
_ = _get_component_from_path("torchtune.models.dummy")
5254

53-
@mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
55+
@mock.patch(
56+
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
57+
)
5458
def test_merge_yaml_and_cli_args(self, mock_load):
5559
parser = TuneRecipeArgumentParser("test parser")
5660
yaml_args, cli_args = parser.parse_known_args(
@@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
6367
"d=6", # Test overriding a flat param
6468
"e=7", # Test adding a new param
6569
"~f", # Test removing a param
70+
"g=bazz", # Test interpolation happens after override
6671
]
6772
)
6873
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
@@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
7580
assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides."
7681
assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides."
7782
assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides."
83+
assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides."
7884
mock_load.assert_called_once()
7985

8086
yaml_args, cli_args = parser.parse_known_args(
@@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):
185191

186192
# Test removing non-existent param fails
187193
cfg = copy.deepcopy(_CONFIG)
188-
with pytest.raises(KeyError, match="'g'"):
189-
_remove_key_by_dotpath(cfg, "g")
194+
with pytest.raises(KeyError, match="'i'"):
195+
_remove_key_by_dotpath(cfg, "i")

tests/torchtune/config/test_parse.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from torchtune.config._parse import TuneRecipeArgumentParser
1515

16-
_CONFIG = {"a": 1, "b": 2}
16+
_CONFIG = {"a": 1, "b": 2, "c": "foo", "d": "${c}/bar"}
1717

1818

1919
class TestParse:
@@ -41,7 +41,9 @@ def parser(self):
4141
parser = TuneRecipeArgumentParser("Test parser")
4242
return parser
4343

44-
@patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
44+
@patch(
45+
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
46+
)
4547
def test_parse_known_args(self, mock_load, parser):
4648
"""
4749
Test that the parser can load a config and override parameters provided on CLI.
@@ -65,3 +67,11 @@ def test_parse_known_args(self, mock_load, parser):
6567
_ = parser.parse_known_args(
6668
["--config", "test.yaml", "--b", "3"],
6769
)
70+
71+
# Test that parsing does not prematurely interpolate variables.
72+
config_args, cli_args = parser.parse_known_args(
73+
["--config", "test.yaml", "c=bazz"]
74+
)
75+
assert (
76+
config_args.d == "${c}/bar"
77+
), f"d == {config_args.d} not ${{c}}/bar as set in config."

torchtune/config/_parse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]:
5757

5858
config = OmegaConf.load(namespace.config)
5959
assert "config" not in config, "Cannot use 'config' within a config file"
60-
self.set_defaults(**config)
60+
self.set_defaults(**OmegaConf.to_container(config, resolve=False))
6161

6262
namespace, unknown_args = super().parse_known_args(*args, **kwargs)
6363
del namespace.config

0 commit comments

Comments
 (0)