Skip to content

Commit cb023ea

Browse files
Merge pull request #12 from EPFL-AI-Team/marcus-server-implementation
Implement Server Websocket Endpoint (#1)
2 parents e4d52fb + 321e4b7 commit cb023ea

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,4 @@ tags
679679
.ionide
680680

681681
# End of https://www.toptal.com/developers/gitignore/api/vim,latex,linux,macos,synology,jetbrains+all,visualstudiocode,python,jupyternotebooks
682+
inference_results.txt

src/iris/cli/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from iris.server.app import main
2+
3+
if __name__ == "__main__":
4+
main()

src/iris/server/__init__.py

Whitespace-only changes.

src/iris/server/app.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""IRIS Inference Server - receives frames, runs VLM inference."""
2+
3+
import base64
4+
import logging
5+
from collections.abc import AsyncGenerator
6+
from contextlib import asynccontextmanager
7+
from io import BytesIO
8+
9+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10+
from PIL import Image
11+
12+
from iris.server.dependencies import get_server_state
13+
from iris.vlm.inference.model_loader import load_model_and_processor
14+
from iris.vlm.inference.queue.jobs import SingleFrameJob
15+
from iris.vlm.inference.queue.queue import InferenceQueue
16+
17+
logging.basicConfig(level=logging.INFO)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@asynccontextmanager
22+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
23+
"""Manage startup and shutdown."""
24+
# Startup
25+
state = get_server_state()
26+
27+
logger.info("Loading model...")
28+
state.model, state.processor = load_model_and_processor("smolvlm2")
29+
30+
logger.info("Starting inference queue...")
31+
state.queue = InferenceQueue(max_queue_size=10, num_workers=1)
32+
await state.queue.start()
33+
34+
state.model_loaded = True
35+
logger.info("Server ready!")
36+
37+
yield
38+
39+
# Shutdown
40+
if state.queue:
41+
await state.queue.stop()
42+
logger.info("Server stopped.")
43+
44+
45+
app = FastAPI(title="IRIS Inference Server", lifespan=lifespan)
46+
47+
48+
@app.get("/health")
49+
async def health() -> dict[str, str | bool]:
50+
"""Health check endpoint."""
51+
state = get_server_state()
52+
return {
53+
"status": "healthy" if state.model_loaded else "loading",
54+
"model_loaded": state.model_loaded,
55+
}
56+
57+
58+
@app.websocket("/ws/stream")
59+
async def inference_endpoint(websocket: WebSocket) -> None:
60+
"""Receive frames and return inference results."""
61+
await websocket.accept()
62+
state = get_server_state()
63+
logger.info("Client connected")
64+
65+
try:
66+
while True:
67+
data = await websocket.receive_json()
68+
69+
frame_b64 = data["frame"]
70+
frame_id = data["frame_id"]
71+
72+
image_data = base64.b64decode(frame_b64)
73+
image = Image.open(BytesIO(image_data))
74+
75+
job = SingleFrameJob(
76+
job_id=f"frame-{frame_id}",
77+
frame=image,
78+
model=state.model,
79+
processor=state.processor,
80+
prompt="Describe what you see in one sentence.",
81+
executor=state.queue.executor,
82+
)
83+
84+
await state.queue.submit(job)
85+
result_job = await state.queue.get_result(timeout=30.0)
86+
87+
if result_job:
88+
await websocket.send_json({
89+
"job_id": result_job.job_id,
90+
"status": result_job.status.value,
91+
"result": result_job.result,
92+
"processing_time": result_job.processing_time,
93+
})
94+
95+
except WebSocketDisconnect:
96+
logger.info("Client disconnected")
97+
except Exception as e:
98+
logger.error(f"Error: {e}", exc_info=True)
99+
100+
101+
def main() -> None:
102+
"""Entry point for server."""
103+
import uvicorn
104+
105+
uvicorn.run(app, host="0.0.0.0", port=8001)
106+
107+
108+
if __name__ == "__main__":
109+
main()

src/iris/server/dependencies.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Server-side state management with dependency injection."""
2+
3+
from iris.vlm.inference.queue.queue import InferenceQueue
4+
5+
6+
class ServerState:
7+
"""Server application state."""
8+
9+
def __init__(self):
10+
self.model = None
11+
self.processor = None
12+
self.queue: InferenceQueue | None = None
13+
self.model_loaded = False
14+
15+
16+
# Singleton
17+
_server_state = ServerState()
18+
19+
20+
def get_server_state() -> ServerState:
21+
"""Get server state for dependency injection."""
22+
return _server_state

0 commit comments

Comments
 (0)