@@ -48,6 +48,8 @@ async def register(self, key: K, value: V) -> None:
4848class _AsyncCallablesComputer (Generic [T ]):
4949 """Evaluate a set of async callables across every sample in a dataset."""
5050
51+ _QUEUE_BUFFER_FACTOR = 3
52+
5153 def __init__ (
5254 self ,
5355 dataset : Dataset ,
@@ -57,6 +59,7 @@ def __init__(
5759 """Configure the computer with the dataset, callables, and concurrency cap."""
5860 self .dataset = dataset
5961 self .callables = callables
62+ self .max_concurrency = max_concurrency
6063 if max_concurrency == - 1 :
6164 self .semaphore = None
6265 else :
@@ -80,16 +83,55 @@ async def _queue(self, sample_id: Any, callable_id: str) -> None:
8083
8184 async def run (self ) -> Dict [Tuple [Any , str ], T ]:
8285 """Kick off all pending computations and return the populated registry."""
83- # Materialise identifiers up-front to avoid holding async generators open
84- # while scheduling the computation fan-out.
85- sample_ids = [sample_id async for sample_id in self .dataset .ids ()]
86+
8687 metrics_names = list (self .callables .keys ())
87- # ``anyio`` drives every (sample, metric) pair while respecting
88- # the concurrency limit enforced by ``_queue``.
88+ if not metrics_names :
89+ return {}
90+
91+ # For "unlimited" concurrency we still spawn one task per work item since callers
92+ # explicitly opted out of concurrency caps. The producer/worker pattern below
93+ # is primarily meant to prevent memory blow-ups when a bounded concurrency limit is used.
94+ if self .semaphore is None :
95+ sample_ids = [sample_id async for sample_id in self .dataset .ids ()]
96+ async with anyio .create_task_group () as tg :
97+ for sample_id in sample_ids :
98+ for metric_name in metrics_names :
99+ tg .start_soon (self ._queue , sample_id , metric_name )
100+ return self ._registry .store
101+
102+ # Avoid spawning one task per (sample, metric) pair: for large datasets
103+ # that can create millions of tasks and consume large amounts of memory.
104+ #
105+ # Instead, use a producer/worker pattern:
106+ # - one producer enumerates dataset sample ids and enqueues work items
107+ # - N workers consume items from the queue and run computations
108+
109+ num_workers = max (1 , self .max_concurrency )
110+ queue_max_size = max (1 , num_workers * self ._QUEUE_BUFFER_FACTOR )
111+ work_queue : anyio .abc .ObjectSendStream [Tuple [Any , str ]]
112+ receive_stream : anyio .abc .ObjectReceiveStream [Tuple [Any , str ]]
113+ work_queue , receive_stream = anyio .create_memory_object_stream (queue_max_size )
114+
115+ async def producer () -> None :
116+ async with work_queue :
117+ async for sample_id in self .dataset .ids ():
118+ for metric_name in metrics_names :
119+ await work_queue .send ((sample_id , metric_name ))
120+
121+ async def worker (worker_id : int ) -> None :
122+ del worker_id
123+ while True :
124+ try :
125+ sample_id , metric_name = await receive_stream .receive ()
126+ except anyio .EndOfStream :
127+ return
128+ await self ._queue (sample_id , metric_name )
129+
89130 async with anyio .create_task_group () as tg :
90- for sample_id in sample_ids :
91- for metric_name in metrics_names :
92- tg .start_soon (self ._queue , sample_id , metric_name )
131+ tg .start_soon (producer )
132+ for i in range (num_workers ):
133+ tg .start_soon (worker , i )
134+
93135 return self ._registry .store
94136
95137
0 commit comments