Skip to content

Commit a1370de

Browse files
committed
Improve NewType inference for string forward refs
1 parent 83eebfc commit a1370de

File tree

2 files changed

+265
-2
lines changed

2 files changed

+265
-2
lines changed

astroid/brain/brain_typing.py

+99-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from astroid.builder import _extract_single_node
2323
from astroid.const import PY37_PLUS, PY38_PLUS, PY39_PLUS
2424
from astroid.exceptions import (
25+
AstroidImportError,
2526
AttributeInferenceError,
2627
InferenceError,
2728
UseInferenceDefault,
@@ -191,19 +192,115 @@ def infer_typing_newtype(
191192

192193
# If the base type is given as a string (e.g. for a forward reference),
193194
# make a naive attempt to find the corresponding node.
194-
# Note that this will not work with imported types.
195195
if isinstance(base, nodes.Const) and isinstance(base.value, str):
196196
_, resolved_base = node.frame().lookup(base_name)
197197
if resolved_base:
198+
base_node = resolved_base[0]
199+
200+
# If the value is from an "import from" statement, follow the import chain
201+
if isinstance(base_node, nodes.ImportFrom):
202+
ctx = (
203+
context_itton.clone()
204+
if context_itton
205+
else context.InferenceContext()
206+
)
207+
ctx.lookupname = base_name
208+
base_node = next(base_node.infer(context=ctx))
209+
198210
new_node.postinit(
199-
bases=[resolved_base[0]],
211+
bases=[base_node],
200212
body=new_node.body,
201213
decorators=new_node.decorators,
202214
)
215+
elif "." in base.value:
216+
possible_base = _try_find_imported_object_from_str(
217+
node, base.value, context_itton
218+
)
219+
if possible_base:
220+
new_node.postinit(
221+
bases=[possible_base],
222+
body=new_node.body,
223+
decorators=new_node.decorators,
224+
)
203225

204226
return new_node.infer(context=context_itton)
205227

206228

229+
def _try_find_imported_object_from_str(
230+
node: nodes.Call,
231+
name: str,
232+
context_itton: typing.Optional[context.InferenceContext],
233+
) -> typing.Optional[nodes.NodeNG]:
234+
for statement_mod_name, _ in _possible_module_object_splits(name):
235+
# Find import statements that may pull in the appropriate modules
236+
# The name used to find this statement may not correspond to the name of the module actually being imported
237+
# For example, "import email.charset" is found by lookup("email")
238+
_, resolved_bases = node.frame().lookup(statement_mod_name)
239+
if not resolved_bases:
240+
continue
241+
242+
resolved_base = resolved_bases[0]
243+
if isinstance(resolved_base, nodes.Import):
244+
# Extract the names of the module as they are accessed from actual code
245+
scope_names = {(alias or name) for (name, alias) in resolved_base.names}
246+
aliases = {alias: name for (name, alias) in resolved_base.names if alias}
247+
248+
# Find potential mod_name, obj_name splits that work with the available names
249+
# for the module in this scope
250+
import_targets = [
251+
(mod_name, obj_name)
252+
for (mod_name, obj_name) in _possible_module_object_splits(name)
253+
if mod_name in scope_names
254+
]
255+
if not import_targets:
256+
continue
257+
258+
import_target, name_in_mod = import_targets[0]
259+
import_target = aliases.get(import_target, import_target)
260+
261+
# Try to import the module and find the object in it
262+
try:
263+
resolved_mod: nodes.Module = resolved_base.do_import_module(
264+
import_target
265+
)
266+
except AstroidImportError:
267+
# If the module doesn't actually exist, try the next option
268+
continue
269+
270+
# Try to find the appropriate ClassDef or other such node in the target module
271+
_, object_results_in_mod = resolved_mod.lookup(name_in_mod)
272+
if not object_results_in_mod:
273+
continue
274+
275+
base_node = object_results_in_mod[0]
276+
277+
# If the value is from an "import from" statement, follow the import chain
278+
if isinstance(base_node, nodes.ImportFrom):
279+
ctx = (
280+
context_itton.clone()
281+
if context_itton
282+
else context.InferenceContext()
283+
)
284+
ctx.lookupname = name_in_mod
285+
base_node = next(base_node.infer(context=ctx))
286+
287+
return base_node
288+
289+
return None
290+
291+
292+
def _possible_module_object_splits(
293+
dot_str: str,
294+
) -> typing.Iterator[typing.Tuple[str, str]]:
295+
components = dot_str.split(".")
296+
popped = []
297+
298+
while components:
299+
popped.append(components.pop())
300+
301+
yield ".".join(components), ".".join(reversed(popped))
302+
303+
207304
def _looks_like_typing_subscript(node):
208305
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
209306
if isinstance(node, Name):

tests/unittest_brain.py

+166
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,172 @@ def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
18171817
with self.assertRaises(AttributeInferenceError):
18181818
inferred.getattr("bad_attr")
18191819

1820+
def test_typing_newtype_forward_reference_imported(self) -> None:
1821+
all_ast_nodes = builder.extract_node(
1822+
"""
1823+
from typing import NewType
1824+
1825+
A = NewType("A", "decimal.Decimal")
1826+
B = NewType("B", "decimal_mod_alias.Decimal")
1827+
C = NewType("C", "Decimal")
1828+
D = NewType("D", "DecimalAlias")
1829+
1830+
import decimal
1831+
import decimal as decimal_mod_alias
1832+
from decimal import Decimal
1833+
from decimal import Decimal as DecimalAlias
1834+
1835+
Decimal #@
1836+
1837+
a = A(decimal.Decimal(2))
1838+
a #@
1839+
b = B(decimal_mod_alias.Decimal(2))
1840+
b #@
1841+
c = C(Decimal(2))
1842+
c #@
1843+
d = D(DecimalAlias(2))
1844+
d #@
1845+
"""
1846+
)
1847+
assert isinstance(all_ast_nodes, list)
1848+
1849+
real_dec, *ast_nodes = all_ast_nodes
1850+
1851+
real_quantize = next(real_dec.infer()).getattr("quantize")
1852+
1853+
for node in ast_nodes:
1854+
all_inferred = list(node.infer())
1855+
assert len(all_inferred) == 1
1856+
inferred = all_inferred[0]
1857+
self.assertIsInstance(inferred, astroid.Instance)
1858+
1859+
assert inferred.getattr("quantize") == real_quantize
1860+
1861+
def test_typing_newtype_forward_ref_bad_base(self) -> None:
1862+
ast_nodes = builder.extract_node(
1863+
"""
1864+
from typing import NewType
1865+
1866+
A = NewType("A", "DoesntExist")
1867+
1868+
a = A()
1869+
a #@
1870+
1871+
# Valid name, but not actually imported
1872+
B = NewType("B", "decimal.Decimal")
1873+
1874+
b = B()
1875+
b #@
1876+
1877+
# AST works out, but can't import the module
1878+
import not_a_real_module
1879+
1880+
C = NewType("C", "not_a_real_module.SomeClass")
1881+
c = C()
1882+
c #@
1883+
1884+
# Real module, fake base class name
1885+
import email.charset
1886+
1887+
D = NewType("D", "email.charset.BadClassRef")
1888+
d = D()
1889+
d #@
1890+
1891+
# Real module, but aliased differently than used
1892+
import email.header as header_mod
1893+
1894+
E = NewType("E", "email.header.Header")
1895+
e = E(header_mod.Header())
1896+
e #@
1897+
"""
1898+
)
1899+
assert isinstance(ast_nodes, list)
1900+
1901+
for ast_node in ast_nodes:
1902+
inferred = next(ast_node.infer())
1903+
1904+
with self.assertRaises(astroid.AttributeInferenceError):
1905+
inferred.getattr("value")
1906+
1907+
def test_typing_newtype_forward_ref_nested_module(self) -> None:
1908+
ast_nodes = builder.extract_node(
1909+
"""
1910+
from typing import NewType
1911+
1912+
A = NewType("A", "email.charset.Charset")
1913+
B = NewType("B", "charset.Charset")
1914+
1915+
# header is unused in both cases, but verifies that module name is properly checked
1916+
import email.header, email.charset
1917+
from email import header, charset
1918+
1919+
real = charset.Charset()
1920+
real #@
1921+
1922+
a = A(email.charset.Charset())
1923+
a #@
1924+
1925+
b = B(charset.Charset())
1926+
"""
1927+
)
1928+
assert isinstance(ast_nodes, list)
1929+
1930+
real, *newtypes = ast_nodes
1931+
1932+
real_inferred_all = list(real.infer())
1933+
assert len(real_inferred_all) == 1
1934+
real_inferred = real_inferred_all[0]
1935+
1936+
real_method = real_inferred.getattr("get_body_encoding")
1937+
1938+
for newtype_node in newtypes:
1939+
newtype_inferred_all = list(newtype_node.infer())
1940+
assert len(newtype_inferred_all) == 1
1941+
newtype_inferred = newtype_inferred_all[0]
1942+
1943+
newtype_method = newtype_inferred.getattr("get_body_encoding")
1944+
1945+
assert real_method == newtype_method
1946+
1947+
def test_typing_newtype_forward_ref_nested_class(self) -> None:
1948+
ast_nodes = builder.extract_node(
1949+
"""
1950+
from typing import NewType
1951+
1952+
A = NewType("A", "SomeClass.Nested")
1953+
1954+
class SomeClass:
1955+
class Nested:
1956+
def method(self) -> None:
1957+
pass
1958+
1959+
real = SomeClass.Nested()
1960+
real #@
1961+
1962+
a = A(SomeClass.Nested())
1963+
a #@
1964+
"""
1965+
)
1966+
assert isinstance(ast_nodes, list)
1967+
1968+
real, newtype = ast_nodes
1969+
1970+
real_all_inferred = list(real.infer())
1971+
assert len(real_all_inferred) == 1
1972+
real_inferred = real_all_inferred[0]
1973+
real_method = real_inferred.getattr("method")
1974+
1975+
newtype_all_inferred = list(newtype.infer())
1976+
assert len(newtype_all_inferred) == 1
1977+
newtype_inferred = newtype_all_inferred[0]
1978+
1979+
# This could theoretically work, but for now just here to check that
1980+
# the "forward-declared module" inference doesn't totally break things
1981+
with self.assertRaises(astroid.AttributeInferenceError):
1982+
newtype_method = newtype_inferred.getattr("method")
1983+
1984+
assert real_method == newtype_method
1985+
18201986
def test_namedtuple_nested_class(self):
18211987
result = builder.extract_node(
18221988
"""

0 commit comments

Comments
 (0)