@@ -64,6 +64,50 @@ def validate_dict(data: dict, model) -> dict:
6464 return model (** data ).model_dump (by_alias = True , exclude_unset = True )
6565
6666
67+ def _messages_to_sap_template (messages : List [Dict [str , str ]]) -> list : # type: ignore[type-arg]
68+ template = []
69+ for message in messages :
70+ if message ["role" ] == "user" :
71+ template .append (validate_dict (message , SAPUserMessage ))
72+ elif message ["role" ] == "assistant" :
73+ template .append (validate_dict (message , SAPAssistantMessage ))
74+ elif message ["role" ] == "tool" :
75+ template .append (validate_dict (message , SAPToolChatMessage ))
76+ else :
77+ template .append (validate_dict (message , SAPMessage ))
78+ return template
79+
80+
81+ def _tools_response_format_and_stream (
82+ optional_params : dict , model_params : dict
83+ ) -> Tuple [dict , dict , dict ]:
84+ tools_ = optional_params .pop ("tools" , [])
85+ tools_ = [validate_dict (tool , ChatCompletionTool ) for tool in tools_ ]
86+ tools : dict = {"tools" : tools_ } if tools_ else {}
87+
88+ response_format = model_params .pop ("response_format" , {})
89+ resp_type = response_format .get ("type" , None )
90+ if resp_type :
91+ if resp_type == "json_schema" :
92+ response_format = validate_dict (
93+ response_format , ResponseFormatJSONSchema
94+ )
95+ else :
96+ response_format = validate_dict (response_format , ResponseFormat )
97+ response_format = {"response_format" : response_format }
98+
99+ model_params .pop ("stream" , False )
100+ stream_config : dict = {}
101+ if "stream_options" in optional_params :
102+ stream_options = optional_params .pop ("stream_options" , {})
103+ if "chunk_size" in stream_options :
104+ stream_config ["chunk_size" ] = stream_options .get ("chunk_size" )
105+ if "delimiters" in stream_options :
106+ stream_config ["delimiters" ] = stream_options .get ("delimiters" )
107+
108+ return tools , response_format , stream_config
109+
110+
67111class GenAIHubOrchestrationConfig (OpenAIGPTConfig ):
68112 frequency_penalty : Optional [int ] = None
69113 function_call : Optional [Union [str , dict ]] = None
@@ -292,57 +336,19 @@ def transform_request(
292336 optional_params = dict (optional_params )
293337 optional_params .pop ("deployment_url" , None )
294338
295- template = messages
296-
297339 excluded_params = _SAP_MODEL_PARAMS_EXCLUDED_KEYS
298340 model_params = {
299341 k : v for k , v in optional_params .items () if k not in excluded_params
300342 }
301343
302344 model_version = optional_params .pop ("model_version" , "latest" )
303- template = []
304- for message in messages :
305- if message ["role" ] == "user" :
306- template .append (validate_dict (message , SAPUserMessage ))
307- elif message ["role" ] == "assistant" :
308- template .append (validate_dict (message , SAPAssistantMessage ))
309- elif message ["role" ] == "tool" :
310- template .append (validate_dict (message , SAPToolChatMessage ))
311- else :
312- template .append (validate_dict (message , SAPMessage ))
313-
314- tools_ = optional_params .pop ("tools" , [])
315- tools_ = [validate_dict (tool , ChatCompletionTool ) for tool in tools_ ]
316- if tools_ != []:
317- tools = {"tools" : tools_ }
318- else :
319- tools = {}
345+ template = _messages_to_sap_template (messages )
320346
321- response_format = model_params .pop ("response_format" , {})
322- resp_type = response_format .get ("type" , None )
323- if resp_type :
324- if resp_type == "json_schema" :
325- response_format = validate_dict (
326- response_format , ResponseFormatJSONSchema
327- )
328- else :
329- response_format = validate_dict (response_format , ResponseFormat )
330- response_format = {"response_format" : response_format }
331- model_params .pop ("stream" , False )
332- stream_config = {}
333- if "stream_options" in optional_params :
334- stream_options = optional_params .pop ("stream_options" , {})
335- if "chunk_size" in stream_options :
336- stream_config ["chunk_size" ] = stream_options .get ("chunk_size" )
337- if "delimiters" in stream_options :
338- stream_config ["delimiters" ] = stream_options .get ("delimiters" )
347+ tools , response_format , stream_config = _tools_response_format_and_stream (
348+ optional_params , model_params
349+ )
339350
340351 placeholder_values = optional_params .pop ("placeholder_values" , None )
341- placeholder_values = (
342- {"placeholder_values" : placeholder_values }
343- if placeholder_values is not None
344- else {}
345- )
346352
347353 fallback_modules = optional_params .pop ("fallback_sap_modules" , [])
348354
@@ -373,26 +379,24 @@ def transform_request(
373379 )
374380 )
375381
376- if fallback_modules :
377- modules_payload = modules
378- else :
379- modules_payload = modules [0 ] # type: ignore
380-
381- request_body = {
382- "config" : {
383- "modules" : {
384- "prompt_templating" : {
385- "prompt" : {"template" : template , ** tools , ** response_format },
386- "model" : {
387- "name" : model ,
388- "params" : model_params ,
389- "version" : model_version ,
390- },
382+ config_payload : Dict [str , Any ] = {
383+ "modules" : {
384+ "prompt_templating" : {
385+ "prompt" : {"template" : template , ** tools , ** response_format },
386+ "model" : {
387+ "name" : model ,
388+ "params" : model_params ,
389+ "version" : model_version ,
391390 },
392391 },
393- "stream" : stream_config ,
394- }
392+ },
395393 }
394+ if stream_config :
395+ config_payload ["stream" ] = stream_config
396+
397+ request_body : Dict [str , Any ] = {"config" : config_payload }
398+ if placeholder_values is not None :
399+ request_body ["placeholder_values" ] = placeholder_values
396400
397401 body = validate_dict (request_body , OrchestrationRequest )
398402
0 commit comments