Skip to content

Commit 3609469

Browse files
committed
Treat NewTypes like normal subclasses
NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. pylint-dev/pylint#3162 pylint-dev/pylint#2296
1 parent 347ea0c commit 3609469

File tree

3 files changed

+269
-15
lines changed

3 files changed

+269
-15
lines changed

ChangeLog

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ Release date: TBA
5555
* Fix test for Python ``3.11``. In some instances ``err.__traceback__`` will
5656
be uninferable now.
5757

58+
* Treat ``typing.NewType()`` values as normal subclasses.
59+
60+
Closes PyCQA/pylint#2296
61+
Closes PyCQA/pylint#3162
62+
5863
What's New in astroid 2.11.6?
5964
=============================
6065
Release date: TBA

astroid/brain/brain_typing.py

+79-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Iterator
1111
from functools import partial
1212

13-
from astroid import context, extract_node, inference_tip
13+
from astroid import context, extract_node, inference_tip, nodes
1414
from astroid.builder import _extract_single_node
1515
from astroid.const import PY38_PLUS, PY39_PLUS
1616
from astroid.exceptions import (
@@ -35,8 +35,6 @@
3535
from astroid.util import Uninferable
3636

3737
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
38-
TYPING_TYPEVARS = {"TypeVar", "NewType"}
39-
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
4038
TYPING_TYPE_TEMPLATE = """
4139
class Meta(type):
4240
def __getitem__(self, item):
@@ -49,6 +47,13 @@ def __args__(self):
4947
class {0}(metaclass=Meta):
5048
pass
5149
"""
50+
# PEP484 suggests NewType is equivalent to this for typing purposes
51+
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
52+
TYPING_NEWTYPE_TEMPLATE = """
53+
class {derived}({base}):
54+
def __init__(self, val: {base}) -> None:
55+
...
56+
"""
5257
TYPING_MEMBERS = set(getattr(typing, "__all__", []))
5358

5459
TYPING_ALIAS = frozenset(
@@ -103,24 +108,33 @@ def __class_getitem__(cls, item):
103108
"""
104109

105110

106-
def looks_like_typing_typevar_or_newtype(node):
111+
def looks_like_typing_typevar(node: nodes.Call) -> bool:
112+
func = node.func
113+
if isinstance(func, Attribute):
114+
return func.attrname == "TypeVar"
115+
if isinstance(func, Name):
116+
return func.name == "TypeVar"
117+
return False
118+
119+
120+
def looks_like_typing_newtype(node: nodes.Call) -> bool:
107121
func = node.func
108122
if isinstance(func, Attribute):
109-
return func.attrname in TYPING_TYPEVARS
123+
return func.attrname == "NewType"
110124
if isinstance(func, Name):
111-
return func.name in TYPING_TYPEVARS
125+
return func.name == "NewType"
112126
return False
113127

114128

115-
def infer_typing_typevar_or_newtype(node, context_itton=None):
116-
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
129+
def infer_typing_typevar(
130+
node: nodes.Call, ctx: context.InferenceContext | None = None
131+
) -> Iterator[nodes.ClassDef]:
132+
"""Infer a typing.TypeVar(...) call"""
117133
try:
118-
func = next(node.func.infer(context=context_itton))
134+
next(node.func.infer(context=ctx))
119135
except (InferenceError, StopIteration) as exc:
120136
raise UseInferenceDefault from exc
121137

122-
if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
123-
raise UseInferenceDefault
124138
if not node.args:
125139
raise UseInferenceDefault
126140
# Cannot infer from a dynamic class name (f-string)
@@ -129,7 +143,53 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
129143

130144
typename = node.args[0].as_string().strip("'")
131145
node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
132-
return node.infer(context=context_itton)
146+
return node.infer(context=ctx)
147+
148+
149+
def infer_typing_newtype(
150+
node: nodes.Call, ctx: context.InferenceContext | None = None
151+
) -> Iterator[nodes.ClassDef]:
152+
"""Infer a typing.NewType(...) call"""
153+
try:
154+
next(node.func.infer(context=ctx))
155+
except (InferenceError, StopIteration) as exc:
156+
raise UseInferenceDefault from exc
157+
158+
if len(node.args) != 2:
159+
raise UseInferenceDefault
160+
161+
# Cannot infer from a dynamic class name (f-string)
162+
if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr):
163+
raise UseInferenceDefault
164+
165+
derived, base = node.args
166+
derived_name = derived.as_string().strip("'")
167+
base_name = base.as_string().strip("'")
168+
169+
new_node: ClassDef = extract_node(
170+
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
171+
)
172+
new_node.parent = node.parent
173+
174+
# Base type arg is a normal reference, so no need to do special lookups
175+
if not isinstance(base, nodes.Const):
176+
new_node.postinit(
177+
bases=[base], body=new_node.body, decorators=new_node.decorators
178+
)
179+
180+
# If the base type is given as a string (e.g. for a forward reference),
181+
# make a naive attempt to find the corresponding node.
182+
# Note that this will not work with imported types.
183+
if isinstance(base, nodes.Const) and isinstance(base.value, str):
184+
_, resolved_base = node.frame().lookup(base_name)
185+
if resolved_base:
186+
new_node.postinit(
187+
bases=[resolved_base[0]],
188+
body=new_node.body,
189+
decorators=new_node.decorators,
190+
)
191+
192+
return new_node.infer(context=ctx)
133193

134194

135195
def _looks_like_typing_subscript(node):
@@ -403,8 +463,13 @@ def infer_typing_cast(
403463

404464
AstroidManager().register_transform(
405465
Call,
406-
inference_tip(infer_typing_typevar_or_newtype),
407-
looks_like_typing_typevar_or_newtype,
466+
inference_tip(infer_typing_typevar),
467+
looks_like_typing_typevar,
468+
)
469+
AstroidManager().register_transform(
470+
Call,
471+
inference_tip(infer_typing_newtype),
472+
looks_like_typing_newtype,
408473
)
409474
AstroidManager().register_transform(
410475
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript

tests/unittest_brain.py

+185-1
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,26 @@ def test_typing_types(self) -> None:
16401640
inferred = next(node.infer())
16411641
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())
16421642

1643+
def test_typing_typevar_bad_args(self) -> None:
1644+
ast_nodes = builder.extract_node(
1645+
"""
1646+
from typing import TypeVar
1647+
1648+
T = TypeVar()
1649+
T #@
1650+
1651+
U = TypeVar(f"U")
1652+
U #@
1653+
"""
1654+
)
1655+
assert isinstance(ast_nodes, list)
1656+
1657+
no_args_node = ast_nodes[0]
1658+
assert list(no_args_node.infer()) == [util.Uninferable]
1659+
1660+
fstr_node = ast_nodes[1]
1661+
assert list(fstr_node.infer()) == [util.Uninferable]
1662+
16431663
def test_typing_type_without_tip(self):
16441664
"""Regression test for https://github.com/PyCQA/pylint/issues/5770"""
16451665
node = builder.extract_node(
@@ -1651,7 +1671,171 @@ def make_new_type(t):
16511671
"""
16521672
)
16531673
with self.assertRaises(UseInferenceDefault):
1654-
astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value)
1674+
astroid.brain.brain_typing.infer_typing_newtype(node.value)
1675+
1676+
def test_typing_newtype_attrs(self) -> None:
1677+
ast_nodes = builder.extract_node(
1678+
"""
1679+
from typing import NewType
1680+
import decimal
1681+
from decimal import Decimal
1682+
1683+
NewType("Foo", str) #@
1684+
NewType("Bar", "int") #@
1685+
NewType("Baz", Decimal) #@
1686+
NewType("Qux", decimal.Decimal) #@
1687+
"""
1688+
)
1689+
assert isinstance(ast_nodes, list)
1690+
1691+
# Base type given by reference
1692+
foo_node = ast_nodes[0]
1693+
1694+
# Should be unambiguous
1695+
foo_inferred_all = list(foo_node.infer())
1696+
assert len(foo_inferred_all) == 1
1697+
1698+
foo_inferred = foo_inferred_all[0]
1699+
assert isinstance(foo_inferred, astroid.ClassDef)
1700+
1701+
# Check base type method is inferred by accessing one of its methods
1702+
foo_base_class_method = foo_inferred.getattr("endswith")[0]
1703+
assert isinstance(foo_base_class_method, astroid.FunctionDef)
1704+
assert foo_base_class_method.qname() == "builtins.str.endswith"
1705+
1706+
# Base type given by string (i.e. "int")
1707+
bar_node = ast_nodes[1]
1708+
bar_inferred_all = list(bar_node.infer())
1709+
assert len(bar_inferred_all) == 1
1710+
bar_inferred = bar_inferred_all[0]
1711+
assert isinstance(bar_inferred, astroid.ClassDef)
1712+
1713+
bar_base_class_method = bar_inferred.getattr("bit_length")[0]
1714+
assert isinstance(bar_base_class_method, astroid.FunctionDef)
1715+
assert bar_base_class_method.qname() == "builtins.int.bit_length"
1716+
1717+
# Decimal may be reexported from an implementation-defined module. For
1718+
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
1719+
# _pydecimal. So the expected qname needs to be grabbed dynamically.
1720+
decimal_quant_node = builder.extract_node(
1721+
"""
1722+
from decimal import Decimal
1723+
Decimal.quantize #@
1724+
"""
1725+
)
1726+
assert isinstance(decimal_quant_node, nodes.NodeNG)
1727+
1728+
# Just grab the first result, since infer() may return values for both
1729+
# _decimal and _pydecimal
1730+
decimal_quant_qname = next(decimal_quant_node.infer()).qname()
1731+
1732+
# Base type is from an "import from"
1733+
baz_node = ast_nodes[2]
1734+
baz_inferred_all = list(baz_node.infer())
1735+
assert len(baz_inferred_all) == 1
1736+
baz_inferred = baz_inferred_all[0]
1737+
assert isinstance(baz_inferred, astroid.ClassDef)
1738+
1739+
baz_base_class_method = baz_inferred.getattr("quantize")[0]
1740+
assert isinstance(baz_base_class_method, astroid.FunctionDef)
1741+
assert decimal_quant_qname == baz_base_class_method.qname()
1742+
1743+
# Base type is from an import
1744+
qux_node = ast_nodes[3]
1745+
qux_inferred_all = list(qux_node.infer())
1746+
qux_inferred = qux_inferred_all[0]
1747+
assert isinstance(qux_inferred, astroid.ClassDef)
1748+
1749+
qux_base_class_method = qux_inferred.getattr("quantize")[0]
1750+
assert isinstance(qux_base_class_method, astroid.FunctionDef)
1751+
assert decimal_quant_qname == qux_base_class_method.qname()
1752+
1753+
def test_typing_newtype_bad_args(self) -> None:
1754+
ast_nodes = builder.extract_node(
1755+
"""
1756+
from typing import NewType
1757+
1758+
NoArgs = NewType()
1759+
NoArgs #@
1760+
1761+
OneArg = NewType("OneArg")
1762+
OneArg #@
1763+
1764+
ThreeArgs = NewType("ThreeArgs", int, str)
1765+
ThreeArgs #@
1766+
1767+
DynamicArg = NewType(f"DynamicArg", int)
1768+
DynamicArg #@
1769+
1770+
DynamicBase = NewType("DynamicBase", f"int")
1771+
DynamicBase #@
1772+
"""
1773+
)
1774+
assert isinstance(ast_nodes, list)
1775+
1776+
node: nodes.NodeNG
1777+
for node in ast_nodes:
1778+
assert list(node.infer()) == [util.Uninferable]
1779+
1780+
def test_typing_newtype_user_defined(self) -> None:
1781+
ast_nodes = builder.extract_node(
1782+
"""
1783+
from typing import NewType
1784+
1785+
class A:
1786+
def __init__(self, value: int):
1787+
self.value = value
1788+
1789+
a = A(5)
1790+
a #@
1791+
1792+
B = NewType("B", A)
1793+
b = B(5)
1794+
b #@
1795+
"""
1796+
)
1797+
assert isinstance(ast_nodes, list)
1798+
1799+
for node in ast_nodes:
1800+
self._verify_node_has_expected_attr(node)
1801+
1802+
def test_typing_newtype_forward_reference(self) -> None:
1803+
# Similar to the test above, but using a forward reference for "A"
1804+
ast_nodes = builder.extract_node(
1805+
"""
1806+
from typing import NewType
1807+
1808+
B = NewType("B", "A")
1809+
1810+
class A:
1811+
def __init__(self, value: int):
1812+
self.value = value
1813+
1814+
a = A(5)
1815+
a #@
1816+
1817+
b = B(5)
1818+
b #@
1819+
"""
1820+
)
1821+
assert isinstance(ast_nodes, list)
1822+
1823+
for node in ast_nodes:
1824+
self._verify_node_has_expected_attr(node)
1825+
1826+
def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
1827+
inferred_all = list(node.infer())
1828+
assert len(inferred_all) == 1
1829+
inferred = inferred_all[0]
1830+
assert isinstance(inferred, astroid.Instance)
1831+
1832+
# Should be able to infer that the "value" attr is present on both types
1833+
val = inferred.getattr("value")[0]
1834+
assert isinstance(val, astroid.AssignAttr)
1835+
1836+
# Sanity check: nonexistent attr is not inferred
1837+
with self.assertRaises(AttributeInferenceError):
1838+
inferred.getattr("bad_attr")
16551839

16561840
def test_namedtuple_nested_class(self):
16571841
result = builder.extract_node(

0 commit comments

Comments
 (0)