Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nested instantiation #2519

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/torchtune/config/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,29 @@ def test_tokenizer_config_with_null(self):

tokenizer = instantiate(config.tokenizer)
assert tokenizer.max_seq_len is None

def test_nested_instantiation(self) -> None:
class Foo:
def __init__(self, bar):
self.bar = bar

def __call__(self, x):
return self.bar(x)

class Bar:
def __call__(self, x):
return x + 1

s = dedent(
"""\
foo:
_component_: foo
bar:
_component_: bar
"""
)
config = OmegaConf.create(s)

foo = instantiate(config.foo)
output = foo(1)
assert output == 2, f"Foo should call bar and return 1+1. Got {output} instead."
144 changes: 93 additions & 51 deletions torchtune/config/_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,85 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import sys
from typing import Any, Callable, Dict, Tuple

from omegaconf import DictConfig, OmegaConf
from torchtune.config._errors import InstantiationError
from torchtune.config._utils import _get_component_from_path, _has_component
from torchtune.config._utils import _get_component_from_path


def _create_component(
_component_: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""Create an instance of a component with given arguments."""
return _component_(*args, **kwargs)


def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any:
def _instantiate_node(config_dict: Dict[str, Any], *args: Any) -> Any:
"""
Creates the object specified in _component_ field with provided positional args
and kwargs already merged. Raises an InstantiationError if _component_ is not specified.
Instantiate a component from a config dictionary.

If the dictionary has a '_component_' field, retrieve the component, process
any nested arguments, and create the object with the given positional args.

Args:
config_dict (Dict[str, Any]): Config dictionary with '_component_' and arguments.
*args (Any): Positional arguments for the component.

Returns:
Any: The instantiated object.

Examples:
>>> class Spice:
>>> def __init__(self, heat_level):
>>> self.heat_level = heat_level
>>> class Food:
>>> def __init__(self, seed, ingredient):
>>> self.seed = seed
>>> self.ingredient = ingredient
>>> config_dict = {'_component_': 'Food', 'seed': 42,
>>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5}}
>>> food = _instantiate_node(config_dict)
>>> print(food.seed) # 42
>>> print(food.ingredient.heat_level) # 5

Raises:
InstantiationError: If '_component_' is missing.
"""
if _has_component(node):
_component_ = _get_component_from_path(node.get("_component_"))
kwargs = {k: v for k, v in node.items() if k != "_component_"}
if "_component_" in config_dict:
_component_ = _get_component_from_path(config_dict["_component_"])
kwargs = {
k: _instantiate_nested(v)
for k, v in config_dict.items()
if k != "_component_"
}
return _create_component(_component_, args, kwargs)
else:
raise InstantiationError(
"Cannot instantiate specified object."
+ "\nMake sure you've specified a _component_ field with a valid dotpath."
)
raise InstantiationError("Cannot instantiate: '_component_' field is missing.")


def _instantiate_nested(obj: Any) -> Any:
"""
Processes dictionaries and lists to recursively instantiate any nested '_component_' fields.

Args:
obj (Any): Object to process (dict, list, or other).

Returns:
Any: Object with nested components instantiated.
"""
if isinstance(obj, dict):
if "_component_" in obj:
config = OmegaConf.create(obj)
return instantiate(config)
return {k: _instantiate_nested(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_instantiate_nested(item) for item in obj]
return obj


def instantiate(
Expand All @@ -44,54 +90,53 @@ def instantiate(
**kwargs: Any,
) -> Any:
"""
Given a DictConfig with a _component_ field specifying the object to instantiate and
additional fields for keyword arguments, create an instance of the specified object.
You can use this function to create the exact instance of a torchtune object you want
to use in your recipe using the specification from the config.
Instantiate a component from a configuration, recursively handling nested components.

This function also supports passing in positional args and keyword args within the
function call. These are automatically merged with the provided config, with keyword
args taking precedence.
Given a DictConfig with a '_component_' field specifying the object to instantiate and
additional fields as keyword arguments, create an instance of the specified object.
Positional and keyword arguments passed in the call are merged with the config, with
keyword arguments taking precedence.

Based on Hydra's `instantiate` utility from Facebook Research:
https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/instantiate/_instantiate2.py#L148
Based on Hydra's `instantiate` utility.

Args:
config (DictConfig): a single field in the OmegaConf object parsed from the yaml file.
This is expected to have a _component_ field specifying the path of the object
to instantiate.
*args (Any): positional arguments to pass to the object to instantiate.
**kwargs (Any): keyword arguments to pass to the object to instantiate.

Examples:
>>> config.yaml:
>>> model:
>>> _component_: torchtune.models.llama2
>>> num_layers: 32
>>> num_heads: 32
>>> num_kv_heads: 32

>>> from torchtune import config
>>> vocab_size = 32000
>>> # Pass in vocab size as positional argument. Since it is positioned first
>>> # in llama2(), it must be specified first. Pass in other arguments as kwargs.
>>> # This will return an nn.Module directly for llama2 with specified args.
>>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096)
config (DictConfig): Configuration with '_component_' and optional arguments.
*args (Any): Positional arguments for the component.
**kwargs (Any): Keyword arguments to override or add to the config.

Returns:
Any: the instantiated object.
Any: The instantiated object, or None if config is None.

Examples:
>>> class Spice:
>>> def __init__(self, heat_level):
>>> self.heat_level = heat_level
>>> class Food:
>>> def __init__(self, seed, ingredient):
>>> self.seed = seed
>>> self.ingredient = ingredient
>>> config = OmegaConf.create({
>>> '_component_': 'Food',
>>> 'seed': 0,
>>> 'ingredient': {'_component_': 'Spice', 'heat_level': 5}
>>> })
>>> food = instantiate(config, seed=42)
>>> print(food.seed) # 42
>>> print(food.ingredient.heat_level) # 5
>>> new_spice = {'_component_': 'Spice', 'heat_level': 10}
>>> food = instantiate(config, ingredient=new_spice)
>>> print(food.ingredient.heat_level) # 10

Raises:
ValueError: if config is not a DictConfig.
"""
ValueError: If config is not a DictConfig.

# Return None if config is None
Note: Modifies sys.path to include the current working directory for local imports.
"""
if config is None:
return None
if not OmegaConf.is_dict(config):
raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}")

# Ensure local imports are able to be instantiated
if os.getcwd() not in sys.path:
sys.path.append(os.getcwd())

Expand All @@ -103,10 +148,7 @@ def instantiate(
config = config_copy

if kwargs:
# This overwrites any repeated fields in the config with kwargs
config = OmegaConf.merge(config, kwargs)

# Resolve all interpolations, or references to other fields within the same config
OmegaConf.resolve(config)

return _instantiate_node(OmegaConf.to_object(config), *args)
return _instantiate_node(OmegaConf.to_container(config, resolve=True), *args)
Loading