Skip to content

Commit 97b3aa0

Browse files
authored
Implement allow_force_rerun for once (#24)
* Implement allow_force_rerun for once * Black formatter * Type ignore function attribute * Switch force_rerun to reset * Black formatter * fix mypy * Reformat to sign commit
1 parent b9f0765 commit 97b3aa0

File tree

3 files changed

+221
-23
lines changed

3 files changed

+221
-23
lines changed

once/__init__.py

+111-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utility for initialization ensuring functions are called only once."""
2+
23
import abc
34
import asyncio
45
import collections.abc
@@ -51,7 +52,7 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy
5152

5253

5354
class _OnceBase:
54-
def __init__(self, is_async: bool) -> None:
55+
def __init__(self, is_async: bool, allow_reset: bool = False) -> None:
5556
self.is_async = is_async
5657
# We are going to be extra pedantic about these next two variables only being read or set
5758
# with a lock by defining getters and setters which enforce that the lock is held. If this
@@ -60,6 +61,7 @@ def __init__(self, is_async: bool) -> None:
6061
# this is python :)
6162
self._called = False
6263
self._return_value: typing.Any = None
64+
self.allow_reset = allow_reset
6365
if self.is_async:
6466
self.async_lock = asyncio.Lock()
6567
else:
@@ -125,7 +127,11 @@ async def wrapped(*args, **kwargs) -> typing.Any:
125127
async with once_base.async_lock:
126128
if not once_base.called:
127129
once_base.return_value = _iterator_wrappers.AsyncGeneratorWrapper(
128-
retry_exceptions, func, *args, **kwargs
130+
retry_exceptions,
131+
func,
132+
*args,
133+
allow_reset=once_base.allow_reset,
134+
**kwargs,
129135
)
130136
once_base.called = True
131137
return_value = once_base.return_value
@@ -180,7 +186,11 @@ def wrapped(*args, **kwargs) -> typing.Any:
180186
with once_base.lock:
181187
if not once_base.called:
182188
once_base.return_value = _iterator_wrappers.GeneratorWrapper(
183-
retry_exceptions, func, *args, **kwargs
189+
retry_exceptions,
190+
func,
191+
*args,
192+
allow_reset=once_base.allow_reset,
193+
**kwargs,
184194
)
185195
once_base.called = True
186196
iterator = once_base.return_value
@@ -189,13 +199,56 @@ def wrapped(*args, **kwargs) -> typing.Any:
189199
else:
190200
raise NotImplementedError()
191201

202+
def reset():
203+
once_base: _OnceBase = once_factory()
204+
with once_base.lock:
205+
if not once_base.called:
206+
return
207+
if fn_type == _WrappedFunctionType.SYNC_GENERATOR:
208+
iterator = once_base.return_value
209+
with iterator.lock:
210+
iterator.reset()
211+
else:
212+
once_base.called = False
213+
214+
async def async_reset():
215+
once_base: _OnceBase = once_factory()
216+
async with once_base.async_lock:
217+
if not once_base.called:
218+
return
219+
if fn_type == _WrappedFunctionType.ASYNC_GENERATOR:
220+
iterator = once_base.return_value
221+
async with iterator.lock:
222+
iterator.reset()
223+
else:
224+
once_base.called = False
225+
226+
def not_allowed_reset():
227+
# This doesn't need to be awaitable even in the async case because it will
228+
# raise the error before an `await` has a chance to do anything.
229+
raise RuntimeError(
230+
f"reset() is not allowed to be called on onced function {func}.\n"
231+
"Did you mean to add `allow_reset=True` to your once.once() annotation?"
232+
)
233+
234+
# No need for the lock here since we're the only thread that could be running,
235+
# since we haven't even finished wrapping the func yet.
236+
once_base: _OnceBase = once_factory()
237+
if not once_base.allow_reset:
238+
wrapped.reset = not_allowed_reset # type: ignore
239+
else:
240+
if once_base.is_async:
241+
wrapped.reset = async_reset # type: ignore
242+
else:
243+
wrapped.reset = reset # type: ignore
244+
192245
functools.update_wrapper(wrapped, func)
193246
return wrapped
194247

195248

196-
def _once_factory(is_async: bool, per_thread: bool) -> _ONCE_FACTORY_TYPE:
249+
def _once_factory(is_async: bool, per_thread: bool, allow_reset: bool) -> _ONCE_FACTORY_TYPE:
197250
if not per_thread:
198-
singleton_once = _OnceBase(is_async)
251+
singleton_once = _OnceBase(is_async, allow_reset=allow_reset)
199252
return lambda: singleton_once
200253

201254
per_thread_onces = threading.local()
@@ -206,13 +259,15 @@ def _get_once_per_thread():
206259
# itself!
207260
if once := getattr(per_thread_onces, "once", None):
208261
return once
209-
per_thread_onces.once = _OnceBase(is_async)
262+
per_thread_onces.once = _OnceBase(is_async, allow_reset=allow_reset)
210263
return per_thread_onces.once
211264

212265
return _get_once_per_thread
213266

214267

215-
def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Callable:
268+
def once(
269+
*args, per_thread=False, retry_exceptions=False, allow_reset=False
270+
) -> collections.abc.Callable:
216271
"""Decorator to ensure a function is only called once.
217272
218273
The restriction of only one call also holds across threads. However, this
@@ -233,6 +288,21 @@ def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Cal
233288
collection until after the decorated function itself has been deleted. For
234289
module and class level functions (i.e. non-closures), this means the return
235290
value will never be deleted.
291+
292+
per_thread:
293+
If true, the decorated function should be allowed to run once-per-thread
294+
as opposed to once per process.
295+
retry_exceptions:
296+
If true, exceptions in the onced function will allow the function to be
297+
called again. Otherwise, the exceptions are cached and re-raised on
298+
subsequent executions.
299+
allow_reset:
300+
If true, the returned wrapped function will have an attribute
301+
`.reset(*args, **kwargs)` which will reset the cache to allow a
302+
rerun of the underlying callable. This only resets the cache in the
303+
same scope as it would have used otherwise, e.g. resetting a callable
304+
wrapped in once_per_instance will reset the cache only for that instance,
305+
once_per_thread only for that thread, etc.
236306
"""
237307
if len(args) == 1:
238308
func: collections.abc.Callable = args[0]
@@ -242,14 +312,23 @@ def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Cal
242312
# This trick lets this function be a decorator directly, or be called
243313
# to create a decorator.
244314
# Both @once and @once() will function correctly.
245-
return functools.partial(once, per_thread=per_thread, retry_exceptions=retry_exceptions)
315+
return functools.partial(
316+
once,
317+
per_thread=per_thread,
318+
retry_exceptions=retry_exceptions,
319+
allow_reset=allow_reset,
320+
)
246321
if _is_method(func):
247322
raise SyntaxError(
248323
"Attempting to use @once.once decorator on method "
249324
"instead of @once.once_per_class or @once.once_per_instance"
250325
)
251326
fn_type = _wrapped_function_type(func)
252-
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread)
327+
once_factory = _once_factory(
328+
is_async=fn_type in _ASYNC_FN_TYPES,
329+
per_thread=per_thread,
330+
allow_reset=allow_reset,
331+
)
253332
return _wrap(func, once_factory, fn_type, retry_exceptions)
254333

255334

@@ -260,19 +339,27 @@ class once_per_class: # pylint: disable=invalid-name
260339
is_staticmethod: bool
261340

262341
@classmethod
263-
def with_options(cls, per_thread: bool = False, retry_exceptions=False):
264-
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions)
342+
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False):
343+
return lambda func: cls(
344+
func,
345+
per_thread=per_thread,
346+
retry_exceptions=retry_exceptions,
347+
allow_reset=allow_reset,
348+
)
265349

266350
def __init__(
267351
self,
268352
func: collections.abc.Callable,
269353
per_thread: bool = False,
270354
retry_exceptions: bool = False,
355+
allow_reset: bool = False,
271356
) -> None:
272357
self.func = self._inspect_function(func)
273358
self.fn_type = _wrapped_function_type(self.func)
274359
self.once_factory = _once_factory(
275-
is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread
360+
is_async=self.fn_type in _ASYNC_FN_TYPES,
361+
per_thread=per_thread,
362+
allow_reset=allow_reset,
276363
)
277364
self.retry_exceptions = retry_exceptions
278365

@@ -310,31 +397,37 @@ class once_per_instance: # pylint: disable=invalid-name
310397
"""A version of once for class methods which runs once per instance."""
311398

312399
@classmethod
313-
def with_options(cls, per_thread: bool = False, retry_exceptions=False):
314-
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions)
400+
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False):
401+
return lambda func: cls(
402+
func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_reset=False
403+
)
315404

316405
def __init__(
317406
self,
318407
func: collections.abc.Callable,
319408
per_thread: bool = False,
320409
retry_exceptions: bool = False,
410+
allow_reset: bool = False,
321411
) -> None:
322412
self.func = self._inspect_function(func)
323413
self.fn_type = _wrapped_function_type(self.func)
324414
self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES
325415
self.callables_lock = threading.Lock()
326-
self.callables: weakref.WeakKeyDictionary[
327-
typing.Any, collections.abc.Callable
328-
] = weakref.WeakKeyDictionary()
416+
self.callables: weakref.WeakKeyDictionary[typing.Any, collections.abc.Callable] = (
417+
weakref.WeakKeyDictionary()
418+
)
329419
self.per_thread = per_thread
330420
self.retry_exceptions = retry_exceptions
421+
self.allow_reset = allow_reset
331422

332423
def once_factory(self) -> _ONCE_FACTORY_TYPE:
333424
"""Generate a new once factory.
334425
335426
A once factory factory if you will.
336427
"""
337-
return _once_factory(self.is_async_fn, per_thread=self.per_thread)
428+
return _once_factory(
429+
self.is_async_fn, per_thread=self.per_thread, allow_reset=self.allow_reset
430+
)
338431

339432
def _inspect_function(self, func: collections.abc.Callable):
340433
if isinstance(func, (classmethod, staticmethod)):

once/_iterator_wrappers.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,19 @@ class _GeneratorWrapperBase(abc.ABC):
6464
"""
6565

6666
def __init__(
67-
self, reset_on_exception: bool, func: collections.abc.Callable, *args, **kwargs
67+
self,
68+
reset_on_exception: bool,
69+
func: collections.abc.Callable,
70+
allow_reset: bool = False,
71+
*args,
72+
**kwargs,
6873
) -> None:
6974
self.callable: collections.abc.Callable | None = functools.partial(func, *args, **kwargs)
7075
self.generator = self.callable()
7176
self.result = IteratorResults()
7277
self.generating = False
7378
self.reset_on_exception = reset_on_exception
79+
self.allow_reset = allow_reset
7480

7581
# Why do we make the generating boolean property abstract?
7682
# This makes the code when the iterator state is WAITING more efficient. If this was simply
@@ -129,24 +135,30 @@ def record_successful_completion(self, result: IteratorResults):
129135
result.finished = True
130136
self.generating = False
131137
self.generator = None # Allow this to be GCed.
132-
self.callable = None # Allow this to be GCed.
138+
if not self.allow_reset:
139+
# Allow this to be GCed as long as we know we'll never need it again.
140+
self.callable = None
133141

134142
def record_item(self, result: IteratorResults, item: typing.Any):
135143
"""Must be called with lock."""
136144
self.generating = False
137145
result.items.append(item)
138146

147+
def reset(self):
148+
"""Must be called with lock."""
149+
self.result = IteratorResults()
150+
assert self.callable is not None
151+
self.generator = self.callable() # Reset the iterator for the next call.
152+
139153
def record_exception(self, result: IteratorResults, exception: Exception):
140154
"""Must be called with lock."""
141155
result.finished = True
142156
# We need to keep track of the exception so that we can raise it in the same
143157
# position every time the iterator is called.
144158
result.exception = exception
145159
self.generating = False
146-
assert self.callable is not None
147-
self.generator = self.callable() # Reset the iterator for the next call.
148160
if self.reset_on_exception:
149-
self.result = IteratorResults()
161+
self.reset()
150162
else:
151163
self.generator = None # allow this to be GCed
152164

0 commit comments

Comments
 (0)