Skip to content

Commit e7fe8b3

Browse files
committed
Improve NewType inference for string forward refs
1 parent 1f4ac46 commit e7fe8b3

File tree

2 files changed

+263
-14
lines changed

2 files changed

+263
-14
lines changed

astroid/brain/brain_typing.py

+97-14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from astroid.builder import _extract_single_node
1515
from astroid.const import PY38_PLUS, PY39_PLUS
1616
from astroid.exceptions import (
17+
AstroidImportError,
1718
AttributeInferenceError,
1819
InferenceError,
1920
UseInferenceDefault,
@@ -175,27 +176,109 @@ def infer_typing_newtype(
175176
)
176177
new_node.parent = node.parent
177178

178-
# Base type arg is a normal reference, so no need to do special lookups
179-
if not isinstance(base, nodes.Const):
180-
new_node.postinit(
181-
bases=[base], body=new_node.body, decorators=new_node.decorators
182-
)
179+
new_bases: list[NodeNG] = []
183180

184-
# If the base type is given as a string (e.g. for a forward reference),
185-
# make a naive attempt to find the corresponding node.
186-
# Note that this will not work with imported types.
187-
if isinstance(base, nodes.Const) and isinstance(base.value, str):
181+
if not isinstance(base, nodes.Const):
182+
# Base type arg is a normal reference, so no need to do special lookups
183+
new_bases = [base]
184+
elif isinstance(base, nodes.Const) and isinstance(base.value, str):
185+
# If the base type is given as a string (e.g. for a forward reference),
186+
# make a naive attempt to find the corresponding node.
188187
_, resolved_base = node.frame().lookup(base_name)
189188
if resolved_base:
190-
new_node.postinit(
191-
bases=[resolved_base[0]],
192-
body=new_node.body,
193-
decorators=new_node.decorators,
194-
)
189+
base_node = resolved_base[0]
190+
191+
# If the value is from an "import from" statement, follow the import chain
192+
if isinstance(base_node, nodes.ImportFrom):
193+
ctx = ctx.clone() if ctx else context.InferenceContext()
194+
ctx.lookupname = base_name
195+
base_node = next(base_node.infer(context=ctx))
196+
197+
new_bases = [base_node]
198+
elif "." in base.value:
199+
possible_base = _try_find_imported_object_from_str(node, base.value, ctx)
200+
if possible_base:
201+
new_bases = [possible_base]
202+
203+
if new_bases:
204+
new_node.postinit(
205+
bases=new_bases, body=new_node.body, decorators=new_node.decorators
206+
)
195207

196208
return new_node.infer(context=ctx)
197209

198210

211+
def _try_find_imported_object_from_str(
212+
node: nodes.Call,
213+
name: str,
214+
ctx: context.InferenceContext | None,
215+
) -> nodes.NodeNG | None:
216+
for statement_mod_name, _ in _possible_module_object_splits(name):
217+
# Find import statements that may pull in the appropriate modules
218+
# The name used to find this statement may not correspond to the name of the module actually being imported
219+
# For example, "import email.charset" is found by lookup("email")
220+
_, resolved_bases = node.frame().lookup(statement_mod_name)
221+
if not resolved_bases:
222+
continue
223+
224+
resolved_base = resolved_bases[0]
225+
if isinstance(resolved_base, nodes.Import):
226+
# Extract the names of the module as they are accessed from actual code
227+
scope_names = {(alias or name) for (name, alias) in resolved_base.names}
228+
aliases = {alias: name for (name, alias) in resolved_base.names if alias}
229+
230+
# Find potential mod_name, obj_name splits that work with the available names
231+
# for the module in this scope
232+
import_targets = [
233+
(mod_name, obj_name)
234+
for (mod_name, obj_name) in _possible_module_object_splits(name)
235+
if mod_name in scope_names
236+
]
237+
if not import_targets:
238+
continue
239+
240+
import_target, name_in_mod = import_targets[0]
241+
import_target = aliases.get(import_target, import_target)
242+
243+
# Try to import the module and find the object in it
244+
try:
245+
resolved_mod: nodes.Module = resolved_base.do_import_module(
246+
import_target
247+
)
248+
except AstroidImportError:
249+
# If the module doesn't actually exist, try the next option
250+
continue
251+
252+
# Try to find the appropriate ClassDef or other such node in the target module
253+
_, object_results_in_mod = resolved_mod.lookup(name_in_mod)
254+
if not object_results_in_mod:
255+
continue
256+
257+
base_node = object_results_in_mod[0]
258+
259+
# If the value is from an "import from" statement, follow the import chain
260+
if isinstance(base_node, nodes.ImportFrom):
261+
ctx = ctx.clone() if ctx else context.InferenceContext()
262+
ctx.lookupname = name_in_mod
263+
base_node = next(base_node.infer(context=ctx))
264+
265+
return base_node
266+
267+
return None
268+
269+
270+
def _possible_module_object_splits(
271+
dot_str: str,
272+
) -> Iterator[tuple[str, str]]:
273+
components = dot_str.split(".")
274+
popped = []
275+
276+
while components:
277+
popped.append(components.pop())
278+
279+
yield ".".join(components), ".".join(reversed(popped))
280+
281+
199282
def _looks_like_typing_subscript(node):
200283
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
201284
if isinstance(node, Name):

tests/unittest_brain.py

+166
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,172 @@ def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
17761776
with self.assertRaises(AttributeInferenceError):
17771777
inferred.getattr("bad_attr")
17781778

1779+
def test_typing_newtype_forward_reference_imported(self) -> None:
1780+
all_ast_nodes = builder.extract_node(
1781+
"""
1782+
from typing import NewType
1783+
1784+
A = NewType("A", "decimal.Decimal")
1785+
B = NewType("B", "decimal_mod_alias.Decimal")
1786+
C = NewType("C", "Decimal")
1787+
D = NewType("D", "DecimalAlias")
1788+
1789+
import decimal
1790+
import decimal as decimal_mod_alias
1791+
from decimal import Decimal
1792+
from decimal import Decimal as DecimalAlias
1793+
1794+
Decimal #@
1795+
1796+
a = A(decimal.Decimal(2))
1797+
a #@
1798+
b = B(decimal_mod_alias.Decimal(2))
1799+
b #@
1800+
c = C(Decimal(2))
1801+
c #@
1802+
d = D(DecimalAlias(2))
1803+
d #@
1804+
"""
1805+
)
1806+
assert isinstance(all_ast_nodes, list)
1807+
1808+
real_dec, *ast_nodes = all_ast_nodes
1809+
1810+
real_quantize = next(real_dec.infer()).getattr("quantize")
1811+
1812+
for node in ast_nodes:
1813+
all_inferred = list(node.infer())
1814+
assert len(all_inferred) == 1
1815+
inferred = all_inferred[0]
1816+
assert isinstance(inferred, astroid.Instance)
1817+
1818+
assert inferred.getattr("quantize") == real_quantize
1819+
1820+
def test_typing_newtype_forward_ref_bad_base(self) -> None:
1821+
ast_nodes = builder.extract_node(
1822+
"""
1823+
from typing import NewType
1824+
1825+
A = NewType("A", "DoesntExist")
1826+
1827+
a = A()
1828+
a #@
1829+
1830+
# Valid name, but not actually imported
1831+
B = NewType("B", "decimal.Decimal")
1832+
1833+
b = B()
1834+
b #@
1835+
1836+
# AST works out, but can't import the module
1837+
import not_a_real_module
1838+
1839+
C = NewType("C", "not_a_real_module.SomeClass")
1840+
c = C()
1841+
c #@
1842+
1843+
# Real module, fake base class name
1844+
import email.charset
1845+
1846+
D = NewType("D", "email.charset.BadClassRef")
1847+
d = D()
1848+
d #@
1849+
1850+
# Real module, but aliased differently than used
1851+
import email.header as header_mod
1852+
1853+
E = NewType("E", "email.header.Header")
1854+
e = E(header_mod.Header())
1855+
e #@
1856+
"""
1857+
)
1858+
assert isinstance(ast_nodes, list)
1859+
1860+
for ast_node in ast_nodes:
1861+
inferred = next(ast_node.infer())
1862+
1863+
with self.assertRaises(astroid.AttributeInferenceError):
1864+
inferred.getattr("value")
1865+
1866+
def test_typing_newtype_forward_ref_nested_module(self) -> None:
1867+
ast_nodes = builder.extract_node(
1868+
"""
1869+
from typing import NewType
1870+
1871+
A = NewType("A", "email.charset.Charset")
1872+
B = NewType("B", "charset.Charset")
1873+
1874+
# header is unused in both cases, but verifies that module name is properly checked
1875+
import email.header, email.charset
1876+
from email import header, charset
1877+
1878+
real = charset.Charset()
1879+
real #@
1880+
1881+
a = A(email.charset.Charset())
1882+
a #@
1883+
1884+
b = B(charset.Charset())
1885+
"""
1886+
)
1887+
assert isinstance(ast_nodes, list)
1888+
1889+
real, *newtypes = ast_nodes
1890+
1891+
real_inferred_all = list(real.infer())
1892+
assert len(real_inferred_all) == 1
1893+
real_inferred = real_inferred_all[0]
1894+
1895+
real_method = real_inferred.getattr("get_body_encoding")
1896+
1897+
for newtype_node in newtypes:
1898+
newtype_inferred_all = list(newtype_node.infer())
1899+
assert len(newtype_inferred_all) == 1
1900+
newtype_inferred = newtype_inferred_all[0]
1901+
1902+
newtype_method = newtype_inferred.getattr("get_body_encoding")
1903+
1904+
assert real_method == newtype_method
1905+
1906+
def test_typing_newtype_forward_ref_nested_class(self) -> None:
1907+
ast_nodes = builder.extract_node(
1908+
"""
1909+
from typing import NewType
1910+
1911+
A = NewType("A", "SomeClass.Nested")
1912+
1913+
class SomeClass:
1914+
class Nested:
1915+
def method(self) -> None:
1916+
pass
1917+
1918+
real = SomeClass.Nested()
1919+
real #@
1920+
1921+
a = A(SomeClass.Nested())
1922+
a #@
1923+
"""
1924+
)
1925+
assert isinstance(ast_nodes, list)
1926+
1927+
real, newtype = ast_nodes
1928+
1929+
real_all_inferred = list(real.infer())
1930+
assert len(real_all_inferred) == 1
1931+
real_inferred = real_all_inferred[0]
1932+
real_method = real_inferred.getattr("method")
1933+
1934+
newtype_all_inferred = list(newtype.infer())
1935+
assert len(newtype_all_inferred) == 1
1936+
newtype_inferred = newtype_all_inferred[0]
1937+
1938+
# This could theoretically work, but for now just here to check that
1939+
# the "forward-declared module" inference doesn't totally break things
1940+
with self.assertRaises(astroid.AttributeInferenceError):
1941+
newtype_method = newtype_inferred.getattr("method")
1942+
1943+
assert real_method == newtype_method
1944+
17791945
def test_namedtuple_nested_class(self):
17801946
result = builder.extract_node(
17811947
"""

0 commit comments

Comments
 (0)