Skip to content

Commit d211cb2

Browse files
Docsum Gateway Fix (opea-project#902)
* update gateway Signed-off-by: Mustafa <[email protected]> * update the gateway Signed-off-by: Mustafa <[email protected]> * update the gateway Signed-off-by: Mustafa <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mustafa <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 405a632 commit d211cb2

File tree

1 file changed

+59
-7
lines changed

1 file changed

+59
-7
lines changed

comps/cores/mega/gateway.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,62 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):
419419
output_datatype=ChatCompletionResponse,
420420
)
421421

422-
async def handle_request(self, request: Request):
423-
data = await request.json()
424-
stream_opt = data.get("stream", True)
425-
chat_request = ChatCompletionRequest.model_validate(data)
422+
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
423+
424+
if "application/json" in request.headers.get("content-type"):
425+
data = await request.json()
426+
stream_opt = data.get("stream", True)
427+
chat_request = ChatCompletionRequest.model_validate(data)
428+
prompt = self._handle_message(chat_request.messages)
429+
430+
initial_inputs_data = {data["type"]: prompt}
431+
432+
elif "multipart/form-data" in request.headers.get("content-type"):
433+
data = await request.form()
434+
stream_opt = data.get("stream", True)
435+
chat_request = ChatCompletionRequest.model_validate(data)
436+
437+
data_type = data.get("type")
438+
439+
file_summaries = []
440+
if files:
441+
for file in files:
442+
file_path = f"/tmp/{file.filename}"
443+
444+
if data_type is not None and data_type in ["audio", "video"]:
445+
raise ValueError(
446+
"Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
447+
)
448+
449+
else:
450+
import aiofiles
451+
452+
async with aiofiles.open(file_path, "wb") as f:
453+
await f.write(await file.read())
454+
455+
docs = read_text_from_file(file, file_path)
456+
os.remove(file_path)
457+
458+
if isinstance(docs, list):
459+
file_summaries.extend(docs)
460+
else:
461+
file_summaries.append(docs)
462+
463+
if file_summaries:
464+
prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries)
465+
else:
466+
prompt = self._handle_message(chat_request.messages)
467+
468+
data_type = data.get("type")
469+
if data_type is not None:
470+
initial_inputs_data = {}
471+
initial_inputs_data[data_type] = prompt
472+
else:
473+
initial_inputs_data = {"query": prompt}
474+
475+
else:
476+
raise ValueError(f"Unknown request type: {request.headers.get('content-type')}")
426477

427-
prompt = self._handle_message(chat_request.messages)
428478
parameters = LLMParams(
429479
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
430480
top_k=chat_request.top_k if chat_request.top_k else 10,
@@ -434,12 +484,14 @@ async def handle_request(self, request: Request):
434484
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
435485
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
436486
streaming=stream_opt,
437-
language=chat_request.language if chat_request.language else "auto",
438487
model=chat_request.model if chat_request.model else None,
488+
language=chat_request.language if chat_request.language else "auto",
439489
)
490+
440491
result_dict, runtime_graph = await self.megaservice.schedule(
441-
initial_inputs={data["type"]: prompt}, llm_parameters=parameters
492+
initial_inputs=initial_inputs_data, llm_parameters=parameters
442493
)
494+
443495
for node, response in result_dict.items():
444496
# Here it suppose the last microservice in the megaservice is LLM.
445497
if (

0 commit comments

Comments
 (0)