Skip to content

Commit 957d223

Browse files
committed
Treat NewTypes like normal subclasses
NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. pylint-dev/pylint#3162 pylint-dev/pylint#2296
1 parent 841401d commit 957d223

File tree

3 files changed

+203
-12
lines changed

3 files changed

+203
-12
lines changed

ChangeLog

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ What's New in astroid 2.10.1?
1313
Release date: TBA
1414

1515

16+
* Treat ``typing.NewType()`` values as normal subclasses.
17+
18+
Closes PyCQA/pylint#2296
19+
Closes PyCQA/pylint#3162
20+
21+
1622

1723
What's New in astroid 2.10.0?
1824
=============================

astroid/brain/brain_typing.py

+74-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import typing
1919
from functools import partial
2020

21-
from astroid import context, extract_node, inference_tip
21+
from astroid import context, extract_node, inference_tip, nodes
2222
from astroid.builder import _extract_single_node
2323
from astroid.const import PY37_PLUS, PY38_PLUS, PY39_PLUS
2424
from astroid.exceptions import (
@@ -43,8 +43,6 @@
4343
from astroid.util import Uninferable
4444

4545
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
46-
TYPING_TYPEVARS = {"TypeVar", "NewType"}
47-
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
4846
TYPING_TYPE_TEMPLATE = """
4947
class Meta(type):
5048
def __getitem__(self, item):
@@ -57,6 +55,13 @@ def __args__(self):
5755
class {0}(metaclass=Meta):
5856
pass
5957
"""
58+
# PEP484 suggests NewType is equivalent to this for typing purposes
59+
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
60+
TYPING_NEWTYPE_TEMPLATE = """
61+
class {derived}({base}):
62+
def __init__(self, val: {base}) -> None:
63+
...
64+
"""
6065
TYPING_MEMBERS = set(getattr(typing, "__all__", []))
6166

6267
TYPING_ALIAS = frozenset(
@@ -111,23 +116,34 @@ def __class_getitem__(cls, item):
111116
"""
112117

113118

114-
def looks_like_typing_typevar_or_newtype(node):
119+
def looks_like_typing_typevar(node: nodes.Call) -> bool:
120+
func = node.func
121+
if isinstance(func, Attribute):
122+
return func.attrname == "TypeVar"
123+
if isinstance(func, Name):
124+
return func.name == "TypeVar"
125+
return False
126+
127+
128+
def looks_like_typing_newtype(node: nodes.Call) -> bool:
115129
func = node.func
116130
if isinstance(func, Attribute):
117-
return func.attrname in TYPING_TYPEVARS
131+
return func.attrname == "NewType"
118132
if isinstance(func, Name):
119-
return func.name in TYPING_TYPEVARS
133+
return func.name == "NewType"
120134
return False
121135

122136

123-
def infer_typing_typevar_or_newtype(node, context_itton=None):
124-
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
137+
def infer_typing_typevar(
138+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
139+
) -> typing.Iterator[nodes.ClassDef]:
140+
"""Infer a typing.TypeVar(...) call"""
125141
try:
126142
func = next(node.func.infer(context=context_itton))
127143
except (InferenceError, StopIteration) as exc:
128144
raise UseInferenceDefault from exc
129145

130-
if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
146+
if func.qname() != "typing.TypeVar":
131147
raise UseInferenceDefault
132148
if not node.args:
133149
raise UseInferenceDefault
@@ -140,6 +156,48 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
140156
return node.infer(context=context_itton)
141157

142158

159+
def infer_typing_newtype(
160+
node: nodes.Call, context_itton: typing.Optional[context.InferenceContext] = None
161+
) -> typing.Iterator[nodes.ClassDef]:
162+
"""Infer a typing.NewType(...) call"""
163+
try:
164+
func = next(node.func.infer(context=context_itton))
165+
except (InferenceError, StopIteration) as exc:
166+
raise UseInferenceDefault from exc
167+
168+
if func.qname() != "typing.NewType":
169+
raise UseInferenceDefault
170+
if len(node.args) != 2:
171+
raise UseInferenceDefault
172+
173+
# Cannot infer from a dynamic class name (f-string)
174+
if isinstance(node.args[0], JoinedStr):
175+
raise UseInferenceDefault
176+
177+
derived, base = node.args
178+
derived_name = derived.as_string().strip("'")
179+
base_name = base.as_string().strip("'")
180+
181+
new_node: ClassDef = extract_node(
182+
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
183+
)
184+
new_node.parent = node.parent
185+
186+
# Base type arg is a normal reference, so no need to do special lookups
187+
if not isinstance(base, nodes.Const):
188+
new_node.bases = [base]
189+
190+
# If the base type is given as a string (e.g. for a forward reference),
191+
# make a naive attempt to find the corresponding node.
192+
# Note that this will not work with imported types.
193+
if isinstance(base, nodes.Const) and isinstance(base.value, str):
194+
_, resolved_base = node.frame().lookup(base_name)
195+
if resolved_base:
196+
new_node.bases = [resolved_base[0]]
197+
198+
return new_node.infer(context=context_itton)
199+
200+
143201
def _looks_like_typing_subscript(node):
144202
"""Try to figure out if a Subscript node *might* be a typing-related subscript"""
145203
if isinstance(node, Name):
@@ -417,8 +475,13 @@ def infer_typing_cast(
417475

418476
AstroidManager().register_transform(
419477
Call,
420-
inference_tip(infer_typing_typevar_or_newtype),
421-
looks_like_typing_typevar_or_newtype,
478+
inference_tip(infer_typing_typevar),
479+
looks_like_typing_typevar,
480+
)
481+
AstroidManager().register_transform(
482+
Call,
483+
inference_tip(infer_typing_newtype),
484+
looks_like_typing_newtype,
422485
)
423486
AstroidManager().register_transform(
424487
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript

tests/unittest_brain.py

+123-1
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,129 @@ def make_new_type(t):
16781678
"""
16791679
)
16801680
with self.assertRaises(UseInferenceDefault):
1681-
astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value)
1681+
astroid.brain.brain_typing.infer_typing_newtype(node.value)
1682+
1683+
def test_typing_newtype_attrs(self) -> None:
1684+
ast_nodes = builder.extract_node(
1685+
"""
1686+
from typing import NewType
1687+
import decimal
1688+
from decimal import Decimal
1689+
1690+
NewType("Foo", str) #@
1691+
NewType("Bar", "int") #@
1692+
NewType("Baz", Decimal) #@
1693+
NewType("Qux", decimal.Decimal) #@
1694+
"""
1695+
)
1696+
assert isinstance(ast_nodes, list)
1697+
1698+
# Base type given by reference
1699+
foo_node = ast_nodes[0]
1700+
foo_inferred = next(foo_node.infer())
1701+
self.assertIsInstance(foo_inferred, astroid.ClassDef)
1702+
1703+
# Check base type method is inferred by accessing one of its methods
1704+
foo_base_class_method = foo_inferred.getattr("endswith")[0]
1705+
self.assertIsInstance(foo_base_class_method, astroid.FunctionDef)
1706+
self.assertEqual("builtins.str.endswith", foo_base_class_method.qname())
1707+
1708+
# Base type given by string (i.e. "int")
1709+
bar_node = ast_nodes[1]
1710+
bar_inferred = next(bar_node.infer())
1711+
self.assertIsInstance(bar_inferred, astroid.ClassDef)
1712+
1713+
bar_base_class_method = bar_inferred.getattr("bit_length")[0]
1714+
self.assertIsInstance(bar_base_class_method, astroid.FunctionDef)
1715+
self.assertEqual("builtins.int.bit_length", bar_base_class_method.qname())
1716+
1717+
# Decimal may be reexported from an implementation-defined module. For
1718+
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
1719+
# _pydecimal. So the expected qname needs to be grabbed dynamically.
1720+
decimal_quant_node = builder.extract_node(
1721+
"""
1722+
from decimal import Decimal
1723+
Decimal.quantize #@
1724+
"""
1725+
)
1726+
assert isinstance(decimal_quant_node, nodes.NodeNG)
1727+
decimal_quant_qname = next(decimal_quant_node.infer()).qname()
1728+
1729+
# Base type is from an "import from"
1730+
baz_node = ast_nodes[2]
1731+
baz_inferred = next(baz_node.infer())
1732+
self.assertIsInstance(baz_inferred, astroid.ClassDef)
1733+
1734+
baz_base_class_method = baz_inferred.getattr("quantize")[0]
1735+
self.assertIsInstance(baz_base_class_method, astroid.FunctionDef)
1736+
self.assertEqual(decimal_quant_qname, baz_base_class_method.qname())
1737+
1738+
# Base type is from an import
1739+
qux_node = ast_nodes[3]
1740+
qux_inferred = next(qux_node.infer())
1741+
self.assertIsInstance(qux_inferred, astroid.ClassDef)
1742+
1743+
qux_base_class_method = qux_inferred.getattr("quantize")[0]
1744+
self.assertIsInstance(qux_base_class_method, astroid.FunctionDef)
1745+
self.assertEqual(decimal_quant_qname, qux_base_class_method.qname())
1746+
1747+
def test_typing_newtype_user_defined(self) -> None:
1748+
ast_nodes = builder.extract_node(
1749+
"""
1750+
from typing import NewType
1751+
1752+
class A:
1753+
def __init__(self, value: int):
1754+
self.value = value
1755+
1756+
a = A(5)
1757+
a #@
1758+
1759+
B = NewType("B", A)
1760+
b = B(5)
1761+
b #@
1762+
"""
1763+
)
1764+
assert isinstance(ast_nodes, list)
1765+
1766+
for node in ast_nodes:
1767+
self._verify_node_has_expected_attr(node)
1768+
1769+
def test_typing_newtype_forward_reference(self) -> None:
1770+
# Similar to the test above, but using a forward reference for "A"
1771+
ast_nodes = builder.extract_node(
1772+
"""
1773+
from typing import NewType
1774+
1775+
B = NewType("B", "A")
1776+
1777+
class A:
1778+
def __init__(self, value: int):
1779+
self.value = value
1780+
1781+
a = A(5)
1782+
a #@
1783+
1784+
b = B(5)
1785+
b #@
1786+
"""
1787+
)
1788+
assert isinstance(ast_nodes, list)
1789+
1790+
for node in ast_nodes:
1791+
self._verify_node_has_expected_attr(node)
1792+
1793+
def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
1794+
inferred = next(node.infer())
1795+
self.assertIsInstance(inferred, astroid.Instance)
1796+
1797+
# Should be able to infer that the "value" attr is present on both types
1798+
val = inferred.getattr("value")[0]
1799+
self.assertIsInstance(val, astroid.AssignAttr)
1800+
1801+
# Sanity check: nonexistent attr is not inferred
1802+
with self.assertRaises(AttributeInferenceError):
1803+
inferred.getattr("bad_attr")
16821804

16831805
def test_namedtuple_nested_class(self):
16841806
result = builder.extract_node(

0 commit comments

Comments
 (0)