Skip to content

Commit fa8d131

Browse files
committed
Update
[ghstack-poisoned]
1 parent 545c0ad commit fa8d131

8 files changed

Lines changed: 680 additions & 33 deletions

File tree

sota-implementations/vla_grpo/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,32 @@ throughput split into collection and optimization:
204204
- `throughput/train_decisions_per_s`
205205
- `throughput/optim_steps_per_s`
206206

207+
The collector can also be switched between synchronous and async execution
208+
paths for throughput experiments:
209+
210+
```bash
211+
# fully synchronous baseline
212+
python sota-implementations/vla_grpo/vla-grpo.py \
213+
collector.async_env=false collector.async_policy=false
214+
215+
# asynchronous env slots, but no policy auto-batching
216+
python sota-implementations/vla_grpo/vla-grpo.py \
217+
collector.async_env=true collector.async_policy=false env.num_envs=8
218+
219+
# asynchronous env slots plus auto-batched policy inference
220+
python sota-implementations/vla_grpo/vla-grpo.py \
221+
collector.async_env=true collector.async_policy=true env.num_envs=8 \
222+
collector.server_max_batch_size=8 collector.server_timeout=0.01
223+
```
224+
225+
`collector.async_env=true` uses `AsyncBatchedCollector` so faster environment
226+
slots do not wait at a global step barrier. `collector.async_policy=true` routes
227+
policy calls through an inference server; with multiple async env slots this
228+
enables auto-batching and logs `policy_server/*` counters such as average batch
229+
size, request rate, and queue/forward latency. The `false/true` combination is
230+
available as a policy-server plumbing ablation, but policy auto-batching is most
231+
meaningful when several env slots submit requests concurrently.
232+
207233
Eval rollouts can also be rendered to video (`logger.record_video=true`, on by
208234
default). A dedicated single-environment recorder is built with
209235
`from_pixels=True`: `ToyVLAEnv` renders the tracking scene, while `LiberoEnv`

sota-implementations/vla_grpo/config/vla_grpo_libero.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,23 @@ collector:
9191
min_replay_decisions: null # null/0 = collect one target group wave
9292
total_iters: 100 # the paper's total_epochs
9393
policy_device: null # null = policy.device; set e.g. cuda:1 for rollout inference
94+
# Execution-mode switches for throughput ablations:
95+
# - false/false: regular synchronous TorchRL Collector.
96+
# - true/false: async env slots with one request per policy forward.
97+
# - true/true: async env slots plus auto-batched policy inference.
98+
# - false/true: sync env stepping through the policy server path.
99+
async_env: false
100+
async_policy: false
101+
env_backend: threading # AsyncBatchedCollector env backend: threading | multiprocessing
102+
policy_backend: threading # inference transport: threading | multiprocessing | ray | monarch
103+
server_backend: thread # process server needs a policy_factory and is not used here
104+
server_max_batch_size: null # null = env.num_envs when async_policy=true
105+
server_min_batch_size: 1
106+
server_timeout: 0.01
107+
server_collect_stats: true
108+
server_stats_window_size: 1024
109+
max_inflight_per_env: 1
110+
storing_device: null
94111

95112
advantage:
96113
trajectory_return: sum # binary success return per trajectory

sota-implementations/vla_grpo/config/vla_grpo_toy.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ env:
1111
success_tol: 0.35 # sized so a random policy succeeds sometimes (cold-start signal)
1212
max_outer_steps: 6 # episode truncation, in chunk decisions
1313
render_size: 64 # side length of the from_pixels eval-video frame
14+
num_envs: 1 # async-env workers; ToyVLAEnv grouped rollouts are per worker
1415
seed: 0
1516

1617
tokenizer:
@@ -35,6 +36,23 @@ collector:
3536
max_same_policy_collect_attempts: 2
3637
min_replay_decisions: null # null/0 = collect one target group wave
3738
total_iters: 200
39+
# Execution-mode switches for throughput ablations:
40+
# - false/false: regular synchronous TorchRL Collector.
41+
# - true/false: async env slots with one request per policy forward.
42+
# - true/true: async env slots plus auto-batched policy inference.
43+
# - false/true: sync env stepping through the policy server path.
44+
async_env: false
45+
async_policy: false
46+
env_backend: threading # AsyncBatchedCollector env backend: threading | multiprocessing
47+
policy_backend: threading # inference transport: threading | multiprocessing | ray | monarch
48+
server_backend: thread # process server needs a policy_factory and is not used here
49+
server_max_batch_size: null # null = number of async envs when async_policy=true
50+
server_min_batch_size: 1
51+
server_timeout: 0.01
52+
server_collect_stats: true
53+
server_stats_window_size: 1024
54+
max_inflight_per_env: 1
55+
storing_device: null
3856

3957
advantage:
4058
trajectory_return: sum # binary success return per trajectory

sota-implementations/vla_grpo/test_openvla.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,187 @@ def hook(_):
356356
assert captured["kwargs"]["frames_per_batch"] == 4
357357

358358

359+
def test_make_collector_async_env_uses_async_batched_collector(monkeypatch):
360+
captured = {}
361+
362+
class _FakeAsyncCollector:
363+
def __init__(self, *args, **kwargs):
364+
captured["args"] = args
365+
captured["kwargs"] = kwargs
366+
367+
def __iter__(self):
368+
return self
369+
370+
def __next__(self):
371+
raise StopIteration
372+
373+
def server_stats(self, *, reset=False):
374+
return {"requests": 0}
375+
376+
def shutdown(self):
377+
captured["shutdown"] = True
378+
379+
class _FakeEnv:
380+
batch_size = torch.Size([1])
381+
device = torch.device("cpu")
382+
383+
cfg = SimpleNamespace(
384+
collector=SimpleNamespace(
385+
groups_per_iter=4,
386+
group_size=2,
387+
async_env=True,
388+
async_policy=True,
389+
server_min_batch_size=2,
390+
),
391+
env=SimpleNamespace(
392+
backend="toy",
393+
action_dim=2,
394+
state_dim=4,
395+
image_shape=(3, 8, 8),
396+
render_size=16,
397+
success_steps=2,
398+
success_tol=0.25,
399+
max_outer_steps=3,
400+
num_envs=4,
401+
seed=0,
402+
),
403+
)
404+
monkeypatch.setattr(utils, "AsyncBatchedCollector", _FakeAsyncCollector)
405+
406+
collector = utils.make_collector(
407+
cfg,
408+
_FakeEnv(),
409+
object(),
410+
torch.device("cpu"),
411+
tokenizer=object(),
412+
replay_buffer=object(),
413+
)
414+
collector._ensure_collector()
415+
416+
assert len(captured["kwargs"]["create_env_fn"]) == 4
417+
assert captured["kwargs"]["yield_completed_trajectories"]
418+
server_config = captured["kwargs"]["server_config"]
419+
assert server_config.max_batch_size == 4
420+
assert server_config.min_batch_size == 2
421+
422+
423+
def test_make_collector_async_env_without_policy_batching(monkeypatch):
424+
captured = {}
425+
426+
class _FakeAsyncCollector:
427+
def __init__(self, *args, **kwargs):
428+
captured["kwargs"] = kwargs
429+
430+
def __iter__(self):
431+
return self
432+
433+
def __next__(self):
434+
raise StopIteration
435+
436+
def server_stats(self, *, reset=False):
437+
return {}
438+
439+
def shutdown(self):
440+
pass
441+
442+
class _FakeEnv:
443+
batch_size = torch.Size([1])
444+
device = torch.device("cpu")
445+
446+
cfg = SimpleNamespace(
447+
collector=SimpleNamespace(
448+
groups_per_iter=2,
449+
group_size=2,
450+
async_env=True,
451+
async_policy=False,
452+
),
453+
env=SimpleNamespace(
454+
backend="toy",
455+
action_dim=2,
456+
state_dim=4,
457+
image_shape=(3, 8, 8),
458+
render_size=16,
459+
success_steps=2,
460+
success_tol=0.25,
461+
max_outer_steps=3,
462+
num_envs=2,
463+
seed=0,
464+
),
465+
)
466+
monkeypatch.setattr(utils, "AsyncBatchedCollector", _FakeAsyncCollector)
467+
468+
collector = utils.make_collector(
469+
cfg,
470+
_FakeEnv(),
471+
object(),
472+
torch.device("cpu"),
473+
tokenizer=object(),
474+
)
475+
collector._ensure_collector()
476+
477+
server_config = captured["kwargs"]["server_config"]
478+
assert server_config.max_batch_size == 1
479+
assert server_config.timeout == 0.0
480+
481+
482+
def test_make_collector_sync_env_can_use_policy_server(monkeypatch):
483+
captured = {}
484+
485+
class _FakeCollector:
486+
def __init__(self, *args, **kwargs):
487+
captured["collector_args"] = args
488+
captured["collector_kwargs"] = kwargs
489+
self.requested_frames_per_batch = kwargs["frames_per_batch"]
490+
491+
def shutdown(self, *args, **kwargs):
492+
captured["collector_shutdown"] = True
493+
494+
def reset(self, *args, **kwargs):
495+
captured["collector_reset"] = True
496+
497+
class _FakeServer:
498+
def __init__(self, *args, **kwargs):
499+
captured["server_args"] = args
500+
captured["server_kwargs"] = kwargs
501+
502+
def start(self):
503+
return self
504+
505+
def shutdown(self):
506+
captured["server_shutdown"] = True
507+
508+
def stats(self, *, reset=False):
509+
return {"requests": 0}
510+
511+
class _FakeEnv:
512+
batch_size = torch.Size([2])
513+
device = None
514+
515+
policy = SimpleNamespace(
516+
in_keys=["observation"], out_keys=[("vla_action", "tokens")]
517+
)
518+
cfg = SimpleNamespace(
519+
collector=SimpleNamespace(
520+
groups_per_iter=2,
521+
group_size=1,
522+
async_policy=True,
523+
),
524+
env=SimpleNamespace(max_outer_steps=3),
525+
)
526+
monkeypatch.setattr(utils, "Collector", _FakeCollector)
527+
monkeypatch.setattr(utils, "InferenceServer", _FakeServer)
528+
529+
collector = utils.make_collector(cfg, _FakeEnv(), policy, torch.device("cpu"))
530+
531+
assert isinstance(collector, utils._ServerBackedCollector)
532+
assert isinstance(captured["collector_args"][1], utils.PolicyClientModule)
533+
assert captured["server_kwargs"]["server_config"].max_batch_size == 2
534+
assert captured["collector_kwargs"]["policy_device"] == torch.device("cpu")
535+
assert captured["collector_kwargs"]["trust_policy"] is True
536+
collector.shutdown()
537+
assert captured["server_shutdown"]
538+
539+
359540
def test_make_replay_buffer_scales_capacity_with_overcollection():
360541
cfg = SimpleNamespace(
361542
collector=SimpleNamespace(

0 commit comments

Comments
 (0)