Skip to content

Commit d5009aa

Browse files
committed
✅ Add regression tests for hydra
1 parent c480bb1 commit d5009aa

1 file changed

Lines changed: 106 additions & 0 deletions

File tree

tests/config/test_hydra.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import inspect
2+
from importlib.resources import files
3+
4+
import pytest
5+
from hydra import compose, initialize_config_dir
6+
from hydra.core.global_hydra import GlobalHydra
7+
from omegaconf import DictConfig
8+
9+
from icenet_mp.cli.hydra import hydra_adaptor
10+
11+
12+
class TestHydraConfigLoading:
13+
"""Regression tests for icenet-mp config composition via hydra."""
14+
15+
CONFIG_DIR = str(files("icenet_mp.config"))
16+
17+
def setup_method(self) -> None:
18+
GlobalHydra.instance().clear()
19+
20+
def teardown_method(self) -> None:
21+
GlobalHydra.instance().clear()
22+
23+
def load_config(self, overrides: list[str] | None = None) -> DictConfig:
24+
"""Compose the sample config from the icenet-mp config directory with any overrides."""
25+
with initialize_config_dir(config_dir=self.CONFIG_DIR, version_base=None):
26+
return compose(config_name="sample", overrides=overrides or [])
27+
28+
def test_sample_config_has_expected_top_level_keys(self) -> None:
29+
cfg = self.load_config()
30+
for key in ("data", "model", "train", "loss", "predict", "evaluate"):
31+
assert key in cfg, f"Key '{key}' missing from composed config"
32+
33+
def test_model_group_overridden_by_sample(self) -> None:
34+
# sample.yaml uses `override /model: quick_test`, replacing the base default
35+
cfg = self.load_config()
36+
assert cfg.model.name == "quick_test"
37+
assert cfg.model._target_ == "icenet_mp.models.EncodeProcessDecode"
38+
39+
def test_loss_defaults_resolved_from_base(self) -> None:
40+
cfg = self.load_config()
41+
assert cfg.loss._target_ == "torch.nn.HuberLoss"
42+
assert cfg.loss.delta == pytest.approx(0.5)
43+
44+
def test_scalar_override_applied(self) -> None:
45+
cfg = self.load_config(overrides=["loss.delta=1.0"])
46+
assert cfg.loss.delta == pytest.approx(1.0)
47+
48+
def test_config_group_override_swaps_loss(self) -> None:
49+
cfg = self.load_config(overrides=["loss=mse"])
50+
assert cfg.loss._target_ == "torch.nn.MSELoss"
51+
52+
53+
class TestHydraAdaptor:
54+
"""Regression tests for icenet-mp's hydra_adaptor signature rewriter."""
55+
56+
def test_signature_rewriting(self) -> None:
57+
def fn(config: DictConfig) -> None:
58+
pass
59+
60+
params = inspect.signature(hydra_adaptor(fn)).parameters
61+
assert "config" not in params
62+
assert "config_name" in params
63+
assert "overrides" in params
64+
65+
def test_preserves_positional_params(self) -> None:
66+
def fn(x: int, config: DictConfig) -> None:
67+
del x, config
68+
69+
params = list(inspect.signature(hydra_adaptor(fn)).parameters)
70+
assert "x" in params
71+
assert "config" not in params
72+
73+
def test_preserves_keyword_only_params(self) -> None:
74+
def fn(config: DictConfig, *, flag: bool = False) -> None:
75+
del config, flag
76+
77+
params = inspect.signature(hydra_adaptor(fn)).parameters
78+
assert "flag" in params
79+
assert params["flag"].kind == inspect.Parameter.KEYWORD_ONLY
80+
81+
def test_preserves_name_and_doc(self) -> None:
82+
def fn(config: DictConfig) -> None:
83+
"""My docstring."""
84+
85+
assert hydra_adaptor(fn).__name__ == "fn"
86+
assert hydra_adaptor(fn).__doc__ == "My docstring."
87+
88+
def test_wrapped_function_receives_dictconfig(self) -> None:
89+
received: list[DictConfig] = []
90+
91+
def fn(config: DictConfig) -> None:
92+
received.append(config)
93+
94+
hydra_adaptor(fn)(config_name="sample") # type: ignore[arg-type]
95+
assert len(received) == 1
96+
assert isinstance(received[0], DictConfig)
97+
98+
def test_wrapped_function_forwards_kwargs(self) -> None:
99+
received: list[tuple] = []
100+
101+
def fn(x: int, config: DictConfig) -> None:
102+
received.append((x, config))
103+
104+
hydra_adaptor(fn)(x=42, config_name="sample") # type: ignore[arg-type]
105+
assert received[0][0] == 42
106+
assert isinstance(received[0][1], DictConfig)

0 commit comments

Comments
 (0)