diff --git a/README.md b/README.md index 26c2eaf2..f46a5e40 100644 --- a/README.md +++ b/README.md @@ -82,31 +82,31 @@ import asyncio async def main(): - bus = await MessageBus().connect() - # the introspection xml would normally be included in your project, but - # this is convenient for development - introspection = await bus.introspect('org.mpris.MediaPlayer2.vlc', '/org/mpris/MediaPlayer2') + async with MessageBus() as bus: + # the introspection xml would normally be included in your project, but + # this is convenient for development + introspection = await bus.introspect('org.mpris.MediaPlayer2.vlc', '/org/mpris/MediaPlayer2') - obj = bus.get_proxy_object('org.mpris.MediaPlayer2.vlc', '/org/mpris/MediaPlayer2', introspection) - player = obj.get_interface('org.mpris.MediaPlayer2.Player') - properties = obj.get_interface('org.freedesktop.DBus.Properties') + obj = bus.get_proxy_object('org.mpris.MediaPlayer2.vlc', '/org/mpris/MediaPlayer2', introspection) + player = obj.get_interface('org.mpris.MediaPlayer2.Player') + properties = obj.get_interface('org.freedesktop.DBus.Properties') - # call methods on the interface (this causes the media player to play) - await player.call_play() + # call methods on the interface (this causes the media player to play) + await player.call_play() - volume = await player.get_volume() - print(f'current volume: {volume}, setting to 0.5') + volume = await player.get_volume() + print(f'current volume: {volume}, setting to 0.5') - await player.set_volume(0.5) + await player.set_volume(0.5) - # listen to signals - def on_properties_changed(interface_name, changed_properties, invalidated_properties): - for changed, variant in changed_properties.items(): - print(f'property changed: {changed} - {variant.value}') + # listen to signals + def on_properties_changed(interface_name, changed_properties, invalidated_properties): + for changed, variant in changed_properties.items(): + print(f'property changed: {changed} - {variant.value}') - properties.on_properties_changed(on_properties_changed) + properties.on_properties_changed(on_properties_changed) - await asyncio.Event().wait() + await asyncio.Event().wait() asyncio.run(main()) ``` @@ -155,13 +155,13 @@ class ExampleInterface(ServiceInterface): return 'hello' async def main(): - bus = await MessageBus().connect() - interface = ExampleInterface('test.interface') - bus.export('/test/path', interface) - # now that we are ready to handle requests, we can request name from D-Bus - await bus.request_name('test.name') - # wait indefinitely - await asyncio.Event().wait() + async with MessageBus() as bus: + interface = ExampleInterface('test.interface') + bus.export('/test/path', interface) + # now that we are ready to handle requests, we can request name from D-Bus + await bus.request_name('test.name') + # wait indefinitely + await asyncio.Event().wait() asyncio.run(main()) ``` @@ -181,18 +181,17 @@ import json async def main(): - bus = await MessageBus().connect() + async with MessageBus() as bus: + reply = await bus.call( + Message(destination='org.freedesktop.DBus', + path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus', + member='ListNames')) - reply = await bus.call( - Message(destination='org.freedesktop.DBus', - path='/org/freedesktop/DBus', - interface='org.freedesktop.DBus', - member='ListNames')) + if reply.message_type == MessageType.ERROR: + raise Exception(reply.body[0]) - if reply.message_type == MessageType.ERROR: - raise Exception(reply.body[0]) - - print(json.dumps(reply.body[0], indent=2)) + print(json.dumps(reply.body[0], indent=2)) asyncio.run(main()) diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 5b5fe0ba..bf37e837 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -219,6 +219,18 @@ def __init__( self._disconnect_future = self._loop.create_future() self._pending_futures: set[asyncio.Future] = set() + async def __aenter__(self) -> MessageBus: + try: + return await self.connect() + except BaseException: + self.disconnect() + await self.wait_for_disconnect() + raise + + async def __aexit__(self, *args, **kwargs) -> None: + self.disconnect() + await self.wait_for_disconnect() + async def connect(self) -> MessageBus: """Connect this message bus to the DBus daemon. diff --git a/src/dbus_fast/glib/message_bus.py b/src/dbus_fast/glib/message_bus.py index 9b4db5ea..d61ee08e 100644 --- a/src/dbus_fast/glib/message_bus.py +++ b/src/dbus_fast/glib/message_bus.py @@ -178,6 +178,12 @@ def __init__( else: self._auth = auth + def __enter__(self) -> "MessageBus": + return self.connect_sync() + + def __exit__(self, *args, **kwargs) -> None: + self.disconnect() + def _on_message(self, msg: Message) -> None: try: self._process_message(msg) diff --git a/tests/test_disconnect.py b/tests/test_disconnect.py index 057c9cd8..e3d13bbe 100644 --- a/tests/test_disconnect.py +++ b/tests/test_disconnect.py @@ -74,3 +74,14 @@ def send(self, *args, **kwargs): bus.disconnect() with pytest.raises(OSError): await asyncio.wait_for(bus.wait_for_disconnect(), timeout=1) + + +@pytest.mark.asyncio +async def test_context_manager(): + bus = MessageBus() + + assert not bus.connected + async with bus: + assert bus.connected + + assert not bus.connected diff --git a/tests/test_glib_low_level.py b/tests/test_glib_low_level.py index 49078219..f2fef901 100644 --- a/tests/test_glib_low_level.py +++ b/tests/test_glib_low_level.py @@ -178,3 +178,14 @@ def message_handler(signal): bus1.disconnect() bus2.disconnect() + + +@pytest.mark.skipif(not has_gi, reason=skip_reason_no_gi) +def test_context_manager(): + bus = MessageBus() + + assert not bus.connected + with bus: + assert bus.connected + + assert not bus.connected