Skip to content

[RFC]: Add support for WebSocketSpec #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

bhimrazy
Copy link
Contributor

@bhimrazy bhimrazy commented Apr 13, 2025

What does this PR do?

This PR introduces an experimental WebSocketSpec for enabling communication using WebSockets in LitServe.

It aims to address the limitation of current server-sent events (SSE)-based streaming by supporting streaming input and output, which is critical for real-time AI use cases such as:

  • Speech-to-text (live transcription)
  • Real-time video inference
    etc

This PR sets the foundation for broader real-time interaction patterns in LitServe, inspired by community discussions in #320.

Fixes #320


Before submitting

🔧 Usage

Here’s how to use WebSocketSpec for real-time object detection with an RFDETR model.

Click to expand: Object Detection Example

Install requirements

pip install -r https://raw.githubusercontent.com/bhimrazy/litserve-examples/main/rfdetr-object-detection/requirements.txt

🖥️ server.py

import base64
import io

import supervision as sv
from PIL import Image
from rfdetr import RFDETRBase
from rfdetr.util.coco_classes import COCO_CLASSES

from litserve import LitAPI, LitServer
from litserve.specs.websocket import WebSocketSpec


class TestAPI(LitAPI):
    def setup(self, device):
        self.model = RFDETRBase()
        self.model.model.model.to(device)
        self.coco_classes = COCO_CLASSES

    def decode_request(self, request: dict):
        image = request.get("image")
        if not image:
            raise ValueError("No image found in request.")
        image_bytes = base64.b64decode(image)
        image = Image.open(io.BytesIO(image_bytes))
        return image.convert("RGB")

    def predict(self, image) -> sv.Detections:
        return self.model.predict(image)

    def encode_response(self, detections: sv.Detections):
        return {
            "detections": [
                {
                    "class_id": int(class_id),
                    "class_name": self.coco_classes[int(class_id)],
                    "confidence": float(confidence),
                    "bbox": bbox.tolist(),
                }
                for class_id, confidence, bbox in zip(
                    detections.class_id,
                    detections.confidence,
                    detections.xyxy,
                )
            ]
        }


if __name__ == "__main__":
    server = LitServer(TestAPI(), spec=WebSocketSpec())
    server.run()

📡 client.py

import asyncio
import base64
import io
import json

import requests
import websockets
from PIL import Image


async def websocket_client():
    uri = "ws://localhost:8000/predict"
    image_url = "https://images.pexels.com/photos/439818/pexels-photo-439818.jpeg"
    async with websockets.connect(uri) as websocket:
        # Send multiple requests
        for i in range(5):
            image = Image.open(requests.get(image_url, stream=True).raw)
            # convert to base64
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            request = json.dumps({"image": img_str})
            await websocket.send(request)
            print(f"Sent request {i + 1}")

        # Receive responses
        for i in range(5):
            response = await websocket.recv()
            response = json.loads(response)
            print(f"Received response: {response.keys()}")

asyncio.run(websocket_client())


⚡ Benchmarks

I performed a basic performance comparison between WebSocketSpec and existing HTTP-based inferencing methods on a video inference workload.

This benchmark approach is quite minimal and can definitely be improved through further discussion and community input.

Details:

  • 🎥 Test video — 150 frames (~6 seconds)
  • 💻 Tested on MacBook Pro base (CPU only; MPS didn't work with rfdetr)
1. Using HTTP Requests

Start the server

Script

import os
import time
from typing import Dict, List, Tuple

import cv2
import numpy as np
import requests
import supervision as sv

# Video URL
VIDEO_URL = "https://videos.pexels.com/video-files/1860079/1860079-uhd_2560_1440_25fps.mp4"
# https://videos.pexels.com/video-files/2103099/2103099-uhd_2560_1440_30fps.mp4
# https://videos.pexels.com/video-files/3779621/3779621-hd_1920_1080_30fps.mp4
OUTPUT_DIR = "output_videos"
os.makedirs(OUTPUT_DIR, exist_ok=True)

API_URL = "http://localhost:8000/predict"


# Helper function to download the video
def download_video(video_url: str, output_path: str):
    response = requests.get(video_url, stream=True)
    with open(output_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
    print(f"Video downloaded to {output_path}")


# Helper function to extract frames from the video
def extract_frames(video_path: str) -> Tuple[List, float]:
    cap = cv2.VideoCapture(video_path)
    frames = []
    fps = cap.get(cv2.CAP_PROP_FPS)  # Get the frames per second (FPS) of the video
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    print(f"Extracted {len(frames)} frames from the video at {fps:.2f} FPS.")
    return frames, fps


# Helper function to write output video
def write_output_video(frames: List, output_path: str, fps: int):
    height, width, _ = frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    for frame in frames:
        out.write(frame)
    out.release()
    print(f"Output video written to {output_path}")


# Helper function to draw bounding boxes on frames
def draw_bounding_boxes(frame, detections: List[Dict]):
    sv_detections = sv.Detections(
        class_id=np.array([detection["class_id"] for detection in detections]),
        confidence=np.array([detection["confidence"] for detection in detections]),
        xyxy=np.array([detection["bbox"] for detection in detections]),
    )
    labels = [f"{detection['class_name']} {detection['confidence']:.2f}" for detection in detections]

    annotated_frame = frame.copy()
    annotated_frame = sv.BoxAnnotator().annotate(annotated_frame, sv_detections)
    return sv.LabelAnnotator().annotate(annotated_frame, sv_detections, labels)


# Benchmark using normal HTTP requests
def benchmark_normal_http(frames: List):
    print("Benchmarking with normal HTTP requests...")
    start_time = time.time()
    processed_frames = []
    for frame in frames:
        _, img_encoded = cv2.imencode(".jpg", frame)
        # save the image to a file
        cv2.imwrite("frame.jpg", frame)
        response = requests.post(API_URL, files={"request": ("frame.jpg", img_encoded.tobytes(), "image/jpeg")})
        detections = response.json().get("detections", [])
        annotated_frame = draw_bounding_boxes(frame, detections)
        cv2.imwrite("annotated_frame.jpg", annotated_frame)
        processed_frames.append(annotated_frame)
    end_time = time.time()
    total_time = end_time - start_time
    fps = len(frames) / total_time
    print(f"Normal HTTP requests took {total_time:.2f} seconds ({fps:.2f} FPS).")
    return processed_frames


# Benchmark using shared HTTP session
def benchmark_shared_session(frames: List):
    print("Benchmarking with shared HTTP session...")
    session = requests.Session()
    start_time = time.time()
    processed_frames = []
    for frame in frames:
        _, img_encoded = cv2.imencode(".jpg", frame)
        response = session.post(API_URL, files={"request": ("frame.jpg", img_encoded.tobytes(), "image/jpeg")})
        detections = response.json().get("detections", [])
        processed_frames.append(draw_bounding_boxes(frame, detections))
    end_time = time.time()
    total_time = end_time - start_time
    fps = len(frames) / total_time
    print(f"Shared HTTP session took {total_time:.2f} seconds ({fps:.2f} FPS).")
    return processed_frames


# Main function to run benchmarks
def main():
    # Download the video
    video_path = "input_video.mp4"
    download_video(VIDEO_URL, video_path)

    # Extract frames
    frames, fps = extract_frames(video_path)

    # Benchmark normal HTTP requests
    processed_frames_http = benchmark_normal_http(frames)
    write_output_video(processed_frames_http, f"{OUTPUT_DIR}/output_http.mp4", fps=fps)

    # Benchmark shared HTTP session
    processed_frames_session = benchmark_shared_session(frames)
    write_output_video(processed_frames_session, f"{OUTPUT_DIR}/output_session.mp4", fps=fps)
if __name__ == "__main__":
    main()

Extracted 150 frames from the video at 25.00 FPS.
Benchmarking with normal HTTP requests...
Normal HTTP requests took 43.59 seconds (3.44 FPS).
Output video written to output_videos/output_http.mp4
Benchmarking with shared HTTP session...
Shared HTTP session took 42.58 seconds (3.52 FPS).
Output video written to output_videos/output_session.mp4
2. Using Websocket Requests

Start the server from usage section using WebSocketSpec.

Script

import asyncio
import base64
import json
import os
import time
from typing import Dict, List, Tuple

import cv2
import numpy as np
import requests
import supervision as sv
import websockets

# Video URL
VIDEO_URL = "https://videos.pexels.com/video-files/1860079/1860079-uhd_2560_1440_25fps.mp4"
OUTPUT_DIR = "output_videos"
os.makedirs(OUTPUT_DIR, exist_ok=True)

WS_URL = "ws://localhost:8000/predict"  # Replace with your WebSocket endpoint


# Helper function to download the video
def download_video(video_url: str, output_path: str):
    response = requests.get(video_url, stream=True)
    with open(output_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
    print(f"Video downloaded to {output_path}")


# Helper function to extract frames from the video
def extract_frames(video_path: str) -> Tuple[List, float]:
    cap = cv2.VideoCapture(video_path)
    frames = []
    fps = cap.get(cv2.CAP_PROP_FPS)  # Get the frames per second (FPS) of the video
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    print(f"Extracted {len(frames)} frames from the video at {fps:.2f} FPS.")
    return frames, fps


# Helper function to write output video
def write_output_video(frames: List, output_path: str, fps: int):
    height, width, _ = frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    for frame in frames:
        out.write(frame)
    out.release()
    print(f"Output video written to {output_path}")


# Helper function to draw bounding boxes on frames
def draw_bounding_boxes(frame, detections: List[Dict]):
    sv_detections = sv.Detections(
        class_id=np.array([detection["class_id"] for detection in detections]),
        confidence=np.array([detection["confidence"] for detection in detections]),
        xyxy=np.array([detection["bbox"] for detection in detections]),
    )
    labels = [f"{detection['class_name']} {detection['confidence']:.2f}" for detection in detections]

    annotated_frame = frame.copy()
    annotated_frame = sv.BoxAnnotator().annotate(annotated_frame, sv_detections)
    return sv.LabelAnnotator().annotate(annotated_frame, sv_detections, labels)


# Benchmark using WebSocket
async def benchmark_websocket(frames: List):
    print("Benchmarking with WebSocket...")
    async with websockets.connect(WS_URL) as websocket:
        start_time = time.time()
        processed_frames = []

        for frame in frames:
            # Encode the frame as a JPEG image
            _, img_encoded = cv2.imencode(".jpg", frame)
            # Convert the image to base64 string
            img_str = base64.b64encode(img_encoded).decode("utf-8")
            # Create the request payload
            request_payload = {"image": img_str}

            # Send the frame to the WebSocket server
            await websocket.send(json.dumps(request_payload))
        print("Sent all frames to WebSocket server.")
        for frame in frames:
            # Receive the response from the server
            response = await websocket.recv()
            detections = eval(response)  # Convert the response string to a Python dictionary

            # Draw bounding boxes on the frame
            annotated_frame = draw_bounding_boxes(frame, detections.get("detections", []))
            processed_frames.append(annotated_frame)

        end_time = time.time()
        total_time = end_time - start_time
        fps = len(frames) / total_time
        print(f"WebSocket benchmark took {total_time:.2f} seconds ({fps:.2f} FPS).")
        return processed_frames


# Main function to run WebSocket benchmark
def main():
    # Download the video
    video_path = "input_video.mp4"
    download_video(VIDEO_URL, video_path)

    # Extract frames
    frames, fps = extract_frames(video_path)

    # Benchmark WebSocket
    processed_frames_ws = asyncio.run(benchmark_websocket(frames))
    write_output_video(processed_frames_ws, f"{OUTPUT_DIR}/output_ws.mp4", fps=fps)


if __name__ == "__main__":
    main()

Extracted 150 frames from the video at 25.00 FPS.
Benchmarking with WebSocket...
Sent all frames to WebSocket server.
WebSocket benchmark took 33.86 seconds (4.43 FPS).
Output video written to output_videos/output_ws.mp4
🎥 Output Video

output_ws.mp4


PR review

Community feedback is welcome! This is an experimental addition and we’d love help shaping it to best serve real-time ML use cases.

Did you have fun?

Yes 😄 — and more to come!

Copy link

codecov bot commented Apr 13, 2025

Codecov Report

Attention: Patch coverage is 35.21127% with 46 lines in your changes missing coverage. Please review.

Project coverage is 87%. Comparing base (489b746) to head (0e7f8a1).

Additional details and impacted files
@@         Coverage Diff         @@
##           main   #467   +/-   ##
===================================
- Coverage    89%    87%   -2%     
===================================
  Files        37     39    +2     
  Lines      2158   2227   +69     
===================================
+ Hits       1913   1932   +19     
- Misses      245    295   +50     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@bhimrazy bhimrazy changed the title [wip]: Experiment with websocket spec [RFC]: Add support for WebSocketSpec Apr 16, 2025
@bhimrazy bhimrazy marked this pull request as ready for review April 16, 2025 16:00
@bhimrazy bhimrazy marked this pull request as draft April 22, 2025 19:24
@aniketmaurya
Copy link
Collaborator

Hey @bhimrazy, great work adding WebSocket support! 🙌

That said, I noticed that the current implementation still follows the same old flow:
decode request → run prediction → encode response, all in one go.

Right now, we’re still treating the WebSocket like a regular HTTP call, just over a different pipe. It works, but we’re not unlocking the real benefits yet.

@bhimrazy
Copy link
Contributor Author

Thanks for the thoughtful feedback, @aniketmaurya ! 🙏
You're absolutely right — there's definitely more we can do to take full advantage of WebSockets.
I’ll do some more research and also see if we can tweak the Loops flow. Appreciate the nudge! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Websocket Support for Streaming Input and Output
2 participants