Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable instantiation from globals #2518

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,73 @@
}


# Test local function
def local_fn():
return "hello world"


class TestUtils:
def test_get_component_from_path(self):
good_paths = [
"torchtune", # Test single module without dot
"torchtune.models", # Test dotpath for a module
"torchtune.models.llama2.llama2_7b", # Test dotpath for an object
# Test valid paths with three components
valid_paths = [
"torchtune",
"os.path.join",
]
for path in valid_paths:
result = _get_component_from_path(path)
assert result is not None, f"Failed to resolve valid path '{path}'"

# test callable
path = "torchtune.models.llama2.llama2_7b"
result = _get_component_from_path(path)
assert callable(result), f"Resolved '{path}' is not callable"

# simulate call from globals
fn = _get_component_from_path("local_fn")
output = fn()
assert output == "hello world", f"Got {output=}. Expected 'hello world'."

# Test empty path
with pytest.raises(InstantiationError, match="Invalid path: ''"):
_get_component_from_path("")

# Test non-string path
with pytest.raises(InstantiationError, match="Invalid path: '123'"):
_get_component_from_path(123)

# Test relative imports
relative_paths = [
".test.module", # Leading dot
"test.module.", # Trailing dot
"test..module", # Consecutive dots
]
for path in good_paths:
_ = _get_component_from_path(path)
for path in relative_paths:
with pytest.raises(
ValueError,
match="Invalid dotstring. Relative imports are not supported.",
):
_get_component_from_path(path)

# Test non-existent components
# Single-part path not found
with pytest.raises(
InstantiationError,
match=r"Could not resolve 'nonexistent': not a module and not found in the caller's globals\.",
):
_get_component_from_path("nonexistent")

# Multi-part path with import failure
with pytest.raises(
InstantiationError, match=r"Could not import module 'os\.nonexistent': .*"
):
_get_component_from_path("os.nonexistent.attr")

# Test that a relative path fails
with pytest.raises(ValueError, match="Relative imports are not supported"):
_ = _get_component_from_path(".test")
# Test that a non-existent path fails
# Multi-part path with attribute error
with pytest.raises(
InstantiationError, match="Error loading 'torchtune.models.dummy'"
InstantiationError,
match=r"Module 'os\.path' has no attribute 'nonexistent'",
):
_ = _get_component_from_path("torchtune.models.dummy")
_get_component_from_path("os.path.nonexistent")

@mock.patch(
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
Expand Down
135 changes: 72 additions & 63 deletions torchtune/config/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from argparse import Namespace
from importlib import import_module
from types import ModuleType
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from omegaconf import DictConfig, OmegaConf

Expand All @@ -34,79 +34,88 @@ def _has_component(node: Union[Dict[str, Any], DictConfig]) -> bool:
return (OmegaConf.is_dict(node) or isinstance(node, dict)) and "_component_" in node


def _get_component_from_path(path: str) -> Any:
def _get_component_from_path(
path: str, caller_globals: Optional[Dict[str, Any]] = None
) -> Any:
"""
Return an object by name or dotted path, importing as necessary.
The base functionality relies on ``getattr()`` and handles all
possible exceptions accordingly.
Resolve a Python object from a dotted path or simple name.

Based on Hydra's `_locate` from Facebook Research:
https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/utils.py#L614
Retrieves a module, class, or function from a string like `"os.path.join"` or `"os"`. For dotted paths,
it imports the module and gets the final attribute. For simple names, it imports the module or checks
`caller_globals` (defaults to `__main__` globals if not provided).

Args:
path (str): Dotted path of the object
path (str): Dotted path (e.g., "os.path.join") or simple name (e.g., "os").
caller_globals (Optional[Dict[str, Any]]): The caller's global namespace. Defaults to __main__ if None.

Returns:
Any: The object
Any: The resolved object (module, class, function, etc.).

Raises:
InstantiationError: If there is an exception loading the
object from the provided path
ValueError: If a relative or invalid dotpath is passed in
InstantiationError: If the path is empty, not a string, or if the module/attribute cannot be resolved.
ValueError: If the path contains invalid dotstrings (e.g., relative imports like ".test" or "test..path").

Examples:
>>> _get_component_from_path("torch.nn.Linear")
<class 'torch.nn.modules.linear.Linear'>
>>> _get_component_from_path("torch")
<module 'torch' from '...'>
>>> # Assuming FooBar is in caller's globals
>>> _get_component_from_path("FooBar")
<class 'FooBar'>
"""
if path == "":
raise ValueError("Empty path")
if not path or not isinstance(path, str):
raise InstantiationError(f"Invalid path: '{path}'")

# Check for ".test", "test..path", "test..", etc.
parts = path.split(".")
if any(not part for part in parts):
raise ValueError(
f"Invalid dotstring. Relative imports are not supported. Got {path=}."
)

# single part, e.g. "torch" or "my_local_fn"
if len(parts) == 1:
name = parts[0]
try:
# try to import as a module, e.g. "torch"
return import_module(name)
except ImportError:
# if caller_globals is None, collect __main__ globals of the caller
search_globals = caller_globals if caller_globals is not None else {}
if caller_globals is None:
current_frame = inspect.currentframe()
if current_frame and current_frame.f_back:
search_globals = current_frame.f_back.f_globals

# check if local_fn is in caller_globals, e.g. "my_local_fn"
if name in search_globals:
return search_globals[name]
else:
# scope to differentiate between provided globals and caller's globals in error message
scope = (
"the provided globals"
if caller_globals is not None
else "the caller's globals"
)
raise InstantiationError(
f"Could not resolve '{name}': not a module and not found in {scope}."
) from None

parts = [part for part in path.split(".")]
for part in parts:
# If a relative path is passed in, the first part will be empty
if not len(part):
raise ValueError(
f"Error loading '{path}': invalid dotstring."
+ "\nRelative imports are not supported."
)
# First module requires trying to import to validate
part0 = parts[0]
# multiple parts, e.g. "torch.nn.Linear"
module_path = ".".join(parts[:-1])
try:
obj = import_module(part0)
except ImportError as exc_import:
module = import_module(module_path)
component = getattr(module, parts[-1])
return component
except ImportError as e:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that module '{part0}' is installed?"
) from exc_import
# Subsequent components can be checked via getattr() on first module
# It can either be an attribute that we can return or a submodule that we
# can import and continue searching
for m in range(1, len(parts)):
part = parts[m]
try:
obj = getattr(obj, part)
# If getattr fails, check to see if it's a module we can import and
# continue down the path
except AttributeError as exc_attr:
parent_dotpath = ".".join(parts[:m])
if isinstance(obj, ModuleType):
mod = ".".join(parts[: m + 1])
try:
obj = import_module(mod)
continue
except ModuleNotFoundError as exc_import:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?"
) from exc_import
# Any other error trying to import module can be raised as
# InstantiationError
except Exception as exc_import:
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_import)}"
) from exc_import
# If the component is not an attribute nor a module, it doesn't exist
raise InstantiationError(
f"Error loading '{path}':\n{repr(exc_attr)}"
+ f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?"
) from exc_attr
return obj
f"Could not import module '{module_path}': {str(e)}."
) from e
except AttributeError as e:
raise InstantiationError(
f"Module '{module_path}' has no attribute '{parts[-1]}'."
) from e


def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictConfig:
Expand Down