-
-
Notifications
You must be signed in to change notification settings - Fork 27
fix!: defer socket.connect() from __init__ to connect() #570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import errno | ||
| import io | ||
| import logging | ||
| import socket | ||
| import traceback | ||
| from collections.abc import Callable | ||
|
|
||
|
|
@@ -131,6 +133,34 @@ def dispatch(self, callback, user_data): | |
| return GLib.SOURCE_CONTINUE | ||
|
|
||
|
|
||
| class _ConnectSource(_GLibSource): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe keep glib the same as it was before and only fix asyncio to contain scope. We could pass a flag to defer connect for asyncio. I don't think any of the maintainers use the glib version so we try not to touch it |
||
| """GLib source to wait for async socket connection to complete.""" | ||
|
|
||
| def __init__(self, sock, address): | ||
| self.sock = sock | ||
| self.address = address | ||
| self._connected = False | ||
| self._error = None | ||
|
|
||
| def prepare(self): | ||
| return (False, -1) | ||
|
|
||
| def check(self): | ||
| return False | ||
|
|
||
| def dispatch(self, callback, user_data): | ||
| # Check if connection completed | ||
| err = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) | ||
| if err != 0: | ||
| self._error = OSError(err, errno.errorcode.get(err, "Unknown error")) | ||
| callback(self._error) | ||
| return GLib.SOURCE_REMOVE | ||
|
|
||
| self._connected = True | ||
| callback(None) | ||
| return GLib.SOURCE_REMOVE | ||
|
|
||
|
|
||
| class MessageBus(BaseMessageBus): | ||
| """The message bus implementation for use with the GLib main loop. | ||
|
|
||
|
|
@@ -241,7 +271,38 @@ def on_hello(reply, err): | |
| self._stream.write(hello_msg._marshall(False)) | ||
| self._stream.flush() | ||
|
|
||
| self._authenticate(authenticate_notify) | ||
| def on_socket_connect(err): | ||
| if err is not None: | ||
| if connect_notify is not None: | ||
| connect_notify(None, err) | ||
| return | ||
| self._authenticate(authenticate_notify) | ||
|
|
||
| # Start async socket connection | ||
| try: | ||
| self._sock.connect(self._sock_connect_address) | ||
| # Connected immediately (e.g., local socket) | ||
| self._authenticate(authenticate_notify) | ||
| except BlockingIOError: | ||
| # Connection in progress, wait for it to complete | ||
| connect_source = _ConnectSource(self._sock, self._sock_connect_address) | ||
| connect_source.set_callback(on_socket_connect) | ||
| connect_source.add_unix_fd(self._fd, GLib.IO_OUT) | ||
| connect_source.attach(self._main_context) | ||
| # Keep a reference to prevent garbage collection | ||
| self._connect_source = connect_source | ||
| except OSError as e: | ||
| if e.errno == errno.EINPROGRESS: | ||
| # Connection in progress, wait for it to complete | ||
| connect_source = _ConnectSource(self._sock, self._sock_connect_address) | ||
| connect_source.set_callback(on_socket_connect) | ||
| connect_source.add_unix_fd(self._fd, GLib.IO_OUT) | ||
| connect_source.attach(self._main_context) | ||
| self._connect_source = connect_source | ||
| elif connect_notify is not None: | ||
| connect_notify(None, e) | ||
| else: | ||
| raise | ||
|
|
||
| def connect_sync(self) -> "MessageBus": | ||
| """Connect this message bus to the DBus daemon. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,6 +118,7 @@ class BaseMessageBus: | |
| "_path_exports", | ||
| "_serial", | ||
| "_sock", | ||
| "_sock_connect_address", | ||
| "_stream", | ||
| "_user_disconnect", | ||
| "_user_message_handlers", | ||
|
|
@@ -172,6 +173,7 @@ def __init__( | |
| self._sock: socket.socket | None = None | ||
| self._fd: int | None = None | ||
| self._stream: io.BufferedRWPair | None = None | ||
| self._sock_connect_address: bytes | str | tuple[str, int] | None = None | ||
|
|
||
| self._setup_socket() | ||
|
|
||
|
|
@@ -681,8 +683,15 @@ def _introspect_export_path(self, path: str) -> intr.Node: | |
| return node | ||
|
|
||
| def _setup_socket(self) -> None: | ||
| last_err: Exception | None = None | ||
| """Create and configure the socket without connecting. | ||
|
|
||
| This method creates the socket and prepares the connection address, | ||
| but does not perform the actual connection. Call _connect_socket() | ||
| to complete the connection synchronously, or use async socket connect | ||
| in async implementations. | ||
|
|
||
| Sets self._sock, self._stream, self._fd, and self._sock_connect_address. | ||
| """ | ||
| for transport, options in self._bus_address: | ||
| filename: bytes | str | None = None | ||
| ip_addr = "" | ||
|
|
@@ -705,16 +714,13 @@ def _setup_socket(self) -> None: | |
| "got unix transport with unknown path specifier" | ||
| ) | ||
|
|
||
| try: | ||
| self._sock.connect(filename) | ||
| self._sock.setblocking(False) | ||
| except Exception as e: | ||
| last_err = e | ||
| else: | ||
| stack.pop_all() # responsibility to close sockets is deferred | ||
| return | ||
| # Store connect address for later; don't connect yet | ||
| self._sock_connect_address: bytes | str | tuple[str, int] = filename | ||
| self._sock.setblocking(False) | ||
| stack.pop_all() # responsibility to close sockets is deferred | ||
| return | ||
|
|
||
| elif transport == "tcp": | ||
| if transport == "tcp": | ||
| self._sock = stack.enter_context( | ||
| socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
| ) | ||
|
|
@@ -726,25 +732,31 @@ def _setup_socket(self) -> None: | |
| if "port" in options: | ||
| ip_port = int(options["port"]) | ||
|
|
||
| try: | ||
| self._sock.connect((ip_addr, ip_port)) | ||
| self._sock.setblocking(False) | ||
| except Exception as e: | ||
| last_err = e | ||
| else: | ||
| stack.pop_all() | ||
| return | ||
| # Store connect address for later; don't connect yet | ||
| self._sock_connect_address = (ip_addr, ip_port) | ||
| self._sock.setblocking(False) | ||
| stack.pop_all() | ||
| return | ||
|
|
||
| else: | ||
| raise InvalidAddressError( | ||
| f"got unknown address transport: {transport}" | ||
| ) | ||
| raise InvalidAddressError(f"got unknown address transport: {transport}") | ||
|
|
||
| # Should not normally happen, but just in case | ||
| raise TypeError("empty list of bus addresses given") # pragma: no cover | ||
|
Comment on lines
717
to
744
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the |
||
|
|
||
| if last_err is None: # pragma: no branch | ||
| # Should not normally happen, but just in case | ||
| raise TypeError("empty list of bus addresses given") # pragma: no cover | ||
| def _connect_socket(self) -> None: | ||
| """Perform the blocking socket connection. | ||
|
|
||
| raise last_err | ||
| This is used by synchronous implementations (like glib's connect_sync). | ||
| Async implementations should use their event loop's async socket connect | ||
| (e.g., loop.sock_connect) instead. | ||
|
|
||
| :raises: Connection errors from socket.connect() | ||
| """ | ||
| self._sock.setblocking(True) | ||
| try: | ||
| self._sock.connect(self._sock_connect_address) | ||
| finally: | ||
| self._sock.setblocking(False) | ||
|
|
||
| def _reply_notify( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docs for this method need a
..versionchanged: X.Y.Zand to have the:raises:section updated.And we should add
..versionchanged: X.Y.Ztoclass MessageBusas well.