Skip to content

Commit 2771576

Browse files
lukass16cursoragent
andcommitted
docs(robot): document [N, T, A] infer contract and BatchedAgent ownership
Spell out on Model.infer/ainfer that implementations must keep the leading batch dim N (ainfer indexes [0], BatchedModel scatters rows along it) and add a one-line assert in LeRobotModel.infer. Document that BatchedAgent mutates the passed-in agent in place, leaving it permanently batched. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent b4c5d06 commit 2771576

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

hud/agents/robot/batching.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,19 @@ class BatchedAgent(Agent):
104104
105105
Requires an in-process batchable model; :class:`~hud.agents.robot.model.RemoteModel`
106106
is not supported (the OpenPI server protocol has no batched-request shape).
107+
108+
Takes ownership of ``agent``: it swaps ``agent.model`` for a :class:`BatchedModel` wrapper
109+
in place (so the wrapper is shared by every per-run clone). The passed-in instance is
110+
therefore permanently batched — hand :class:`BatchedAgent` a dedicated agent and don't
111+
also use that same instance for direct, unbatched :class:`RobotAgent` rollouts.
107112
"""
108113

109114
def __init__(self, agent: RobotAgent, *, batch_size: int, max_wait_s: float = 0.05) -> None:
110115
if agent.model is None:
111116
raise RuntimeError("BatchedAgent needs agent.model set")
112117
self._template = agent
113-
# Wrap once; every per-run clone shares this batcher by reference.
118+
# Wrap once, in place: the passed-in agent is now permanently batched (see class doc).
119+
# Every per-run clone shares this batcher by reference.
114120
agent.model = BatchedModel(agent.model, batch_size=batch_size, max_wait_s=max_wait_s)
115121

116122
async def __call__(self, run: Run, **kwargs: Any) -> None:

hud/agents/robot/model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@ class Model:
3030
"""
3131

3232
def infer(self, batch: Any) -> ActionArray:
33-
"""runs policy on a batch, returns [N, T, A] action chunk"""
33+
"""Run the policy on an ``[N, ...]`` batch, return an ``[N, T, A]`` chunk.
34+
35+
Implementations MUST keep the leading batch dim ``N`` (even for ``N == 1``):
36+
:meth:`ainfer` indexes ``[0]`` and :class:`~hud.agents.robot.batching.BatchedModel`
37+
scatters rows along it, so a squeezed ``[T, A]`` silently breaks both.
38+
"""
3439
raise NotImplementedError
3540

3641
async def ainfer(self, batch: Any) -> ActionArray:
37-
"""Awaited single-rollout entry: run :meth:`infer` in a thread, return its ``[T, A]``."""
42+
"""Awaited single-rollout entry: run :meth:`infer` in a thread, return its single
43+
``[T, A]`` row. Indexing ``[0]`` assumes :meth:`infer` honors the ``[N, T, A]`` contract.
44+
"""
3845
return (await asyncio.to_thread(self.infer, batch))[0]
3946

4047

@@ -65,7 +72,9 @@ def infer(self, batch: Any) -> ActionArray:
6572
if self._first_inference:
6673
print("[agent] first inference done — inference is now fast", flush=True)
6774
self._first_inference = False
68-
return chunk.float().cpu().numpy()
75+
arr = chunk.float().cpu().numpy()
76+
assert arr.ndim == 3, f"expected [N, T, A] chunk, got {arr.shape}" # LeRobot keeps the N dim
77+
return arr
6978

7079

7180

0 commit comments

Comments
 (0)