Skip to content

Commit 4e2a4f5

Browse files
authored
Enum pass through (#186)
* pass through enums as-is so the C++ layer can deal with them. * Detect a special namespace addition when working with enums. * Fix up pylance error * Ignore the injected namespace. * Also mark other known enums. * Fix 3.9 support issue
1 parent ed624ca commit 4e2a4f5

File tree

4 files changed

+113
-13
lines changed

4 files changed

+113
-13
lines changed

func_adl/type_based_replacement.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,8 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
906906
elif node.id in _global_functions:
907907
self._found_types[node] = Callable
908908
else:
909-
logging.getLogger(__name__).warning(f"Unknown type for name {node.id}")
909+
if not getattr(node, "_ignore", False):
910+
logging.getLogger(__name__).warning(f"Unknown type for name {node.id}")
910911
self._found_types[node] = Any
911912
return node
912913

func_adl/util_ast.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from __future__ import annotations
2-
31
import ast
2+
import copy
3+
import importlib
44
import inspect
55
import tokenize
66
from collections import defaultdict
77
from dataclasses import is_dataclass
88
from enum import Enum
99
from types import ModuleType
10-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
10+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
1111

1212

1313
def as_literal(p: Union[str, int, float, bool, None]) -> ast.Constant:
@@ -252,6 +252,13 @@ def rewrite_func_as_lambda(f: ast.FunctionDef) -> ast.Lambda:
252252
return ast.Lambda(args, ret.value) # type: ignore
253253

254254

255+
class _mark_ignore_name(ast.NodeTransformer):
256+
def visit_Name(self, node: ast.Name) -> Any:
257+
new_node = cast(ast.Expr, ast.parse(node.id).body[0]).value
258+
new_node._ignore = True # type: ignore
259+
return new_node
260+
261+
255262
class _rewrite_captured_vars(ast.NodeTransformer):
256263
def __init__(self, cv: inspect.ClosureVars):
257264
self._lookup_dict: Dict[str, Any] = dict(cv.nonlocals)
@@ -293,7 +300,23 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
293300
new_value = getattr(value.value, node.attr)
294301
# When 3.10 is not supported, replace with EnumType
295302
if isinstance(value.value, Enum.__class__):
296-
new_value = new_value.value
303+
# Sometimes we need to prepend a namespace. We look
304+
# for secret info here, and then prepend if necessary.
305+
# But no one else knows about this, so we need to mark this
306+
# as "ignore".
307+
enum_mod = importlib.import_module(new_value.__module__)
308+
additional_ns = getattr(enum_mod, "_object_cpp_as_py_namespace", None)
309+
if additional_ns is None:
310+
return node
311+
312+
if len(additional_ns) > 0:
313+
ns_node = cast(
314+
ast.Expr, ast.parse(f"{additional_ns}.{ast.unparse(node)}").body[0]
315+
).value
316+
else:
317+
ns_node = copy.copy(node)
318+
319+
return _mark_ignore_name().visit(ns_node)
297320
return ast.Constant(value=new_value)
298321

299322
# If we fail, then just move on.
@@ -448,9 +471,7 @@ def find_identifier(
448471
break
449472
return None, None
450473

451-
def tokens_till(
452-
self, stop_condition: Dict[int, List[str]]
453-
) -> Generator[tokenize.Token, None, None]: # type: ignore
474+
def tokens_till(self, stop_condition: Dict[int, List[str]]):
454475
"""Yield tokens until we find a stop condition.
455476
456477
* Properly tracks parentheses, etc.

tests/test_type_based_replacement.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import logging
55
from inspect import isclass
6-
from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, cast
6+
from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, Union, cast
77

88
import pytest
99

@@ -128,8 +128,8 @@ def EventNumber(self) -> int: ... # noqa
128128
def MyLambdaCallback(self, cb: Callable) -> int: ... # noqa
129129

130130

131-
def return_type_test(expr: str, arg_type: type, expected_type: type):
132-
s = ast_lambda(expr)
131+
def return_type_test(expr: Union[str, ast.expr], arg_type: type, expected_type: type):
132+
s = expr if isinstance(expr, ast.expr) else ast_lambda(expr)
133133
objs = ObjectStream(ast.Name(id="e", ctx=ast.Load()), arg_type)
134134

135135
_, _, expr_type = remap_by_types(objs, {"e": arg_type}, s)
@@ -251,6 +251,17 @@ def test_subscript_any():
251251
return_type_test("e[0]", Any, Any) # type: ignore
252252

253253

254+
def test_ast_marked_ignore(caplog):
255+
"Make sure that a ast marked ignore does not generate a warning"
256+
a = ast.parse("donttalk").body[0].value # type: ignore
257+
a._ignore = True
258+
259+
with caplog.at_level(logging.WARNING):
260+
return_type_test(a, Any, Any) # type: ignore
261+
262+
assert len(caplog.messages) == 0
263+
264+
254265
def test_collection():
255266
"A simple collection"
256267
s = ast_lambda("e.Jets('default')")

tests/test_util_ast.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,76 @@ class MyEnum(Enum):
292292
VALUE = 20
293293

294294
r = parse_as_ast(lambda x: x > forkit.MyEnum.VALUE)
295-
r_true = parse_as_ast(lambda x: x > 20)
296295

297-
assert ast.unparse(r) == ast.unparse(r_true)
296+
assert "VALUE" in ast.unparse(r)
297+
298+
299+
def test_parse_lambda_with_implied_ns():
300+
"Test adding the special attribute to the module to prefix a namespace"
301+
# Add the attribute to the module
302+
global _object_cpp_as_py_namespace
303+
_object_cpp_as_py_namespace = "aweful"
304+
305+
try:
306+
307+
class forkit:
308+
class MyEnum(Enum):
309+
VALUE = 20
310+
311+
r = parse_as_ast(lambda x: x > forkit.MyEnum.VALUE)
312+
313+
assert "aweful.forkit.MyEnum.VALUE" in ast.unparse(r)
314+
315+
found_it = False
316+
317+
class check_it(ast.NodeVisitor):
318+
def visit_Name(self, node: ast.Name):
319+
nonlocal found_it
320+
if node.id == "aweful":
321+
found_it = True
322+
assert hasattr(node, "_ignore")
323+
assert node._ignore # type: ignore
324+
325+
check_it().visit(r)
326+
assert found_it
327+
328+
finally:
329+
# Remove the attribute from the module
330+
del _object_cpp_as_py_namespace
331+
332+
333+
def test_parse_lambda_with_implied_ns_empty():
334+
"Test adding the special attribute to the module to prefix a namespace"
335+
# Add the attribute to the module
336+
global _object_cpp_as_py_namespace
337+
_object_cpp_as_py_namespace = ""
338+
339+
try:
340+
341+
class forkit:
342+
class MyEnum(Enum):
343+
VALUE = 20
344+
345+
r = parse_as_ast(lambda x: x > forkit.MyEnum.VALUE)
346+
347+
assert "forkit.MyEnum.VALUE" in ast.unparse(r)
348+
349+
found_it = False
350+
351+
class check_it(ast.NodeVisitor):
352+
def visit_Name(self, node: ast.Name):
353+
nonlocal found_it
354+
if node.id == "forkit":
355+
found_it = True
356+
assert hasattr(node, "_ignore")
357+
assert node._ignore # type: ignore
358+
359+
check_it().visit(r)
360+
assert found_it
361+
362+
finally:
363+
# Remove the attribute from the module
364+
del _object_cpp_as_py_namespace
298365

299366

300367
def test_parse_lambda_class_constant_in_module():

0 commit comments

Comments
 (0)