Skip to content

Subscribing to Quote data overwhelms single async io call in websocket. Here is some fixes that could work. #849

Open
@Vetti420

Description

@Vetti420

cmsg = await asyncio.wait_for(s.recv(), timeout=1)
which means that messages are received one at a time. Here’s what happens in practice if many messages arrive:

Message Queuing:
The underlying websockets library (and the TCP connection it uses) will buffer incoming messages. So, if messages arrive in rapid succession, they will be queued up internally.

Sequential Processing:
Once a message is received, the code processes it by optionally decoding and then passing it to the processor callback

Here is updated code,
In this revision, we add an internal asyncio.Queue and a pool of consumer tasks that call your message processor. This decouples receiving messages from processing them, so bursts of incoming messages are queued (with optional back‐pressure via a maximum queue size) and processed concurrently (up to a fixed number of tasks).

import os
from enum import Enum
from typing import Optional, Union, List, Set, Callable, Awaitable, Any
import logging
import json
import asyncio
import ssl
import certifi
from .models import *
from websockets.client import connect, WebSocketClientProtocol
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError
from ..logging import get_logger
import logging
from ..exceptions import AuthError

env_key = "POLYGON_API_KEY"
logger = get_logger("WebSocketClient")


class WebSocketClient:
    def __init__(
        self,
        api_key: Optional[str] = os.getenv(env_key),
        feed: Union[str, Feed] = Feed.RealTime,
        market: Union[str, Market] = Market.Stocks,
        raw: bool = False,
        verbose: bool = False,
        subscriptions: Optional[List[str]] = None,
        max_reconnects: Optional[int] = 5,
        secure: bool = True,
        custom_json: Optional[Any] = None,
        **kwargs,
    ):
        """
        Initialize a Polygon WebSocketClient.
        """
        if api_key is None:
            raise AuthError(
                f"Must specify env var {env_key} or pass api_key in constructor"
            )
        self.api_key = api_key
        self.feed = feed
        self.market = market
        self.raw = raw
        if verbose:
            logger.setLevel(logging.DEBUG)
        self.websocket_cfg = kwargs
        if isinstance(feed, Enum):
            feed = feed.value
        if isinstance(market, Enum):
            market = market.value
        self.url = f"ws{'s' if secure else ''}://{feed}/{market}"
        self.subscribed = False
        self.subs: Set[str] = set()
        self.max_reconnects = max_reconnects
        self.websocket: Optional[WebSocketClientProtocol] = None
        if subscriptions is None:
            subscriptions = []
        self.scheduled_subs: Set[str] = set(subscriptions)
        self.schedule_resub = True
        if custom_json:
            self.json = custom_json
        else:
            self.json = json

    async def _message_consumer(
        self,
        processor: Union[
            Callable[[List[WebSocketMessage]], Awaitable],
            Callable[[Union[str, bytes]], Awaitable],
        ],
        message_queue: asyncio.Queue,
    ):
        """
        A worker that continuously gets messages from the queue and processes them.
        """
        while True:
            msg = await message_queue.get()
            try:
                await processor(msg)
            except Exception as e:
                logger.exception("Error processing message: %s", e)
            finally:
                message_queue.task_done()

    # Updated connect method using a message queue and consumer pool.
    async def connect(
        self,
        processor: Union[
            Callable[[List[WebSocketMessage]], Awaitable],
            Callable[[Union[str, bytes]], Awaitable],
        ],
        close_timeout: int = 1,
        concurrency: int = 10,
        queue_size: int = 1000,
        **kwargs,
    ):
        """
        Connect to the websocket server and use a pool of tasks to process messages concurrently.
        :param processor: Callback to process each message.
        :param close_timeout: How long to wait for handshake when calling .close.
        :param concurrency: Number of concurrent consumer tasks to process messages.
        :param queue_size: Maximum number of messages that can be waiting for processing.
        :raises AuthError: If invalid API key is supplied.
        """
        reconnects = 0
        logger.debug("connect: %s", self.url)
        ssl_context = None
        if self.url.startswith("wss://"):
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
            ssl_context.load_verify_locations(certifi.where())

        async for s in connect(self.url, close_timeout=close_timeout, ssl=ssl_context, **kwargs):
            self.websocket = s
            try:
                # Initial handshake and authentication.
                msg = await s.recv()
                logger.debug("connected: %s", msg)
                logger.debug("authing...")
                await s.send(self.json.dumps({"action": "auth", "params": self.api_key}))
                auth_msg = await s.recv()
                auth_msg_parsed = self.json.loads(auth_msg)
                logger.debug("authed: %s", auth_msg)
                if auth_msg_parsed[0]["status"] == "auth_failed":
                    raise AuthError(auth_msg_parsed[0]["message"])

                # Create an asyncio queue for incoming messages.
                message_queue = asyncio.Queue(maxsize=queue_size)
                # Create a pool of consumer tasks.
                consumer_tasks = [
                    asyncio.create_task(self._message_consumer(processor, message_queue))
                    for _ in range(concurrency)
                ]

                # Main loop: receive messages and put them on the queue.
                while True:
                    if self.schedule_resub:
                        logger.debug("Reconciling subscriptions: current: %s, scheduled: %s", self.subs, self.scheduled_subs)
                        new_subs = self.scheduled_subs.difference(self.subs)
                        await self._subscribe(new_subs)
                        old_subs = self.subs.difference(self.scheduled_subs)
                        await self._unsubscribe(old_subs)
                        self.subs = set(self.scheduled_subs)
                        self.schedule_resub = False

                    try:
                        raw_msg = await asyncio.wait_for(s.recv(), timeout=1)
                    except asyncio.TimeoutError:
                        continue

                    # Process the raw message.
                    if not self.raw:
                        msgJson = self.json.loads(raw_msg)
                        filtered_msgs = []
                        for m in msgJson:
                            if m.get("ev") == "status":
                                logger.debug("status: %s", m.get("message"))
                            else:
                                filtered_msgs.append(m)
                        if not filtered_msgs:
                            continue
                        processed_msg = parse(filtered_msgs, logger)
                    else:
                        processed_msg = raw_msg

                    # Enqueue the processed message.
                    await message_queue.put(processed_msg)

                # (If the loop ever ends, cancel consumer tasks.)
                for task in consumer_tasks:
                    task.cancel()
                await asyncio.gather(*consumer_tasks, return_exceptions=True)

            except ConnectionClosedOK as e:
                logger.debug("connection closed (OK): %s", e)
                return
            except ConnectionClosedError as e:
                logger.debug("connection closed (ERR): %s", e)
                reconnects += 1
                # Save subscriptions for reconnection.
                self.scheduled_subs = set(self.subs)
                self.subs = set()
                self.schedule_resub = True
                if self.max_reconnects is not None and reconnects > self.max_reconnects:
                    return
                continue

    def run(
        self,
        handle_msg: Union[
            Callable[[List[WebSocketMessage]], None],
            Callable[[Union[str, bytes]], None],
        ],
        close_timeout: int = 1,
        **kwargs,
    ):
        """
        Synchronous version of .connect.
        """
        async def handle_msg_wrapper(msgs):
            handle_msg(msgs)

        asyncio.run(self.connect(handle_msg_wrapper, close_timeout, **kwargs))

    async def _subscribe(self, topics: Union[List[str], Set[str]]):
        if self.websocket is None or len(topics) == 0:
            return
        subs = ",".join(topics)
        logger.debug("subscribing: %s", subs)
        await self.websocket.send(self.json.dumps({"action": "subscribe", "params": subs}))

    async def _unsubscribe(self, topics: Union[List[str], Set[str]]):
        if self.websocket is None or len(topics) == 0:
            return
        subs = ",".join(topics)
        logger.debug("unsubscribing: %s", subs)
        await self.websocket.send(self.json.dumps({"action": "unsubscribe", "params": subs}))

    @staticmethod
    def _parse_subscription(s: str):
        s = s.strip()
        split = s.split(".", 1)  # Split at the first period.
        if len(split) != 2:
            logger.warning("invalid subscription: %s", s)
            return [None, None]
        return split

    def subscribe(self, *subscriptions: str):
        """
        Subscribe to given subscriptions.
        """
        for s in subscriptions:
            topic, sym = self._parse_subscription(s)
            if topic is None:
                continue
            logger.debug("Desired subscription: %s", s)
            self.scheduled_subs.add(s)
            # If subscribing to X.*, remove other X.<something> subscriptions.
            if sym == "*":
                for t in list(self.subs):
                    if t.startswith(topic):
                        self.scheduled_subs.discard(t)
        self.schedule_resub = True

    def unsubscribe(self, *subscriptions: str):
        """
        Unsubscribe from given subscriptions.
        """
        for s in subscriptions:
            topic, sym = self._parse_subscription(s)
            if topic is None:
                continue
            logger.debug("Unsubscribe request: %s", s)
            self.scheduled_subs.discard(s)
            # If unsubscribing from X.*, remove other X.<something> subscriptions.
            if sym == "*":
                for t in list(self.subs):
                    if t.startswith(topic):
                        self.scheduled_subs.discard(t)
        self.schedule_resub = True

    def unsubscribe_all(self):
        """
        Unsubscribe from all subscriptions.
        """
        self.scheduled_subs = set()
        self.schedule_resub = True

    async def close(self):
        """
        Close the websocket connection.
        """
        logger.debug("closing connection")
        if self.websocket:
            await self.websocket.close()
            self.websocket = None
        else:
            logger.warning("no websocket open to close")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions