diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index fe7da9f5c8..cf61989248 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -2,9 +2,11 @@ import copy import functools +import hashlib from enum import Enum from textwrap import dedent from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from functools import wraps from omegaconf import OmegaConf, SCMode from omegaconf._utils import is_structured_config @@ -22,6 +24,8 @@ class _Keys(str, Enum): RECURSIVE = "_recursive_" ARGS = "_args_" PARTIAL = "_partial_" + ONCE = "_once_" + KEY = "_key_" def _is_target(x: Any) -> bool: @@ -171,10 +175,41 @@ def _deep_copy_full_config(subconfig: Any) -> Any: return OmegaConf.select(full_config_copy, full_key) +_ONCE_STORAGE: Dict[str, Any] = {} + + +# could be exposed in public API if useful +def clear_instantiate_cache(): + _ONCE_STORAGE.clear() + + +def _once_storage_swap(func): + @wraps(func) + def wrapper(*args, cache=None, **kwargs): + global _ONCE_STORAGE + + if cache is None: + cache = _ONCE_STORAGE + + OLD = _ONCE_STORAGE + _ONCE_STORAGE = cache + try: + # Call the original function + result = func(*args, **kwargs) + finally: + # Restore the original _ONCE_STORAGE + _ONCE_STORAGE = OLD + # Return the result of the original function + return result + + return wrapper + +@_once_storage_swap def instantiate( config: Any, *args: Any, _skip_instantiate_full_deepcopy_: bool = False, + cache: Union[Dict[str, Any], None] = None, #implemented in decorator **kwargs: Any, ) -> Any: """ @@ -199,10 +234,16 @@ def instantiate( are converted to dicts / lists too. _partial_: If True, return functools.partial wrapped method or object False by default. Configure per target. + _once_: If True, instantiate the target only once and return the same + instance on subsequent calls. + _key_: If set, this used to identify the target in the 'once' cache. + Note required in most cases. :param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead of full config before resolving omegaconf interpolations, which may potentially modify the config's parent/sibling configs in place. False by default. + :param cache: Optional cache to use for once storage. Pass '{}' to discard the cache + between different calls to instantiate. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -317,13 +358,14 @@ def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any: ) return node - +@_once_storage_swap def instantiate_node( node: Any, *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, partial: bool = False, + cache: Union[Dict[str, Any], None] = None, # implemented in decorator ) -> Any: # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): @@ -349,7 +391,7 @@ def instantiate_node( raise TypeError(msg) if not isinstance(partial, bool): - msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}" + msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}" if node and full_key: msg += f"\nfull_key: {full_key}" raise TypeError(msg) @@ -371,7 +413,34 @@ def instantiate_node( return lst elif OmegaConf.is_dict(node): - exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) + # Use cached return if once is True and it exists in the cache. + if "_once_" in node: + once = node.pop(_Keys.ONCE) + if _Keys.KEY in node: + once_key = node.pop(_Keys.KEY) + elif once is not True: + once_key = once + else: + once_key = OmegaConf.to_yaml(node) + if recursive != True: + once_key = f"recursive: ${recursive}\n\n{once_key}" + if convert != ConvertMode.NONE: + once_key = f"convert: ${convert}\n\n{once_key}" + if partial != True: + once_key = f"partial: ${partial}\n\n{once_key}" + once_key = hashlib.md5(once_key.encode()).hexdigest() + + if once_key in _ONCE_STORAGE: + return _ONCE_STORAGE[once_key] + else: + _ONCE_STORAGE[once_key] = instantiate_node( + node, *args, convert=convert, recursive=recursive, partial=partial + ) + return _ONCE_STORAGE[once_key] + + exclude_keys = set( + {"_target_", "_convert_", "_recursive_", "_partial_", "_once_", "_key_"} + ) if _is_target(node): _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) kwargs = {} diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index e4afaec733..6a2a4b3951 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -451,3 +451,14 @@ def recisinstance(got: Any, expected: Any) -> bool: an_object = object() + +_counts = {} + +def counter_function(key = None): + _counts[key] = _counts.get(key, 0) + 1 + return _counts[key], key + + +def counter_function2(key=None): + _counts[key] = _counts.get(key, 0) + 1 + return _counts[key], key, "counter_function2" diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index f311271fba..051c6d4eeb 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -2195,3 +2195,231 @@ class DictValuesConf: cfg = OmegaConf.structured(DictValuesConf) obj = instantiate_func(config=cfg) assert obj.d is None + + +def test_instantiated_once_standard( + instantiate_func: Any, +) -> None: + # standard behiavior calls target every time + cfg = { + "_target_": "tests.instantiate.counter_function", + "key": "test1", # key is pass through as second return by counter_function + } + assert instantiate_func(cfg) == (1, "test1") + assert instantiate_func(cfg) == (2, "test1") + + +def test_instantiated_once_keyword( + instantiate_func: Any, +) -> None: + # once behavior calls target only once + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test2", + } + + assert instantiate_func(cfg) == (1, "test2") + assert instantiate_func(cfg) == (1, "test2") + + +def test_instantiated_once_manual_key1( + instantiate_func: Any, +) -> None: + # With manual key, gives same value, even if config changes. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "_key_": "key1", + "key": "test3", # reusing key does not + } + assert instantiate_func(cfg) == (1, "test3") + assert instantiate_func(cfg) == (1, "test3") + + cfg["key"] = "test3-changed" + cfg["disallowed_arg"] = "broken" + assert instantiate_func(cfg) == (1, "test3") + assert instantiate_func(cfg) == (1, "test3") + + +def test_instantiated_once_manual_key2( + instantiate_func: Any, +) -> None: + # With manual key, gives same value, even if config changes. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": "key_in_once_variable", + "key": "test3.1", # reusing key does not + } + assert instantiate_func(cfg) == (1, "test3.1") + assert instantiate_func(cfg) == (1, "test3.1") + + cfg["key"] = "test3-changed" + cfg["disallowed_arg"] = "broken" + assert instantiate_func(cfg) == (1, "test3.1") + assert instantiate_func(cfg) == (1, "test3.1") + + +def test_instantiated_once_auto_key( + instantiate_func: Any, +) -> None: + # With auto key, change to the config makes new call. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test4", # reusing key does not + } + assert instantiate_func(cfg) == (1, "test4") + assert instantiate_func(cfg) == (1, "test4") + + cfg["key"] = "test4-changed" + assert instantiate_func(cfg) == (1, "test4-changed") + assert instantiate_func(cfg) == (1, "test4-changed") + + +def test_instantiated_once_partial_change( + instantiate_func: Any, +) -> None: + # Changing _partial_ makes new signature for auto key. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test5", + } + instance = instantiate_func(cfg) + assert instantiate_func(cfg) == (1, "test5") + + # setting as default does nothing + # this behaivior can be removed in the future if code base changes. + cfg["_partial_"] = False + assert instantiate_func(cfg) is instance + + # changing busts the key + cfg["_partial_"] = True + assert instantiate_func(cfg) is not instance + + instance = instantiate_func(cfg) + assert instance() == (2, "test5") # return is being called now! + assert instance() == (3, "test5") + + +def test_instantiated_once_recursive_change( + instantiate_func: Any, +) -> None: + # Changing _recursive_ makes new signature for auto key. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test6", + } + instance = instantiate_func(cfg) + assert instance is instantiate_func(cfg) + + # setting as default does nothing + # this behaivior can be removed in the future if code base changes. + cfg["_recursive_"] = True + assert instance is instantiate_func(cfg) + + # changing busts the key + cfg["_recursive_"] = False + assert instance is not instantiate_func(cfg) + + +def test_instantiated_once_convert_change( + instantiate_func: Any, +) -> None: + # Changing _convert_ makes new signature for auto key. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test7", + } + instance = instantiate_func(cfg) + assert instance is instantiate_func(cfg) + + # setting as default does nothing + # this behaivior can be removed in the future if code base changes. + cfg["_convert_"] = "none" + assert instance is instantiate_func(cfg) + + # changing busts the key + cfg["_convert_"] = "partial" + assert instance is not instantiate_func(cfg) + + +def test_instantiated_once_target_change( + instantiate_func: Any, +) -> None: + # Changing _recursive_ makes new signature for auto key. + cfg = { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test8", + } + assert instantiate_func(cfg) == (1, "test8") + + cfg["_target_"] = "tests.instantiate.counter_function2" + assert instantiate_func(cfg) == (2, "test8", "counter_function2") + assert instantiate_func(cfg) == (2, "test8", "counter_function2") + + +def test_instantiated_once_nested( + instantiate_func: Any, +) -> None: + # Changing _recursive_ makes new signature for auto key. + cfg = { + "base": { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test9", + }, + "ref1": "${base}", + "ref2": "${base}", + } + x = instantiate_func(cfg) + assert x.base == (1, "test9") + assert x.ref1 == (1, "test9") + assert x.ref2 == (1, "test9") + + x = instantiate_func(cfg) + assert x.base == (1, "test9") + assert x.ref1 == (1, "test9") + assert x.ref2 == (1, "test9") + + +def test_instantiated_once_custom_cache( + instantiate_func: Any, +) -> None: + # Changing _recursive_ makes new signature for auto key. + cfg = { + "base": { + "_target_": "tests.instantiate.counter_function", + "_once_": True, + "key": "test10", + }, + "ref1": "${base}", + "ref2": "${base}", + } + + # you can specify a custom cache too. + cache = {} + x = instantiate_func(cfg, cache=cache) + assert x.base == (1, "test10") + assert x.ref1 == (1, "test10") + assert x.ref2 == (1, "test10") + + x = instantiate_func(cfg, cache=cache) + assert x.base == (1, "test10") + assert x.ref1 == (1, "test10") + assert x.ref2 == (1, "test10") + + # this way, the is cache is ephermeral. + x = instantiate_func(cfg, cache={}) + assert x.base == (2, "test10") + assert x.ref1 == (2, "test10") + assert x.ref2 == (2, "test10") + + x = instantiate_func(cfg, cache={}) + assert x.base == (3, "test10") + assert x.ref1 == (3, "test10") + assert x.ref2 == (3, "test10") diff --git a/website/docs/advanced/instantiate_objects/overview.md b/website/docs/advanced/instantiate_objects/overview.md index 5030d3d27f..4cdb2ec6e6 100644 --- a/website/docs/advanced/instantiate_objects/overview.md +++ b/website/docs/advanced/instantiate_objects/overview.md @@ -380,6 +380,42 @@ assert bar1.foo is bar2.foo # the `Foo` instance is re-used here This does not apply if `_partial_=False`, in which case a new `Foo` instance would be created with each call to `instantiate`. +### Once-only instantiation + +If you want to ensure that a given object is instantiated only once and reused, +set the `_once_` key to `True` in the config. Hydra will cache the object based +on an md5 hash of the config yaml. +```yaml +# config.yaml +foo: + _target_: my_app.Foo + _once_: true + +bar: + foo1: ${foo} + foo2: ${foo} +``` +Now, instantiate will render `foo`, `bar.foo1` and `bar.foo2` as the same object, +even when calling instantiate multiple times or in nested invocations. + +To use a manaully specified key, set the `_key_` to a unique string value. +This usually should not be necessary. + +```yaml +# config.yaml +foo: + _target_: my_app.Foo + _once_: true + _key_: my_unique_key +bar: + foo1: ${foo} + foo2: ${foo} +``` + +Here also, +All the other instantiation parameters are still valid, including `_recursive_`, + `_partial_`, and `_convert_`. + ### Instantiation of builtins