Skip to content

Commit 1f4ac46

Browse files
committed
Treat NewTypes like normal subclasses
NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. pylint-dev/pylint#3162 pylint-dev/pylint#2296
1 parent 1520834 commit 1f4ac46

File tree

3 files changed

+225
-14
lines changed

3 files changed

+225
-14
lines changed

ChangeLog

+5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ Release date: TBA
5151
* Fix test for Python ``3.11``. In some instances ``err.__traceback__`` will
5252
be uninferable now.
5353

54+
* Treat ``typing.NewType()`` values as normal subclasses.
55+
56+
Closes PyCQA/pylint#2296
57+
Closes PyCQA/pylint#3162
58+
5459
What's New in astroid 2.11.6?
5560
=============================
5661
Release date: TBA

astroid/brain/brain_typing.py

+82-13
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Iterator
1111
from functools import partial
1212

13-
from astroid import context, extract_node, inference_tip
13+
from astroid import context, extract_node, inference_tip, nodes
1414
from astroid.builder import _extract_single_node
1515
from astroid.const import PY38_PLUS, PY39_PLUS
1616
from astroid.exceptions import (
@@ -35,8 +35,6 @@
3535
from astroid.util import Uninferable
3636

3737
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
38-
TYPING_TYPEVARS = {"TypeVar", "NewType"}
39-
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
4038
TYPING_TYPE_TEMPLATE = """
4139
class Meta(type):
4240
def __getitem__(self, item):
@@ -49,6 +47,13 @@ def __args__(self):
4947
class {0}(metaclass=Meta):
5048
pass
5149
"""
50+
# PEP484 suggests NewType is equivalent to this for typing purposes
51+
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
52+
TYPING_NEWTYPE_TEMPLATE = """
53+
class {derived}({base}):
54+
def __init__(self, val: {base}) -> None:
55+
...
56+
"""
5257
TYPING_MEMBERS = set(getattr(typing, "__all__", []))
5358

5459
TYPING_ALIAS = frozenset(
@@ -103,23 +108,34 @@ def __class_getitem__(cls, item):
103108
"""
104109

105110

106-
def looks_like_typing_typevar_or_newtype(node):
111+
def looks_like_typing_typevar(node: nodes.Call) -> bool:
112+
func = node.func
113+
if isinstance(func, Attribute):
114+
return func.attrname == "TypeVar"
115+
if isinstance(func, Name):
116+
return func.name == "TypeVar"
117+
return False
118+
119+
120+
def looks_like_typing_newtype(node: nodes.Call) -> bool:
107121
func = node.func
108122
if isinstance(func, Attribute):
109-
return func.attrname in TYPING_TYPEVARS
123+
return func.attrname == "NewType"
110124
if isinstance(func, Name):
111-
return func.name in TYPING_TYPEVARS
125+
return func.name == "NewType"
112126
return False
113127

114128

115-
def infer_typing_typevar_or_newtype(node, context_itton=None):
116-
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
129+
def infer_typing_typevar(
130+
node: nodes.Call, ctx: context.InferenceContext | None = None
131+
) -> Iterator[nodes.ClassDef]:
132+
"""Infer a typing.TypeVar(...) call"""
117133
try:
118-
func = next(node.func.infer(context=context_itton))
134+
func = next(node.func.infer(context=ctx))
119135
except (InferenceError, StopIteration) as exc:
120136
raise UseInferenceDefault from exc
121137

122-
if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
138+
if func.qname() != "typing.TypeVar":
123139
raise UseInferenceDefault
124140
if not node.args:
125141
raise UseInferenceDefault
@@ -129,7 +145,55 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):
129145

130146
typename = node.args[0].as_string().strip("'")
131147
node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
132-
return node.infer(context=context_itton)
148+
return node.infer(context=ctx)
149+
150+
151+
def infer_typing_newtype(
152+
node: nodes.Call, ctx: context.InferenceContext | None = None
153+
) -> Iterator[nodes.ClassDef]:
154+
"""Infer a typing.NewType(...) call"""
155+
try:
156+
func = next(node.func.infer(context=ctx))
157+
except (InferenceError, StopIteration) as exc:
158+
raise UseInferenceDefault from exc
159+
160+
if func.qname() != "typing.NewType":
161+
raise UseInferenceDefault
162+
if len(node.args) != 2:
163+
raise UseInferenceDefault
164+
165+
# Cannot infer from a dynamic class name (f-string)
166+
if isinstance(node.args[0], JoinedStr):
167+
raise UseInferenceDefault
168+
169+
derived, base = node.args
170+
derived_name = derived.as_string().strip("'")
171+
base_name = base.as_string().strip("'")
172+
173+
new_node: ClassDef = extract_node(
174+
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
175+
)
176+
new_node.parent = node.parent
177+
178+
# Base type arg is a normal reference, so no need to do special lookups
179+
if not isinstance(base, nodes.Const):
180+
new_node.postinit(
181+
bases=[base], body=new_node.body, decorators=new_node.decorators
182+
)
183+
184+
# If the base type is given as a string (e.g. for a forward reference),
185+
# make a naive attempt to find the corresponding node.
186+
# Note that this will not work with imported types.
187+
if isinstance(base, nodes.Const) and isinstance(base.value, str):
188+
_, resolved_base = node.frame().lookup(base_name)
189+
if resolved_base:
190+
new_node.postinit(
191+
bases=[resolved_base[0]],
192+
body=new_node.body,
193+
decorators=new_node.decorators,
194+
)
195+
196+
return new_node.infer(context=ctx)
133197

134198

135199
def _looks_like_typing_subscript(node):
@@ -403,8 +467,13 @@ def infer_typing_cast(
403467

404468
AstroidManager().register_transform(
405469
Call,
406-
inference_tip(infer_typing_typevar_or_newtype),
407-
looks_like_typing_typevar_or_newtype,
470+
inference_tip(infer_typing_typevar),
471+
looks_like_typing_typevar,
472+
)
473+
AstroidManager().register_transform(
474+
Call,
475+
inference_tip(infer_typing_newtype),
476+
looks_like_typing_newtype,
408477
)
409478
AstroidManager().register_transform(
410479
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript

tests/unittest_brain.py

+138-1
Original file line numberDiff line numberDiff line change
@@ -1637,7 +1637,144 @@ def make_new_type(t):
16371637
"""
16381638
)
16391639
with self.assertRaises(UseInferenceDefault):
1640-
astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value)
1640+
astroid.brain.brain_typing.infer_typing_newtype(node.value)
1641+
1642+
def test_typing_newtype_attrs(self) -> None:
1643+
ast_nodes = builder.extract_node(
1644+
"""
1645+
from typing import NewType
1646+
import decimal
1647+
from decimal import Decimal
1648+
1649+
NewType("Foo", str) #@
1650+
NewType("Bar", "int") #@
1651+
NewType("Baz", Decimal) #@
1652+
NewType("Qux", decimal.Decimal) #@
1653+
"""
1654+
)
1655+
assert isinstance(ast_nodes, list)
1656+
1657+
# Base type given by reference
1658+
foo_node = ast_nodes[0]
1659+
1660+
# Should be unambiguous
1661+
foo_inferred_all = list(foo_node.infer())
1662+
assert len(foo_inferred_all) == 1
1663+
1664+
foo_inferred = foo_inferred_all[0]
1665+
assert isinstance(foo_inferred, astroid.ClassDef)
1666+
1667+
# Check base type method is inferred by accessing one of its methods
1668+
foo_base_class_method = foo_inferred.getattr("endswith")[0]
1669+
assert isinstance(foo_base_class_method, astroid.FunctionDef)
1670+
assert foo_base_class_method.qname() == "builtins.str.endswith"
1671+
1672+
# Base type given by string (i.e. "int")
1673+
bar_node = ast_nodes[1]
1674+
bar_inferred_all = list(bar_node.infer())
1675+
assert len(bar_inferred_all) == 1
1676+
bar_inferred = bar_inferred_all[0]
1677+
assert isinstance(bar_inferred, astroid.ClassDef)
1678+
1679+
bar_base_class_method = bar_inferred.getattr("bit_length")[0]
1680+
assert isinstance(bar_base_class_method, astroid.FunctionDef)
1681+
assert bar_base_class_method.qname() == "builtins.int.bit_length"
1682+
1683+
# Decimal may be reexported from an implementation-defined module. For
1684+
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
1685+
# _pydecimal. So the expected qname needs to be grabbed dynamically.
1686+
decimal_quant_node = builder.extract_node(
1687+
"""
1688+
from decimal import Decimal
1689+
Decimal.quantize #@
1690+
"""
1691+
)
1692+
assert isinstance(decimal_quant_node, nodes.NodeNG)
1693+
1694+
# Just grab the first result, since infer() may return values for both
1695+
# _decimal and _pydecimal
1696+
decimal_quant_qname = next(decimal_quant_node.infer()).qname()
1697+
1698+
# Base type is from an "import from"
1699+
baz_node = ast_nodes[2]
1700+
baz_inferred_all = list(baz_node.infer())
1701+
assert len(baz_inferred_all) == 1
1702+
baz_inferred = baz_inferred_all[0]
1703+
assert isinstance(baz_inferred, astroid.ClassDef)
1704+
1705+
baz_base_class_method = baz_inferred.getattr("quantize")[0]
1706+
assert isinstance(baz_base_class_method, astroid.FunctionDef)
1707+
assert decimal_quant_qname == baz_base_class_method.qname()
1708+
1709+
# Base type is from an import
1710+
qux_node = ast_nodes[3]
1711+
qux_inferred_all = list(qux_node.infer())
1712+
qux_inferred = qux_inferred_all[0]
1713+
assert isinstance(qux_inferred, astroid.ClassDef)
1714+
1715+
qux_base_class_method = qux_inferred.getattr("quantize")[0]
1716+
assert isinstance(qux_base_class_method, astroid.FunctionDef)
1717+
assert decimal_quant_qname == qux_base_class_method.qname()
1718+
1719+
def test_typing_newtype_user_defined(self) -> None:
1720+
ast_nodes = builder.extract_node(
1721+
"""
1722+
from typing import NewType
1723+
1724+
class A:
1725+
def __init__(self, value: int):
1726+
self.value = value
1727+
1728+
a = A(5)
1729+
a #@
1730+
1731+
B = NewType("B", A)
1732+
b = B(5)
1733+
b #@
1734+
"""
1735+
)
1736+
assert isinstance(ast_nodes, list)
1737+
1738+
for node in ast_nodes:
1739+
self._verify_node_has_expected_attr(node)
1740+
1741+
def test_typing_newtype_forward_reference(self) -> None:
1742+
# Similar to the test above, but using a forward reference for "A"
1743+
ast_nodes = builder.extract_node(
1744+
"""
1745+
from typing import NewType
1746+
1747+
B = NewType("B", "A")
1748+
1749+
class A:
1750+
def __init__(self, value: int):
1751+
self.value = value
1752+
1753+
a = A(5)
1754+
a #@
1755+
1756+
b = B(5)
1757+
b #@
1758+
"""
1759+
)
1760+
assert isinstance(ast_nodes, list)
1761+
1762+
for node in ast_nodes:
1763+
self._verify_node_has_expected_attr(node)
1764+
1765+
def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
1766+
inferred_all = list(node.infer())
1767+
assert len(inferred_all) == 1
1768+
inferred = inferred_all[0]
1769+
assert isinstance(inferred, astroid.Instance)
1770+
1771+
# Should be able to infer that the "value" attr is present on both types
1772+
val = inferred.getattr("value")[0]
1773+
assert isinstance(val, astroid.AssignAttr)
1774+
1775+
# Sanity check: nonexistent attr is not inferred
1776+
with self.assertRaises(AttributeInferenceError):
1777+
inferred.getattr("bad_attr")
16411778

16421779
def test_namedtuple_nested_class(self):
16431780
result = builder.extract_node(

0 commit comments

Comments
 (0)