Skip to content

Support resolving via "from ... import ..." in normal python modules #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions typeshed_client/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
from functools import lru_cache
from importlib.util import find_spec
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -103,6 +104,12 @@ def get_stub_file(
return get_stub_file_name(ModulePath(tuple(module_name.split("."))), search_context)


def get_py_module_file(module_name: ModulePath) -> Optional[Path]:
"""Return the path to the python file for this module, if any."""
spec = find_spec(".".join(module_name))
return Path(spec.origin) if spec and spec.origin else None


def get_stub_ast(
module_name: str, *, search_context: Optional[SearchContext] = None
) -> Optional[ast.Module]:
Expand Down
54 changes: 54 additions & 0 deletions typeshed_client/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NamedTuple,
NoReturn,
Optional,
Set,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -438,3 +439,56 @@ def _warn(message: str, ctx: SearchContext) -> None:
raise InvalidStub(message)
else:
log.warning(message)


def get_py_imported_name_sources(
module_name: ModulePath,
) -> Tuple[Dict[str, str], Set[str]]:
"""Given a python module path, return named and star import sources."""
path = finder.get_py_module_file(module_name)
if path is None:
return {}, set()
ast = parse_stub_file(path)
visitor = _PyImportedSourcesExtractor(module_name)
visitor.visit(ast)
return visitor.named_imports, visitor.star_imports


class _PyImportedSourcesExtractor(ast.NodeVisitor):
"""Extract imported sources from a normal python module."""

def __init__(self, module_path: ModulePath) -> None:
self.module_path = module_path
self.named_imports: Dict[str, str] = {}
self.star_imports: Set[str] = set()

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.level:
# Relative import
module_path = self.module_path[: len(self.module_path) - node.level + 1]
if node.module:
module_path = module_path + tuple(node.module.split("."))
import_module_path: str = ".".join(module_path)
else:
# Absolute import
import_module_path = node.module # type: ignore[assignment] # always str when level=0

if len(node.names) == 1 and node.names[0].name == "*":
# Star import
self.star_imports.add(import_module_path)
else:
# Named import
for alias in node.names:
self.named_imports[alias.asname or alias.name] = import_module_path

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
# Skip visiting function definitions
pass

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
# Skip visiting async function definitions
pass

def visit_ClassDef(self, node: ast.ClassDef) -> None:
# Skip visiting class definitions
pass
34 changes: 32 additions & 2 deletions typeshed_client/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ class ImportedInfo(NamedTuple):


class Resolver:
def __init__(self, search_context: Optional[SearchContext] = None) -> None:
def __init__(
self,
search_context: Optional[SearchContext] = None,
search_py_files: bool = False,
) -> None:
if search_context is None:
search_context = get_search_context()
self.ctx = search_context
self.search_py_files = search_py_files
self._module_cache: Dict[ModulePath, Module] = {}

def get_module(self, module_name: ModulePath) -> "Module":
Expand All @@ -34,7 +39,32 @@ def get_module(self, module_name: ModulePath) -> "Module":

def get_name(self, module_name: ModulePath, name: str) -> ResolvedName:
module = self.get_module(module_name)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the way I'd want to do this is more in get_module. Currently, that will only look for pyi files, but it should look for py files too as a fallback. Basically, the library should follow https://typing.python.org/en/latest/spec/distributing.html#import-resolution-ordering and look at all the places data can come from, both pyi and py files.

Then when we parse the module, we should probably do it in a different 'mode' than we do parsing .pyi files, because when parsing .py files we need to be more forgiving about unrecognized AST nodes.

We may also need to put this new functionality behind an off-by-default flag so we don't interfere with existing users that look for .py files separately already.

Copy link
Author

@mauvilsa mauvilsa Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also need to put this new functionality behind an off-by-default flag so we don't interfere with existing users that look for .py files separately already.

Done.

I think the way I'd want to do this is more in get_module. Currently, that will only look for pyi files, but it should look for py files too as a fallback. Basically, the library should follow https://typing.python.org/en/latest/spec/distributing.html#import-resolution-ordering and look at all the places data can come from, both pyi and py files.

I looked at this but unfortunately did not understand. After get_module is called, it is not known yet if name is defined in that module. This is only known after module.get_name(name, self) is called. Please look again and explain a bit more how you think it should be changed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then when we parse the module, we should probably do it in a different 'mode' than we do parsing .pyi files, because when parsing .py files we need to be more forgiving about unrecognized AST nodes.

This I am already doing with the _PyImportedSourcesExtractor visitor I implemented. Does this look okay?

return module.get_name(name, self)
resolved_name = module.get_name(name, self)
if resolved_name is None and self.search_py_files:
resolved_name = self._get_py_imported_name(module_name, name)
return resolved_name

def _get_py_imported_name(self, module_name: ModulePath, name: str) -> ResolvedName:
resolved_name = None
try:
named_imports, star_imports = parser.get_py_imported_name_sources(
module_name
)
for star_import in star_imports:
import_source = ModulePath(tuple(star_import.split(".")))
if import_source != module_name:
# Try to resolve stub from star import
resolved_name = self.get_name(import_source, name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can happen that a name imported from one module, is also imported in that module. So this needs to be recursive.

if resolved_name is not None:
break
if resolved_name is None and name in named_imports:
import_source = ModulePath(tuple(named_imports[name].split(".")))
if import_source != module_name:
# Try to resolve stub from named import
resolved_name = self.get_name(import_source, name)
except Exception:
pass
return resolved_name

def get_fully_qualified_name(self, name: str) -> ResolvedName:
"""Public API."""
Expand Down