44import os
55import sys
66import traceback
7- from contextlib import AsyncExitStack , asynccontextmanager
7+ from contextlib import AsyncExitStack
88from inspect import Parameter , isawaitable , signature
99from types import TracebackType
1010from typing import (
@@ -101,8 +101,8 @@ async def __aexit__(
101101
102102 async def get (self , t : type , name : str = "" ) -> Any :
103103 injector = self .checker .resolve_injector (name , t )
104- args = await self ._exit_stack . enter_async_context (
105- self .checker . _inject_dependencies ( self . task , injector , None )
104+ args = await self .checker . _inject_dependencies (
105+ self .task , injector , self . _exit_stack
106106 )
107107 res = injector (* args )
108108 if isawaitable (res ):
@@ -243,13 +243,13 @@ def resolve_injector(self, name: str, t: type) -> Callable[..., Any]:
243243 return self ._dependency_injections [generic_key ]
244244 return self ._dependency_injections [key ]
245245
246- @asynccontextmanager
247246 async def _inject_dependencies (
248247 self ,
249248 task : BaseCheckerTaskMessage ,
250249 f : Callable [..., Any ],
250+ stack : AsyncExitStack ,
251251 dependencies : Optional [Set [Callable [..., Any ]]] = None ,
252- ) -> AsyncIterator [Any ]:
252+ ) -> List [Any ]:
253253 dependencies = dependencies or set ()
254254
255255 sig = signature (f )
@@ -271,23 +271,22 @@ async def _inject_dependencies(
271271 f"Detected circular dependency in { f } with injected type { v .annotation } "
272272 )
273273 else :
274- async with self ._inject_dependencies (
275- task , injector , dependencies .union ([injector ])
276- ) as args_ :
277- arg = injector (* args_ )
278- if isawaitable (arg ):
279- arg = await arg
280- args .append (arg )
281-
282- async with AsyncExitStack () as stack :
283- # new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:"
284- new_args = []
285- for arg in args :
286- if not hasattr (arg , "__enter__" ) and not hasattr (arg , "__aenter__" ):
287- new_args .append (arg )
288- continue
289- new_args .append (await stack .enter_async_context (arg ))
290- yield new_args
274+ args_ = await self ._inject_dependencies (
275+ task , injector , stack , dependencies .union ([injector ])
276+ )
277+ arg = injector (* args_ )
278+ if isawaitable (arg ):
279+ arg = await arg
280+ args .append (arg )
281+
282+ # new_args contains the return values of __(a)enter__, which would be the "x" in "(async) with ... as x:"
283+ new_args = []
284+ for arg in args :
285+ if not hasattr (arg , "__enter__" ) and not hasattr (arg , "__aenter__" ):
286+ new_args .append (arg )
287+ continue
288+ new_args .append (await stack .enter_async_context (arg ))
289+ return new_args
291290
292291 async def _call_method_raw (self , task : BaseCheckerTaskMessage ) -> Optional [str ]:
293292 variant_id = task .variant_id
@@ -299,8 +298,10 @@ async def _call_method_raw(self, task: BaseCheckerTaskMessage) -> Optional[str]:
299298 f"Variant_id { variant_id } not defined for method { method } "
300299 )
301300
302- async with self ._inject_dependencies (task , f ) as args :
303- return await f (* args )
301+ async with AsyncExitStack () as stack :
302+ args = await self ._inject_dependencies (task , f , stack )
303+ res = await f (* args )
304+ return res
304305
305306 async def _call_method (self , task : BaseCheckerTaskMessage ) -> Optional [str ]:
306307 try :
0 commit comments