Description
cmsg = await asyncio.wait_for(s.recv(), timeout=1)
which means that messages are received one at a time. Here’s what happens in practice if many messages arrive:
Message Queuing:
The underlying websockets library (and the TCP connection it uses) will buffer incoming messages. So, if messages arrive in rapid succession, they will be queued up internally.
Sequential Processing:
Once a message is received, the code processes it by optionally decoding and then passing it to the processor callback
Here is updated code,
In this revision, we add an internal asyncio.Queue and a pool of consumer tasks that call your message processor. This decouples receiving messages from processing them, so bursts of incoming messages are queued (with optional back‐pressure via a maximum queue size) and processed concurrently (up to a fixed number of tasks).
import os
from enum import Enum
from typing import Optional, Union, List, Set, Callable, Awaitable, Any
import logging
import json
import asyncio
import ssl
import certifi
from .models import *
from websockets.client import connect, WebSocketClientProtocol
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError
from ..logging import get_logger
import logging
from ..exceptions import AuthError
env_key = "POLYGON_API_KEY"
logger = get_logger("WebSocketClient")
class WebSocketClient:
def __init__(
self,
api_key: Optional[str] = os.getenv(env_key),
feed: Union[str, Feed] = Feed.RealTime,
market: Union[str, Market] = Market.Stocks,
raw: bool = False,
verbose: bool = False,
subscriptions: Optional[List[str]] = None,
max_reconnects: Optional[int] = 5,
secure: bool = True,
custom_json: Optional[Any] = None,
**kwargs,
):
"""
Initialize a Polygon WebSocketClient.
"""
if api_key is None:
raise AuthError(
f"Must specify env var {env_key} or pass api_key in constructor"
)
self.api_key = api_key
self.feed = feed
self.market = market
self.raw = raw
if verbose:
logger.setLevel(logging.DEBUG)
self.websocket_cfg = kwargs
if isinstance(feed, Enum):
feed = feed.value
if isinstance(market, Enum):
market = market.value
self.url = f"ws{'s' if secure else ''}://{feed}/{market}"
self.subscribed = False
self.subs: Set[str] = set()
self.max_reconnects = max_reconnects
self.websocket: Optional[WebSocketClientProtocol] = None
if subscriptions is None:
subscriptions = []
self.scheduled_subs: Set[str] = set(subscriptions)
self.schedule_resub = True
if custom_json:
self.json = custom_json
else:
self.json = json
async def _message_consumer(
self,
processor: Union[
Callable[[List[WebSocketMessage]], Awaitable],
Callable[[Union[str, bytes]], Awaitable],
],
message_queue: asyncio.Queue,
):
"""
A worker that continuously gets messages from the queue and processes them.
"""
while True:
msg = await message_queue.get()
try:
await processor(msg)
except Exception as e:
logger.exception("Error processing message: %s", e)
finally:
message_queue.task_done()
# Updated connect method using a message queue and consumer pool.
async def connect(
self,
processor: Union[
Callable[[List[WebSocketMessage]], Awaitable],
Callable[[Union[str, bytes]], Awaitable],
],
close_timeout: int = 1,
concurrency: int = 10,
queue_size: int = 1000,
**kwargs,
):
"""
Connect to the websocket server and use a pool of tasks to process messages concurrently.
:param processor: Callback to process each message.
:param close_timeout: How long to wait for handshake when calling .close.
:param concurrency: Number of concurrent consumer tasks to process messages.
:param queue_size: Maximum number of messages that can be waiting for processing.
:raises AuthError: If invalid API key is supplied.
"""
reconnects = 0
logger.debug("connect: %s", self.url)
ssl_context = None
if self.url.startswith("wss://"):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(certifi.where())
async for s in connect(self.url, close_timeout=close_timeout, ssl=ssl_context, **kwargs):
self.websocket = s
try:
# Initial handshake and authentication.
msg = await s.recv()
logger.debug("connected: %s", msg)
logger.debug("authing...")
await s.send(self.json.dumps({"action": "auth", "params": self.api_key}))
auth_msg = await s.recv()
auth_msg_parsed = self.json.loads(auth_msg)
logger.debug("authed: %s", auth_msg)
if auth_msg_parsed[0]["status"] == "auth_failed":
raise AuthError(auth_msg_parsed[0]["message"])
# Create an asyncio queue for incoming messages.
message_queue = asyncio.Queue(maxsize=queue_size)
# Create a pool of consumer tasks.
consumer_tasks = [
asyncio.create_task(self._message_consumer(processor, message_queue))
for _ in range(concurrency)
]
# Main loop: receive messages and put them on the queue.
while True:
if self.schedule_resub:
logger.debug("Reconciling subscriptions: current: %s, scheduled: %s", self.subs, self.scheduled_subs)
new_subs = self.scheduled_subs.difference(self.subs)
await self._subscribe(new_subs)
old_subs = self.subs.difference(self.scheduled_subs)
await self._unsubscribe(old_subs)
self.subs = set(self.scheduled_subs)
self.schedule_resub = False
try:
raw_msg = await asyncio.wait_for(s.recv(), timeout=1)
except asyncio.TimeoutError:
continue
# Process the raw message.
if not self.raw:
msgJson = self.json.loads(raw_msg)
filtered_msgs = []
for m in msgJson:
if m.get("ev") == "status":
logger.debug("status: %s", m.get("message"))
else:
filtered_msgs.append(m)
if not filtered_msgs:
continue
processed_msg = parse(filtered_msgs, logger)
else:
processed_msg = raw_msg
# Enqueue the processed message.
await message_queue.put(processed_msg)
# (If the loop ever ends, cancel consumer tasks.)
for task in consumer_tasks:
task.cancel()
await asyncio.gather(*consumer_tasks, return_exceptions=True)
except ConnectionClosedOK as e:
logger.debug("connection closed (OK): %s", e)
return
except ConnectionClosedError as e:
logger.debug("connection closed (ERR): %s", e)
reconnects += 1
# Save subscriptions for reconnection.
self.scheduled_subs = set(self.subs)
self.subs = set()
self.schedule_resub = True
if self.max_reconnects is not None and reconnects > self.max_reconnects:
return
continue
def run(
self,
handle_msg: Union[
Callable[[List[WebSocketMessage]], None],
Callable[[Union[str, bytes]], None],
],
close_timeout: int = 1,
**kwargs,
):
"""
Synchronous version of .connect.
"""
async def handle_msg_wrapper(msgs):
handle_msg(msgs)
asyncio.run(self.connect(handle_msg_wrapper, close_timeout, **kwargs))
async def _subscribe(self, topics: Union[List[str], Set[str]]):
if self.websocket is None or len(topics) == 0:
return
subs = ",".join(topics)
logger.debug("subscribing: %s", subs)
await self.websocket.send(self.json.dumps({"action": "subscribe", "params": subs}))
async def _unsubscribe(self, topics: Union[List[str], Set[str]]):
if self.websocket is None or len(topics) == 0:
return
subs = ",".join(topics)
logger.debug("unsubscribing: %s", subs)
await self.websocket.send(self.json.dumps({"action": "unsubscribe", "params": subs}))
@staticmethod
def _parse_subscription(s: str):
s = s.strip()
split = s.split(".", 1) # Split at the first period.
if len(split) != 2:
logger.warning("invalid subscription: %s", s)
return [None, None]
return split
def subscribe(self, *subscriptions: str):
"""
Subscribe to given subscriptions.
"""
for s in subscriptions:
topic, sym = self._parse_subscription(s)
if topic is None:
continue
logger.debug("Desired subscription: %s", s)
self.scheduled_subs.add(s)
# If subscribing to X.*, remove other X.<something> subscriptions.
if sym == "*":
for t in list(self.subs):
if t.startswith(topic):
self.scheduled_subs.discard(t)
self.schedule_resub = True
def unsubscribe(self, *subscriptions: str):
"""
Unsubscribe from given subscriptions.
"""
for s in subscriptions:
topic, sym = self._parse_subscription(s)
if topic is None:
continue
logger.debug("Unsubscribe request: %s", s)
self.scheduled_subs.discard(s)
# If unsubscribing from X.*, remove other X.<something> subscriptions.
if sym == "*":
for t in list(self.subs):
if t.startswith(topic):
self.scheduled_subs.discard(t)
self.schedule_resub = True
def unsubscribe_all(self):
"""
Unsubscribe from all subscriptions.
"""
self.scheduled_subs = set()
self.schedule_resub = True
async def close(self):
"""
Close the websocket connection.
"""
logger.debug("closing connection")
if self.websocket:
await self.websocket.close()
self.websocket = None
else:
logger.warning("no websocket open to close")