@@ -356,6 +356,187 @@ def hook(_):
356356 assert captured ["kwargs" ]["frames_per_batch" ] == 4
357357
358358
359+ def test_make_collector_async_env_uses_async_batched_collector (monkeypatch ):
360+ captured = {}
361+
362+ class _FakeAsyncCollector :
363+ def __init__ (self , * args , ** kwargs ):
364+ captured ["args" ] = args
365+ captured ["kwargs" ] = kwargs
366+
367+ def __iter__ (self ):
368+ return self
369+
370+ def __next__ (self ):
371+ raise StopIteration
372+
373+ def server_stats (self , * , reset = False ):
374+ return {"requests" : 0 }
375+
376+ def shutdown (self ):
377+ captured ["shutdown" ] = True
378+
379+ class _FakeEnv :
380+ batch_size = torch .Size ([1 ])
381+ device = torch .device ("cpu" )
382+
383+ cfg = SimpleNamespace (
384+ collector = SimpleNamespace (
385+ groups_per_iter = 4 ,
386+ group_size = 2 ,
387+ async_env = True ,
388+ async_policy = True ,
389+ server_min_batch_size = 2 ,
390+ ),
391+ env = SimpleNamespace (
392+ backend = "toy" ,
393+ action_dim = 2 ,
394+ state_dim = 4 ,
395+ image_shape = (3 , 8 , 8 ),
396+ render_size = 16 ,
397+ success_steps = 2 ,
398+ success_tol = 0.25 ,
399+ max_outer_steps = 3 ,
400+ num_envs = 4 ,
401+ seed = 0 ,
402+ ),
403+ )
404+ monkeypatch .setattr (utils , "AsyncBatchedCollector" , _FakeAsyncCollector )
405+
406+ collector = utils .make_collector (
407+ cfg ,
408+ _FakeEnv (),
409+ object (),
410+ torch .device ("cpu" ),
411+ tokenizer = object (),
412+ replay_buffer = object (),
413+ )
414+ collector ._ensure_collector ()
415+
416+ assert len (captured ["kwargs" ]["create_env_fn" ]) == 4
417+ assert captured ["kwargs" ]["yield_completed_trajectories" ]
418+ server_config = captured ["kwargs" ]["server_config" ]
419+ assert server_config .max_batch_size == 4
420+ assert server_config .min_batch_size == 2
421+
422+
423+ def test_make_collector_async_env_without_policy_batching (monkeypatch ):
424+ captured = {}
425+
426+ class _FakeAsyncCollector :
427+ def __init__ (self , * args , ** kwargs ):
428+ captured ["kwargs" ] = kwargs
429+
430+ def __iter__ (self ):
431+ return self
432+
433+ def __next__ (self ):
434+ raise StopIteration
435+
436+ def server_stats (self , * , reset = False ):
437+ return {}
438+
439+ def shutdown (self ):
440+ pass
441+
442+ class _FakeEnv :
443+ batch_size = torch .Size ([1 ])
444+ device = torch .device ("cpu" )
445+
446+ cfg = SimpleNamespace (
447+ collector = SimpleNamespace (
448+ groups_per_iter = 2 ,
449+ group_size = 2 ,
450+ async_env = True ,
451+ async_policy = False ,
452+ ),
453+ env = SimpleNamespace (
454+ backend = "toy" ,
455+ action_dim = 2 ,
456+ state_dim = 4 ,
457+ image_shape = (3 , 8 , 8 ),
458+ render_size = 16 ,
459+ success_steps = 2 ,
460+ success_tol = 0.25 ,
461+ max_outer_steps = 3 ,
462+ num_envs = 2 ,
463+ seed = 0 ,
464+ ),
465+ )
466+ monkeypatch .setattr (utils , "AsyncBatchedCollector" , _FakeAsyncCollector )
467+
468+ collector = utils .make_collector (
469+ cfg ,
470+ _FakeEnv (),
471+ object (),
472+ torch .device ("cpu" ),
473+ tokenizer = object (),
474+ )
475+ collector ._ensure_collector ()
476+
477+ server_config = captured ["kwargs" ]["server_config" ]
478+ assert server_config .max_batch_size == 1
479+ assert server_config .timeout == 0.0
480+
481+
482+ def test_make_collector_sync_env_can_use_policy_server (monkeypatch ):
483+ captured = {}
484+
485+ class _FakeCollector :
486+ def __init__ (self , * args , ** kwargs ):
487+ captured ["collector_args" ] = args
488+ captured ["collector_kwargs" ] = kwargs
489+ self .requested_frames_per_batch = kwargs ["frames_per_batch" ]
490+
491+ def shutdown (self , * args , ** kwargs ):
492+ captured ["collector_shutdown" ] = True
493+
494+ def reset (self , * args , ** kwargs ):
495+ captured ["collector_reset" ] = True
496+
497+ class _FakeServer :
498+ def __init__ (self , * args , ** kwargs ):
499+ captured ["server_args" ] = args
500+ captured ["server_kwargs" ] = kwargs
501+
502+ def start (self ):
503+ return self
504+
505+ def shutdown (self ):
506+ captured ["server_shutdown" ] = True
507+
508+ def stats (self , * , reset = False ):
509+ return {"requests" : 0 }
510+
511+ class _FakeEnv :
512+ batch_size = torch .Size ([2 ])
513+ device = None
514+
515+ policy = SimpleNamespace (
516+ in_keys = ["observation" ], out_keys = [("vla_action" , "tokens" )]
517+ )
518+ cfg = SimpleNamespace (
519+ collector = SimpleNamespace (
520+ groups_per_iter = 2 ,
521+ group_size = 1 ,
522+ async_policy = True ,
523+ ),
524+ env = SimpleNamespace (max_outer_steps = 3 ),
525+ )
526+ monkeypatch .setattr (utils , "Collector" , _FakeCollector )
527+ monkeypatch .setattr (utils , "InferenceServer" , _FakeServer )
528+
529+ collector = utils .make_collector (cfg , _FakeEnv (), policy , torch .device ("cpu" ))
530+
531+ assert isinstance (collector , utils ._ServerBackedCollector )
532+ assert isinstance (captured ["collector_args" ][1 ], utils .PolicyClientModule )
533+ assert captured ["server_kwargs" ]["server_config" ].max_batch_size == 2
534+ assert captured ["collector_kwargs" ]["policy_device" ] == torch .device ("cpu" )
535+ assert captured ["collector_kwargs" ]["trust_policy" ] is True
536+ collector .shutdown ()
537+ assert captured ["server_shutdown" ]
538+
539+
359540def test_make_replay_buffer_scales_capacity_with_overcollection ():
360541 cfg = SimpleNamespace (
361542 collector = SimpleNamespace (
0 commit comments