Skip to content

Commit b50c077

Browse files
committed
Also mark other known enums.
1 parent d28abfe commit b50c077

File tree

2 files changed

+56
-23
lines changed

2 files changed

+56
-23
lines changed

func_adl/util_ast.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from __future__ import annotations
2-
31
import ast
2+
import copy
3+
import importlib
44
import inspect
55
import tokenize
66
from collections import defaultdict
77
from dataclasses import is_dataclass
88
from enum import Enum
99
from types import ModuleType
10-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
10+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
1111

1212

1313
def as_literal(p: Union[str, int, float, bool, None]) -> ast.Constant:
@@ -252,6 +252,13 @@ def rewrite_func_as_lambda(f: ast.FunctionDef) -> ast.Lambda:
252252
return ast.Lambda(args, ret.value) # type: ignore
253253

254254

255+
class _mark_ignore_name(ast.NodeTransformer):
256+
def visit_Name(self, node: ast.Name) -> Any:
257+
new_node = cast(ast.Expr, ast.parse(node.id).body[0]).value
258+
new_node._ignore = True # type: ignore
259+
return new_node
260+
261+
255262
class _rewrite_captured_vars(ast.NodeTransformer):
256263
def __init__(self, cv: inspect.ClosureVars):
257264
self._lookup_dict: Dict[str, Any] = dict(cv.nonlocals)
@@ -297,25 +304,19 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
297304
# for secret info here, and then prepend if necessary.
298305
# But no one else knows about this, so we need to mark this
299306
# as "ignore".
300-
import importlib
301-
302307
enum_mod = importlib.import_module(new_value.__module__)
303-
additional_ns = getattr(enum_mod, "_object_cpp_as_py_namespace", "")
304-
if len(additional_ns) == 0:
308+
additional_ns = getattr(enum_mod, "_object_cpp_as_py_namespace", None)
309+
if additional_ns is None:
305310
return node
306-
ns_node = cast(
307-
ast.Expr, ast.parse(f"{additional_ns}.{ast.unparse(node)}").body[0]
308-
).value
309-
310-
class mark_ignore(ast.NodeTransformer):
311-
def visit_Name(self, node: ast.Name) -> Any:
312-
if node.id == additional_ns:
313-
new_node = cast(ast.Expr, ast.parse(node.id).body[0]).value
314-
new_node._ignore = True # type: ignore
315-
return new_node
316-
return node
317-
318-
return mark_ignore().visit(ns_node)
311+
312+
if len(additional_ns) > 0:
313+
ns_node = cast(
314+
ast.Expr, ast.parse(f"{additional_ns}.{ast.unparse(node)}").body[0]
315+
).value
316+
else:
317+
ns_node = copy.copy(node)
318+
319+
return _mark_ignore_name().visit(ns_node)
319320
return ast.Constant(value=new_value)
320321

321322
# If we fail, then just move on.
@@ -470,9 +471,7 @@ def find_identifier(
470471
break
471472
return None, None
472473

473-
def tokens_till(
474-
self, stop_condition: Dict[int, List[str]]
475-
) -> Generator[tokenize.Token, None, None]: # type: ignore
474+
def tokens_till(self, stop_condition: Dict[int, List[str]]):
476475
"""Yield tokens until we find a stop condition.
477476
478477
* Properly tracks parentheses, etc.

tests/test_util_ast.py

+34
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,40 @@ def visit_Name(self, node: ast.Name):
330330
del _object_cpp_as_py_namespace
331331

332332

333+
def test_parse_lambda_with_implied_ns_empty():
334+
"Test adding the special attribute to the module to prefix a namespace"
335+
# Add the attribute to the module
336+
global _object_cpp_as_py_namespace
337+
_object_cpp_as_py_namespace = ""
338+
339+
try:
340+
341+
class forkit:
342+
class MyEnum(Enum):
343+
VALUE = 20
344+
345+
r = parse_as_ast(lambda x: x > forkit.MyEnum.VALUE)
346+
347+
assert "forkit.MyEnum.VALUE" in ast.unparse(r)
348+
349+
found_it = False
350+
351+
class check_it(ast.NodeVisitor):
352+
def visit_Name(self, node: ast.Name):
353+
nonlocal found_it
354+
if node.id == "forkit":
355+
found_it = True
356+
assert hasattr(node, "_ignore")
357+
assert node._ignore # type: ignore
358+
359+
check_it().visit(r)
360+
assert found_it
361+
362+
finally:
363+
# Remove the attribute from the module
364+
del _object_cpp_as_py_namespace
365+
366+
333367
def test_parse_lambda_class_constant_in_module():
334368
from . import xAOD
335369

0 commit comments

Comments
 (0)