-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
135 lines (107 loc) · 4.25 KB
/
server.py
File metadata and controls
135 lines (107 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Modal GPU inference server for LeRobot policies.
Currently runs a pretrained ACT policy (small, fast).
To swap in π₀.₅ VLA (from Physical Intelligence):
1. Add `jax jaxlib openpi` to pip_install in the Modal image
2. Replace ACTPolicy with pi0.5 checkpoint loading (from openpi.policies.policy_config)
3. Update _infer() to handle LIBERO observation format (256x256 images, 8D state, text prompt)
4. Update protocol.py to serialize text prompts
Note: π₀.₅ is ~8GB and JAX-compiled, so first inference will be slow (30-60s JIT).
Consider using Modal's snapshot feature to cache the compiled model.
Serves inference over HTTP POST and WebSocket on the same FastAPI app.
Usage:
modal serve server.py # dev mode (auto-reload)
modal deploy server.py # production
"""
import time
import modal
app = modal.App("robot-inference")
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
"torch",
"torchvision",
"lerobot",
"msgpack",
"msgpack-numpy",
"numpy",
"fastapi[standard]",
)
.add_local_file("protocol.py", "/root/protocol.py")
)
@app.cls(gpu="A10G", image=image, min_containers=1)
@modal.concurrent(max_inputs=4)
class PolicyServer:
@modal.enter()
def load_model(self):
import sys
sys.path.insert(0, "/root")
import torch
from lerobot.policies.act.modeling_act import ACTPolicy
self.device = torch.device("cuda")
self.policy = ACTPolicy.from_pretrained(
"lerobot/act_aloha_sim_transfer_cube_human"
)
self.policy.to(self.device)
self.policy.eval()
self.policy.reset()
print(f"Model loaded on {self.device}")
def _infer(self, obs_dict: dict) -> tuple:
"""Run inference on an observation dict. Returns (action_np, inference_ms).
Expects:
- image arrays as HWC uint8 (480, 640, 3) — converted to CHW float [0,1]
- state arrays as float32 (14,)
"""
import numpy as np
import torch
batch = {}
for key, arr in obs_dict.items():
if arr.ndim == 3:
# Image: HWC uint8 -> CHW float [0, 1], add batch dim
t = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
t = t.unsqueeze(0).to(self.device) # (1, C, H, W)
else:
# State vector: add batch dim
t = torch.from_numpy(arr).float().unsqueeze(0).to(self.device)
batch[key] = t
start = time.perf_counter()
with torch.inference_mode():
action = self.policy.select_action(batch)
torch.cuda.synchronize()
inference_ms = (time.perf_counter() - start) * 1000
action_np = action.squeeze(0).cpu().numpy()
return action_np, inference_ms
@modal.asgi_app()
def serve(self):
import sys
sys.path.insert(0, "/root")
from fastapi import FastAPI, WebSocket, Request
from fastapi.responses import Response
web_app = FastAPI()
@web_app.get("/health")
def health():
return {"status": "ok"}
@web_app.post("/infer")
async def http_infer(request: Request):
"""HTTP POST: send msgpack obs, get msgpack action back."""
from protocol import pack_response, unpack_obs
body = await request.body()
obs_dict = unpack_obs(body)
action_np, inference_ms = self._infer(obs_dict)
return Response(
content=pack_response(action_np, inference_ms),
media_type="application/x-msgpack",
)
@web_app.websocket("/ws")
async def ws_infer(ws: WebSocket):
"""WebSocket: persistent connection for control-loop streaming."""
await ws.accept()
from protocol import pack_response, unpack_obs
try:
while True:
data = await ws.receive_bytes()
obs_dict = unpack_obs(data)
action_np, inference_ms = self._infer(obs_dict)
await ws.send_bytes(pack_response(action_np, inference_ms))
except Exception:
pass # Client disconnected
return web_app