Skip to content

Commit e3f40db

Browse files
authored
Refa: make RAGFlow more asynchronous 2 (infiniflow#11689)
### What problem does this PR solve? Make RAGFlow more asynchronous 2. infiniflow#11551, infiniflow#11579, infiniflow#11619. ### Type of change - [x] Refactoring - [x] Performance Improvement
1 parent b5ad7b7 commit e3f40db

File tree

15 files changed

+667
-305
lines changed

15 files changed

+667
-305
lines changed

agent/canvas.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,19 @@ async def _run_batch(f, t):
416416

417417
loop = asyncio.get_running_loop()
418418
tasks = []
419+
420+
def _run_async_in_thread(coro_func, **call_kwargs):
421+
return asyncio.run(coro_func(**call_kwargs))
422+
419423
i = f
420424
while i < t:
421425
cpn = self.get_component_obj(self.path[i])
422426
task_fn = None
427+
call_kwargs = None
423428

424429
if cpn.component_name.lower() in ["begin", "userfillup"]:
425-
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
430+
call_kwargs = {"inputs": kwargs.get("inputs", {})}
431+
task_fn = cpn.invoke
426432
i += 1
427433
else:
428434
for _, ele in cpn.get_input_elements().items():
@@ -431,13 +437,18 @@ async def _run_batch(f, t):
431437
t -= 1
432438
break
433439
else:
434-
task_fn = partial(cpn.invoke, **cpn.get_input())
440+
call_kwargs = cpn.get_input()
441+
task_fn = cpn.invoke
435442
i += 1
436443

437444
if task_fn is None:
438445
continue
439446

440-
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
447+
invoke_async = getattr(cpn, "invoke_async", None)
448+
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
449+
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
450+
else:
451+
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
441452

442453
if tasks:
443454
await asyncio.gather(*tasks)

agent/component/agent_with_tools.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import asyncio
1617
import json
1718
import logging
1819
import os
@@ -239,6 +240,86 @@ def clean_formated_answer(ans: str) -> str:
239240
self.set_output("use_tools", use_tools)
240241
return ans
241242

243+
async def _invoke_async(self, **kwargs):
244+
"""
245+
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
246+
"""
247+
if self.check_if_canceled("Agent processing"):
248+
return
249+
250+
if kwargs.get("user_prompt"):
251+
usr_pmt = ""
252+
if kwargs.get("reasoning"):
253+
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
254+
if kwargs.get("context"):
255+
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
256+
if usr_pmt:
257+
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
258+
else:
259+
usr_pmt = str(kwargs["user_prompt"])
260+
self._param.prompts = [{"role": "user", "content": usr_pmt}]
261+
262+
if not self.tools:
263+
if self.check_if_canceled("Agent processing"):
264+
return
265+
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
266+
267+
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
268+
output_schema = self._get_output_schema()
269+
schema_prompt = ""
270+
if output_schema:
271+
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
272+
schema_prompt = structured_output_prompt(schema)
273+
274+
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
275+
ex = self.exception_handler()
276+
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
277+
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
278+
return
279+
280+
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
281+
use_tools = []
282+
ans = ""
283+
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
284+
if self.check_if_canceled("Agent processing"):
285+
return
286+
ans += delta_ans
287+
288+
if ans.find("**ERROR**") >= 0:
289+
logging.error(f"Agent._chat got error. response: {ans}")
290+
if self.get_exception_default_value():
291+
self.set_output("content", self.get_exception_default_value())
292+
else:
293+
self.set_output("_ERROR", ans)
294+
return
295+
296+
if output_schema:
297+
error = ""
298+
for _ in range(self._param.max_retries + 1):
299+
try:
300+
def clean_formated_answer(ans: str) -> str:
301+
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
302+
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
303+
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
304+
obj = json_repair.loads(clean_formated_answer(ans))
305+
self.set_output("structured", obj)
306+
if use_tools:
307+
self.set_output("use_tools", use_tools)
308+
return obj
309+
except Exception:
310+
error = "The answer cannot be parsed as JSON"
311+
ans = self._force_format_to_schema(ans, schema_prompt)
312+
if ans.find("**ERROR**") >= 0:
313+
continue
314+
315+
self.set_output("_ERROR", error)
316+
return
317+
318+
self.set_output("content", ans)
319+
if use_tools:
320+
self.set_output("use_tools", use_tools)
321+
return ans
322+
242323
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
243324
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
244325
answer_without_toolcall = ""
@@ -261,6 +342,54 @@ def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
261342
if use_tools:
262343
self.set_output("use_tools", use_tools)
263344

345+
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
346+
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
347+
answer_without_toolcall = ""
348+
use_tools = []
349+
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
350+
if self.check_if_canceled("Agent streaming"):
351+
return
352+
353+
if delta_ans.find("**ERROR**") >= 0:
354+
if self.get_exception_default_value():
355+
self.set_output("content", self.get_exception_default_value())
356+
yield self.get_exception_default_value()
357+
else:
358+
self.set_output("_ERROR", delta_ans)
359+
return
360+
answer_without_toolcall += delta_ans
361+
yield delta_ans
362+
363+
self.set_output("content", answer_without_toolcall)
364+
if use_tools:
365+
self.set_output("use_tools", use_tools)
366+
367+
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
368+
"""
369+
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
370+
"""
371+
loop = asyncio.get_running_loop()
372+
queue: asyncio.Queue = asyncio.Queue()
373+
374+
def worker():
375+
try:
376+
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
377+
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
378+
except Exception as e:
379+
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
380+
finally:
381+
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
382+
383+
await asyncio.to_thread(worker)
384+
385+
while True:
386+
item = await queue.get()
387+
if item is StopAsyncIteration:
388+
break
389+
if isinstance(item, Exception):
390+
raise item
391+
yield item
392+
264393
def _gen_citations(self, text):
265394
retrievals = self._canvas.get_reference()
266395
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
@@ -433,4 +562,3 @@ def reset(self, only_output=False):
433562
for k in self._param.inputs.keys():
434563
self._param.inputs[k]["value"] = None
435564
self._param.debug_inputs = {}
436-

agent/component/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17+
import asyncio
1718
import re
1819
import time
1920
from abc import ABC
@@ -445,6 +446,34 @@ def invoke(self, **kwargs) -> dict[str, Any]:
445446
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
446447
return self.output()
447448

449+
async def invoke_async(self, **kwargs) -> dict[str, Any]:
450+
"""
451+
Async wrapper for component invocation.
452+
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
453+
Handles timing and error recording consistently with `invoke`.
454+
"""
455+
self.set_output("_created_time", time.perf_counter())
456+
try:
457+
if self.check_if_canceled("Component processing"):
458+
return
459+
460+
fn_async = getattr(self, "_invoke_async", None)
461+
if fn_async and asyncio.iscoroutinefunction(fn_async):
462+
await fn_async(**kwargs)
463+
elif asyncio.iscoroutinefunction(self._invoke):
464+
await self._invoke(**kwargs)
465+
else:
466+
await asyncio.to_thread(self._invoke, **kwargs)
467+
except Exception as e:
468+
if self.get_exception_default_value():
469+
self.set_exception_default_value()
470+
else:
471+
self.set_output("_ERROR", str(e))
472+
logging.exception(e)
473+
self._param.debug_inputs = {}
474+
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
475+
return self.output()
476+
448477
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
449478
def _invoke(self, **kwargs):
450479
raise NotImplementedError()

agent/component/llm.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import asyncio
1617
import json
1718
import logging
1819
import os
1920
import re
21+
import threading
2022
from copy import deepcopy
21-
from typing import Any, Generator
23+
from typing import Any, Generator, AsyncGenerator
2224
import json_repair
2325
from functools import partial
2426
from common.constants import LLMType
@@ -171,6 +173,13 @@ def _generate(self, msg:list[dict], **kwargs) -> str:
171173
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
172174
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
173175

176+
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
177+
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
178+
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
179+
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
180+
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
181+
return await asyncio.to_thread(self._generate, msg, **kwargs)
182+
174183
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
175184
ans = ""
176185
last_idx = 0
@@ -205,6 +214,69 @@ def delta(txt):
205214
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
206215
yield delta(txt)
207216

217+
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
218+
async def delta_wrapper(txt_iter):
219+
ans = ""
220+
last_idx = 0
221+
endswith_think = False
222+
223+
def delta(txt):
224+
nonlocal ans, last_idx, endswith_think
225+
delta_ans = txt[last_idx:]
226+
ans = txt
227+
228+
if delta_ans.find("<think>") == 0:
229+
last_idx += len("<think>")
230+
return "<think>"
231+
elif delta_ans.find("<think>") > 0:
232+
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
233+
last_idx += delta_ans.find("<think>")
234+
return delta_ans
235+
elif delta_ans.endswith("</think>"):
236+
endswith_think = True
237+
elif endswith_think:
238+
endswith_think = False
239+
return "</think>"
240+
241+
last_idx = len(ans)
242+
if ans.endswith("</think>"):
243+
last_idx -= len("</think>")
244+
return re.sub(r"(<think>|</think>)", "", delta_ans)
245+
246+
async for t in txt_iter:
247+
yield delta(t)
248+
249+
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
250+
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
251+
yield t
252+
return
253+
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
254+
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
255+
yield t
256+
return
257+
258+
# fallback
259+
loop = asyncio.get_running_loop()
260+
queue: asyncio.Queue = asyncio.Queue()
261+
262+
def worker():
263+
try:
264+
for item in self._generate_streamly(msg, **kwargs):
265+
loop.call_soon_threadsafe(queue.put_nowait, item)
266+
except Exception as e:
267+
loop.call_soon_threadsafe(queue.put_nowait, e)
268+
finally:
269+
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
270+
271+
threading.Thread(target=worker, daemon=True).start()
272+
while True:
273+
item = await queue.get()
274+
if item is StopAsyncIteration:
275+
break
276+
if isinstance(item, Exception):
277+
raise item
278+
yield item
279+
208280
async def _stream_output_async(self, prompt, msg):
209281
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
210282
answer = ""

agent/tools/base.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
import time
1919
from copy import deepcopy
20+
import asyncio
2021
from functools import partial
2122
from typing import TypedDict, List, Any
2223
from agent.component.base import ComponentParamBase, ComponentBase
@@ -50,10 +51,14 @@ def __init__(self, tools_map: dict[str, object], callback: partial):
5051
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
5152
assert name in self.tools_map, f"LLM tool {name} does not exist"
5253
st = timer()
53-
if isinstance(self.tools_map[name], MCPToolCallSession):
54-
resp = self.tools_map[name].tool_call(name, arguments, 60)
54+
tool_obj = self.tools_map[name]
55+
if isinstance(tool_obj, MCPToolCallSession):
56+
resp = tool_obj.tool_call(name, arguments, 60)
5557
else:
56-
resp = self.tools_map[name].invoke(**arguments)
58+
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
59+
resp = asyncio.run(tool_obj.invoke_async(**arguments))
60+
else:
61+
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
5762

5863
self.callback(name, arguments, resp, elapsed_time=timer()-st)
5964
return resp
@@ -139,6 +144,33 @@ def invoke(self, **kwargs):
139144
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
140145
return res
141146

147+
async def invoke_async(self, **kwargs):
148+
"""
149+
Async wrapper for tool invocation.
150+
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
151+
Mirrors the exception handling of `invoke`.
152+
"""
153+
if self.check_if_canceled("Tool processing"):
154+
return
155+
156+
self.set_output("_created_time", time.perf_counter())
157+
try:
158+
fn_async = getattr(self, "_invoke_async", None)
159+
if fn_async and asyncio.iscoroutinefunction(fn_async):
160+
res = await fn_async(**kwargs)
161+
elif asyncio.iscoroutinefunction(self._invoke):
162+
res = await self._invoke(**kwargs)
163+
else:
164+
res = await asyncio.to_thread(self._invoke, **kwargs)
165+
except Exception as e:
166+
self._param.outputs["_ERROR"] = {"value": str(e)}
167+
logging.exception(e)
168+
res = str(e)
169+
self._param.debug_inputs = []
170+
171+
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
172+
return res
173+
142174
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
143175
chunks = []
144176
aggs = []

0 commit comments

Comments
 (0)