-
-
Notifications
You must be signed in to change notification settings - Fork 293
Treat NewTypes like normal subclasses #1301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,10 +10,11 @@ | |||||||||
from collections.abc import Iterator | ||||||||||
from functools import partial | ||||||||||
|
||||||||||
from astroid import context, extract_node, inference_tip | ||||||||||
from astroid import context, extract_node, inference_tip, nodes | ||||||||||
from astroid.builder import _extract_single_node | ||||||||||
from astroid.const import PY38_PLUS, PY39_PLUS | ||||||||||
from astroid.exceptions import ( | ||||||||||
AstroidImportError, | ||||||||||
AttributeInferenceError, | ||||||||||
InferenceError, | ||||||||||
UseInferenceDefault, | ||||||||||
|
@@ -35,8 +36,6 @@ | |||||||||
from astroid.util import Uninferable | ||||||||||
|
||||||||||
TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"} | ||||||||||
TYPING_TYPEVARS = {"TypeVar", "NewType"} | ||||||||||
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"} | ||||||||||
TYPING_TYPE_TEMPLATE = """ | ||||||||||
class Meta(type): | ||||||||||
def __getitem__(self, item): | ||||||||||
|
@@ -49,6 +48,13 @@ def __args__(self): | |||||||||
class {0}(metaclass=Meta): | ||||||||||
pass | ||||||||||
""" | ||||||||||
# PEP484 suggests NewType is equivalent to this for typing purposes | ||||||||||
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function | ||||||||||
TYPING_NEWTYPE_TEMPLATE = """ | ||||||||||
class {derived}({base}): | ||||||||||
def __init__(self, val: {base}) -> None: | ||||||||||
... | ||||||||||
""" | ||||||||||
TYPING_MEMBERS = set(getattr(typing, "__all__", [])) | ||||||||||
|
||||||||||
TYPING_ALIAS = frozenset( | ||||||||||
|
@@ -103,24 +109,33 @@ def __class_getitem__(cls, item): | |||||||||
""" | ||||||||||
|
||||||||||
|
||||||||||
def looks_like_typing_typevar_or_newtype(node): | ||||||||||
def looks_like_typing_typevar(node: nodes.Call) -> bool: | ||||||||||
func = node.func | ||||||||||
if isinstance(func, Attribute): | ||||||||||
return func.attrname in TYPING_TYPEVARS | ||||||||||
return func.attrname == "TypeVar" | ||||||||||
if isinstance(func, Name): | ||||||||||
return func.name in TYPING_TYPEVARS | ||||||||||
return func.name == "TypeVar" | ||||||||||
return False | ||||||||||
|
||||||||||
|
||||||||||
def infer_typing_typevar_or_newtype(node, context_itton=None): | ||||||||||
"""Infer a typing.TypeVar(...) or typing.NewType(...) call""" | ||||||||||
def looks_like_typing_newtype(node: nodes.Call) -> bool: | ||||||||||
func = node.func | ||||||||||
if isinstance(func, Attribute): | ||||||||||
return func.attrname == "NewType" | ||||||||||
if isinstance(func, Name): | ||||||||||
return func.name == "NewType" | ||||||||||
return False | ||||||||||
|
||||||||||
|
||||||||||
def infer_typing_typevar( | ||||||||||
node: nodes.Call, ctx: context.InferenceContext | None = None | ||||||||||
) -> Iterator[nodes.ClassDef]: | ||||||||||
"""Infer a typing.TypeVar(...) call""" | ||||||||||
try: | ||||||||||
func = next(node.func.infer(context=context_itton)) | ||||||||||
next(node.func.infer(context=ctx)) | ||||||||||
except (InferenceError, StopIteration) as exc: | ||||||||||
raise UseInferenceDefault from exc | ||||||||||
|
||||||||||
if func.qname() not in TYPING_TYPEVARS_QUALIFIED: | ||||||||||
raise UseInferenceDefault | ||||||||||
if not node.args: | ||||||||||
raise UseInferenceDefault | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L127-128 has a drop in coverage. Could you re-create a test for it? |
||||||||||
# Cannot infer from a dynamic class name (f-string) | ||||||||||
|
@@ -129,7 +144,135 @@ def infer_typing_typevar_or_newtype(node, context_itton=None): | |||||||||
|
||||||||||
typename = node.args[0].as_string().strip("'") | ||||||||||
node = extract_node(TYPING_TYPE_TEMPLATE.format(typename)) | ||||||||||
return node.infer(context=context_itton) | ||||||||||
return node.infer(context=ctx) | ||||||||||
|
||||||||||
|
||||||||||
def infer_typing_newtype( | ||||||||||
node: nodes.Call, ctx: context.InferenceContext | None = None | ||||||||||
) -> Iterator[nodes.ClassDef]: | ||||||||||
"""Infer a typing.NewType(...) call""" | ||||||||||
try: | ||||||||||
next(node.func.infer(context=ctx)) | ||||||||||
except (InferenceError, StopIteration) as exc: | ||||||||||
raise UseInferenceDefault from exc | ||||||||||
|
||||||||||
if len(node.args) != 2: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you create a test for this? It is currently uncovered. |
||||||||||
raise UseInferenceDefault | ||||||||||
|
||||||||||
# Cannot infer from a dynamic class name (f-string) | ||||||||||
if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr): | ||||||||||
raise UseInferenceDefault | ||||||||||
|
||||||||||
derived, base = node.args | ||||||||||
derived_name = derived.as_string().strip("'") | ||||||||||
base_name = base.as_string().strip("'") | ||||||||||
|
||||||||||
new_node: ClassDef = extract_node( | ||||||||||
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name) | ||||||||||
) | ||||||||||
new_node.parent = node.parent | ||||||||||
|
||||||||||
new_bases: list[NodeNG] = [] | ||||||||||
|
||||||||||
if not isinstance(base, nodes.Const): | ||||||||||
# Base type arg is a normal reference, so no need to do special lookups | ||||||||||
new_bases = [base] | ||||||||||
elif isinstance(base, nodes.Const) and isinstance(base.value, str): | ||||||||||
# If the base type is given as a string (e.g. for a forward reference), | ||||||||||
# make a naive attempt to find the corresponding node. | ||||||||||
_, resolved_base = node.frame().lookup(base_name) | ||||||||||
if resolved_base: | ||||||||||
base_node = resolved_base[0] | ||||||||||
|
||||||||||
# If the value is from an "import from" statement, follow the import chain | ||||||||||
if isinstance(base_node, nodes.ImportFrom): | ||||||||||
ctx = ctx.clone() if ctx else context.InferenceContext() | ||||||||||
ctx.lookupname = base_name | ||||||||||
base_node = next(base_node.infer(context=ctx)) | ||||||||||
|
||||||||||
new_bases = [base_node] | ||||||||||
elif "." in base.value: | ||||||||||
possible_base = _try_find_imported_object_from_str(node, base.value, ctx) | ||||||||||
if possible_base: | ||||||||||
new_bases = [possible_base] | ||||||||||
|
||||||||||
if new_bases: | ||||||||||
new_node.postinit( | ||||||||||
bases=new_bases, body=new_node.body, decorators=new_node.decorators | ||||||||||
) | ||||||||||
Comment on lines
+200
to
+202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
return new_node.infer(context=ctx) | ||||||||||
|
||||||||||
|
||||||||||
def _try_find_imported_object_from_str( | ||||||||||
node: nodes.Call, | ||||||||||
name: str, | ||||||||||
ctx: context.InferenceContext | None, | ||||||||||
) -> nodes.NodeNG | None: | ||||||||||
for statement_mod_name, _ in _possible_module_object_splits(name): | ||||||||||
# Find import statements that may pull in the appropriate modules | ||||||||||
# The name used to find this statement may not correspond to the name of the module actually being imported | ||||||||||
# For example, "import email.charset" is found by lookup("email") | ||||||||||
_, resolved_bases = node.frame().lookup(statement_mod_name) | ||||||||||
if not resolved_bases: | ||||||||||
continue | ||||||||||
|
||||||||||
resolved_base = resolved_bases[0] | ||||||||||
if isinstance(resolved_base, nodes.Import): | ||||||||||
# Extract the names of the module as they are accessed from actual code | ||||||||||
scope_names = {(alias or name) for (name, alias) in resolved_base.names} | ||||||||||
aliases = {alias: name for (name, alias) in resolved_base.names if alias} | ||||||||||
|
||||||||||
# Find potential mod_name, obj_name splits that work with the available names | ||||||||||
# for the module in this scope | ||||||||||
import_targets = [ | ||||||||||
(mod_name, obj_name) | ||||||||||
for (mod_name, obj_name) in _possible_module_object_splits(name) | ||||||||||
if mod_name in scope_names | ||||||||||
] | ||||||||||
if not import_targets: | ||||||||||
continue | ||||||||||
|
||||||||||
import_target, name_in_mod = import_targets[0] | ||||||||||
import_target = aliases.get(import_target, import_target) | ||||||||||
|
||||||||||
# Try to import the module and find the object in it | ||||||||||
try: | ||||||||||
resolved_mod: nodes.Module = resolved_base.do_import_module( | ||||||||||
import_target | ||||||||||
) | ||||||||||
except AstroidImportError: | ||||||||||
# If the module doesn't actually exist, try the next option | ||||||||||
continue | ||||||||||
|
||||||||||
# Try to find the appropriate ClassDef or other such node in the target module | ||||||||||
_, object_results_in_mod = resolved_mod.lookup(name_in_mod) | ||||||||||
if not object_results_in_mod: | ||||||||||
continue | ||||||||||
|
||||||||||
base_node = object_results_in_mod[0] | ||||||||||
|
||||||||||
# If the value is from an "import from" statement, follow the import chain | ||||||||||
if isinstance(base_node, nodes.ImportFrom): | ||||||||||
ctx = ctx.clone() if ctx else context.InferenceContext() | ||||||||||
ctx.lookupname = name_in_mod | ||||||||||
base_node = next(base_node.infer(context=ctx)) | ||||||||||
|
||||||||||
return base_node | ||||||||||
|
||||||||||
return None | ||||||||||
|
||||||||||
|
||||||||||
def _possible_module_object_splits( | ||||||||||
dot_str: str, | ||||||||||
) -> Iterator[tuple[str, str]]: | ||||||||||
components = dot_str.split(".") | ||||||||||
popped = [] | ||||||||||
|
||||||||||
while components: | ||||||||||
popped.append(components.pop()) | ||||||||||
|
||||||||||
yield ".".join(components), ".".join(reversed(popped)) | ||||||||||
|
||||||||||
|
||||||||||
def _looks_like_typing_subscript(node): | ||||||||||
|
@@ -403,8 +546,13 @@ def infer_typing_cast( | |||||||||
|
||||||||||
AstroidManager().register_transform( | ||||||||||
Call, | ||||||||||
inference_tip(infer_typing_typevar_or_newtype), | ||||||||||
looks_like_typing_typevar_or_newtype, | ||||||||||
inference_tip(infer_typing_typevar), | ||||||||||
looks_like_typing_typevar, | ||||||||||
) | ||||||||||
AstroidManager().register_transform( | ||||||||||
Call, | ||||||||||
inference_tip(infer_typing_newtype), | ||||||||||
looks_like_typing_newtype, | ||||||||||
) | ||||||||||
AstroidManager().register_transform( | ||||||||||
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript | ||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.