1
1
"""Utility for initialization ensuring functions are called only once."""
2
+
2
3
import abc
3
4
import asyncio
4
5
import collections .abc
@@ -51,7 +52,7 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy
51
52
52
53
53
54
class _OnceBase :
54
- def __init__ (self , is_async : bool ) -> None :
55
+ def __init__ (self , is_async : bool , allow_reset : bool = False ) -> None :
55
56
self .is_async = is_async
56
57
# We are going to be extra pedantic about these next two variables only being read or set
57
58
# with a lock by defining getters and setters which enforce that the lock is held. If this
@@ -60,6 +61,7 @@ def __init__(self, is_async: bool) -> None:
60
61
# this is python :)
61
62
self ._called = False
62
63
self ._return_value : typing .Any = None
64
+ self .allow_reset = allow_reset
63
65
if self .is_async :
64
66
self .async_lock = asyncio .Lock ()
65
67
else :
@@ -125,7 +127,11 @@ async def wrapped(*args, **kwargs) -> typing.Any:
125
127
async with once_base .async_lock :
126
128
if not once_base .called :
127
129
once_base .return_value = _iterator_wrappers .AsyncGeneratorWrapper (
128
- retry_exceptions , func , * args , ** kwargs
130
+ retry_exceptions ,
131
+ func ,
132
+ * args ,
133
+ allow_reset = once_base .allow_reset ,
134
+ ** kwargs ,
129
135
)
130
136
once_base .called = True
131
137
return_value = once_base .return_value
@@ -180,7 +186,11 @@ def wrapped(*args, **kwargs) -> typing.Any:
180
186
with once_base .lock :
181
187
if not once_base .called :
182
188
once_base .return_value = _iterator_wrappers .GeneratorWrapper (
183
- retry_exceptions , func , * args , ** kwargs
189
+ retry_exceptions ,
190
+ func ,
191
+ * args ,
192
+ allow_reset = once_base .allow_reset ,
193
+ ** kwargs ,
184
194
)
185
195
once_base .called = True
186
196
iterator = once_base .return_value
@@ -189,13 +199,56 @@ def wrapped(*args, **kwargs) -> typing.Any:
189
199
else :
190
200
raise NotImplementedError ()
191
201
202
+ def reset ():
203
+ once_base : _OnceBase = once_factory ()
204
+ with once_base .lock :
205
+ if not once_base .called :
206
+ return
207
+ if fn_type == _WrappedFunctionType .SYNC_GENERATOR :
208
+ iterator = once_base .return_value
209
+ with iterator .lock :
210
+ iterator .reset ()
211
+ else :
212
+ once_base .called = False
213
+
214
+ async def async_reset ():
215
+ once_base : _OnceBase = once_factory ()
216
+ async with once_base .async_lock :
217
+ if not once_base .called :
218
+ return
219
+ if fn_type == _WrappedFunctionType .ASYNC_GENERATOR :
220
+ iterator = once_base .return_value
221
+ async with iterator .lock :
222
+ iterator .reset ()
223
+ else :
224
+ once_base .called = False
225
+
226
+ def not_allowed_reset ():
227
+ # This doesn't need to be awaitable even in the async case because it will
228
+ # raise the error before an `await` has a chance to do anything.
229
+ raise RuntimeError (
230
+ f"reset() is not allowed to be called on onced function { func } .\n "
231
+ "Did you mean to add `allow_reset=True` to your once.once() annotation?"
232
+ )
233
+
234
+ # No need for the lock here since we're the only thread that could be running,
235
+ # since we haven't even finished wrapping the func yet.
236
+ once_base : _OnceBase = once_factory ()
237
+ if not once_base .allow_reset :
238
+ wrapped .reset = not_allowed_reset # type: ignore
239
+ else :
240
+ if once_base .is_async :
241
+ wrapped .reset = async_reset # type: ignore
242
+ else :
243
+ wrapped .reset = reset # type: ignore
244
+
192
245
functools .update_wrapper (wrapped , func )
193
246
return wrapped
194
247
195
248
196
- def _once_factory (is_async : bool , per_thread : bool ) -> _ONCE_FACTORY_TYPE :
249
+ def _once_factory (is_async : bool , per_thread : bool , allow_reset : bool ) -> _ONCE_FACTORY_TYPE :
197
250
if not per_thread :
198
- singleton_once = _OnceBase (is_async )
251
+ singleton_once = _OnceBase (is_async , allow_reset = allow_reset )
199
252
return lambda : singleton_once
200
253
201
254
per_thread_onces = threading .local ()
@@ -206,13 +259,15 @@ def _get_once_per_thread():
206
259
# itself!
207
260
if once := getattr (per_thread_onces , "once" , None ):
208
261
return once
209
- per_thread_onces .once = _OnceBase (is_async )
262
+ per_thread_onces .once = _OnceBase (is_async , allow_reset = allow_reset )
210
263
return per_thread_onces .once
211
264
212
265
return _get_once_per_thread
213
266
214
267
215
- def once (* args , per_thread = False , retry_exceptions = False ) -> collections .abc .Callable :
268
+ def once (
269
+ * args , per_thread = False , retry_exceptions = False , allow_reset = False
270
+ ) -> collections .abc .Callable :
216
271
"""Decorator to ensure a function is only called once.
217
272
218
273
The restriction of only one call also holds across threads. However, this
@@ -233,6 +288,21 @@ def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Cal
233
288
collection until after the decorated function itself has been deleted. For
234
289
module and class level functions (i.e. non-closures), this means the return
235
290
value will never be deleted.
291
+
292
+ per_thread:
293
+ If true, the decorated function should be allowed to run once-per-thread
294
+ as opposed to once per process.
295
+ retry_exceptions:
296
+ If true, exceptions in the onced function will allow the function to be
297
+ called again. Otherwise, the exceptions are cached and re-raised on
298
+ subsequent executions.
299
+ allow_reset:
300
+ If true, the returned wrapped function will have an attribute
301
+ `.reset(*args, **kwargs)` which will reset the cache to allow a
302
+ rerun of the underlying callable. This only resets the cache in the
303
+ same scope as it would have used otherwise, e.g. resetting a callable
304
+ wrapped in once_per_instance will reset the cache only for that instance,
305
+ once_per_thread only for that thread, etc.
236
306
"""
237
307
if len (args ) == 1 :
238
308
func : collections .abc .Callable = args [0 ]
@@ -242,14 +312,23 @@ def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Cal
242
312
# This trick lets this function be a decorator directly, or be called
243
313
# to create a decorator.
244
314
# Both @once and @once() will function correctly.
245
- return functools .partial (once , per_thread = per_thread , retry_exceptions = retry_exceptions )
315
+ return functools .partial (
316
+ once ,
317
+ per_thread = per_thread ,
318
+ retry_exceptions = retry_exceptions ,
319
+ allow_reset = allow_reset ,
320
+ )
246
321
if _is_method (func ):
247
322
raise SyntaxError (
248
323
"Attempting to use @once.once decorator on method "
249
324
"instead of @once.once_per_class or @once.once_per_instance"
250
325
)
251
326
fn_type = _wrapped_function_type (func )
252
- once_factory = _once_factory (is_async = fn_type in _ASYNC_FN_TYPES , per_thread = per_thread )
327
+ once_factory = _once_factory (
328
+ is_async = fn_type in _ASYNC_FN_TYPES ,
329
+ per_thread = per_thread ,
330
+ allow_reset = allow_reset ,
331
+ )
253
332
return _wrap (func , once_factory , fn_type , retry_exceptions )
254
333
255
334
@@ -260,19 +339,27 @@ class once_per_class: # pylint: disable=invalid-name
260
339
is_staticmethod : bool
261
340
262
341
@classmethod
263
- def with_options (cls , per_thread : bool = False , retry_exceptions = False ):
264
- return lambda func : cls (func , per_thread = per_thread , retry_exceptions = retry_exceptions )
342
+ def with_options (cls , per_thread : bool = False , retry_exceptions = False , allow_reset = False ):
343
+ return lambda func : cls (
344
+ func ,
345
+ per_thread = per_thread ,
346
+ retry_exceptions = retry_exceptions ,
347
+ allow_reset = allow_reset ,
348
+ )
265
349
266
350
def __init__ (
267
351
self ,
268
352
func : collections .abc .Callable ,
269
353
per_thread : bool = False ,
270
354
retry_exceptions : bool = False ,
355
+ allow_reset : bool = False ,
271
356
) -> None :
272
357
self .func = self ._inspect_function (func )
273
358
self .fn_type = _wrapped_function_type (self .func )
274
359
self .once_factory = _once_factory (
275
- is_async = self .fn_type in _ASYNC_FN_TYPES , per_thread = per_thread
360
+ is_async = self .fn_type in _ASYNC_FN_TYPES ,
361
+ per_thread = per_thread ,
362
+ allow_reset = allow_reset ,
276
363
)
277
364
self .retry_exceptions = retry_exceptions
278
365
@@ -310,31 +397,37 @@ class once_per_instance: # pylint: disable=invalid-name
310
397
"""A version of once for class methods which runs once per instance."""
311
398
312
399
@classmethod
313
- def with_options (cls , per_thread : bool = False , retry_exceptions = False ):
314
- return lambda func : cls (func , per_thread = per_thread , retry_exceptions = retry_exceptions )
400
+ def with_options (cls , per_thread : bool = False , retry_exceptions = False , allow_reset = False ):
401
+ return lambda func : cls (
402
+ func , per_thread = per_thread , retry_exceptions = retry_exceptions , allow_reset = False
403
+ )
315
404
316
405
def __init__ (
317
406
self ,
318
407
func : collections .abc .Callable ,
319
408
per_thread : bool = False ,
320
409
retry_exceptions : bool = False ,
410
+ allow_reset : bool = False ,
321
411
) -> None :
322
412
self .func = self ._inspect_function (func )
323
413
self .fn_type = _wrapped_function_type (self .func )
324
414
self .is_async_fn = self .fn_type in _ASYNC_FN_TYPES
325
415
self .callables_lock = threading .Lock ()
326
- self .callables : weakref .WeakKeyDictionary [
327
- typing . Any , collections . abc . Callable
328
- ] = weakref . WeakKeyDictionary ( )
416
+ self .callables : weakref .WeakKeyDictionary [typing . Any , collections . abc . Callable ] = (
417
+ weakref . WeakKeyDictionary ()
418
+ )
329
419
self .per_thread = per_thread
330
420
self .retry_exceptions = retry_exceptions
421
+ self .allow_reset = allow_reset
331
422
332
423
def once_factory (self ) -> _ONCE_FACTORY_TYPE :
333
424
"""Generate a new once factory.
334
425
335
426
A once factory factory if you will.
336
427
"""
337
- return _once_factory (self .is_async_fn , per_thread = self .per_thread )
428
+ return _once_factory (
429
+ self .is_async_fn , per_thread = self .per_thread , allow_reset = self .allow_reset
430
+ )
338
431
339
432
def _inspect_function (self , func : collections .abc .Callable ):
340
433
if isinstance (func , (classmethod , staticmethod )):
0 commit comments