|
1 |
| -from __future__ import annotations |
2 |
| - |
3 | 1 | import ast
|
| 2 | +import copy |
| 3 | +import importlib |
4 | 4 | import inspect
|
5 | 5 | import tokenize
|
6 | 6 | from collections import defaultdict
|
7 | 7 | from dataclasses import is_dataclass
|
8 | 8 | from enum import Enum
|
9 | 9 | from types import ModuleType
|
10 |
| -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast |
| 10 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def as_literal(p: Union[str, int, float, bool, None]) -> ast.Constant:
|
@@ -252,6 +252,13 @@ def rewrite_func_as_lambda(f: ast.FunctionDef) -> ast.Lambda:
|
252 | 252 | return ast.Lambda(args, ret.value) # type: ignore
|
253 | 253 |
|
254 | 254 |
|
| 255 | +class _mark_ignore_name(ast.NodeTransformer): |
| 256 | + def visit_Name(self, node: ast.Name) -> Any: |
| 257 | + new_node = cast(ast.Expr, ast.parse(node.id).body[0]).value |
| 258 | + new_node._ignore = True # type: ignore |
| 259 | + return new_node |
| 260 | + |
| 261 | + |
255 | 262 | class _rewrite_captured_vars(ast.NodeTransformer):
|
256 | 263 | def __init__(self, cv: inspect.ClosureVars):
|
257 | 264 | self._lookup_dict: Dict[str, Any] = dict(cv.nonlocals)
|
@@ -297,25 +304,19 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
|
297 | 304 | # for secret info here, and then prepend if necessary.
|
298 | 305 | # But no one else knows about this, so we need to mark this
|
299 | 306 | # as "ignore".
|
300 |
| - import importlib |
301 |
| - |
302 | 307 | enum_mod = importlib.import_module(new_value.__module__)
|
303 |
| - additional_ns = getattr(enum_mod, "_object_cpp_as_py_namespace", "") |
304 |
| - if len(additional_ns) == 0: |
| 308 | + additional_ns = getattr(enum_mod, "_object_cpp_as_py_namespace", None) |
| 309 | + if additional_ns is None: |
305 | 310 | return node
|
306 |
| - ns_node = cast( |
307 |
| - ast.Expr, ast.parse(f"{additional_ns}.{ast.unparse(node)}").body[0] |
308 |
| - ).value |
309 |
| - |
310 |
| - class mark_ignore(ast.NodeTransformer): |
311 |
| - def visit_Name(self, node: ast.Name) -> Any: |
312 |
| - if node.id == additional_ns: |
313 |
| - new_node = cast(ast.Expr, ast.parse(node.id).body[0]).value |
314 |
| - new_node._ignore = True # type: ignore |
315 |
| - return new_node |
316 |
| - return node |
317 |
| - |
318 |
| - return mark_ignore().visit(ns_node) |
| 311 | + |
| 312 | + if len(additional_ns) > 0: |
| 313 | + ns_node = cast( |
| 314 | + ast.Expr, ast.parse(f"{additional_ns}.{ast.unparse(node)}").body[0] |
| 315 | + ).value |
| 316 | + else: |
| 317 | + ns_node = copy.copy(node) |
| 318 | + |
| 319 | + return _mark_ignore_name().visit(ns_node) |
319 | 320 | return ast.Constant(value=new_value)
|
320 | 321 |
|
321 | 322 | # If we fail, then just move on.
|
@@ -470,9 +471,7 @@ def find_identifier(
|
470 | 471 | break
|
471 | 472 | return None, None
|
472 | 473 |
|
473 |
| - def tokens_till( |
474 |
| - self, stop_condition: Dict[int, List[str]] |
475 |
| - ) -> Generator[tokenize.Token, None, None]: # type: ignore |
| 474 | + def tokens_till(self, stop_condition: Dict[int, List[str]]): |
476 | 475 | """Yield tokens until we find a stop condition.
|
477 | 476 |
|
478 | 477 | * Properly tracks parentheses, etc.
|
|
0 commit comments