@@ -419,12 +419,62 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):
419419 output_datatype = ChatCompletionResponse ,
420420 )
421421
422- async def handle_request (self , request : Request ):
423- data = await request .json ()
424- stream_opt = data .get ("stream" , True )
425- chat_request = ChatCompletionRequest .model_validate (data )
422+ async def handle_request (self , request : Request , files : List [UploadFile ] = File (default = None )):
423+
424+ if "application/json" in request .headers .get ("content-type" ):
425+ data = await request .json ()
426+ stream_opt = data .get ("stream" , True )
427+ chat_request = ChatCompletionRequest .model_validate (data )
428+ prompt = self ._handle_message (chat_request .messages )
429+
430+ initial_inputs_data = {data ["type" ]: prompt }
431+
432+ elif "multipart/form-data" in request .headers .get ("content-type" ):
433+ data = await request .form ()
434+ stream_opt = data .get ("stream" , True )
435+ chat_request = ChatCompletionRequest .model_validate (data )
436+
437+ data_type = data .get ("type" )
438+
439+ file_summaries = []
440+ if files :
441+ for file in files :
442+ file_path = f"/tmp/{ file .filename } "
443+
444+ if data_type is not None and data_type in ["audio" , "video" ]:
445+ raise ValueError (
446+ "Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
447+ )
448+
449+ else :
450+ import aiofiles
451+
452+ async with aiofiles .open (file_path , "wb" ) as f :
453+ await f .write (await file .read ())
454+
455+ docs = read_text_from_file (file , file_path )
456+ os .remove (file_path )
457+
458+ if isinstance (docs , list ):
459+ file_summaries .extend (docs )
460+ else :
461+ file_summaries .append (docs )
462+
463+ if file_summaries :
464+ prompt = self ._handle_message (chat_request .messages ) + "\n " .join (file_summaries )
465+ else :
466+ prompt = self ._handle_message (chat_request .messages )
467+
468+ data_type = data .get ("type" )
469+ if data_type is not None :
470+ initial_inputs_data = {}
471+ initial_inputs_data [data_type ] = prompt
472+ else :
473+ initial_inputs_data = {"query" : prompt }
474+
475+ else :
476+ raise ValueError (f"Unknown request type: { request .headers .get ('content-type' )} " )
426477
427- prompt = self ._handle_message (chat_request .messages )
428478 parameters = LLMParams (
429479 max_tokens = chat_request .max_tokens if chat_request .max_tokens else 1024 ,
430480 top_k = chat_request .top_k if chat_request .top_k else 10 ,
@@ -434,12 +484,14 @@ async def handle_request(self, request: Request):
434484 presence_penalty = chat_request .presence_penalty if chat_request .presence_penalty else 0.0 ,
435485 repetition_penalty = chat_request .repetition_penalty if chat_request .repetition_penalty else 1.03 ,
436486 streaming = stream_opt ,
437- language = chat_request .language if chat_request .language else "auto" ,
438487 model = chat_request .model if chat_request .model else None ,
488+ language = chat_request .language if chat_request .language else "auto" ,
439489 )
490+
440491 result_dict , runtime_graph = await self .megaservice .schedule (
441- initial_inputs = { data [ "type" ]: prompt } , llm_parameters = parameters
492+ initial_inputs = initial_inputs_data , llm_parameters = parameters
442493 )
494+
443495 for node , response in result_dict .items ():
444496 # Here it suppose the last microservice in the megaservice is LLM.
445497 if (
0 commit comments