@@ -218,6 +218,7 @@ def test_async_and_sync(
218218 websockets_only = False ,
219219 num_retries = 1 ,
220220 use_testnet = False ,
221+ async_only = False ,
221222):
222223 def decorator (test_function ):
223224 lines = _get_non_decorator_code (test_function )
@@ -235,15 +236,18 @@ def decorator(test_function):
235236 first_line = lines [0 ]
236237 sync_code += first_line .replace (" async def " , "" ).replace (":" , "" )
237238
238- sync_modules_to_import = {}
239- if modules is not None :
240- for module_str in modules :
241- function = module_str .split ("." )[- 1 ]
242- location = module_str [: - 1 * len (function ) - 1 ]
243- module = getattr (importlib .import_module (location ), function )
244- sync_modules_to_import [function ] = module
245-
246- all_modules = {** original_globals , ** globals (), ** sync_modules_to_import }
239+ if not async_only :
240+ sync_modules_to_import = {}
241+ if modules is not None :
242+ for module_str in modules :
243+ function = module_str .split ("." )[- 1 ]
244+ location = module_str [: - 1 * len (function ) - 1 ]
245+ module = getattr (importlib .import_module (location ), function )
246+ sync_modules_to_import [function ] = module
247+
248+ all_modules = {** original_globals , ** globals (), ** sync_modules_to_import }
249+ else :
250+ all_modules = {** original_globals , ** globals ()}
247251 # NOTE: passing `globals()` into `exec` is really bad practice and not safe at
248252 # all, but in this case it's fine because it's only running test code
249253
@@ -290,14 +294,16 @@ def modified_test(self):
290294 asyncio .run (
291295 _run_async_test (self , _get_client (True , True , use_testnet ), 1 )
292296 )
293- with self .subTest (version = "sync" , client = "json" ):
294- _run_sync_test (self , _get_client (False , True , use_testnet ), 2 )
297+ if not async_only :
298+ with self .subTest (version = "sync" , client = "json" ):
299+ _run_sync_test (self , _get_client (False , True , use_testnet ), 2 )
295300 with self .subTest (version = "async" , client = "websocket" ):
296301 asyncio .run (
297302 _run_async_test (self , _get_client (True , False , use_testnet ), 3 )
298303 )
299- with self .subTest (version = "sync" , client = "websocket" ):
300- _run_sync_test (self , _get_client (False , False , use_testnet ), 4 )
304+ if not async_only :
305+ with self .subTest (version = "sync" , client = "websocket" ):
306+ _run_sync_test (self , _get_client (False , False , use_testnet ), 4 )
301307
302308 return modified_test
303309
@@ -315,6 +321,7 @@ def _get_non_decorator_code(function):
315321
316322def create_amm_pool (
317323 client : SyncClient = JSON_RPC_CLIENT ,
324+ enable_amm_clawback : bool = False ,
318325) -> Dict [str , Any ]:
319326 issuer_wallet = Wallet .create ()
320327 fund_wallet (issuer_wallet )
@@ -331,6 +338,16 @@ def create_amm_pool(
331338 issuer_wallet ,
332339 )
333340
341+ # The below flag is required for AMMClawback tests
342+ if enable_amm_clawback :
343+ sign_and_reliable_submission (
344+ AccountSet (
345+ account = issuer_wallet .classic_address ,
346+ set_flag = AccountSetAsfFlag .ASF_ALLOW_TRUSTLINE_CLAWBACK ,
347+ ),
348+ issuer_wallet ,
349+ )
350+
334351 sign_and_reliable_submission (
335352 TrustSet (
336353 account = lp_wallet .classic_address ,
@@ -382,11 +399,13 @@ def create_amm_pool(
382399 "asset" : asset ,
383400 "asset2" : asset2 ,
384401 "issuer_wallet" : issuer_wallet ,
402+ "lp_wallet" : lp_wallet ,
385403 }
386404
387405
388406async def create_amm_pool_async (
389407 client : AsyncClient = ASYNC_JSON_RPC_CLIENT ,
408+ enable_amm_clawback : bool = False ,
390409) -> Dict [str , Any ]:
391410 issuer_wallet = Wallet .create ()
392411 await fund_wallet_async (issuer_wallet )
@@ -403,6 +422,16 @@ async def create_amm_pool_async(
403422 issuer_wallet ,
404423 )
405424
425+ # The below flag is required for AMMClawback tests
426+ if enable_amm_clawback :
427+ await sign_and_reliable_submission_async (
428+ AccountSet (
429+ account = issuer_wallet .classic_address ,
430+ set_flag = AccountSetAsfFlag .ASF_ALLOW_TRUSTLINE_CLAWBACK ,
431+ ),
432+ issuer_wallet ,
433+ )
434+
406435 await sign_and_reliable_submission_async (
407436 TrustSet (
408437 account = lp_wallet .classic_address ,
@@ -454,6 +483,7 @@ async def create_amm_pool_async(
454483 "asset" : asset ,
455484 "asset2" : asset2 ,
456485 "issuer_wallet" : issuer_wallet ,
486+ "lp_wallet" : lp_wallet ,
457487 }
458488
459489
0 commit comments