Skip to content

Commit 473a7cb

Browse files
committed
add a pubsub database implementation
1 parent c91443a commit 473a7cb

File tree

6 files changed

+400
-12
lines changed

6 files changed

+400
-12
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: minor
2+
3+
The :doc:`Hypothesis database <database>` now supports a pub-sub interface to efficiently listen for changes in the database, via ``.add_listener`` and ``.remove_listener``. While all databases that ship with Hypothesis support this interface, implementing it is not required for custom database subclasses. Hypothesis will warn when trying to listen on a database without support.
4+
5+
This feature is currently only used downstream in `hypofuzz <https://github.com/zac-hd/hypofuzz>`_.

hypothesis-python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def local_file(name):
7070
# We also leave the choice of timezone library to the user, since it
7171
# might be zoneinfo or pytz depending on version and configuration.
7272
"django": ["django>=4.2"],
73+
"watchdog": ["watchdog>=4.0.0"],
7374
}
7475

7576
extras["all"] = sorted(set(sum(extras.values(), [])))

hypothesis-python/src/hypothesis/database.py

Lines changed: 169 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pathlib import Path, PurePath
2525
from queue import Queue
2626
from threading import Thread
27-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
27+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
2828
from urllib.error import HTTPError, URLError
2929
from urllib.request import Request, urlopen
3030
from zipfile import BadZipFile, ZipFile
@@ -47,6 +47,7 @@
4747
from typing import TypeAlias
4848

4949
StrPathT: "TypeAlias" = Union[str, PathLike[str]]
50+
ListenerT: "TypeAlias" = Callable[[], Any]
5051

5152

5253
def _usable_dir(path: StrPathT) -> bool:
@@ -121,9 +122,13 @@ class ExampleDatabase(metaclass=_EDMeta):
121122
"""An abstract base class for storing examples in Hypothesis' internal format.
122123
123124
An ExampleDatabase maps each ``bytes`` key to many distinct ``bytes``
124-
values, like a ``Mapping[bytes, AbstractSet[bytes]]``.
125+
values, like a ``Mapping[bytes, set[bytes]]``.
125126
"""
126127

128+
def __init__(self) -> None:
129+
self._listeners: list[ListenerT] = []
130+
self._listening = False
131+
127132
@abc.abstractmethod
128133
def save(self, key: bytes, value: bytes) -> None:
129134
"""Save ``value`` under ``key``.
@@ -159,6 +164,77 @@ def move(self, src: bytes, dest: bytes, value: bytes) -> None:
159164
self.delete(src, value)
160165
self.save(dest, value)
161166

167+
def add_listener(self, f: ListenerT, /) -> None:
168+
"""Add a change listener."""
169+
self._listeners.append(f)
170+
self._update_listening()
171+
172+
def remove_listener(self, f: ListenerT, /) -> None:
173+
"""
174+
Remove a change listener. If the listener is not present, silently do
175+
nothing.
176+
"""
177+
self._listeners.remove(f)
178+
self._update_listening()
179+
180+
def clear_listeners(self) -> None:
181+
"""Remove all change listeners."""
182+
self._listeners.clear()
183+
self._update_listening()
184+
185+
def _update_listening(self) -> None:
186+
# - start listening if we're moving from zero to some listeners
187+
# - stop listening if we're moving from some to zero listeners
188+
if not self._listening and self._listeners:
189+
self._start_listening()
190+
self._listening = True
191+
elif self._listening and not self._listeners:
192+
self._stop_listening()
193+
self._listening = False
194+
195+
def _broadcast_changed(self) -> None:
196+
"""
197+
Called when a change has been made to the database that would cause
198+
.fetch to return something different than it did before. If your database
199+
implementation supports change listening, this method should be called
200+
whenever an item is added to or deleted from the underlying database store.
201+
202+
Note that you should not assume you are the only reference to the underlying
203+
database store. For example, if two DirectoryBasedExampleDatabase reference
204+
the same directory, _broadcast_changed should be called whenever a file is
205+
added or removed from the directory, even if that database was not responsible
206+
for changing the file.
207+
"""
208+
for listener in self._listeners:
209+
listener()
210+
211+
def _start_listening(self) -> None:
212+
"""
213+
Called when the database adds a change listener, and did not previously
214+
have any change listeners. Intended to allow databases to wait to start
215+
expensive listening operations until necessary.
216+
217+
_start_listening and _stop_listening are guaranteed to alternate, so you
218+
do not need to handle the case of multiple consecutive _start_listening
219+
calls without an intermediate _stop_listening call.
220+
"""
221+
warnings.warn(
222+
f"{self.__class__} does not support listening for changes", stacklevel=4
223+
)
224+
225+
def _stop_listening(self) -> None:
226+
"""
227+
Called whenever no change listeners remain on the database.
228+
229+
_stop_listening and _start_listening are guaranteed to alternate, so you
230+
do not need to handle the case of multiple consecutive _stop_listening
231+
calls without an intermediate _start_listening call.
232+
"""
233+
warnings.warn(
234+
f"{self.__class__} does not support stopping listening for changes",
235+
stacklevel=4,
236+
)
237+
162238

163239
class InMemoryExampleDatabase(ExampleDatabase):
164240
"""A non-persistent example database, implemented in terms of a dict of sets.
@@ -169,6 +245,7 @@ class InMemoryExampleDatabase(ExampleDatabase):
169245
"""
170246

171247
def __init__(self) -> None:
248+
super().__init__()
172249
self.data: dict[bytes, set[bytes]] = {}
173250

174251
def __repr__(self) -> str:
@@ -178,10 +255,31 @@ def fetch(self, key: bytes) -> Iterable[bytes]:
178255
yield from self.data.get(key, ())
179256

180257
def save(self, key: bytes, value: bytes) -> None:
181-
self.data.setdefault(key, set()).add(bytes(value))
258+
value = bytes(value)
259+
values = self.data.setdefault(key, set())
260+
changed = value not in values
261+
values.add(value)
262+
263+
if changed:
264+
self._broadcast_changed()
182265

183266
def delete(self, key: bytes, value: bytes) -> None:
184-
self.data.get(key, set()).discard(bytes(value))
267+
value = bytes(value)
268+
values = self.data.get(key, set())
269+
changed = value in values
270+
values.discard(value)
271+
272+
if changed:
273+
self._broadcast_changed()
274+
275+
def _start_listening(self) -> None:
276+
# declare compatibility with the listener api, but do the actual
277+
# implementation in .delete and .save, since we know we are the only
278+
# writer to .data.
279+
pass
280+
281+
def _stop_listening(self) -> None:
282+
pass
185283

186284

187285
def _hash(key: bytes) -> str:
@@ -208,8 +306,12 @@ class DirectoryBasedExampleDatabase(ExampleDatabase):
208306
"""
209307

210308
def __init__(self, path: StrPathT) -> None:
309+
super().__init__()
211310
self.path = Path(path)
212311
self.keypaths: dict[bytes, Path] = {}
312+
# would type as watchdog.observers.Observer | None, but don't want to
313+
# mess with a conditional import just for type checking.
314+
self._observer: Any = None
213315

214316
def __repr__(self) -> str:
215317
return f"DirectoryBasedExampleDatabase({self.path!r})"
@@ -278,6 +380,44 @@ def delete(self, key: bytes, value: bytes) -> None:
278380
except OSError:
279381
pass
280382

383+
def _start_listening(self) -> None:
384+
try:
385+
from watchdog.events import (
386+
FileCreatedEvent,
387+
FileDeletedEvent,
388+
FileSystemEvent,
389+
FileSystemEventHandler,
390+
)
391+
from watchdog.observers import Observer
392+
except ImportError:
393+
warnings.warn(
394+
f"listening for changes in a {self.__class__.__name__} "
395+
"requires the watchdog library. To install, run "
396+
"`pip install hypothesis[watchdog]`",
397+
stacklevel=4,
398+
)
399+
return
400+
401+
_broadcast_changed = self._broadcast_changed
402+
403+
class Handler(FileSystemEventHandler):
404+
def on_any_event(self, event: FileSystemEvent) -> None:
405+
_broadcast_changed()
406+
407+
self._observer = Observer()
408+
self._observer.schedule(
409+
Handler(),
410+
self.path, # type: ignore # upstream type is too narrow (str only)
411+
recursive=True,
412+
event_filter=[FileCreatedEvent, FileDeletedEvent],
413+
)
414+
self._observer.start()
415+
416+
def _stop_listening(self) -> None:
417+
self._observer.stop()
418+
self._observer.join()
419+
self._observer = None
420+
281421

282422
class ReadOnlyDatabase(ExampleDatabase):
283423
"""A wrapper to make the given database read-only.
@@ -291,6 +431,7 @@ class ReadOnlyDatabase(ExampleDatabase):
291431
"""
292432

293433
def __init__(self, db: ExampleDatabase) -> None:
434+
super().__init__()
294435
assert isinstance(db, ExampleDatabase)
295436
self._wrapped = db
296437

@@ -306,6 +447,13 @@ def save(self, key: bytes, value: bytes) -> None:
306447
def delete(self, key: bytes, value: bytes) -> None:
307448
pass
308449

450+
def _start_listening(self) -> None:
451+
# we're read only, so there are no changes to broadcast.
452+
pass
453+
454+
def _stop_listening(self) -> None:
455+
pass
456+
309457

310458
class MultiplexedDatabase(ExampleDatabase):
311459
"""A wrapper around multiple databases.
@@ -334,6 +482,7 @@ class MultiplexedDatabase(ExampleDatabase):
334482
"""
335483

336484
def __init__(self, *dbs: ExampleDatabase) -> None:
485+
super().__init__()
337486
assert all(isinstance(db, ExampleDatabase) for db in dbs)
338487
self._wrapped = dbs
339488

@@ -360,6 +509,14 @@ def move(self, src: bytes, dest: bytes, value: bytes) -> None:
360509
for db in self._wrapped:
361510
db.move(src, dest, value)
362511

512+
def _start_listening(self) -> None:
513+
for db in self._wrapped:
514+
db.add_listener(self._broadcast_changed)
515+
516+
def _stop_listening(self) -> None:
517+
for db in self._wrapped:
518+
db.remove_listener(self._broadcast_changed)
519+
363520

364521
class GitHubArtifactDatabase(ExampleDatabase):
365522
"""
@@ -439,6 +596,7 @@ def __init__(
439596
cache_timeout: timedelta = timedelta(days=1),
440597
path: Optional[StrPathT] = None,
441598
):
599+
super().__init__()
442600
self.owner = owner
443601
self.repo = repo
444602
self.artifact_name = artifact_name
@@ -699,6 +857,7 @@ class BackgroundWriteDatabase(ExampleDatabase):
699857
"""
700858

701859
def __init__(self, db: ExampleDatabase) -> None:
860+
super().__init__()
702861
self._db = db
703862
self._queue: Queue[tuple[str, tuple[bytes, ...]]] = Queue()
704863
self._thread = Thread(target=self._worker, daemon=True)
@@ -735,6 +894,12 @@ def delete(self, key: bytes, value: bytes) -> None:
735894
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
736895
self._queue.put(("move", (src, dest, value)))
737896

897+
def _start_listening(self) -> None:
898+
self._db.add_listener(self._broadcast_changed)
899+
900+
def _stop_listening(self) -> None:
901+
self._db.remove_listener(self._broadcast_changed)
902+
738903

739904
def _pack_uleb128(value: int) -> bytes:
740905
"""

hypothesis-python/src/hypothesis/extra/redis.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import Iterable
1212
from contextlib import contextmanager
1313
from datetime import timedelta
14+
from typing import Any
1415

1516
from redis import Redis
1617

@@ -36,31 +37,42 @@ def __init__(
3637
*,
3738
expire_after: timedelta = timedelta(days=8),
3839
key_prefix: bytes = b"hypothesis-example:",
40+
listener_channel: str = "hypothesis-changes",
3941
):
42+
super().__init__()
4043
check_type(Redis, redis, "redis")
4144
check_type(timedelta, expire_after, "expire_after")
4245
check_type(bytes, key_prefix, "key_prefix")
46+
check_type(str, listener_channel, "listener_channel")
4347
self.redis = redis
4448
self._expire_after = expire_after
4549
self._prefix = key_prefix
50+
self.listener_channel = listener_channel
51+
self._pubsub: Any = None
4652

4753
def __repr__(self) -> str:
4854
return (
4955
f"RedisExampleDatabase({self.redis!r}, expire_after={self._expire_after!r})"
5056
)
5157

5258
@contextmanager
53-
def _pipeline(self, *reset_expire_keys, transaction=False, auto_execute=True):
59+
def _pipeline(self, *reset_expire_keys, execute_and_publish=True):
5460
# Context manager to batch updates and expiry reset, reducing TCP roundtrips
55-
pipe = self.redis.pipeline(transaction=transaction)
61+
pipe = self.redis.pipeline()
5662
yield pipe
5763
for key in reset_expire_keys:
5864
pipe.expire(self._prefix + key, self._expire_after)
59-
if auto_execute:
60-
pipe.execute()
65+
if execute_and_publish:
66+
# pipe.execute returns a value for each operation, which includes
67+
# whatever we did in the yield as a prefix, and the n operations from
68+
# pipe.expire as a suffix. remove that suffix to get just the prefix.
69+
values = pipe.execute()[: -len(reset_expire_keys)]
70+
# only publish if anything changed
71+
if any(value > 0 for value in values):
72+
self.redis.publish(self.listener_channel, "change")
6173

6274
def fetch(self, key: bytes) -> Iterable[bytes]:
63-
with self._pipeline(key, auto_execute=False) as pipe:
75+
with self._pipeline(key, execute_and_publish=False) as pipe:
6476
pipe.smembers(self._prefix + key)
6577
yield from pipe.execute()[0]
6678

@@ -76,3 +88,18 @@ def move(self, src: bytes, dest: bytes, value: bytes) -> None:
7688
with self._pipeline(src, dest) as pipe:
7789
pipe.srem(self._prefix + src, value)
7890
pipe.sadd(self._prefix + dest, value)
91+
92+
def _handle_message(self, message: dict) -> None:
93+
# other message types include "subscribe" and "unsubscribe". these are
94+
# sent to the client, but not to the pubsub channel.
95+
assert message["type"] == "message"
96+
self._broadcast_changed()
97+
98+
def _start_listening(self) -> None:
99+
self._pubsub = self.redis.pubsub()
100+
self._pubsub.subscribe(**{self.listener_channel: self._handle_message})
101+
102+
def _stop_listening(self) -> None:
103+
self._pubsub.unsubscribe()
104+
self._pubsub.close()
105+
self._pubsub = None

0 commit comments

Comments
 (0)