Skip to content

Commit ab489b1

Browse files
Optimize task generation
1 parent 8a21016 commit ab489b1

File tree

2 files changed

+90
-8
lines changed

2 files changed

+90
-8
lines changed

pyagentspec/src/pyagentspec/evaluation/_computers/_async_callables_computers.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ async def register(self, key: K, value: V) -> None:
4848
class _AsyncCallablesComputer(Generic[T]):
4949
"""Evaluate a set of async callables across every sample in a dataset."""
5050

51+
_QUEUE_BUFFER_FACTOR = 3
52+
5153
def __init__(
5254
self,
5355
dataset: Dataset,
@@ -57,6 +59,7 @@ def __init__(
5759
"""Configure the computer with the dataset, callables, and concurrency cap."""
5860
self.dataset = dataset
5961
self.callables = callables
62+
self.max_concurrency = max_concurrency
6063
if max_concurrency == -1:
6164
self.semaphore = None
6265
else:
@@ -80,16 +83,55 @@ async def _queue(self, sample_id: Any, callable_id: str) -> None:
8083

8184
async def run(self) -> Dict[Tuple[Any, str], T]:
8285
"""Kick off all pending computations and return the populated registry."""
83-
# Materialise identifiers up-front to avoid holding async generators open
84-
# while scheduling the computation fan-out.
85-
sample_ids = [sample_id async for sample_id in self.dataset.ids()]
86+
8687
metrics_names = list(self.callables.keys())
87-
# ``anyio`` drives every (sample, metric) pair while respecting
88-
# the concurrency limit enforced by ``_queue``.
88+
if not metrics_names:
89+
return {}
90+
91+
# For "unlimited" concurrency we still spawn one task per work item since callers
92+
# explicitly opted out of concurrency caps. The producer/worker pattern below
93+
# is primarily meant to prevent memory blow-ups when a bounded concurrency limit is used.
94+
if self.semaphore is None:
95+
sample_ids = [sample_id async for sample_id in self.dataset.ids()]
96+
async with anyio.create_task_group() as tg:
97+
for sample_id in sample_ids:
98+
for metric_name in metrics_names:
99+
tg.start_soon(self._queue, sample_id, metric_name)
100+
return self._registry.store
101+
102+
# Avoid spawning one task per (sample, metric) pair: for large datasets
103+
# that can create millions of tasks and consume large amounts of memory.
104+
#
105+
# Instead, use a producer/worker pattern:
106+
# - one producer enumerates dataset sample ids and enqueues work items
107+
# - N workers consume items from the queue and run computations
108+
109+
num_workers = max(1, self.max_concurrency)
110+
queue_max_size = max(1, num_workers * self._QUEUE_BUFFER_FACTOR)
111+
work_queue: anyio.abc.ObjectSendStream[Tuple[Any, str]]
112+
receive_stream: anyio.abc.ObjectReceiveStream[Tuple[Any, str]]
113+
work_queue, receive_stream = anyio.create_memory_object_stream(queue_max_size)
114+
115+
async def producer() -> None:
116+
async with work_queue:
117+
async for sample_id in self.dataset.ids():
118+
for metric_name in metrics_names:
119+
await work_queue.send((sample_id, metric_name))
120+
121+
async def worker(worker_id: int) -> None:
122+
del worker_id
123+
while True:
124+
try:
125+
sample_id, metric_name = await receive_stream.receive()
126+
except anyio.EndOfStream:
127+
return
128+
await self._queue(sample_id, metric_name)
129+
89130
async with anyio.create_task_group() as tg:
90-
for sample_id in sample_ids:
91-
for metric_name in metrics_names:
92-
tg.start_soon(self._queue, sample_id, metric_name)
131+
tg.start_soon(producer)
132+
for i in range(num_workers):
133+
tg.start_soon(worker, i)
134+
93135
return self._registry.store
94136

95137

pyagentspec/tests/evaluation/_computers/test_evaluator_concurrency.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,46 @@ async def test_unlimited_concurrency() -> None:
121121
assert num_runnings_sequence[-i - 1] == i
122122

123123

124+
@pytest.mark.anyio
125+
async def test_run_does_not_spawn_one_task_per_item() -> None:
126+
"""
127+
Ensure ``_AsyncCallablesComputer.run`` does not create O(N) tasks.
128+
This is a regression test for memory blow-ups when datasets are large.
129+
"""
130+
131+
class CountingTaskGroup:
132+
def __init__(self, max_allowed: int) -> None:
133+
self.max_allowed = max_allowed
134+
self.started = 0
135+
136+
async def __aenter__(self) -> "CountingTaskGroup":
137+
return self
138+
139+
async def __aexit__(self, exc_type, exc, tb) -> None:
140+
return None
141+
142+
def start_soon(self, func, *args) -> None:
143+
self.started += 1
144+
assert self.started <= self.max_allowed
145+
146+
dataset = Dataset.from_dict([{"dummy_arg": i} for i in range(10000)])
147+
callables = {"dummy_callable": (lambda **kwargs: asyncio.sleep(0))}
148+
computer = _AsyncCallablesComputer(
149+
dataset=dataset,
150+
callables=callables,
151+
max_concurrency=10,
152+
)
153+
154+
import anyio # imported here to keep the patch localized to this test
155+
156+
original = anyio.create_task_group
157+
try:
158+
anyio.create_task_group = lambda: CountingTaskGroup(max_allowed=1 + 10)
159+
await computer.run()
160+
finally:
161+
anyio.create_task_group = original
162+
163+
124164
@pytest.mark.anyio
125165
@pytest.mark.parametrize("max_concurrency", [5, 10, 20])
126166
async def test_firsts_begin_together(max_concurrency: int) -> None:

0 commit comments

Comments
 (0)