|
| 1 | +# Skill: add-model-server |
| 2 | + |
| 3 | +Add a new VLA model server to the evaluation harness. |
| 4 | + |
| 5 | +## Trigger |
| 6 | + |
| 7 | +User asks to add/integrate a new model (e.g. "add OpenVLA server", "integrate RT-2"). |
| 8 | + |
| 9 | +## Steps |
| 10 | + |
| 11 | +### 1. Gather Requirements |
| 12 | + |
| 13 | +Ask the user (if not already provided): |
| 14 | +- **Model name** (e.g. `openvla`) |
| 15 | +- **Framework/library** (e.g. HuggingFace Transformers, custom repo) |
| 16 | +- **Python dependencies** (torch version, model-specific packages) |
| 17 | +- **Checkpoint source** (HuggingFace Hub model ID or local path) |
| 18 | +- **Action output format** (dimension, chunk_size, continuous vs discrete) |
| 19 | +- **Input requirements** (single image vs multi-view, needs proprioceptive state?) |
| 20 | + |
| 21 | +### 2. Create Model Server Script |
| 22 | + |
| 23 | +Create `src/vla_eval/model_servers/<name>.py` as a **uv script** (standalone, inline deps). |
| 24 | + |
| 25 | +The file MUST start with a PEP 723 inline script metadata block: |
| 26 | + |
| 27 | +```python |
| 28 | +# /// script |
| 29 | +# requires-python = "~=3.11" |
| 30 | +# dependencies = [ |
| 31 | +# "vla-eval", |
| 32 | +# "<model-package>", |
| 33 | +# "torch>=2.0", |
| 34 | +# "transformers>=4.40,<5", |
| 35 | +# "pillow>=9.0", |
| 36 | +# "numpy>=1.24", |
| 37 | +# ] |
| 38 | +# |
| 39 | +# [tool.uv.sources] |
| 40 | +# vla-eval = { path = "../../.." } |
| 41 | +# <model-package> = { git = "https://github.com/org/repo.git", branch = "main" } |
| 42 | +# /// |
| 43 | +``` |
| 44 | + |
| 45 | +Subclass `PredictModelServer` (most models) or `ModelServer` (advanced async): |
| 46 | + |
| 47 | +```python |
| 48 | +from vla_eval.model_servers.base import SessionContext |
| 49 | +from vla_eval.model_servers.predict import PredictModelServer |
| 50 | +from vla_eval.model_servers.serve import serve |
| 51 | + |
| 52 | + |
| 53 | +class MyModelServer(PredictModelServer): |
| 54 | + def __init__(self, checkpoint: str, *, chunk_size: int = 1, action_ensemble: str = "newest", **kwargs): |
| 55 | + super().__init__(chunk_size=chunk_size, action_ensemble=action_ensemble, **kwargs) |
| 56 | + self.checkpoint = checkpoint |
| 57 | + self._model = None |
| 58 | + |
| 59 | + def _load_model(self) -> None: |
| 60 | + """Lazily load model on first predict() call.""" |
| 61 | + if self._model is not None: |
| 62 | + return |
| 63 | + import torch |
| 64 | + # Load model here... |
| 65 | + self._model = ... |
| 66 | + |
| 67 | + def predict(self, obs: dict[str, Any], ctx: SessionContext) -> dict[str, Any]: |
| 68 | + """Single-observation inference. Blocking call. |
| 69 | +
|
| 70 | + Args: |
| 71 | + obs: {"images": {"cam_name": np.ndarray HWC uint8}, |
| 72 | + "task_description": str, |
| 73 | + "states": np.ndarray (optional)} |
| 74 | + ctx: Session context (session_id, episode_id, step, is_first) |
| 75 | +
|
| 76 | + Returns: |
| 77 | + {"actions": np.ndarray} with shape: |
| 78 | + - (action_dim,) if chunk_size == 1 |
| 79 | + - (chunk_size, action_dim) if chunk_size > 1 |
| 80 | + """ |
| 81 | + self._load_model() |
| 82 | + # Run inference... |
| 83 | + return {"actions": np.array(actions, dtype=np.float32)} |
| 84 | +``` |
| 85 | + |
| 86 | +### Key Patterns (from existing implementations) |
| 87 | + |
| 88 | +**PredictModelServer features (inherited automatically):** |
| 89 | +- **Action chunking**: When `chunk_size > 1`, return `(chunk_size, action_dim)` array. Framework auto-buffers and serves one action per step, re-inferring only when buffer empties. |
| 90 | +- **Action ensemble**: `"newest"` (default), `"average"`, `"ema"` — blends overlapping chunks. Set via `action_ensemble=` in `__init__`. |
| 91 | +- **Batched inference**: Override `predict_batch()` + set `max_batch_size > 1` for GPU-batched multi-shard eval. |
| 92 | +- **Per-suite chunk_size**: Override `on_episode_start()` to set `self._session_chunk_sizes[ctx.session_id] = N` (see CogACT example). |
| 93 | +- **CI/LAAS**: Set `continuous_inference=True` for continuous inference mode (DRAFT). |
| 94 | + |
| 95 | +**Image handling:** |
| 96 | +```python |
| 97 | +from PIL import Image as PILImage |
| 98 | +images = obs.get("images", {}) |
| 99 | +img_array = next(iter(images.values())) # first camera |
| 100 | +pil_image = PILImage.fromarray(img_array).convert("RGB") |
| 101 | +``` |
| 102 | + |
| 103 | +**Task description:** |
| 104 | +```python |
| 105 | +text = obs.get("task_description", "") |
| 106 | +``` |
| 107 | + |
| 108 | +**Lazy model loading**: Always use a `_load_model()` pattern. Do NOT load in `__init__`. |
| 109 | + |
| 110 | +### 3. Add `if __name__ == "__main__"` Entry Point |
| 111 | + |
| 112 | +The script must be runnable via `uv run`: |
| 113 | + |
| 114 | +```python |
| 115 | +if __name__ == "__main__": |
| 116 | + parser = argparse.ArgumentParser(description="<Model> server (uv script)") |
| 117 | + parser.add_argument("--checkpoint", required=True, help="HF model ID or local path") |
| 118 | + parser.add_argument("--chunk_size", type=int, default=1) |
| 119 | + parser.add_argument("--action_ensemble", default="newest") |
| 120 | + parser.add_argument("--host", default="0.0.0.0") |
| 121 | + parser.add_argument("--port", type=int, default=8000) |
| 122 | + parser.add_argument("--verbose", "-v", action="store_true") |
| 123 | + args = parser.parse_args() |
| 124 | + |
| 125 | + logging.basicConfig( |
| 126 | + level=logging.DEBUG if args.verbose else logging.INFO, |
| 127 | + format="%(asctime)s %(levelname)-8s %(name)s: %(message)s", |
| 128 | + ) |
| 129 | + |
| 130 | + server = MyModelServer(args.checkpoint) |
| 131 | + server.chunk_size = args.chunk_size |
| 132 | + server.action_ensemble = args.action_ensemble |
| 133 | + |
| 134 | + logger.info("Pre-loading model...") |
| 135 | + server._load_model() |
| 136 | + logger.info("Model ready, starting server on ws://%s:%d", args.host, args.port) |
| 137 | + serve(server, host=args.host, port=args.port) |
| 138 | +``` |
| 139 | + |
| 140 | +### 4. Create Config YAML |
| 141 | + |
| 142 | +Create `configs/model_servers/<name>.yaml`: |
| 143 | + |
| 144 | +```yaml |
| 145 | +# <Model Name> model server — <benchmark> checkpoint |
| 146 | +# Weight: <HuggingFace model ID> |
| 147 | +# Benchmark: <target benchmark> |
| 148 | + |
| 149 | +script: "src/vla_eval/model_servers/<name>.py" |
| 150 | +args: |
| 151 | + checkpoint: <org/model-id> |
| 152 | + chunk_size: 1 |
| 153 | + port: 8000 |
| 154 | +``` |
| 155 | +
|
| 156 | +The CLI runs this via: `vla-eval serve --config configs/model_servers/<name>.yaml` |
| 157 | +which translates to: `uv run <script> --checkpoint <value> --chunk_size <value> --port <value>` |
| 158 | + |
| 159 | +### 5. Verify |
| 160 | + |
| 161 | +1. Run `make check` — lint + format + type check |
| 162 | +2. Run `make test` — ensure existing tests still pass |
| 163 | +3. Suggest user test: `vla-eval test -c configs/model_servers/<name>.yaml` |
| 164 | + (starts server, sends dummy observations from a StubBenchmark, checks for valid action response — requires `uv` + GPU + model weights) |
| 165 | + |
| 166 | +### Reference Implementations |
| 167 | + |
| 168 | +- **CogACT** (`model_servers/dexbotic/cogact.py`): Diffusion action head, chunk_size_map per suite, batched inference, text template option |
| 169 | +- **starVLA** (`model_servers/starvla.py`): Auto-detecting framework, HuggingFace checkpoint download, monkey-patches for upstream compat |
| 170 | + |
| 171 | +### Server Hierarchy |
| 172 | + |
| 173 | +``` |
| 174 | +ModelServer (ABC) ← Advanced: async on_observation() |
| 175 | + └── PredictModelServer ← Most models: blocking predict() |
| 176 | +``` |
| 177 | +
|
| 178 | +- Use `PredictModelServer` for standard request-response models (95% of cases) |
| 179 | +- Use `ModelServer` only if you need async streaming or custom message handling |
| 180 | +
|
0 commit comments