Skip to content

Commit 8795198

Browse files
committed
Fix code qa
1 parent 1ef7286 commit 8795198

File tree

1 file changed

+63
-59
lines changed

1 file changed

+63
-59
lines changed

litellm/llms/sap/chat/transformation.py

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,50 @@ def validate_dict(data: dict, model) -> dict:
6464
return model(**data).model_dump(by_alias=True, exclude_unset=True)
6565

6666

67+
def _messages_to_sap_template(messages: List[Dict[str, str]]) -> list: # type: ignore[type-arg]
68+
template = []
69+
for message in messages:
70+
if message["role"] == "user":
71+
template.append(validate_dict(message, SAPUserMessage))
72+
elif message["role"] == "assistant":
73+
template.append(validate_dict(message, SAPAssistantMessage))
74+
elif message["role"] == "tool":
75+
template.append(validate_dict(message, SAPToolChatMessage))
76+
else:
77+
template.append(validate_dict(message, SAPMessage))
78+
return template
79+
80+
81+
def _tools_response_format_and_stream(
82+
optional_params: dict, model_params: dict
83+
) -> Tuple[dict, dict, dict]:
84+
tools_ = optional_params.pop("tools", [])
85+
tools_ = [validate_dict(tool, ChatCompletionTool) for tool in tools_]
86+
tools: dict = {"tools": tools_} if tools_ else {}
87+
88+
response_format = model_params.pop("response_format", {})
89+
resp_type = response_format.get("type", None)
90+
if resp_type:
91+
if resp_type == "json_schema":
92+
response_format = validate_dict(
93+
response_format, ResponseFormatJSONSchema
94+
)
95+
else:
96+
response_format = validate_dict(response_format, ResponseFormat)
97+
response_format = {"response_format": response_format}
98+
99+
model_params.pop("stream", False)
100+
stream_config: dict = {}
101+
if "stream_options" in optional_params:
102+
stream_options = optional_params.pop("stream_options", {})
103+
if "chunk_size" in stream_options:
104+
stream_config["chunk_size"] = stream_options.get("chunk_size")
105+
if "delimiters" in stream_options:
106+
stream_config["delimiters"] = stream_options.get("delimiters")
107+
108+
return tools, response_format, stream_config
109+
110+
67111
class GenAIHubOrchestrationConfig(OpenAIGPTConfig):
68112
frequency_penalty: Optional[int] = None
69113
function_call: Optional[Union[str, dict]] = None
@@ -292,57 +336,19 @@ def transform_request(
292336
optional_params = dict(optional_params)
293337
optional_params.pop("deployment_url", None)
294338

295-
template = messages
296-
297339
excluded_params = _SAP_MODEL_PARAMS_EXCLUDED_KEYS
298340
model_params = {
299341
k: v for k, v in optional_params.items() if k not in excluded_params
300342
}
301343

302344
model_version = optional_params.pop("model_version", "latest")
303-
template = []
304-
for message in messages:
305-
if message["role"] == "user":
306-
template.append(validate_dict(message, SAPUserMessage))
307-
elif message["role"] == "assistant":
308-
template.append(validate_dict(message, SAPAssistantMessage))
309-
elif message["role"] == "tool":
310-
template.append(validate_dict(message, SAPToolChatMessage))
311-
else:
312-
template.append(validate_dict(message, SAPMessage))
313-
314-
tools_ = optional_params.pop("tools", [])
315-
tools_ = [validate_dict(tool, ChatCompletionTool) for tool in tools_]
316-
if tools_ != []:
317-
tools = {"tools": tools_}
318-
else:
319-
tools = {}
345+
template = _messages_to_sap_template(messages)
320346

321-
response_format = model_params.pop("response_format", {})
322-
resp_type = response_format.get("type", None)
323-
if resp_type:
324-
if resp_type == "json_schema":
325-
response_format = validate_dict(
326-
response_format, ResponseFormatJSONSchema
327-
)
328-
else:
329-
response_format = validate_dict(response_format, ResponseFormat)
330-
response_format = {"response_format": response_format}
331-
model_params.pop("stream", False)
332-
stream_config = {}
333-
if "stream_options" in optional_params:
334-
stream_options = optional_params.pop("stream_options", {})
335-
if "chunk_size" in stream_options:
336-
stream_config["chunk_size"] = stream_options.get("chunk_size")
337-
if "delimiters" in stream_options:
338-
stream_config["delimiters"] = stream_options.get("delimiters")
347+
tools, response_format, stream_config = _tools_response_format_and_stream(
348+
optional_params, model_params
349+
)
339350

340351
placeholder_values = optional_params.pop("placeholder_values", None)
341-
placeholder_values = (
342-
{"placeholder_values": placeholder_values}
343-
if placeholder_values is not None
344-
else {}
345-
)
346352

347353
fallback_modules = optional_params.pop("fallback_sap_modules", [])
348354

@@ -373,26 +379,24 @@ def transform_request(
373379
)
374380
)
375381

376-
if fallback_modules:
377-
modules_payload = modules
378-
else:
379-
modules_payload = modules[0] # type: ignore
380-
381-
request_body = {
382-
"config": {
383-
"modules": {
384-
"prompt_templating": {
385-
"prompt": {"template": template, **tools, **response_format},
386-
"model": {
387-
"name": model,
388-
"params": model_params,
389-
"version": model_version,
390-
},
382+
config_payload: Dict[str, Any] = {
383+
"modules": {
384+
"prompt_templating": {
385+
"prompt": {"template": template, **tools, **response_format},
386+
"model": {
387+
"name": model,
388+
"params": model_params,
389+
"version": model_version,
391390
},
392391
},
393-
"stream": stream_config,
394-
}
392+
},
395393
}
394+
if stream_config:
395+
config_payload["stream"] = stream_config
396+
397+
request_body: Dict[str, Any] = {"config": config_payload}
398+
if placeholder_values is not None:
399+
request_body["placeholder_values"] = placeholder_values
396400

397401
body = validate_dict(request_body, OrchestrationRequest)
398402

0 commit comments

Comments
 (0)