Skip to content

Commit 4355cc8

Browse files
dberenbaumefiopskshetry
authored andcommitted
hydra: support plugins (iterative#10240)
* hydra: support plugins * Update dvc/utils/hydra.py * fix tests * load plugins after some checks run --------- Co-authored-by: Ruslan Kuprieiev <[email protected]> Co-authored-by: Saugat Pachhai (सौगात) <[email protected]> Co-authored-by: skshetry <[email protected]>
1 parent 253d18f commit 4355cc8

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

dvc/config_schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def __call__(self, data):
346346
Exclusive("config_dir", "config_source"): str,
347347
Exclusive("config_module", "config_source"): str,
348348
"config_name": str,
349+
"plugins_path": str,
349350
},
350351
"studio": {
351352
"token": str,

dvc/repo/experiments/queue/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,15 @@ def _update_params(self, params: Dict[str, List[str]]):
486486
else:
487487
config_dir = None
488488
config_name = hydra_config.get("config_name", "config")
489+
plugins_path = os.path.join(
490+
self.repo.root_dir, hydra_config.get("plugins_path", "")
491+
)
489492
compose_and_dump(
490493
path,
491494
config_dir,
492495
config_module,
493496
config_name,
497+
plugins_path,
494498
overrides,
495499
)
496500
else:

dvc/utils/hydra.py

+15
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,24 @@
1313
logger = logger.getChild(__name__)
1414

1515

16+
def load_hydra_plugins(plugins_path: str):
17+
import sys
18+
19+
from hydra.core.plugins import Plugins
20+
21+
sys.path.append(plugins_path)
22+
try:
23+
Plugins.instance()
24+
finally:
25+
sys.path.remove(plugins_path)
26+
27+
1628
def compose_and_dump(
1729
output_file: "StrPath",
1830
config_dir: Optional[str],
1931
config_module: Optional[str],
2032
config_name: str,
33+
plugins_path: str,
2134
overrides: List[str],
2235
) -> None:
2336
"""Compose Hydra config and dumpt it to `output_file`.
@@ -30,6 +43,7 @@ def compose_and_dump(
3043
Ignored if `config_dir` is not `None`.
3144
config_name: Name of the config file containing defaults,
3245
without the .yaml extension.
46+
plugins_path: Path to auto discover Hydra plugins.
3347
overrides: List of `Hydra Override`_ patterns.
3448
3549
.. _Hydra Override:
@@ -47,6 +61,7 @@ def compose_and_dump(
4761
initialize_config_dir if config_dir else initialize_config_module
4862
)
4963

64+
load_hydra_plugins(plugins_path)
5065
with initialize_config( # type: ignore[attr-defined]
5166
config_source, version_base=None
5267
):

tests/func/utils/test_hydra.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def test_compose_and_dump_overrides(tmp_dir, suffix, overrides, expected):
176176
output_file = tmp_dir / f"params.{suffix}"
177177
config_dir = hydra_setup(tmp_dir, "conf", "config")
178178
config_module = None
179-
compose_and_dump(output_file, config_dir, config_module, config_name, overrides)
179+
compose_and_dump(
180+
output_file, config_dir, config_module, config_name, str(tmp_dir), overrides
181+
)
180182
assert output_file.parse() == expected
181183

182184

@@ -229,7 +231,9 @@ def test_compose_and_dump_dir_module(
229231
)
230232

231233
with error_context:
232-
compose_and_dump(output_file, config_dir, config_module, config_name, [])
234+
compose_and_dump(
235+
output_file, config_dir, config_module, config_name, str(tmp_dir), []
236+
)
233237
assert output_file.parse() == config_content
234238

235239

@@ -241,7 +245,7 @@ def test_compose_and_dump_yaml_handles_string(tmp_dir):
241245
config.parent.mkdir()
242246
config.write_text("foo: 'no'\n")
243247
output_file = tmp_dir / "params.yaml"
244-
compose_and_dump(output_file, str(config.parent), None, "config", [])
248+
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
245249
assert output_file.read_text() == "foo: 'no'\n"
246250

247251

@@ -253,12 +257,38 @@ def test_compose_and_dump_resolves_interpolation(tmp_dir):
253257
config.parent.mkdir()
254258
config.dump({"data": {"root": "path/to/root", "raw": "${.root}/raw"}})
255259
output_file = tmp_dir / "params.yaml"
256-
compose_and_dump(output_file, str(config.parent), None, "config", [])
260+
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
257261
assert output_file.parse() == {
258262
"data": {"root": "path/to/root", "raw": "path/to/root/raw"}
259263
}
260264

261265

266+
def test_compose_and_dump_plugins(tmp_dir):
267+
"""Ensure Hydra plugins are loaded."""
268+
from hydra.core.plugins import Plugins
269+
270+
from dvc.utils.hydra import compose_and_dump
271+
272+
# clear cached plugins
273+
Plugins._instances.pop(Plugins, None)
274+
275+
config = tmp_dir / "conf" / "config.yaml"
276+
config.parent.mkdir()
277+
config.write_text("foo: '${plus_10:1}'\n")
278+
279+
plugins = tmp_dir / "hydra_plugins"
280+
plugins.mkdir()
281+
(plugins / "resolver.py").write_text(
282+
"""\
283+
from omegaconf import OmegaConf
284+
OmegaConf.register_new_resolver('plus_10', lambda x: x + 10)"""
285+
)
286+
287+
output_file = tmp_dir / "params.yaml"
288+
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
289+
assert output_file.read_text() == "foo: 11\n"
290+
291+
262292
@pytest.mark.parametrize(
263293
"overrides, expected",
264294
[

0 commit comments

Comments
 (0)