11from sys import version_info
2- from typing import Optional , Union , List , Any
2+ from typing import Optional , Union , List , Dict , Any
33from types import MethodType , FunctionType
44from collections .abc import Callable
55from inspect import isfunction , iscoroutinefunction , getsource , getfile
6- from ast import parse , NodeTransformer , Expr , AST , FunctionDef , AsyncFunctionDef , increment_lineno , Await , Return , Name , Load , Assign , Constant , Store , arguments
6+ from ast import parse , NodeTransformer , AST , FunctionDef , AsyncFunctionDef , increment_lineno , Await , Call , With , Return , Name , Load , Assign , Constant , Store , arguments
77from functools import wraps , update_wrapper
88
9- from dill .source import getsource as dill_getsource
9+ from dill .source import getsource as dill_getsource # type: ignore[import-untyped]
1010
1111from transfunctions .errors import CallTransfunctionDirectlyError , DualUseOfDecoratorError , WrongDecoratorSyntaxError
1212
@@ -27,7 +27,7 @@ def __init__(self, function: Callable, decorator_lineno: int, decorator_name: st
2727 self .decorator_name = decorator_name
2828 self .extra_transformers = extra_transformers
2929 self .base_object = None
30- self .cache = {}
30+ self .cache : Dict [ str , Callable ] = {}
3131
3232 def __call__ (self , * args : Any , ** kwargs : Any ) -> None :
3333 raise CallTransfunctionDirectlyError ("You can't call a transfunction object directly, create a function, a generator function or a coroutine function from it." )
@@ -49,7 +49,7 @@ def get_async_function(self):
4949 original_function = self .function
5050
5151 class ConvertSyncFunctionToAsync (NodeTransformer ):
52- def visit_FunctionDef (self , node : Expr ) -> Optional [Union [AST , List [AST ]]]:
52+ def visit_FunctionDef (self , node : FunctionDef ) -> Optional [Union [AST , List [AST ]]]:
5353 if node .name == original_function .__name__ :
5454 return AsyncFunctionDef (
5555 name = original_function .__name__ ,
@@ -62,7 +62,7 @@ def visit_FunctionDef(self, node: Expr) -> Optional[Union[AST, List[AST]]]:
6262 return node
6363
6464 class ExtractAwaitExpressions (NodeTransformer ):
65- def visit_Call (self , node : Expr ) -> Optional [Union [AST , List [AST ]]]:
65+ def visit_Call (self , node : Call ) -> Optional [Union [AST , List [AST ]]]:
6666 if node .func .id == 'await_it' :
6767 return Await (
6868 value = node .args [0 ],
@@ -102,7 +102,7 @@ def extract_context(self, context_name: str, addictional_transformers: Optional[
102102 if context_name in self .cache :
103103 return self .cache [context_name ]
104104 try :
105- source_code = getsource (self .function )
105+ source_code : str = getsource (self .function )
106106 except OSError :
107107 source_code = dill_getsource (self .function )
108108
@@ -119,15 +119,15 @@ def extract_context(self, context_name: str, addictional_transformers: Optional[
119119 decorator_name = self .decorator_name
120120
121121 class RewriteContexts (NodeTransformer ):
122- def visit_With (self , node : Expr ) -> Optional [Union [AST , List [AST ]]]:
122+ def visit_With (self , node : With ) -> Optional [Union [AST , List [AST ]]]:
123123 if len (node .items ) == 1 and node .items [0 ].context_expr .id == context_name :
124124 return node .body
125125 elif len (node .items ) == 1 and node .items [0 ].context_expr .id != context_name and context_name in ('async_context' , 'sync_context' , 'generator_context' ):
126126 return None
127127 return node
128128
129129 class DeleteDecorator (NodeTransformer ):
130- def visit_FunctionDef (self , node : Expr ) -> Optional [Union [AST , List [AST ]]]:
130+ def visit_FunctionDef (self , node : FunctionDef ) -> Optional [Union [AST , List [AST ]]]:
131131 if node .name == original_function .__name__ :
132132 nonlocal transfunction_decorator
133133 transfunction_decorator = None
@@ -161,7 +161,7 @@ def visit_FunctionDef(self, node: Expr) -> Optional[Union[AST, List[AST]]]:
161161 increment_lineno (tree , n = (self .decorator_lineno - transfunction_decorator .lineno - 1 ))
162162
163163 code = compile (tree , filename = getfile (self .function ), mode = 'exec' )
164- namespace = {}
164+ namespace : Dict [ str , Callable ] = {}
165165 exec (code , namespace )
166166 function_factory = namespace ['wrapper' ]
167167 result = function_factory ()
0 commit comments