Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable instantiation from globals #2518

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
}


def local_test_func():
return "hello world"


class TestUtils:
def test_get_component_from_path(self):
good_paths = [
Expand All @@ -52,6 +56,11 @@ def test_get_component_from_path(self):
):
_ = _get_component_from_path("torchtune.models.dummy")

# test that a local function instantiates
my_fn = _get_component_from_path("local_test_func")
output = my_fn()
assert output == "hello world", f"output == {output}, not hello world"

@mock.patch(
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
)
Expand Down
121 changes: 58 additions & 63 deletions torchtune/config/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
from argparse import Namespace
from importlib import import_module
from types import ModuleType
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from omegaconf import DictConfig, OmegaConf

Expand All @@ -34,79 +34,74 @@ def _has_component(node: Union[Dict[str, Any], DictConfig]) -> bool:
return (OmegaConf.is_dict(node) or isinstance(node, dict)) and "_component_" in node


def _get_component_from_path(path: str) -> Any:
def _get_component_from_path(
path: str, caller_globals: Optional[Dict[str, Any]] = None
) -> Any:
"""
Return an object by name or dotted path, importing as necessary.
The base functionality relies on ``getattr()`` and handles all
possible exceptions accordingly.
Resolve a Python object from a dotted path or simple name.

Based on Hydra's `_locate` from Facebook Research:
https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/utils.py#L614
Retrieves a module, class, or function from a string like `"os.path.join"` or `"os"`. For dotted paths,
it imports the module and gets the final attribute. For simple names, it imports the module or checks
`caller_globals` (defaults to `__main__` globals if not provided).

Args:
path (str): Dotted path of the object
path (str): Dotted path (e.g., "os.path.join") or simple name (e.g., "os").
caller_globals (Optional[Dict[str, Any]]): The caller's global namespace. Defaults to __main__ if None.

Returns:
Any: The object
Any: The resolved object (module, class, function, etc.).

Raises:
InstantiationError: If there is an exception loading the
object from the provided path
ValueError: If a relative or invalid dotpath is passed in
InstantiationError: If the path is empty or not a string or if the module or attribute cannot be found.

Examples:
>>> # Importing a class from a module
>>> _get_component_from_path("torch.nn.Linear")
<class 'torch.nn.modules.linear.Linear'>

>>> _get_component_from_path("torch")
<module 'torch' from '...'>

>>> # Assuming FooBar is defined in caller's globals
>>> _get_component_from_path("FooBar", globals())
<class 'SetupDataset'>

If globals() is not provided, it will default to __main__ globals.


"""
if path == "":
raise ValueError("Empty path")
if not path or not isinstance(path, str):
raise InstantiationError(f"Invalid path: '{path}'")

parts = [part for part in path.split(".")]
for part in parts:
# If a relative path is passed in, the first part will be empty
if not len(part):
raise ValueError(
f"Error loading '{path}': invalid dotstring."
+ "\nRelative imports are not supported."
)
# First module requires trying to import to validate
part0 = parts[0]
try:
obj = import_module(part0)
except ImportError as exc_import:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that module '{part0}' is installed?"
) from exc_import
# Subsequent components can be checked via getattr() on first module
# It can either be an attribute that we can return or a submodule that we
# can import and continue searching
for m in range(1, len(parts)):
part = parts[m]
parts = path.split(".")
search_globals = caller_globals or sys.modules["__main__"].__dict__

if len(parts) == 1:
name = parts[0]
# Try as a module first
try:
obj = getattr(obj, part)
# If getattr fails, check to see if it's a module we can import and
# continue down the path
except AttributeError as exc_attr:
parent_dotpath = ".".join(parts[:m])
if isinstance(obj, ModuleType):
mod = ".".join(parts[: m + 1])
try:
obj = import_module(mod)
continue
except ModuleNotFoundError as exc_import:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?"
) from exc_import
# Any other error trying to import module can be raised as
# InstantiationError
except Exception as exc_import:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
) from exc_import
# If the component is not an attribute nor a module, it doesn't exist
return import_module(name)
except ImportError:
# Fall back to globals
if name in search_globals:
return search_globals[name]
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_attr)}"
+ f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?"
) from exc_attr
return obj
f"Could not resolve '{name}': not a module and not found in globals"
) from None

module_path = ".".join(parts[:-1])
try:
module = import_module(module_path)
component = getattr(module, parts[-1])
return component
except ImportError as e:
raise InstantiationError(
f"Could not import module '{module_path}': {str(e)}"
) from e
except AttributeError:
raise InstantiationError(
f"Module '{module_path}' has no attribute '{parts[-1]}'"
) from None


def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictConfig:
Expand Down
Loading