Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/dbus_fast/aio/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ async def connect(self) -> MessageBus:
the DBus daemon failed.
- :class:`Exception` - If there was a connection error.
"""
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs for this method need a ..versionchanged: X.Y.Z and to have the :raises: section updated.

And we should add ..versionchanged: X.Y.Z to class MessageBus as well.

await self._loop.sock_connect(self._sock, self._sock_connect_address)
except Exception:
self._stream.close()
self._sock.close()
raise
await self._authenticate()

future = self._loop.create_future()
Expand Down
63 changes: 62 additions & 1 deletion src/dbus_fast/glib/message_bus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import errno
import io
import logging
import socket
import traceback
from collections.abc import Callable

Expand Down Expand Up @@ -131,6 +133,34 @@ def dispatch(self, callback, user_data):
return GLib.SOURCE_CONTINUE


class _ConnectSource(_GLibSource):
Copy link
Member

@bdraco bdraco Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep glib the same as it was before and only fix asyncio to contain scope. We could pass a flag to defer connect for asyncio.

I don't think any of the maintainers use the glib version so we try not to touch it

"""GLib source to wait for async socket connection to complete."""

def __init__(self, sock, address):
self.sock = sock
self.address = address
self._connected = False
self._error = None

def prepare(self):
return (False, -1)

def check(self):
return False

def dispatch(self, callback, user_data):
# Check if connection completed
err = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
self._error = OSError(err, errno.errorcode.get(err, "Unknown error"))
callback(self._error)
return GLib.SOURCE_REMOVE

self._connected = True
callback(None)
return GLib.SOURCE_REMOVE


class MessageBus(BaseMessageBus):
"""The message bus implementation for use with the GLib main loop.

Expand Down Expand Up @@ -241,7 +271,38 @@ def on_hello(reply, err):
self._stream.write(hello_msg._marshall(False))
self._stream.flush()

self._authenticate(authenticate_notify)
def on_socket_connect(err):
if err is not None:
if connect_notify is not None:
connect_notify(None, err)
return
self._authenticate(authenticate_notify)

# Start async socket connection
try:
self._sock.connect(self._sock_connect_address)
# Connected immediately (e.g., local socket)
self._authenticate(authenticate_notify)
except BlockingIOError:
# Connection in progress, wait for it to complete
connect_source = _ConnectSource(self._sock, self._sock_connect_address)
connect_source.set_callback(on_socket_connect)
connect_source.add_unix_fd(self._fd, GLib.IO_OUT)
connect_source.attach(self._main_context)
# Keep a reference to prevent garbage collection
self._connect_source = connect_source
except OSError as e:
if e.errno == errno.EINPROGRESS:
# Connection in progress, wait for it to complete
connect_source = _ConnectSource(self._sock, self._sock_connect_address)
connect_source.set_callback(on_socket_connect)
connect_source.add_unix_fd(self._fd, GLib.IO_OUT)
connect_source.attach(self._main_context)
self._connect_source = connect_source
elif connect_notify is not None:
connect_notify(None, e)
else:
raise

def connect_sync(self) -> "MessageBus":
"""Connect this message bus to the DBus daemon.
Expand Down
3 changes: 3 additions & 0 deletions src/dbus_fast/message_bus.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cdef class BaseMessageBus:
cdef public object _machine_id
cdef public bint _negotiate_unix_fd
cdef public object _sock
cdef public object _sock_connect_address
cdef public object _stream
cdef public object _fd

Expand All @@ -62,6 +63,8 @@ cdef class BaseMessageBus:

cdef _setup_socket(self)

cdef _connect_socket(self)

cpdef _call(self, Message msg, object callback)

cpdef next_serial(self)
Expand Down
64 changes: 38 additions & 26 deletions src/dbus_fast/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class BaseMessageBus:
"_path_exports",
"_serial",
"_sock",
"_sock_connect_address",
"_stream",
"_user_disconnect",
"_user_message_handlers",
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self._sock: socket.socket | None = None
self._fd: int | None = None
self._stream: io.BufferedRWPair | None = None
self._sock_connect_address: bytes | str | tuple[str, int] | None = None

self._setup_socket()

Expand Down Expand Up @@ -681,8 +683,15 @@ def _introspect_export_path(self, path: str) -> intr.Node:
return node

def _setup_socket(self) -> None:
last_err: Exception | None = None
"""Create and configure the socket without connecting.

This method creates the socket and prepares the connection address,
but does not perform the actual connection. Call _connect_socket()
to complete the connection synchronously, or use async socket connect
in async implementations.

Sets self._sock, self._stream, self._fd, and self._sock_connect_address.
"""
for transport, options in self._bus_address:
filename: bytes | str | None = None
ip_addr = ""
Expand All @@ -705,16 +714,13 @@ def _setup_socket(self) -> None:
"got unix transport with unknown path specifier"
)

try:
self._sock.connect(filename)
self._sock.setblocking(False)
except Exception as e:
last_err = e
else:
stack.pop_all() # responsibility to close sockets is deferred
return
# Store connect address for later; don't connect yet
self._sock_connect_address: bytes | str | tuple[str, int] = filename
self._sock.setblocking(False)
stack.pop_all() # responsibility to close sockets is deferred
return

elif transport == "tcp":
if transport == "tcp":
self._sock = stack.enter_context(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
)
Expand All @@ -726,25 +732,31 @@ def _setup_socket(self) -> None:
if "port" in options:
ip_port = int(options["port"])

try:
self._sock.connect((ip_addr, ip_port))
self._sock.setblocking(False)
except Exception as e:
last_err = e
else:
stack.pop_all()
return
# Store connect address for later; don't connect yet
self._sock_connect_address = (ip_addr, ip_port)
self._sock.setblocking(False)
stack.pop_all()
return

else:
raise InvalidAddressError(
f"got unknown address transport: {transport}"
)
raise InvalidAddressError(f"got unknown address transport: {transport}")

# Should not normally happen, but just in case
raise TypeError("empty list of bus addresses given") # pragma: no cover
Comment on lines 717 to 744
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the try/excepts here breaks the automatic fallback to "tcp" if connecting the "unix" transport fails. This entire method will need to be moved to the connect method.


if last_err is None: # pragma: no branch
# Should not normally happen, but just in case
raise TypeError("empty list of bus addresses given") # pragma: no cover
def _connect_socket(self) -> None:
"""Perform the blocking socket connection.

raise last_err
This is used by synchronous implementations (like glib's connect_sync).
Async implementations should use their event loop's async socket connect
(e.g., loop.sock_connect) instead.

:raises: Connection errors from socket.connect()
"""
self._sock.setblocking(True)
try:
self._sock.connect(self._sock_connect_address)
finally:
self._sock.setblocking(False)

def _reply_notify(
self,
Expand Down
26 changes: 10 additions & 16 deletions tests/test_message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
@pytest.mark.asyncio
async def test_tcp_socket_cleanup_on_connect_fail() -> None:
"""Test that socket resources are cleaned up on a failed TCP connection."""

# A bit ugly, but we need to access members of the class after __init__()
# raises, so we need to split __new__() and __init__().
bus = MessageBus.__new__(MessageBus)
bus = MessageBus("tcp:host=127.0.0.1,port=1")

with pytest.raises(ConnectionRefusedError):
bus.__init__("tcp:host=127.0.0.1,port=1")
await bus.connect()

assert bus._stream.closed
assert bus._sock._closed
Expand All @@ -22,13 +19,10 @@ async def test_tcp_socket_cleanup_on_connect_fail() -> None:
@pytest.mark.asyncio
async def test_unix_socket_cleanup_on_connect_fail() -> None:
"""Test that socket resources are cleaned up on a failed Unix socket connection."""

# A bit ugly, but we need to access members of the class after __init__()
# raises, so we need to split __new__() and __init__().
bus = MessageBus.__new__(MessageBus)
bus = MessageBus("unix:path=/there-is-no-way-that-this-file-should-exist")

with pytest.raises(FileNotFoundError):
bus.__init__("unix:path=/there-is-no-way-that-this-file-should-exist")
await bus.connect()

assert bus._stream.closed
assert bus._sock._closed
Expand All @@ -37,11 +31,11 @@ async def test_unix_socket_cleanup_on_connect_fail() -> None:
@pytest.mark.asyncio
async def test_tcp_socket_cleanup_with_host_only() -> None:
"""Test TCP connection with host option only (no port)."""
bus = MessageBus.__new__(MessageBus)
bus = MessageBus("tcp:host=127.0.0.1")

with pytest.raises(OSError):
# Port defaults to 0, which will fail
bus.__init__("tcp:host=127.0.0.1")
await bus.connect()

assert bus._stream.closed
assert bus._sock._closed
Expand All @@ -50,11 +44,11 @@ async def test_tcp_socket_cleanup_with_host_only() -> None:
@pytest.mark.asyncio
async def test_tcp_socket_cleanup_with_port_only() -> None:
"""Test TCP connection with port option only (no host)."""
bus = MessageBus.__new__(MessageBus)
bus = MessageBus("tcp:port=1")

with pytest.raises(OSError):
# Host defaults to empty string, which will fail
bus.__init__("tcp:port=1")
await bus.connect()

assert bus._stream.closed
assert bus._sock._closed
Expand All @@ -63,11 +57,11 @@ async def test_tcp_socket_cleanup_with_port_only() -> None:
@pytest.mark.asyncio
async def test_unix_socket_abstract_cleanup_on_connect_fail() -> None:
"""Test that socket resources are cleaned up on a failed abstract Unix socket connection."""
bus = MessageBus.__new__(MessageBus)
bus = MessageBus("unix:abstract=/tmp/nonexistent-abstract-socket")

# On Linux: ConnectionRefusedError, on macOS: FileNotFoundError
with pytest.raises((FileNotFoundError, ConnectionRefusedError)):
bus.__init__("unix:abstract=/tmp/nonexistent-abstract-socket")
await bus.connect()

assert bus._stream.closed
assert bus._sock._closed
Expand Down
Loading