1- import weakref
21from sys import getrefcount
3- from enum import Enum
4- from time import sleep as sync_sleep
5- from asyncio import sleep as async_sleep
62from abc import ABC , abstractmethod
73from threading import RLock
8- from dataclasses import dataclass
94from typing import List , Dict , Awaitable , Optional , Union , Any
10- from types import TracebackType
11- from collections .abc import Coroutine
125
13- from cantok .errors import CancellationError
14-
15-
16- class CancelCause (Enum ):
17- CANCELLED = 1
18- SUPERPOWER = 2
19- NOT_CANCELLED = 3
20-
21- class WaitCoroutineWrapper (Coroutine ): # type: ignore[type-arg]
22- def __init__ (self , step : Union [int , float ], token_for_wait : 'AbstractToken' , token_for_check : 'AbstractToken' ) -> None :
23- self .step = step
24- self .token_for_wait = token_for_wait
25- self .token_for_check = token_for_check
26-
27- self .flags : Dict [str , bool ] = {}
28- self .coroutine = self .async_wait (step , self .flags , token_for_wait , token_for_check )
29-
30- weakref .finalize (self , self .sync_wait , step , self .flags , token_for_wait , token_for_check , self .coroutine )
31-
32- def __await__ (self ) -> Any :
33- return self .coroutine .__await__ ()
34-
35- def send (self , value : Any ) -> Any :
36- return self .coroutine .send (value )
37-
38- def throw (self , exception_type : Any , value : Optional [Any ] = None , traceback : Optional [TracebackType ] = None ) -> Any :
39- pass # pragma: no cover
40-
41- def close (self ) -> None :
42- pass # pragma: no cover
43-
44- @staticmethod
45- def sync_wait (step : Union [int , float ], flags : Dict [str , bool ], token_for_wait : 'AbstractToken' , token_for_check : 'AbstractToken' , wrapped_coroutine : Coroutine ) -> None : # type: ignore[type-arg]
46- if not flags .get ('used' , False ):
47- if getrefcount (wrapped_coroutine ) < 5 :
48- wrapped_coroutine .close ()
49-
50- while token_for_wait :
51- sync_sleep (step )
52-
53- token_for_check .check ()
54-
55- @staticmethod
56- async def async_wait (step : Union [int , float ], flags : Dict [str , bool ], token_for_wait : 'AbstractToken' , token_for_check : 'AbstractToken' ) -> None :
57- flags ['used' ] = True
586
59- while token_for_wait :
60- await async_sleep (step )
61-
62- await async_sleep (0 )
63-
64- token_for_check .check ()
7+ from cantok .errors import CancellationError
8+ from cantok .tokens .abstract .cancel_cause import CancelCause
9+ from cantok .tokens .abstract .report import CancellationReport
10+ from cantok .tokens .abstract .coroutine_wrapper import WaitCoroutineWrapper
6511
66- @dataclass
67- class CancellationReport :
68- cause : CancelCause
69- from_token : 'AbstractToken'
7012
7113class AbstractToken (ABC ):
7214 exception = CancellationError
@@ -75,9 +17,10 @@ class AbstractToken(ABC):
7517 def __init__ (self , * tokens : 'AbstractToken' , cancelled : bool = False ) -> None :
7618 from cantok import DefaultToken
7719
78- self .tokens = [token for token in tokens if not isinstance (token , DefaultToken )]
79- self ._cancelled = cancelled
80- self .lock = RLock ()
20+ self .cached_report : Optional [CancellationReport ] = None
21+ self .tokens : List [AbstractToken ] = [token for token in tokens if not isinstance (token , DefaultToken )]
22+ self ._cancelled : bool = cancelled
23+ self .lock : RLock = RLock ()
8124
8225 def __repr__ (self ) -> str :
8326 chunks = []
@@ -113,7 +56,15 @@ def __add__(self, item: 'AbstractToken') -> 'AbstractToken':
11356
11457 from cantok import SimpleToken
11558
116- return SimpleToken (self , item )
59+ nested_tokens = []
60+
61+ for token in self , item :
62+ if isinstance (token , SimpleToken ) and getrefcount (token ) < 6 :
63+ nested_tokens .extend (token .tokens )
64+ else :
65+ nested_tokens .append (token )
66+
67+ return SimpleToken (* nested_tokens )
11768
11869 def __bool__ (self ) -> bool :
11970 return self .keep_on ()
@@ -124,11 +75,12 @@ def cancelled(self) -> bool:
12475
12576 @cancelled .setter
12677 def cancelled (self , new_value : bool ) -> None :
127- if new_value == True :
128- self ._cancelled = True
129- else :
130- if self ._cancelled == True :
131- raise ValueError ('You cannot restore a cancelled token.' )
78+ with self .lock :
79+ if new_value == True :
80+ self ._cancelled = True
81+ else :
82+ if self .is_cancelled ():
83+ raise ValueError ('You cannot restore a cancelled token.' )
13284
13385 def keep_on (self ) -> bool :
13486 return not self .is_cancelled ()
@@ -159,16 +111,18 @@ def get_report(self, direct: bool = True) -> CancellationReport:
159111 cause = CancelCause .CANCELLED ,
160112 from_token = self ,
161113 )
162- else :
163- if self .check_superpower (direct ):
164- return CancellationReport (
165- cause = CancelCause .SUPERPOWER ,
166- from_token = self ,
167- )
114+ elif self .check_superpower (direct ):
115+ return CancellationReport (
116+ cause = CancelCause .SUPERPOWER ,
117+ from_token = self ,
118+ )
119+ elif self .cached_report is not None :
120+ return self .cached_report
168121
169122 for token in self .tokens :
170123 report = token .get_report (direct = False )
171124 if report .cause != CancelCause .NOT_CANCELLED :
125+ self .cached_report = report
172126 return report
173127
174128 return CancellationReport (
0 commit comments