Skip to content

Commit 0eb0cf4

Browse files
Fixed Ollama Structured Response not working #10616 (#10617)
* Fixed Json.dumps in JSON Schema Validation Error * Added Response Schema to Ollama chat for structured response * Added Test cases * refactor(ollama): remove redundant response_format check The response_format parameter conversion is already handled in utils.py's get_optional_params function, making the duplicate check in ollama_chat.py unnecessary. This change removes the redundant code while maintaining the same functionality.
1 parent b7fc726 commit 0eb0cf4

File tree

3 files changed

+112
-13
lines changed

3 files changed

+112
-13
lines changed

litellm/litellm_core_utils/json_validation_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def validate_schema(schema: dict, response: str):
1717
response_dict = json.loads(response)
1818
except json.JSONDecodeError:
1919
raise JSONSchemaValidationError(
20-
model="", llm_provider="", raw_response=response, schema=response
20+
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
2121
)
2222

2323
try:

litellm/llms/ollama_chat.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import aiohttp
77
import httpx
88
from pydantic import BaseModel
9+
import inspect
910

1011
import litellm
1112
from litellm import verbose_logger
@@ -141,6 +142,13 @@ def map_openai_params(
141142
model: str,
142143
drop_params: bool,
143144
) -> dict:
145+
value = non_default_params["response_format"]
146+
if inspect.isclass(value) and issubclass(value, BaseModel):
147+
non_default_params["response_format"] = {
148+
"type": "json_schema",
149+
"json_schema": {"schema": value.model_json_schema()}
150+
}
151+
144152
for param, value in non_default_params.items():
145153
if param == "max_tokens" or param == "max_completion_tokens":
146154
optional_params["num_predict"] = value
@@ -156,13 +164,13 @@ def map_openai_params(
156164
optional_params["repeat_penalty"] = value
157165
if param == "stop":
158166
optional_params["stop"] = value
159-
if param == "response_format" and value["type"] == "json_object":
167+
if param == "response_format" and isinstance(value, dict) and value.get("type") == "json_object":
160168
optional_params["format"] = "json"
161-
if param == "response_format" and value["type"] == "json_schema":
162-
optional_params["format"] = value["json_schema"]["schema"]
169+
if param == "response_format" and isinstance(value, dict) and value.get("type") == "json_schema":
170+
if value.get("json_schema") and value["json_schema"].get("schema"):
171+
optional_params["format"] = value["json_schema"]["schema"]
163172
### FUNCTION CALLING LOGIC ###
164173
if param == "tools":
165-
# ollama actually supports json output
166174
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
167175
try:
168176
model_info = litellm.get_model_info(
@@ -185,14 +193,23 @@ def map_openai_params(
185193
][0]["function"]["name"]
186194

187195
if param == "functions":
188-
# ollama actually supports json output
189-
optional_params["format"] = "json"
190-
litellm.add_function_to_prompt = (
191-
True # so that main.py adds the function call to the prompt
192-
)
193-
optional_params["functions_unsupported_model"] = non_default_params.get(
194-
"functions"
195-
)
196+
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
197+
try:
198+
model_info = litellm.get_model_info(
199+
model=model, custom_llm_provider="ollama"
200+
)
201+
if model_info.get("supports_function_calling") is True:
202+
optional_params["tools"] = value
203+
else:
204+
raise Exception
205+
except Exception:
206+
optional_params["format"] = "json"
207+
litellm.add_function_to_prompt = (
208+
True # so that main.py adds the function call to the prompt
209+
)
210+
optional_params["functions_unsupported_model"] = non_default_params.get(
211+
"functions"
212+
)
196213
non_default_params.pop("tool_choice", None) # causes ollama requests to hang
197214
non_default_params.pop("functions", None) # causes ollama requests to hang
198215
return optional_params
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
import sys
3+
import pytest
4+
from pydantic import BaseModel
5+
import inspect
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../..')))
8+
9+
from litellm.llms.ollama_chat import OllamaChatConfig
10+
11+
class TestEvent(BaseModel):
12+
name: str
13+
value: int
14+
15+
class TestOllamaChatConfigResponseFormat:
16+
def test_map_openai_params_with_pydantic_model(self):
17+
config = OllamaChatConfig()
18+
19+
non_default_params = {
20+
"response_format": TestEvent
21+
}
22+
optional_params = {}
23+
24+
expected_schema_structure = TestEvent.model_json_schema()
25+
26+
config.map_openai_params(
27+
non_default_params=non_default_params,
28+
optional_params=optional_params,
29+
model="ollama_chat/test-model",
30+
drop_params=False
31+
)
32+
33+
assert "format" in optional_params, "Transformed 'format' key not found in optional_params"
34+
35+
transformed_format = optional_params["format"]
36+
37+
assert transformed_format == expected_schema_structure, \
38+
f"Transformed schema does not match expected. Got: {transformed_format}, Expected: {expected_schema_structure}"
39+
40+
def test_map_openai_params_with_dict_json_schema(self):
41+
config = OllamaChatConfig()
42+
43+
direct_schema = TestEvent.model_json_schema()
44+
response_format_dict = {
45+
"type": "json_schema",
46+
"json_schema": {"schema": direct_schema}
47+
}
48+
49+
non_default_params = {
50+
"response_format": response_format_dict
51+
}
52+
optional_params = {}
53+
54+
config.map_openai_params(
55+
non_default_params=non_default_params,
56+
optional_params=optional_params,
57+
model="ollama_chat/test-model",
58+
drop_params=False
59+
)
60+
61+
assert "format" in optional_params
62+
assert optional_params["format"] == direct_schema, \
63+
f"Schema from dict did not pass through correctly. Got: {optional_params['format']}, Expected: {direct_schema}"
64+
65+
def test_map_openai_params_with_json_object(self):
66+
config = OllamaChatConfig()
67+
68+
non_default_params = {
69+
"response_format": {"type": "json_object"}
70+
}
71+
optional_params = {}
72+
73+
config.map_openai_params(
74+
non_default_params=non_default_params,
75+
optional_params=optional_params,
76+
model="ollama_chat/test-model",
77+
drop_params=False
78+
)
79+
80+
assert "format" in optional_params
81+
assert optional_params["format"] == "json", \
82+
f"Expected 'json' for type 'json_object', got: {optional_params['format']}"

0 commit comments

Comments
 (0)