-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathhydra.py
More file actions
71 lines (59 loc) · 2.48 KB
/
hydra.py
File metadata and controls
71 lines (59 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import inspect
import itertools
from collections.abc import Callable
from typing import Annotated, ParamSpec, TypeVar
from hydra import compose, initialize
from omegaconf import DictConfig
from typer import Argument, Option
Param = ParamSpec("Param")
RetType = TypeVar("RetType")
def hydra_adaptor(function: Callable) -> Callable[Param, RetType]:
"""Replace a function that takes a Hydra config with one that takes string arguments.
Args:
function: Callable(*args, config: DictConfig, **kwargs)
Returns:
Callable(*args, config_name: str, **kwargs, overrides: list[str])
"""
def wrapper(
overrides: Annotated[
list[str] | None,
Argument(
help="Apply space-separated Hydra config overrides (https://hydra.cc/docs/advanced/override_grammar/basic/)"
),
] = None,
config_name: Annotated[
str | None,
Option(help="Specify the name of a file to load from the config directory"),
] = "sample",
*args: Param.args,
**kwargs: Param.kwargs,
) -> RetType:
with initialize(config_path="../config", version_base=None):
config = compose(config_name=config_name, overrides=overrides)
return function(*args, config=config, **kwargs)
# Separate parameters by kind
positional_params = []
keyword_only_params = []
# Remove the DictConfig parameter from the function signature
fn_signature = inspect.signature(function, eval_str=True)
for param in fn_signature.parameters.values():
if param.annotation == DictConfig:
continue # skip config param
if param.kind == inspect.Parameter.KEYWORD_ONLY:
keyword_only_params.append(param)
else:
positional_params.append(param)
# Take only the overrides and config_name names from the function signature
additional_params = (
param
for param in inspect.signature(wrapper, eval_str=True).parameters.values()
if param.name in ("overrides", "config_name")
)
# Combine in correct order: positional, then additional, then keyword-only
combined_parameters = list(
itertools.chain(positional_params, additional_params, keyword_only_params)
)
wrapper.__signature__ = fn_signature.replace(parameters=combined_parameters) # type: ignore[attr-defined]
wrapper.__name__ = function.__name__
wrapper.__doc__ = function.__doc__
return wrapper # type: ignore[return-value]