diff --git a/docs_src/src/pages/documentation/en/api_reference/websockets.mdx b/docs_src/src/pages/documentation/en/api_reference/websockets.mdx index 39eeaf737..acdf866d3 100644 --- a/docs_src/src/pages/documentation/en/api_reference/websockets.mdx +++ b/docs_src/src/pages/documentation/en/api_reference/websockets.mdx @@ -60,13 +60,18 @@ To handle real-time bidirectional communication, Batman learned how to work with --- -## Receiving Messages {{ tag: 'receive_text', label: 'receive_text' }} +## Receiving Messages {{ tag: 'receive', label: 'receive' }} - The `receive_text()` method blocks until the next message arrives from the client. It is backed by a Rust `tokio::mpsc` channel, so the Python handler genuinely suspends without holding the GIL. + Robyn provides several methods for receiving WebSocket messages, all backed by a Rust `tokio::mpsc` channel so the Python handler genuinely suspends without holding the GIL. - When the client disconnects, `receive_text()` raises `WebSocketDisconnect`. You can either catch it explicitly or let the internal wrapper handle it silently. + - `receive()` returns the next frame as `str` (text frame) or `bytes` (binary frame). + - `receive_text()` returns the next text frame, raising `TypeError` if a binary frame arrives. + - `receive_bytes()` returns the next binary frame, raising `TypeError` if a text frame arrives. + - `receive_json()` receives a text frame and JSON-decodes it. + + All receive methods raise `WebSocketDisconnect` when the client disconnects. You can either catch it explicitly or let the internal wrapper handle it silently. @@ -82,6 +87,26 @@ To handle real-time bidirectional communication, Batman learned how to work with print(f"Client {websocket.id} disconnected") ``` + ```python {{ title: 'Binary Messages' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + result = process(data) + await websocket.send_bytes(result) + ``` + + ```python {{ title: 'Mixed Frames' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + msg = await websocket.receive() + if isinstance(msg, str): + await websocket.send_text(f"Got text: {msg}") + else: + await websocket.send_bytes(msg) # echo binary + ``` + ```python {{ title: 'JSON Messages' }} @app.websocket("/api") async def handler(websocket): @@ -96,11 +121,16 @@ To handle real-time bidirectional communication, Batman learned how to work with --- -## Sending Messages {{ tag: 'send_text', label: 'send_text' }} +## Sending Messages {{ tag: 'send', label: 'send' }} - To send a message to the current client, use `send_text()` or `send_json()`. All send methods are async. + To send a message to the current client, use any of the send methods. All are async. + + - `send(data)` accepts `str` or `bytes` — a `str` sends a text frame, `bytes` sends a binary frame. + - `send_text(data)` sends a text frame. + - `send_bytes(data)` sends a binary frame. + - `send_json(data)` serializes to JSON and sends a text frame. @@ -113,6 +143,15 @@ To handle real-time bidirectional communication, Batman learned how to work with await websocket.send_text(f"Echo: {msg}") ``` + ```python {{ title: 'Send Binary' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + compressed = zlib.compress(data) + await websocket.send_bytes(compressed) + ``` + ```python {{ title: 'Send JSON' }} @app.websocket("/ws") async def handler(websocket): @@ -130,12 +169,12 @@ To handle real-time bidirectional communication, Batman learned how to work with - To send a message to all connected clients on the same WebSocket endpoint, use the `broadcast()` method. + To send a message to all connected clients on the same WebSocket endpoint, use the `broadcast()` method. It accepts both `str` (text frame) and `bytes` (binary frame). - ```python {{ title: 'Broadcast' }} + ```python {{ title: 'Broadcast Text' }} @app.websocket("/chat") async def handler(websocket): while True: @@ -145,6 +184,15 @@ To handle real-time bidirectional communication, Batman learned how to work with # Also send a confirmation to this client only await websocket.send_text("Your message was sent") ``` + + ```python {{ title: 'Broadcast Binary' }} + @app.websocket("/stream") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + # Broadcast binary data to all clients + await websocket.broadcast(data) + ``` @@ -225,7 +273,7 @@ To handle real-time bidirectional communication, Batman learned how to work with To programmatically close a WebSocket connection from the server side, use `websocket.close()`. This will: 1. Close the WebSocket connection. 2. Remove the client from the WebSocket registry. -3. Cause any pending `receive_text()` to raise `WebSocketDisconnect`. +3. Cause any pending `receive()` / `receive_text()` / `receive_bytes()` to raise `WebSocketDisconnect`. @@ -289,13 +337,15 @@ To handle real-time bidirectional communication, Batman learned how to work with | Method / Property | Description | |---|---| - | `await websocket.receive_text()` | Block until next message; raises `WebSocketDisconnect` on close | - | `await websocket.receive_bytes()` | Block until next binary message; raises `WebSocketDisconnect` on close | - | `await websocket.receive_json()` | Same as `receive_text()` but JSON-decoded | - | `await websocket.send_text(data)` | Send string to this client | - | `await websocket.send_bytes(data)` | Send binary data to this client | - | `await websocket.send_json(data)` | Send JSON to this client | - | `await websocket.broadcast(data)` | Send to all clients on this endpoint | + | `await websocket.receive()` | Block until next frame; returns `str` for text frames, `bytes` for binary frames; raises `WebSocketDisconnect` on close | + | `await websocket.receive_text()` | Block until next text frame; raises `TypeError` if a binary frame arrives; raises `WebSocketDisconnect` on close | + | `await websocket.receive_bytes()` | Block until next binary frame; raises `TypeError` if a text frame arrives; raises `WebSocketDisconnect` on close | + | `await websocket.receive_json()` | Receive a text frame and JSON-decode it | + | `await websocket.send(data)` | Send `str` as a text frame or `bytes` as a binary frame | + | `await websocket.send_text(data)` | Send a text frame to this client | + | `await websocket.send_bytes(data)` | Send a binary frame to this client | + | `await websocket.send_json(data)` | Send JSON as a text frame to this client | + | `await websocket.broadcast(data)` | Broadcast `str` (text) or `bytes` (binary) to all clients on this endpoint | | `await websocket.close()` | Close the connection server-side | | `websocket.id` | Connection UUID string | | `websocket.query_params` | Query parameters from the connection URL | diff --git a/docs_src/src/pages/documentation/zh/api_reference/websockets.mdx b/docs_src/src/pages/documentation/zh/api_reference/websockets.mdx index cf9cb8fc5..75b14a116 100644 --- a/docs_src/src/pages/documentation/zh/api_reference/websockets.mdx +++ b/docs_src/src/pages/documentation/zh/api_reference/websockets.mdx @@ -57,13 +57,18 @@ export const description = --- -## 接收消息 {{ tag: 'receive_text', label: 'receive_text' }} +## 接收消息 {{ tag: 'receive', label: 'receive' }} - `receive_text()` 方法会阻塞直到下一条消息到达。它由 Rust 的 `tokio::mpsc` 通道支持,因此 Python 处理程序在等待时不会持有 GIL。 + Robyn 提供了多种接收 WebSocket 消息的方法,所有方法都由 Rust 的 `tokio::mpsc` 通道支持,因此 Python 处理程序在等待时不会持有 GIL。 - 当客户端断开连接时,`receive_text()` 会抛出 `WebSocketDisconnect` 异常。您可以显式捕获它,也可以让内部包装器静默处理。 + - `receive()` 返回下一帧,文本帧返回 `str`,二进制帧返回 `bytes`。 + - `receive_text()` 返回下一个文本帧,如果收到二进制帧则抛出 `TypeError`。 + - `receive_bytes()` 返回下一个二进制帧,如果收到文本帧则抛出 `TypeError`。 + - `receive_json()` 接收文本帧并进行 JSON 解码。 + + 所有接收方法在客户端断开连接时都会抛出 `WebSocketDisconnect` 异常。您可以显式捕获它,也可以让内部包装器静默处理。 @@ -79,6 +84,26 @@ export const description = print(f"客户端 {websocket.id} 已断开") ``` + ```python {{ title: '二进制消息' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + result = process(data) + await websocket.send_bytes(result) + ``` + + ```python {{ title: '混合帧' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + msg = await websocket.receive() + if isinstance(msg, str): + await websocket.send_text(f"收到文本: {msg}") + else: + await websocket.send_bytes(msg) # 回显二进制 + ``` + ```python {{ title: 'JSON 消息' }} @app.websocket("/api") async def handler(websocket): @@ -93,11 +118,16 @@ export const description = --- -## 发送消息 {{ tag: 'send_text', label: 'send_text' }} +## 发送消息 {{ tag: 'send', label: 'send' }} - 使用 `send_text()` 或 `send_json()` 向当前客户端发送消息。所有发送方法都是异步的。 + 使用以下方法向当前客户端发送消息。所有方法都是异步的。 + + - `send(data)` 接受 `str` 或 `bytes` — `str` 发送文本帧,`bytes` 发送二进制帧。 + - `send_text(data)` 发送文本帧。 + - `send_bytes(data)` 发送二进制帧。 + - `send_json(data)` 序列化为 JSON 并发送文本帧。 @@ -110,6 +140,15 @@ export const description = await websocket.send_text(f"Echo: {msg}") ``` + ```python {{ title: '发送二进制' }} + @app.websocket("/ws") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + compressed = zlib.compress(data) + await websocket.send_bytes(compressed) + ``` + ```python {{ title: '发送 JSON' }} @app.websocket("/ws") async def handler(websocket): @@ -127,12 +166,12 @@ export const description = - 使用 `broadcast()` 方法向同一 WebSocket 端点上的所有已连接客户端发送消息。 + 使用 `broadcast()` 方法向同一 WebSocket 端点上的所有已连接客户端发送消息。它同时接受 `str`(文本帧)和 `bytes`(二进制帧)。 - ```python {{ title: '广播' }} + ```python {{ title: '广播文本' }} @app.websocket("/chat") async def handler(websocket): while True: @@ -142,6 +181,15 @@ export const description = # 仅向当前客户端发送确认 await websocket.send_text("您的消息已发送") ``` + + ```python {{ title: '广播二进制' }} + @app.websocket("/stream") + async def handler(websocket): + while True: + data = await websocket.receive_bytes() + # 向所有客户端广播二进制数据 + await websocket.broadcast(data) + ``` @@ -222,7 +270,7 @@ export const description = 使用 `websocket.close()` 从服务端关闭 WebSocket 连接。该方法将: 1. 关闭 WebSocket 连接。 2. 从 WebSocket 注册表中移除客户端。 -3. 使任何挂起的 `receive_text()` 抛出 `WebSocketDisconnect` 异常。 +3. 使任何挂起的 `receive()` / `receive_text()` / `receive_bytes()` 抛出 `WebSocketDisconnect` 异常。 @@ -286,13 +334,15 @@ export const description = | 方法 / 属性 | 描述 | |---|---| - | `await websocket.receive_text()` | 阻塞直到下一条消息;连接关闭时抛出 `WebSocketDisconnect` | - | `await websocket.receive_bytes()` | 阻塞直到下一条二进制消息;连接关闭时抛出 `WebSocketDisconnect` | - | `await websocket.receive_json()` | 与 `receive_text()` 相同,但返回 JSON 解码后的数据 | - | `await websocket.send_text(data)` | 向当前客户端发送文本 | - | `await websocket.send_bytes(data)` | 向当前客户端发送二进制数据 | - | `await websocket.send_json(data)` | 向当前客户端发送 JSON | - | `await websocket.broadcast(data)` | 向此端点的所有客户端广播 | + | `await websocket.receive()` | 阻塞直到下一帧;文本帧返回 `str`,二进制帧返回 `bytes`;连接关闭时抛出 `WebSocketDisconnect` | + | `await websocket.receive_text()` | 阻塞直到下一个文本帧;收到二进制帧时抛出 `TypeError`;连接关闭时抛出 `WebSocketDisconnect` | + | `await websocket.receive_bytes()` | 阻塞直到下一个二进制帧;收到文本帧时抛出 `TypeError`;连接关闭时抛出 `WebSocketDisconnect` | + | `await websocket.receive_json()` | 接收文本帧并进行 JSON 解码 | + | `await websocket.send(data)` | 发送 `str` 为文本帧或 `bytes` 为二进制帧 | + | `await websocket.send_text(data)` | 向当前客户端发送文本帧 | + | `await websocket.send_bytes(data)` | 向当前客户端发送二进制帧 | + | `await websocket.send_json(data)` | 向当前客户端发送 JSON 文本帧 | + | `await websocket.broadcast(data)` | 向此端点的所有客户端广播 `str`(文本)或 `bytes`(二进制) | | `await websocket.close()` | 从服务端关闭连接 | | `websocket.id` | 连接 UUID 字符串 | | `websocket.query_params` | 连接 URL 中的查询参数 | diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index 554a85aa5..93287cf17 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -534,38 +534,38 @@ class WebSocketConnector: id: str query_params: QueryParams - async def async_broadcast(self, message: str) -> None: + async def async_broadcast(self, message: str | bytes) -> None: """ Broadcasts a message to all clients. Args: - message (str): The message to broadcast + message (str | bytes): The message to broadcast. str sends a text frame, bytes sends a binary frame. """ pass - async def async_send_to(self, sender_id: str, message: str) -> None: + async def async_send_to(self, recipient_id: str, message: str | bytes) -> None: """ Sends a message to a specific client. Args: - sender_id (str): The id of the sender - message (str): The message to send + recipient_id (str): The id of the recipient + message (str | bytes): The message to send. str sends a text frame, bytes sends a binary frame. """ pass - def sync_broadcast(self, message: str) -> None: + def sync_broadcast(self, message: str | bytes) -> None: """ Broadcasts a message to all clients. Args: - message (str): The message to broadcast + message (str | bytes): The message to broadcast. str sends a text frame, bytes sends a binary frame. """ pass - def sync_send_to(self, sender_id: str, message: str) -> None: + def sync_send_to(self, recipient_id: str, message: str | bytes) -> None: """ Sends a message to a specific client. Args: - sender_id (str): The id of the sender - message (str): The message to send + recipient_id (str): The id of the recipient + message (str | bytes): The message to send. str sends a text frame, bytes sends a binary frame. """ pass def close(self) -> None: diff --git a/robyn/ws.py b/robyn/ws.py index 44f5bd3b2..c6d9377b1 100644 --- a/robyn/ws.py +++ b/robyn/ws.py @@ -33,8 +33,8 @@ def __init__(self, websocket_connector: WebSocketConnector, channel=None): self._connector = websocket_connector self._channel = channel - async def receive_text(self) -> str: - """Receive the next text message. Blocks until a message arrives. + async def receive(self) -> str | bytes: + """Receive the next message. Returns str for text frames, bytes for binary frames. Raises WebSocketDisconnect when the connection is closed.""" if self._channel is None: raise WebSocketDisconnect(reason="No message channel available") @@ -43,31 +43,46 @@ async def receive_text(self) -> str: raise WebSocketDisconnect() return result + async def receive_text(self) -> str: + """Receive the next text frame. Raises TypeError if a binary frame arrives. + Raises WebSocketDisconnect when the connection is closed.""" + msg = await self.receive() + if not isinstance(msg, str): + raise TypeError(f"Expected text frame, got {type(msg).__name__}") + return msg + async def receive_bytes(self) -> bytes: - """Receive binary data (decoded from text).""" - text = await self.receive_text() - return text.encode("utf-8") + """Receive the next binary frame. Raises TypeError if a text frame arrives. + Raises WebSocketDisconnect when the connection is closed.""" + msg = await self.receive() + if not isinstance(msg, bytes): + raise TypeError(f"Expected binary frame, got {type(msg).__name__}") + return msg async def receive_json(self): - """Receive and decode JSON data. + """Receive and decode JSON data from a text frame. Raises WebSocketDisconnect when the connection is closed.""" text = await self.receive_text() return orjson.loads(text) + async def send(self, data: str | bytes): + """Send data to this WebSocket client. str sends a text frame, bytes sends a binary frame.""" + await self._connector.async_send_to(self._connector.id, data) + async def send_text(self, data: str): - """Send text data to this WebSocket client.""" + """Send a text frame to this WebSocket client.""" await self._connector.async_send_to(self._connector.id, data) async def send_bytes(self, data: bytes): - """Send binary data (as text) to this WebSocket client.""" - await self._connector.async_send_to(self._connector.id, data.decode("utf-8")) + """Send a binary frame to this WebSocket client.""" + await self._connector.async_send_to(self._connector.id, data) async def send_json(self, data): - """Send JSON data to this WebSocket client.""" + """Send JSON data as a text frame to this WebSocket client.""" await self.send_text(orjson.dumps(data).decode()) - async def broadcast(self, data: str): - """Broadcast text data to all connected WebSocket clients on this endpoint.""" + async def broadcast(self, data: str | bytes): + """Broadcast data to all connected WebSocket clients on this endpoint.""" await self._connector.async_broadcast(data) async def close(self): diff --git a/src/executors/web_socket_executors.rs b/src/executors/web_socket_executors.rs index 82bbc2dd2..68331f918 100644 --- a/src/executors/web_socket_executors.rs +++ b/src/executors/web_socket_executors.rs @@ -4,7 +4,20 @@ use pyo3::prelude::*; use pyo3_async_runtimes::TaskLocals; use crate::types::function_info::FunctionInfo; -use crate::websockets::WebSocketConnector; +use crate::websockets::{WebSocketConnector, WsPayload}; + +fn extract_ws_return(_py: Python, output: &Bound<'_, PyAny>) -> Option { + if output.is_none() { + return None; + } + if let Ok(s) = output.extract::() { + Some(WsPayload::Text(s)) + } else if let Ok(b) = output.extract::>() { + Some(WsPayload::Binary(b)) + } else { + None + } +} pub fn execute_ws_function( function: &FunctionInfo, @@ -26,13 +39,7 @@ pub fn execute_ws_function( }; let f = async move { match fut.await { - Ok(output) => Python::with_gil(|py| match output.extract::>(py) { - Ok(msg) => msg, - Err(e) => { - log::error!("Failed to extract WebSocket handler result: {}", e); - None - } - }), + Ok(output) => Python::with_gil(|py| extract_ws_return(py, output.bind(py))), Err(e) => { log::error!("Async WebSocket handler failed: {}", e); None @@ -40,10 +47,10 @@ pub fn execute_ws_function( } } .into_actor(ws) - .map(|res, _, ctx| { - if let Some(msg) = res { - ctx.text(msg); - } + .map(|res, _, ctx| match res { + Some(WsPayload::Text(s)) => ctx.text(s), + Some(WsPayload::Binary(b)) => ctx.binary(b), + None => {} }); ctx.spawn(f); } else { @@ -56,10 +63,10 @@ pub fn execute_ws_function( } }; match handler.call1((ws.clone(),)) { - Ok(result) => match result.extract::>() { - Ok(Some(op)) => ctx.text(op), - Ok(None) => {} - Err(e) => log::error!("Failed to extract WebSocket handler result: {}", e), + Ok(result) => match extract_ws_return(py, &result) { + Some(WsPayload::Text(s)) => ctx.text(s), + Some(WsPayload::Binary(b)) => ctx.binary(b), + None => {} }, Err(e) => log::error!("Sync WebSocket handler call failed: {}", e), } diff --git a/src/websockets/mod.rs b/src/websockets/mod.rs index b133eb980..1cf27a1e4 100644 --- a/src/websockets/mod.rs +++ b/src/websockets/mod.rs @@ -3,7 +3,7 @@ pub mod registry; use crate::executors::web_socket_executors::execute_ws_function; use crate::types::function_info::FunctionInfo; use crate::types::multimap::QueryParams; -use registry::{Close, SendMessageToAll, SendText}; +use registry::{Close, SendMessage, SendMessageToAll}; use actix::prelude::*; use actix::{Actor, AsyncContext, StreamHandler}; @@ -13,6 +13,7 @@ use log::debug; use once_cell::sync::OnceCell; use parking_lot::RwLock; use pyo3::prelude::*; +use pyo3::types::PyBytes; use pyo3::IntoPyObject; use pyo3_async_runtimes::TaskLocals; use std::sync::Arc; @@ -24,24 +25,47 @@ use crate::runtime; use registry::{Register, WebSocketRegistry}; use std::collections::HashMap; +#[derive(Clone)] +pub enum WsPayload { + Text(String), + Binary(Vec), +} + +fn extract_payload(message: &Bound<'_, PyAny>) -> PyResult { + if let Ok(s) = message.extract::() { + Ok(WsPayload::Text(s)) + } else if let Ok(b) = message.extract::>() { + Ok(WsPayload::Binary(b)) + } else { + Err(pyo3::exceptions::PyTypeError::new_err( + "message must be str or bytes", + )) + } +} + /// A Rust-backed channel receiver exposed to Python. /// Python handlers call `await channel.receive()` to get the next message. -/// Returns the message string, or None when the connection is closed. +/// Returns str for text frames, bytes for binary frames, or None when closed. #[pyclass] pub struct WebSocketChannel { - receiver: Arc>>>, + receiver: Arc>>>, } #[pymethods] impl WebSocketChannel { /// Await the next message from the WebSocket. - /// Returns the message string, or None if the connection was closed. + /// Returns str for text frames, bytes for binary frames, or None if closed. fn receive<'py>(&self, py: Python<'py>) -> PyResult> { let receiver = self.receiver.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { let mut rx = receiver.lock().await; match rx.recv().await { - Some(Some(msg)) => Ok(Some(msg)), + Some(Some(WsPayload::Text(s))) => Python::with_gil(|py| { + Ok(Some(s.into_pyobject(py).unwrap().into_any().unbind())) + }), + Some(Some(WsPayload::Binary(b))) => { + Python::with_gil(|py| Ok(Some(PyBytes::new(py, &b).into_any().unbind()))) + } Some(None) | None => Ok(None), } }) @@ -56,9 +80,7 @@ pub struct WebSocketConnector { pub task_locals: TaskLocals, pub registry_addr: Addr, pub query_params: QueryParams, - /// Sender side of the message channel (stays in the Actix actor). - pub message_sender: Option>>, - /// Receiver side exposed to Python via WebSocketChannel. + pub message_sender: Option>>, pub message_channel: Option>, } @@ -73,7 +95,7 @@ impl Actor for WebSocketConnector { addr: addr.clone(), }); - let (tx, rx) = mpsc::unbounded_channel::>(); + let (tx, rx) = mpsc::unbounded_channel::>(); self.message_sender = Some(tx); self.message_channel = Python::with_gil(|py| { Some( @@ -124,15 +146,21 @@ impl Clone for WebSocketConnector { } } -impl Handler for WebSocketConnector { +impl Handler for WebSocketConnector { type Result = (); - fn handle(&mut self, msg: SendText, ctx: &mut Self::Context) { + fn handle(&mut self, msg: SendMessage, ctx: &mut Self::Context) { if self.id == msg.recipient_id { - ctx.text(msg.message.clone()); - if msg.message == "Connection closed" { - // Close the WebSocket connection - ctx.stop(); + match &msg.payload { + WsPayload::Text(s) => { + ctx.text(s.clone()); + if s == "Connection closed" { + ctx.stop(); + } + } + WsPayload::Binary(b) => { + ctx.binary(b.clone()); + } } } } @@ -152,28 +180,17 @@ impl StreamHandler> for WebSocketConnecto Ok(ws::Message::Text(text)) => { debug!("Text message received {:?}", text); if let Some(ref sender) = self.message_sender { - let _ = sender.send(Some(text.to_string())); + let _ = sender.send(Some(WsPayload::Text(text.to_string()))); } } Ok(ws::Message::Binary(bin)) => { + debug!("Binary message received ({} bytes)", bin.len()); if let Some(ref sender) = self.message_sender { - match String::from_utf8(bin.to_vec()) { - Ok(text) => { - let _ = sender.send(Some(text)); - } - Err(_) => { - debug!("Received non-UTF-8 binary WebSocket frame, echoing back"); - ctx.binary(bin); - } - } - } else { - ctx.binary(bin); + let _ = sender.send(Some(WsPayload::Binary(bin.to_vec()))); } } Ok(ws::Message::Close(_close_reason)) => { debug!("Socket was closed"); - // Drop sender to signal channel closure so receive() returns None. - // The close handler is called once from stopped(). self.message_sender.take(); ctx.stop(); } @@ -184,79 +201,80 @@ impl StreamHandler> for WebSocketConnecto #[pymethods] impl WebSocketConnector { - pub fn sync_send_to(&self, recipient_id: String, message: String) { - let recipient_id = match Uuid::parse_str(&recipient_id) { - Ok(id) => id, - Err(e) => { - log::error!("Invalid recipient_id '{}': {}", recipient_id, e); - return; - } - }; + pub fn sync_send_to(&self, recipient_id: String, message: &Bound<'_, PyAny>) -> PyResult<()> { + let payload = extract_payload(message)?; + let recipient_id = Uuid::parse_str(&recipient_id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Invalid recipient_id '{}': {}", + recipient_id, e + )) + })?; - match self.registry_addr.try_send(SendText { - message, + match self.registry_addr.try_send(SendMessage { + payload, sender_id: self.id, recipient_id, }) { Ok(_) => log::debug!("Message sent successfully"), Err(e) => log::error!("Failed to send message: {}", e), } + Ok(()) } pub fn async_send_to( &self, py: Python, recipient_id: String, - message: String, + message: &Bound<'_, PyAny>, ) -> PyResult> { + let payload = extract_payload(message)?; let registry = self.registry_addr.clone(); - let recipient_id = match Uuid::parse_str(&recipient_id) { - Ok(id) => id, - Err(e) => { - return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Invalid recipient_id '{}': {}", - recipient_id, e - ))); - } - }; + let recipient_id = Uuid::parse_str(&recipient_id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Invalid recipient_id '{}': {}", + recipient_id, e + )) + })?; let sender_id = self.id; let awaitable = runtime::future_into_py(py, async move { - match registry.try_send(SendText { - message, - sender_id, - recipient_id, - }) { - Ok(_) => log::debug!("Message sent successfully"), - Err(e) => log::error!("Failed to send message: {}", e), - } - Ok(()) + registry + .try_send(SendMessage { + payload, + sender_id, + recipient_id, + }) + .map_err(|e| { + anyhow::anyhow!("Failed to enqueue message to registry: {e}") + }) })?; Ok(awaitable.into_pyobject(py)?.into_any().into()) } - pub fn sync_broadcast(&self, message: String) { - let registry = self.registry_addr.clone(); - match registry.try_send(SendMessageToAll { - message, + pub fn sync_broadcast(&self, message: &Bound<'_, PyAny>) -> PyResult<()> { + let payload = extract_payload(message)?; + match self.registry_addr.try_send(SendMessageToAll { + payload, sender_id: self.id, }) { Ok(_) => log::debug!("Broadcast sent successfully"), Err(e) => log::error!("Failed to broadcast message: {}", e), } + Ok(()) } - pub fn async_broadcast(&self, py: Python, message: String) -> PyResult> { + pub fn async_broadcast(&self, py: Python, message: &Bound<'_, PyAny>) -> PyResult> { + let payload = extract_payload(message)?; let registry = self.registry_addr.clone(); let sender_id = self.id; let awaitable = runtime::future_into_py(py, async move { - match registry.try_send(SendMessageToAll { message, sender_id }) { - Ok(_) => log::debug!("Broadcast sent successfully"), - Err(e) => log::error!("Failed to broadcast message: {}", e), - } - Ok(()) + registry + .try_send(SendMessageToAll { payload, sender_id }) + .map_err(|e| { + anyhow::anyhow!("Failed to enqueue broadcast to registry: {e}") + }) })?; Ok(awaitable.into_pyobject(py)?.into_any().into()) @@ -276,7 +294,6 @@ impl WebSocketConnector { self.query_params.clone() } - /// Get the message channel for WebSocket handlers. #[getter] pub fn get_message_channel(&self, py: Python) -> Option> { self.message_channel.as_ref().map(|c| c.clone_ref(py)) diff --git a/src/websockets/registry.rs b/src/websockets/registry.rs index 702a50ac1..c71ec005e 100644 --- a/src/websockets/registry.rs +++ b/src/websockets/registry.rs @@ -5,12 +5,12 @@ use uuid::Uuid; use std::collections::HashMap; +use super::WsPayload; use crate::websockets::WebSocketConnector; #[derive(Default)] #[pyclass] pub struct WebSocketRegistry { - // A map of client IDs to their Actor addresses. clients: HashMap>, } @@ -39,14 +39,13 @@ impl Handler for WebSocketRegistry { } } -// New message for sending text to a specific client -pub struct SendText { +pub struct SendMessage { pub recipient_id: Uuid, - pub message: String, + pub payload: WsPayload, pub sender_id: Uuid, } -impl Message for SendText { +impl Message for SendMessage { type Result = (); } @@ -62,10 +61,10 @@ impl WebSocketRegistry { } } -impl Handler for WebSocketRegistry { +impl Handler for WebSocketRegistry { type Result = (); - fn handle(&mut self, msg: SendText, _ctx: &mut Self::Context) { + fn handle(&mut self, msg: SendMessage, _ctx: &mut Self::Context) { let recipient_id = msg.recipient_id; if let Some(client_addr) = self.clients.get(&recipient_id) { @@ -77,7 +76,7 @@ impl Handler for WebSocketRegistry { } pub struct SendMessageToAll { - pub message: String, + pub payload: WsPayload, pub sender_id: Uuid, } @@ -90,9 +89,9 @@ impl Handler for WebSocketRegistry { fn handle(&mut self, msg: SendMessageToAll, _ctx: &mut Self::Context) { for (id, client) in &self.clients { - client.do_send(SendText { + client.do_send(SendMessage { recipient_id: *id, - message: msg.message.clone(), + payload: msg.payload.clone(), sender_id: msg.sender_id, }); } @@ -112,10 +111,9 @@ impl Handler for WebSocketRegistry { fn handle(&mut self, msg: Close, _ctx: &mut Self::Context) { if let Some(client) = self.clients.remove(&msg.id) { - // Send a close message to the client before removing it - client.do_send(SendText { + client.do_send(SendMessage { recipient_id: msg.id, - message: "Connection closed".to_string(), + payload: WsPayload::Text("Connection closed".to_string()), sender_id: msg.id, }); }