Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import copy
import functools
import hashlib
from enum import Enum
from textwrap import dedent
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from functools import wraps

from omegaconf import OmegaConf, SCMode
from omegaconf._utils import is_structured_config
Expand All @@ -22,6 +24,8 @@ class _Keys(str, Enum):
RECURSIVE = "_recursive_"
ARGS = "_args_"
PARTIAL = "_partial_"
ONCE = "_once_"
KEY = "_key_"


def _is_target(x: Any) -> bool:
Expand Down Expand Up @@ -171,10 +175,41 @@ def _deep_copy_full_config(subconfig: Any) -> Any:
return OmegaConf.select(full_config_copy, full_key)


_ONCE_STORAGE: Dict[str, Any] = {}


# could be exposed in public API if useful
def clear_instantiate_cache():
_ONCE_STORAGE.clear()


def _once_storage_swap(func):
@wraps(func)
def wrapper(*args, cache=None, **kwargs):
global _ONCE_STORAGE

if cache is None:
cache = _ONCE_STORAGE

OLD = _ONCE_STORAGE
_ONCE_STORAGE = cache
try:
# Call the original function
result = func(*args, **kwargs)
finally:
# Restore the original _ONCE_STORAGE
_ONCE_STORAGE = OLD
# Return the result of the original function
return result

return wrapper

@_once_storage_swap
def instantiate(
config: Any,
*args: Any,
_skip_instantiate_full_deepcopy_: bool = False,
cache: Union[Dict[str, Any], None] = None, #implemented in decorator
**kwargs: Any,
) -> Any:
"""
Expand All @@ -199,10 +234,16 @@ def instantiate(
are converted to dicts / lists too.
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
_once_: If True, instantiate the target only once and return the same
instance on subsequent calls.
_key_: If set, this used to identify the target in the 'once' cache.
Note required in most cases.
:param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead
of full config before resolving omegaconf interpolations, which may
potentially modify the config's parent/sibling configs in place.
False by default.
:param cache: Optional cache to use for once storage. Pass '{}' to discard the cache
between different calls to instantiate.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -317,13 +358,14 @@ def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any:
)
return node


@_once_storage_swap
def instantiate_node(
node: Any,
*args: Any,
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
cache: Union[Dict[str, Any], None] = None, # implemented in decorator
) -> Any:
# Return None if config is None
if node is None or (OmegaConf.is_config(node) and node._is_none()):
Expand All @@ -349,7 +391,7 @@ def instantiate_node(
raise TypeError(msg)

if not isinstance(partial, bool):
msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}"
msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}"
if node and full_key:
msg += f"\nfull_key: {full_key}"
raise TypeError(msg)
Expand All @@ -371,7 +413,34 @@ def instantiate_node(
return lst

elif OmegaConf.is_dict(node):
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
# Use cached return if once is True and it exists in the cache.
if "_once_" in node:
once = node.pop(_Keys.ONCE)
if _Keys.KEY in node:
once_key = node.pop(_Keys.KEY)
elif once is not True:
once_key = once
else:
once_key = OmegaConf.to_yaml(node)
if recursive != True:
once_key = f"recursive: ${recursive}\n\n{once_key}"
if convert != ConvertMode.NONE:
once_key = f"convert: ${convert}\n\n{once_key}"
if partial != True:
once_key = f"partial: ${partial}\n\n{once_key}"
once_key = hashlib.md5(once_key.encode()).hexdigest()

if once_key in _ONCE_STORAGE:
return _ONCE_STORAGE[once_key]
else:
_ONCE_STORAGE[once_key] = instantiate_node(
node, *args, convert=convert, recursive=recursive, partial=partial
)
return _ONCE_STORAGE[once_key]

exclude_keys = set(
{"_target_", "_convert_", "_recursive_", "_partial_", "_once_", "_key_"}
)
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
kwargs = {}
Expand Down
11 changes: 11 additions & 0 deletions tests/instantiate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,14 @@ def recisinstance(got: Any, expected: Any) -> bool:


an_object = object()

_counts = {}

def counter_function(key = None):
_counts[key] = _counts.get(key, 0) + 1
return _counts[key], key


def counter_function2(key=None):
_counts[key] = _counts.get(key, 0) + 1
return _counts[key], key, "counter_function2"
228 changes: 228 additions & 0 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,3 +2195,231 @@ class DictValuesConf:
cfg = OmegaConf.structured(DictValuesConf)
obj = instantiate_func(config=cfg)
assert obj.d is None


def test_instantiated_once_standard(
instantiate_func: Any,
) -> None:
# standard behiavior calls target every time
cfg = {
"_target_": "tests.instantiate.counter_function",
"key": "test1", # key is pass through as second return by counter_function
}
assert instantiate_func(cfg) == (1, "test1")
assert instantiate_func(cfg) == (2, "test1")


def test_instantiated_once_keyword(
instantiate_func: Any,
) -> None:
# once behavior calls target only once
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test2",
}

assert instantiate_func(cfg) == (1, "test2")
assert instantiate_func(cfg) == (1, "test2")


def test_instantiated_once_manual_key1(
instantiate_func: Any,
) -> None:
# With manual key, gives same value, even if config changes.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"_key_": "key1",
"key": "test3", # reusing key does not
}
assert instantiate_func(cfg) == (1, "test3")
assert instantiate_func(cfg) == (1, "test3")

cfg["key"] = "test3-changed"
cfg["disallowed_arg"] = "broken"
assert instantiate_func(cfg) == (1, "test3")
assert instantiate_func(cfg) == (1, "test3")


def test_instantiated_once_manual_key2(
instantiate_func: Any,
) -> None:
# With manual key, gives same value, even if config changes.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": "key_in_once_variable",
"key": "test3.1", # reusing key does not
}
assert instantiate_func(cfg) == (1, "test3.1")
assert instantiate_func(cfg) == (1, "test3.1")

cfg["key"] = "test3-changed"
cfg["disallowed_arg"] = "broken"
assert instantiate_func(cfg) == (1, "test3.1")
assert instantiate_func(cfg) == (1, "test3.1")


def test_instantiated_once_auto_key(
instantiate_func: Any,
) -> None:
# With auto key, change to the config makes new call.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test4", # reusing key does not
}
assert instantiate_func(cfg) == (1, "test4")
assert instantiate_func(cfg) == (1, "test4")

cfg["key"] = "test4-changed"
assert instantiate_func(cfg) == (1, "test4-changed")
assert instantiate_func(cfg) == (1, "test4-changed")


def test_instantiated_once_partial_change(
instantiate_func: Any,
) -> None:
# Changing _partial_ makes new signature for auto key.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test5",
}
instance = instantiate_func(cfg)
assert instantiate_func(cfg) == (1, "test5")

# setting as default does nothing
# this behaivior can be removed in the future if code base changes.
cfg["_partial_"] = False
assert instantiate_func(cfg) is instance

# changing busts the key
cfg["_partial_"] = True
assert instantiate_func(cfg) is not instance

instance = instantiate_func(cfg)
assert instance() == (2, "test5") # return is being called now!
assert instance() == (3, "test5")


def test_instantiated_once_recursive_change(
instantiate_func: Any,
) -> None:
# Changing _recursive_ makes new signature for auto key.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test6",
}
instance = instantiate_func(cfg)
assert instance is instantiate_func(cfg)

# setting as default does nothing
# this behaivior can be removed in the future if code base changes.
cfg["_recursive_"] = True
assert instance is instantiate_func(cfg)

# changing busts the key
cfg["_recursive_"] = False
assert instance is not instantiate_func(cfg)


def test_instantiated_once_convert_change(
instantiate_func: Any,
) -> None:
# Changing _convert_ makes new signature for auto key.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test7",
}
instance = instantiate_func(cfg)
assert instance is instantiate_func(cfg)

# setting as default does nothing
# this behaivior can be removed in the future if code base changes.
cfg["_convert_"] = "none"
assert instance is instantiate_func(cfg)

# changing busts the key
cfg["_convert_"] = "partial"
assert instance is not instantiate_func(cfg)


def test_instantiated_once_target_change(
instantiate_func: Any,
) -> None:
# Changing _recursive_ makes new signature for auto key.
cfg = {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test8",
}
assert instantiate_func(cfg) == (1, "test8")

cfg["_target_"] = "tests.instantiate.counter_function2"
assert instantiate_func(cfg) == (2, "test8", "counter_function2")
assert instantiate_func(cfg) == (2, "test8", "counter_function2")


def test_instantiated_once_nested(
instantiate_func: Any,
) -> None:
# Changing _recursive_ makes new signature for auto key.
cfg = {
"base": {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test9",
},
"ref1": "${base}",
"ref2": "${base}",
}
x = instantiate_func(cfg)
assert x.base == (1, "test9")
assert x.ref1 == (1, "test9")
assert x.ref2 == (1, "test9")

x = instantiate_func(cfg)
assert x.base == (1, "test9")
assert x.ref1 == (1, "test9")
assert x.ref2 == (1, "test9")


def test_instantiated_once_custom_cache(
instantiate_func: Any,
) -> None:
# Changing _recursive_ makes new signature for auto key.
cfg = {
"base": {
"_target_": "tests.instantiate.counter_function",
"_once_": True,
"key": "test10",
},
"ref1": "${base}",
"ref2": "${base}",
}

# you can specify a custom cache too.
cache = {}
x = instantiate_func(cfg, cache=cache)
assert x.base == (1, "test10")
assert x.ref1 == (1, "test10")
assert x.ref2 == (1, "test10")

x = instantiate_func(cfg, cache=cache)
assert x.base == (1, "test10")
assert x.ref1 == (1, "test10")
assert x.ref2 == (1, "test10")

# this way, the is cache is ephermeral.
x = instantiate_func(cfg, cache={})
assert x.base == (2, "test10")
assert x.ref1 == (2, "test10")
assert x.ref2 == (2, "test10")

x = instantiate_func(cfg, cache={})
assert x.base == (3, "test10")
assert x.ref1 == (3, "test10")
assert x.ref2 == (3, "test10")
Loading