Skip to content

Commit 83eebfc

Browse files
committed
Treat NewTypes like normal subclasses
NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. pylint-dev/pylint#3162 pylint-dev/pylint#2296
1 parent 7f4d62b commit 83eebfc

File tree

3 files changed

+223
-12
lines changed

3 files changed

+223
-12
lines changed

ChangeLog

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ Release date: TBA
1818

1919
Closes #1410
2020

21+
* Treat ``typing.NewType()`` values as normal subclasses.
22+
23+
Closes PyCQA/pylint#2296
24+
Closes PyCQA/pylint#3162
25+
2126

2227
What's New in astroid 2.10.1?
2328
=============================

astroid/brain/brain_typing.py

+80-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import typing
1919
from functools import partial
2020

21-
from astroid import context, extract_node, inference_tip
21+
from astroid import context, extract_node, inference_tip, nodes
2222
from astroid.builder import _extract_single_node
2323
from astroid.const import PY37_PLUS, PY38_PLUS, PY39_PLUS
2424
from astroid.exceptions import (
@@ -43,8 +43,6 @@
4343
from astroid.util import Uninferable
4444

4545
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
46-
TYPING_TYPEVARS = {"TypeVar", "NewType"}
47-
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
4846
TYPING_TYPE_TEMPLATE = """
4947
class Meta(type):
5048
def __getitem__(self, item):
@@ -57,6 +55,13 @@ def __args__(self):
5755
class {0}(metaclass=Meta):
5856
pass
5957
"""
58+
# PEP484 suggests NewType is equivalent to this for typing purposes
59+
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
60+
TYPING_NEWTYPE_TEMPLATE = """
61+
class {derived}({base}):
62+
def __init__(self, val: {base}) -> None:
63+
...
64+
"""
6065
TYPING_MEMBERS = set(getattr(typing, "__all__", []))
6166

6267
TYPING_ALIAS = frozenset(
@@ -111,23 +116,34 @@ def __class_getitem__(cls, item):
111116
"""
112117

113118

114-
def looks_like_typing_typevar_or_newtype(node):
119+
def looks_like_typing_typevar(node: nodes.Call) -> bool:
120+
func = node.func
121+
if isinstance(func, Attribute):
122+
return func.attrname == "TypeVar"
123+
if isinstance(func, Name):
124+
return func.name == "TypeVar"
125+
return False
126+
127+
128+
def looks_like_typing_newtype(node: nodes.Call) -> bool:
115129
func = node.func
116130
if isinstance(func, Attribute):
117-
return func.attrname in TYPING_TYPEVARS
131+
return func.attrname == "NewType"
118132
if isinstance(func, Name):
119-
return func.name in TYPING_TYPEVARS
133+
return func.name == "NewType"
120134
return False
121135

122136

123-
def infer_typing_typevar_or_newtype(node, context_itton=None):
124-
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
137+
def infer_typing_typevar(
138+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
139+
) -> typing.Iterator[nodes.ClassDef]:
140+
"""Infer a typing.TypeVar(...) call"""
125141
try:
126142
func = next(node.func.infer(context=context_itton))
127143
except (InferenceError, StopIteration) as exc:
128144
raise UseInferenceDefault from exc
129145

130-
if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
146+
if func.qname() != "typing.TypeVar":
131147
raise UseInferenceDefault
132148
if not node.args:
133149
raise UseInferenceDefault
@@ -140,6 +156,54 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
140156
return node.infer(context=context_itton)
141157

142158

159+
def infer_typing_newtype(
160+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
161+
) -> typing.Iterator[nodes.ClassDef]:
162+
"""Infer a typing.NewType(...) call"""
163+
try:
164+
func = next(node.func.infer(context=context_itton))
165+
except (InferenceError, StopIteration) as exc:
166+
raise UseInferenceDefault from exc
167+
168+
if func.qname() != "typing.NewType":
169+
raise UseInferenceDefault
170+
if len(node.args) != 2:
171+
raise UseInferenceDefault
172+
173+
# Cannot infer from a dynamic class name (f-string)
174+
if isinstance(node.args[0], JoinedStr):
175+
raise UseInferenceDefault
176+
177+
derived, base = node.args
178+
derived_name = derived.as_string().strip("'")
179+
base_name = base.as_string().strip("'")
180+
181+
new_node: ClassDef = extract_node(
182+
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
183+
)
184+
new_node.parent = node.parent
185+
186+
# Base type arg is a normal reference, so no need to do special lookups
187+
if not isinstance(base, nodes.Const):
188+
new_node.postinit(
189+
bases=[base], body=new_node.body, decorators=new_node.decorators
190+
)
191+
192+
# If the base type is given as a string (e.g. for a forward reference),
193+
# make a naive attempt to find the corresponding node.
194+
# Note that this will not work with imported types.
195+
if isinstance(base, nodes.Const) and isinstance(base.value, str):
196+
_, resolved_base = node.frame().lookup(base_name)
197+
if resolved_base:
198+
new_node.postinit(
199+
bases=[resolved_base[0]],
200+
body=new_node.body,
201+
decorators=new_node.decorators,
202+
)
203+
204+
return new_node.infer(context=context_itton)
205+
206+
143207
def _looks_like_typing_subscript(node):
144208
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
145209
if isinstance(node, Name):
@@ -417,8 +481,13 @@ def infer_typing_cast(
417481

418482
AstroidManager().register_transform(
419483
Call,
420-
inference_tip(infer_typing_typevar_or_newtype),
421-
looks_like_typing_typevar_or_newtype,
484+
inference_tip(infer_typing_typevar),
485+
looks_like_typing_typevar,
486+
)
487+
AstroidManager().register_transform(
488+
Call,
489+
inference_tip(infer_typing_newtype),
490+
looks_like_typing_newtype,
422491
)
423492
AstroidManager().register_transform(
424493
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript

tests/unittest_brain.py

+138-1
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,144 @@ def make_new_type(t):
16781678
"""
16791679
)
16801680
with self.assertRaises(UseInferenceDefault):
1681-
astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value)
1681+
astroid.brain.brain_typing.infer_typing_newtype(node.value)
1682+
1683+
def test_typing_newtype_attrs(self) -> None:
1684+
ast_nodes = builder.extract_node(
1685+
"""
1686+
from typing import NewType
1687+
import decimal
1688+
from decimal import Decimal
1689+
1690+
NewType("Foo", str) #@
1691+
NewType("Bar", "int") #@
1692+
NewType("Baz", Decimal) #@
1693+
NewType("Qux", decimal.Decimal) #@
1694+
"""
1695+
)
1696+
assert isinstance(ast_nodes, list)
1697+
1698+
# Base type given by reference
1699+
foo_node = ast_nodes[0]
1700+
1701+
# Should be unambiguous
1702+
foo_inferred_all = list(foo_node.infer())
1703+
assert len(foo_inferred_all) == 1
1704+
1705+
foo_inferred = foo_inferred_all[0]
1706+
self.assertIsInstance(foo_inferred, astroid.ClassDef)
1707+
1708+
# Check base type method is inferred by accessing one of its methods
1709+
foo_base_class_method = foo_inferred.getattr("endswith")[0]
1710+
self.assertIsInstance(foo_base_class_method, astroid.FunctionDef)
1711+
self.assertEqual("builtins.str.endswith", foo_base_class_method.qname())
1712+
1713+
# Base type given by string (i.e. "int")
1714+
bar_node = ast_nodes[1]
1715+
bar_inferred_all = list(bar_node.infer())
1716+
assert len(bar_inferred_all) == 1
1717+
bar_inferred = bar_inferred_all[0]
1718+
self.assertIsInstance(bar_inferred, astroid.ClassDef)
1719+
1720+
bar_base_class_method = bar_inferred.getattr("bit_length")[0]
1721+
self.assertIsInstance(bar_base_class_method, astroid.FunctionDef)
1722+
self.assertEqual("builtins.int.bit_length", bar_base_class_method.qname())
1723+
1724+
# Decimal may be reexported from an implementation-defined module. For
1725+
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
1726+
# _pydecimal. So the expected qname needs to be grabbed dynamically.
1727+
decimal_quant_node = builder.extract_node(
1728+
"""
1729+
from decimal import Decimal
1730+
Decimal.quantize #@
1731+
"""
1732+
)
1733+
assert isinstance(decimal_quant_node, nodes.NodeNG)
1734+
1735+
# Just grab the first result, since infer() may return values for both
1736+
# _decimal and _pydecimal
1737+
decimal_quant_qname = next(decimal_quant_node.infer()).qname()
1738+
1739+
# Base type is from an "import from"
1740+
baz_node = ast_nodes[2]
1741+
baz_inferred_all = list(baz_node.infer())
1742+
assert len(baz_inferred_all) == 1
1743+
baz_inferred = baz_inferred_all[0]
1744+
self.assertIsInstance(baz_inferred, astroid.ClassDef)
1745+
1746+
baz_base_class_method = baz_inferred.getattr("quantize")[0]
1747+
self.assertIsInstance(baz_base_class_method, astroid.FunctionDef)
1748+
self.assertEqual(decimal_quant_qname, baz_base_class_method.qname())
1749+
1750+
# Base type is from an import
1751+
qux_node = ast_nodes[3]
1752+
qux_inferred_all = list(qux_node.infer())
1753+
qux_inferred = qux_inferred_all[0]
1754+
self.assertIsInstance(qux_inferred, astroid.ClassDef)
1755+
1756+
qux_base_class_method = qux_inferred.getattr("quantize")[0]
1757+
self.assertIsInstance(qux_base_class_method, astroid.FunctionDef)
1758+
self.assertEqual(decimal_quant_qname, qux_base_class_method.qname())
1759+
1760+
def test_typing_newtype_user_defined(self) -> None:
1761+
ast_nodes = builder.extract_node(
1762+
"""
1763+
from typing import NewType
1764+
1765+
class A:
1766+
def __init__(self, value: int):
1767+
self.value = value
1768+
1769+
a = A(5)
1770+
a #@
1771+
1772+
B = NewType("B", A)
1773+
b = B(5)
1774+
b #@
1775+
"""
1776+
)
1777+
assert isinstance(ast_nodes, list)
1778+
1779+
for node in ast_nodes:
1780+
self._verify_node_has_expected_attr(node)
1781+
1782+
def test_typing_newtype_forward_reference(self) -> None:
1783+
# Similar to the test above, but using a forward reference for "A"
1784+
ast_nodes = builder.extract_node(
1785+
"""
1786+
from typing import NewType
1787+
1788+
B = NewType("B", "A")
1789+
1790+
class A:
1791+
def __init__(self, value: int):
1792+
self.value = value
1793+
1794+
a = A(5)
1795+
a #@
1796+
1797+
b = B(5)
1798+
b #@
1799+
"""
1800+
)
1801+
assert isinstance(ast_nodes, list)
1802+
1803+
for node in ast_nodes:
1804+
self._verify_node_has_expected_attr(node)
1805+
1806+
def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
1807+
inferred_all = list(node.infer())
1808+
assert len(inferred_all) == 1
1809+
inferred = inferred_all[0]
1810+
self.assertIsInstance(inferred, astroid.Instance)
1811+
1812+
# Should be able to infer that the "value" attr is present on both types
1813+
val = inferred.getattr("value")[0]
1814+
self.assertIsInstance(val, astroid.AssignAttr)
1815+
1816+
# Sanity check: nonexistent attr is not inferred
1817+
with self.assertRaises(AttributeInferenceError):
1818+
inferred.getattr("bad_attr")
16821819

16831820
def test_namedtuple_nested_class(self):
16841821
result = builder.extract_node(

0 commit comments

Comments
 (0)