Skip to content

Commit 9cb990d

Browse files
committed
Merge branch 'dev'
2 parents af77434 + abeb1e5 commit 9cb990d

10 files changed

Lines changed: 119 additions & 44 deletions

File tree

agently/core/ModelRequest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ def __init__(
432432
self.get_data = FunctionShifter.syncify(self.async_get_data)
433433
self.get_data_object = FunctionShifter.syncify(self.async_get_data_object)
434434

435+
self.start = self.get_data
436+
self.async_start = self.async_get_data
437+
435438
def set_prompt(
436439
self,
437440
key: "PromptStandardSlot | str",

agently/core/TriggerFlow/BluePrint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def create_execution(
132132
*,
133133
execution_id: str | None = None,
134134
skip_exceptions: bool = False,
135+
concurrency: int | None = None,
135136
):
136137
handlers_snapshot: TriggerFlowAllHandlers = {
137138
"event": {k: v.copy() for k, v in self._handlers["event"].items()},
@@ -143,6 +144,7 @@ def create_execution(
143144
trigger_flow=trigger_flow,
144145
id=execution_id,
145146
skip_exceptions=skip_exceptions,
147+
concurrency=concurrency,
146148
)
147149

148150
def copy(self, *, name: str | None = None):

agently/core/TriggerFlow/Execution.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import uuid
1717
import asyncio
1818
import warnings
19+
from contextvars import ContextVar
1920

2021
from typing import Any, Literal, TYPE_CHECKING
2122

@@ -37,6 +38,7 @@ def __init__(
3738
trigger_flow: "TriggerFlow",
3839
id: str | None = None,
3940
skip_exceptions: bool = False,
41+
concurrency: int | None = None,
4042
):
4143
# Basic Attributions
4244
self.id = id if id is not None else uuid.uuid4().hex
@@ -45,6 +47,11 @@ def __init__(
4547
self._runtime_data = RuntimeData()
4648
self._system_runtime_data = RuntimeData()
4749
self._skip_exceptions = skip_exceptions
50+
self._concurrency_semaphore = asyncio.Semaphore(concurrency) if concurrency and concurrency > 0 else None
51+
self._concurrency_depth = ContextVar(
52+
f"trigger_flow_execution_concurrency_depth_{ self.id }",
53+
default=0,
54+
)
4855

4956
# Settings
5057
self.settings = Settings(
@@ -126,19 +133,29 @@ async def async_emit(
126133
},
127134
self.settings,
128135
)
129-
tasks.append(
130-
asyncio.ensure_future(
131-
FunctionShifter.asyncify(handler)(
132-
TriggerFlowEventData(
133-
trigger_event=trigger_event,
134-
trigger_type=trigger_type,
135-
value=value,
136-
execution=self,
137-
_layer_marks=_layer_marks,
138-
)
139-
)
136+
async def run_handler(handler_func):
137+
if self._concurrency_semaphore is None:
138+
return await handler_func
139+
depth = self._concurrency_depth.get()
140+
token = self._concurrency_depth.set(depth + 1)
141+
try:
142+
if depth > 0:
143+
return await handler_func
144+
async with self._concurrency_semaphore:
145+
return await handler_func
146+
finally:
147+
self._concurrency_depth.reset(token)
148+
149+
handler_task = FunctionShifter.asyncify(handler)(
150+
TriggerFlowEventData(
151+
trigger_event=trigger_event,
152+
trigger_type=trigger_type,
153+
value=value,
154+
execution=self,
155+
_layer_marks=_layer_marks,
140156
)
141157
)
158+
tasks.append(asyncio.ensure_future(run_handler(handler_task)))
142159

143160
if tasks:
144161
await asyncio.gather(*tasks, return_exceptions=self._skip_exceptions)

agently/core/TriggerFlow/TriggerFlow.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,19 @@ def wrapper(func: "TriggerFlowHandler"):
9999
self._blue_print.chunks[handler_or_name.__name__] = chunk
100100
return chunk
101101

102-
def create_execution(self, *, skip_exceptions: bool | None = None):
102+
def create_execution(
103+
self,
104+
*,
105+
skip_exceptions: bool | None = None,
106+
concurrency: int | None = None,
107+
):
103108
execution_id = uuid.uuid4().hex
104109
skip_exceptions = skip_exceptions if skip_exceptions is not None else self._skip_exceptions
105110
execution = self._blue_print.create_execution(
106111
self,
107112
execution_id=execution_id,
108113
skip_exceptions=skip_exceptions,
114+
concurrency=concurrency,
109115
)
110116
self._executions[execution_id] = execution
111117
return execution
@@ -118,8 +124,14 @@ def remove_execution(self, execution: "TriggerFlowExecution | str"):
118124
if execution.id in self._executions:
119125
del self._executions[execution.id]
120126

121-
async def async_start_execution(self, initial_value: Any, *, wait_for_result: bool = False):
122-
execution = self.create_execution()
127+
async def async_start_execution(
128+
self,
129+
initial_value: Any,
130+
*,
131+
wait_for_result: bool = False,
132+
concurrency: int | None = None,
133+
):
134+
execution = self.create_execution(concurrency=concurrency)
123135
await execution.async_start(initial_value, wait_for_result=wait_for_result)
124136
return execution
125137

@@ -192,8 +204,9 @@ async def async_start(
192204
*,
193205
wait_for_result: bool = True,
194206
timeout: int | None = 10,
207+
concurrency: int | None = None,
195208
):
196-
execution = await self.async_start_execution(initial_value)
209+
execution = await self.async_start_execution(initial_value, concurrency=concurrency)
197210
if wait_for_result:
198211
return await execution.async_get_result(timeout=timeout)
199212

@@ -202,8 +215,9 @@ def get_async_runtime_stream(
202215
initial_value: Any = None,
203216
*,
204217
timeout: int | None = 10,
218+
concurrency: int | None = None,
205219
):
206-
execution = self.create_execution()
220+
execution = self.create_execution(concurrency=concurrency)
207221
return execution.get_async_runtime_stream(
208222
initial_value,
209223
timeout=timeout,
@@ -214,8 +228,9 @@ def get_runtime_stream(
214228
initial_value: Any = None,
215229
*,
216230
timeout: int | None = 10,
231+
concurrency: int | None = None,
217232
):
218-
execution = self.create_execution()
233+
execution = self.create_execution(concurrency=concurrency)
219234
return execution.get_runtime_stream(
220235
initial_value,
221236
timeout=timeout,

agently/core/TriggerFlow/process/BaseProcess.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
import uuid
17-
from asyncio import Event
17+
from asyncio import Event, Semaphore
1818
from threading import Lock
1919

2020
from typing import Callable, Any, Literal, TYPE_CHECKING, overload, cast
@@ -266,11 +266,13 @@ def batch(
266266
self,
267267
*chunks: "TriggerFlowChunk | TriggerFlowHandler | tuple[str, TriggerFlowHandler]",
268268
side_branch: bool = False,
269+
concurrency: int | None = None,
269270
):
270271
batch_trigger = f"Batch-{ uuid.uuid4().hex }"
271272
results = {}
272273
triggers_to_wait = {}
273274
trigger_to_chunk_name = {}
275+
semaphore = Semaphore(concurrency) if concurrency and concurrency > 0 else None
274276

275277
async def wait_all_chunks(data: "TriggerFlowEventData"):
276278
if data.event in triggers_to_wait:
@@ -295,10 +297,18 @@ async def wait_all_chunks(data: "TriggerFlowEventData"):
295297
triggers_to_wait[chunk.trigger] = False
296298
trigger_to_chunk_name[chunk.trigger] = chunk.name
297299
results[chunk.name] = None
300+
301+
if semaphore is None:
302+
handler = chunk.async_call
303+
else:
304+
async def handler(data: "TriggerFlowEventData", _chunk=chunk):
305+
async with semaphore:
306+
return await _chunk.async_call(data)
307+
298308
self._blue_print.add_handler(
299309
self.trigger_type,
300310
self.trigger_event,
301-
chunk.async_call,
311+
handler,
302312
)
303313
self._blue_print.add_event_handler(chunk.trigger, wait_all_chunks)
304314

agently/core/TriggerFlow/process/ForEachProcess.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class TriggerFlowForEachProcess(TriggerFlowBaseProcess):
25-
def for_each(self):
25+
def for_each(self, *, concurrency: int | None = None):
2626
for_each_id = uuid.uuid4().hex
2727
for_each_block_data = TriggerFlowBlockData(
2828
outer_block=self._block_data,
@@ -31,40 +31,46 @@ def for_each(self):
3131
},
3232
)
3333
send_item_trigger = f"ForEach-{ for_each_id }-Send"
34+
semaphore = asyncio.Semaphore(concurrency) if concurrency and concurrency > 0 else None
3435

3536
async def send_items(data: "TriggerFlowEventData"):
3637
data.layer_in()
3738
for_each_instance_id = data.layer_mark
3839
assert for_each_instance_id is not None
3940

4041
send_tasks = []
41-
if not isinstance(data.value, str) and isinstance(data.value, Sequence):
42-
items = list(data.value)
43-
for item in items:
44-
data.layer_in()
45-
item_id = data.layer_mark
46-
assert item_id is not None
47-
data._system_runtime_data.set(f"for_each_results.{ for_each_instance_id }.{ item_id }", EMPTY)
48-
send_tasks.append(
49-
data.async_emit(
50-
send_item_trigger,
51-
item,
52-
data._layer_marks.copy(),
53-
)
54-
)
55-
data.layer_out()
56-
await asyncio.gather(*send_tasks)
57-
else:
42+
def prepare_item(item):
5843
data.layer_in()
5944
item_id = data.layer_mark
6045
assert item_id is not None
46+
layer_marks = data._layer_marks.copy()
6147
data._system_runtime_data.set(f"for_each_results.{ for_each_instance_id }.{ item_id }", EMPTY)
62-
await data.async_emit(
63-
send_item_trigger,
64-
data.value,
65-
data._layer_marks.copy(),
66-
)
6748
data.layer_out()
49+
return item_id, layer_marks, item
50+
51+
async def emit_item(item, layer_marks):
52+
if semaphore is None:
53+
await data.async_emit(
54+
send_item_trigger,
55+
item,
56+
layer_marks,
57+
)
58+
else:
59+
async with semaphore:
60+
await data.async_emit(
61+
send_item_trigger,
62+
item,
63+
layer_marks,
64+
)
65+
if not isinstance(data.value, str) and isinstance(data.value, Sequence):
66+
items = list(data.value)
67+
for item in items:
68+
_, layer_marks, item_value = prepare_item(item)
69+
send_tasks.append(emit_item(item_value, layer_marks))
70+
await asyncio.gather(*send_tasks)
71+
else:
72+
_, layer_marks, item_value = prepare_item(data.value)
73+
await emit_item(item_value, layer_marks)
6874

6975
self.to(send_items)
7076

examples/trigger_flow/batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async def echo(data: TriggerFlowEventData):
1717
("echo_2", echo),
1818
("echo_3", echo),
1919
("echo_4", echo),
20+
concurrency=2,
2021
).end()
2122
execution = flow.create_execution()
2223
result = execution.start("Agently")

examples/trigger_flow/for_each.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def handle(data: TriggerFlowEventData):
1212

1313
flow_1 = TriggerFlow()
1414

15-
flow_1.for_each().to(handle).end_for_each().to(lambda data: data.value).end()
15+
flow_1.for_each(concurrency=2).to(handle).end_for_each().to(lambda data: data.value).end()
1616

1717
execution_1 = flow_1.create_execution()
1818
result = execution_1.start(
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
from agently import TriggerFlow, TriggerFlowEventData
3+
4+
5+
async def handle(data: TriggerFlowEventData):
6+
print(f"Hi, { data.value }")
7+
await asyncio.sleep(2)
8+
return f"handled: { data.value }"
9+
10+
11+
flow = TriggerFlow()
12+
flow.batch(
13+
("a", handle),
14+
("b", handle),
15+
("c", handle),
16+
("d", handle),
17+
).end()
18+
19+
execution = flow.create_execution(concurrency=2)
20+
result = execution.start("hello")
21+
print(result)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "agently"
3-
version = "4.0.6.11"
3+
version = "4.0.7"
44
description = ""
55
authors = [
66
{name = "Agently Team",email = "developer@agently.tech"},

0 commit comments

Comments
 (0)