Skip to content

Commit 6bbb840

Browse files
committed
Improve NewType inference for string forward refs
1 parent 3609469 commit 6bbb840

File tree

2 files changed

+263
-14
lines changed

2 files changed

+263
-14
lines changed

astroid/brain/brain_typing.py

+97-14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from astroid.builder import _extract_single_node
1515
from astroid.const import PY38_PLUS, PY39_PLUS
1616
from astroid.exceptions import (
17+
AstroidImportError,
1718
AttributeInferenceError,
1819
InferenceError,
1920
UseInferenceDefault,
@@ -171,27 +172,109 @@ def infer_typing_newtype(
171172
)
172173
new_node.parent = node.parent
173174

174-
# Base type arg is a normal reference, so no need to do special lookups
175-
if not isinstance(base, nodes.Const):
176-
new_node.postinit(
177-
bases=[base], body=new_node.body, decorators=new_node.decorators
178-
)
175+
new_bases: list[NodeNG] = []
179176

180-
# If the base type is given as a string (e.g. for a forward reference),
181-
# make a naive attempt to find the corresponding node.
182-
# Note that this will not work with imported types.
183-
if isinstance(base, nodes.Const) and isinstance(base.value, str):
177+
if not isinstance(base, nodes.Const):
178+
# Base type arg is a normal reference, so no need to do special lookups
179+
new_bases = [base]
180+
elif isinstance(base, nodes.Const) and isinstance(base.value, str):
181+
# If the base type is given as a string (e.g. for a forward reference),
182+
# make a naive attempt to find the corresponding node.
184183
_, resolved_base = node.frame().lookup(base_name)
185184
if resolved_base:
186-
new_node.postinit(
187-
bases=[resolved_base[0]],
188-
body=new_node.body,
189-
decorators=new_node.decorators,
190-
)
185+
base_node = resolved_base[0]
186+
187+
# If the value is from an "import from" statement, follow the import chain
188+
if isinstance(base_node, nodes.ImportFrom):
189+
ctx = ctx.clone() if ctx else context.InferenceContext()
190+
ctx.lookupname = base_name
191+
base_node = next(base_node.infer(context=ctx))
192+
193+
new_bases = [base_node]
194+
elif "." in base.value:
195+
possible_base = _try_find_imported_object_from_str(node, base.value, ctx)
196+
if possible_base:
197+
new_bases = [possible_base]
198+
199+
if new_bases:
200+
new_node.postinit(
201+
bases=new_bases, body=new_node.body, decorators=new_node.decorators
202+
)
191203

192204
return new_node.infer(context=ctx)
193205

194206

207+
def _try_find_imported_object_from_str(
208+
node: nodes.Call,
209+
name: str,
210+
ctx: context.InferenceContext | None,
211+
) -> nodes.NodeNG | None:
212+
for statement_mod_name, _ in _possible_module_object_splits(name):
213+
# Find import statements that may pull in the appropriate modules
214+
# The name used to find this statement may not correspond to the name of the module actually being imported
215+
# For example, "import email.charset" is found by lookup("email")
216+
_, resolved_bases = node.frame().lookup(statement_mod_name)
217+
if not resolved_bases:
218+
continue
219+
220+
resolved_base = resolved_bases[0]
221+
if isinstance(resolved_base, nodes.Import):
222+
# Extract the names of the module as they are accessed from actual code
223+
scope_names = {(alias or name) for (name, alias) in resolved_base.names}
224+
aliases = {alias: name for (name, alias) in resolved_base.names if alias}
225+
226+
# Find potential mod_name, obj_name splits that work with the available names
227+
# for the module in this scope
228+
import_targets = [
229+
(mod_name, obj_name)
230+
for (mod_name, obj_name) in _possible_module_object_splits(name)
231+
if mod_name in scope_names
232+
]
233+
if not import_targets:
234+
continue
235+
236+
import_target, name_in_mod = import_targets[0]
237+
import_target = aliases.get(import_target, import_target)
238+
239+
# Try to import the module and find the object in it
240+
try:
241+
resolved_mod: nodes.Module = resolved_base.do_import_module(
242+
import_target
243+
)
244+
except AstroidImportError:
245+
# If the module doesn't actually exist, try the next option
246+
continue
247+
248+
# Try to find the appropriate ClassDef or other such node in the target module
249+
_, object_results_in_mod = resolved_mod.lookup(name_in_mod)
250+
if not object_results_in_mod:
251+
continue
252+
253+
base_node = object_results_in_mod[0]
254+
255+
# If the value is from an "import from" statement, follow the import chain
256+
if isinstance(base_node, nodes.ImportFrom):
257+
ctx = ctx.clone() if ctx else context.InferenceContext()
258+
ctx.lookupname = name_in_mod
259+
base_node = next(base_node.infer(context=ctx))
260+
261+
return base_node
262+
263+
return None
264+
265+
266+
def _possible_module_object_splits(
267+
dot_str: str,
268+
) -> Iterator[tuple[str, str]]:
269+
components = dot_str.split(".")
270+
popped = []
271+
272+
while components:
273+
popped.append(components.pop())
274+
275+
yield ".".join(components), ".".join(reversed(popped))
276+
277+
195278
def _looks_like_typing_subscript(node):
196279
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
197280
if isinstance(node, Name):

tests/unittest_brain.py

+166
Original file line numberDiff line numberDiff line change
@@ -1837,6 +1837,172 @@ def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
18371837
with self.assertRaises(AttributeInferenceError):
18381838
inferred.getattr("bad_attr")
18391839

1840+
def test_typing_newtype_forward_reference_imported(self) -> None:
1841+
all_ast_nodes = builder.extract_node(
1842+
"""
1843+
from typing import NewType
1844+
1845+
A = NewType("A", "decimal.Decimal")
1846+
B = NewType("B", "decimal_mod_alias.Decimal")
1847+
C = NewType("C", "Decimal")
1848+
D = NewType("D", "DecimalAlias")
1849+
1850+
import decimal
1851+
import decimal as decimal_mod_alias
1852+
from decimal import Decimal
1853+
from decimal import Decimal as DecimalAlias
1854+
1855+
Decimal #@
1856+
1857+
a = A(decimal.Decimal(2))
1858+
a #@
1859+
b = B(decimal_mod_alias.Decimal(2))
1860+
b #@
1861+
c = C(Decimal(2))
1862+
c #@
1863+
d = D(DecimalAlias(2))
1864+
d #@
1865+
"""
1866+
)
1867+
assert isinstance(all_ast_nodes, list)
1868+
1869+
real_dec, *ast_nodes = all_ast_nodes
1870+
1871+
real_quantize = next(real_dec.infer()).getattr("quantize")
1872+
1873+
for node in ast_nodes:
1874+
all_inferred = list(node.infer())
1875+
assert len(all_inferred) == 1
1876+
inferred = all_inferred[0]
1877+
assert isinstance(inferred, astroid.Instance)
1878+
1879+
assert inferred.getattr("quantize") == real_quantize
1880+
1881+
def test_typing_newtype_forward_ref_bad_base(self) -> None:
1882+
ast_nodes = builder.extract_node(
1883+
"""
1884+
from typing import NewType
1885+
1886+
A = NewType("A", "DoesntExist")
1887+
1888+
a = A()
1889+
a #@
1890+
1891+
# Valid name, but not actually imported
1892+
B = NewType("B", "decimal.Decimal")
1893+
1894+
b = B()
1895+
b #@
1896+
1897+
# AST works out, but can't import the module
1898+
import not_a_real_module
1899+
1900+
C = NewType("C", "not_a_real_module.SomeClass")
1901+
c = C()
1902+
c #@
1903+
1904+
# Real module, fake base class name
1905+
import email.charset
1906+
1907+
D = NewType("D", "email.charset.BadClassRef")
1908+
d = D()
1909+
d #@
1910+
1911+
# Real module, but aliased differently than used
1912+
import email.header as header_mod
1913+
1914+
E = NewType("E", "email.header.Header")
1915+
e = E(header_mod.Header())
1916+
e #@
1917+
"""
1918+
)
1919+
assert isinstance(ast_nodes, list)
1920+
1921+
for ast_node in ast_nodes:
1922+
inferred = next(ast_node.infer())
1923+
1924+
with self.assertRaises(astroid.AttributeInferenceError):
1925+
inferred.getattr("value")
1926+
1927+
def test_typing_newtype_forward_ref_nested_module(self) -> None:
1928+
ast_nodes = builder.extract_node(
1929+
"""
1930+
from typing import NewType
1931+
1932+
A = NewType("A", "email.charset.Charset")
1933+
B = NewType("B", "charset.Charset")
1934+
1935+
# header is unused in both cases, but verifies that module name is properly checked
1936+
import email.header, email.charset
1937+
from email import header, charset
1938+
1939+
real = charset.Charset()
1940+
real #@
1941+
1942+
a = A(email.charset.Charset())
1943+
a #@
1944+
1945+
b = B(charset.Charset())
1946+
"""
1947+
)
1948+
assert isinstance(ast_nodes, list)
1949+
1950+
real, *newtypes = ast_nodes
1951+
1952+
real_inferred_all = list(real.infer())
1953+
assert len(real_inferred_all) == 1
1954+
real_inferred = real_inferred_all[0]
1955+
1956+
real_method = real_inferred.getattr("get_body_encoding")
1957+
1958+
for newtype_node in newtypes:
1959+
newtype_inferred_all = list(newtype_node.infer())
1960+
assert len(newtype_inferred_all) == 1
1961+
newtype_inferred = newtype_inferred_all[0]
1962+
1963+
newtype_method = newtype_inferred.getattr("get_body_encoding")
1964+
1965+
assert real_method == newtype_method
1966+
1967+
def test_typing_newtype_forward_ref_nested_class(self) -> None:
1968+
ast_nodes = builder.extract_node(
1969+
"""
1970+
from typing import NewType
1971+
1972+
A = NewType("A", "SomeClass.Nested")
1973+
1974+
class SomeClass:
1975+
class Nested:
1976+
def method(self) -> None:
1977+
pass
1978+
1979+
real = SomeClass.Nested()
1980+
real #@
1981+
1982+
a = A(SomeClass.Nested())
1983+
a #@
1984+
"""
1985+
)
1986+
assert isinstance(ast_nodes, list)
1987+
1988+
real, newtype = ast_nodes
1989+
1990+
real_all_inferred = list(real.infer())
1991+
assert len(real_all_inferred) == 1
1992+
real_inferred = real_all_inferred[0]
1993+
real_method = real_inferred.getattr("method")
1994+
1995+
newtype_all_inferred = list(newtype.infer())
1996+
assert len(newtype_all_inferred) == 1
1997+
newtype_inferred = newtype_all_inferred[0]
1998+
1999+
# This could theoretically work, but for now just here to check that
2000+
# the "forward-declared module" inference doesn't totally break things
2001+
with self.assertRaises(astroid.AttributeInferenceError):
2002+
newtype_method = newtype_inferred.getattr("method")
2003+
2004+
assert real_method == newtype_method
2005+
18402006
def test_namedtuple_nested_class(self):
18412007
result = builder.extract_node(
18422008
"""

0 commit comments

Comments
 (0)