Skip to content

Commit 99ceb86

Browse files
authored
Structured output (#64)
* Add structured outputs tooling * Add response_format field and examples how to use it * Add tests for structured output
1 parent d72f500 commit 99ceb86

File tree

15 files changed

+2466
-7
lines changed

15 files changed

+2466
-7
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import json
7+
8+
import pydantic
9+
10+
from yandex_cloud_ml_sdk import AsyncYCloudML
11+
12+
13+
class Venue(pydantic.BaseModel):
14+
date: str
15+
place: str
16+
17+
18+
@pydantic.dataclasses.dataclass
19+
class VenueDataclass:
20+
date: str
21+
place: str
22+
name: str
23+
24+
25+
async def main() -> None:
26+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
27+
sdk.setup_default_logging()
28+
29+
# NB: for now (24.02.2025) structured output is supported only at release candidate model version.
30+
model = sdk.models.completions('yandexgpt', model_version='rc')
31+
text = (
32+
'The conference will take place from May 10th to 12th, 2023, '
33+
'at 30 Avenue Corentin Cariou in Paris, France.'
34+
)
35+
36+
# We could as model to return data just with json format, model will
37+
# figure out format by itself:
38+
model = model.configure(response_format='json')
39+
result = await model.run([
40+
{'role': 'system', 'text': 'Extract the date and venue information'},
41+
{'role': 'user', 'text': text},
42+
])
43+
print('Any JSON:', result[0].text)
44+
45+
# Now, if you need not just JSON, but a parsed Python structure, you will need to parse it.
46+
# Be aware that you may need to handle parsing exceptions in case the model returns incorrect json.
47+
# This could happen, for example, if you exceed the token limit.
48+
try:
49+
data = json.loads(result.text)
50+
print("Parsed JSON:", data)
51+
52+
bad_text = result.text[:5]
53+
json.loads(bad_text)
54+
except json.JSONDecodeError as e:
55+
print("JSON parsing error:", e)
56+
57+
# You could use not only .run, but .run_stream as well as other methods too:
58+
print('Any JSON in streaming:')
59+
async for partial_result in model.run_stream([
60+
{'role': 'system', 'text': 'Extract the date and venue information'},
61+
{'role': 'user', 'text': text},
62+
]):
63+
print(f" {partial_result.text}")
64+
65+
# NB: For each example, I am trying to make slightly different format to show a difference at print results.
66+
# We could pass a raw json schema:
67+
model = model.configure(response_format={
68+
"json_schema": {
69+
"properties": {
70+
"DATE": {
71+
"title": "Date",
72+
"type": "string"
73+
},
74+
"PLACE": {
75+
"title": "Place",
76+
"type": "string"
77+
}
78+
},
79+
"required": ["DATE", "PLACE"],
80+
"title": "Venue",
81+
"type": "object"
82+
}
83+
})
84+
result = await model.run([
85+
{'role': 'system', 'text': 'Extract the date and venue information'},
86+
{'role': 'user', 'text': text},
87+
])
88+
print('JSONSchema from raw jsonschema:', result[0].text)
89+
90+
# Also we could use pydantic.BaseModel descendant to describe JSONSchema for
91+
# structured output:
92+
model = model.configure(response_format=Venue)
93+
result = await model.run([
94+
{'role': 'system', 'text': 'Extract the date and venue information'},
95+
{'role': 'user', 'text': text},
96+
])
97+
print('JSONSchema from Pydantic model:', result[0].text)
98+
99+
# Lastly we could pass pydantic-dataclass:
100+
assert pydantic.__version__ > "2"
101+
model = model.configure(response_format=VenueDataclass)
102+
result = await model.run([
103+
{'role': 'system', 'text': 'Extract the date and venue information'},
104+
{'role': 'user', 'text': text},
105+
])
106+
print('JSONSchema from Pydantic dataclass:', result[0].text)
107+
108+
109+
if __name__ == '__main__':
110+
asyncio.run(main())
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import json
6+
7+
import pydantic
8+
9+
from yandex_cloud_ml_sdk import YCloudML
10+
11+
12+
class Venue(pydantic.BaseModel):
13+
date: str
14+
place: str
15+
16+
17+
@pydantic.dataclasses.dataclass
18+
class VenueDataclass:
19+
date: str
20+
place: str
21+
name: str
22+
23+
24+
def main() -> None:
25+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
26+
sdk.setup_default_logging()
27+
28+
# NB: for now (24.02.2025) structured output is supported only at release candidate model version.
29+
model = sdk.models.completions('yandexgpt', model_version='rc')
30+
text = (
31+
'The conference will take place from May 10th to 12th, 2023, '
32+
'at 30 Avenue Corentin Cariou in Paris, France.'
33+
)
34+
35+
# We could as model to return data just with json format, model will
36+
# figure out format by itself:
37+
model = model.configure(response_format='json')
38+
result = model.run([
39+
{'role': 'system', 'text': 'Extract the date and venue information'},
40+
{'role': 'user', 'text': text},
41+
])
42+
print('Any JSON:', result[0].text)
43+
44+
# Now, if you need not just JSON, but a parsed Python structure, you will need to parse it.
45+
# Be aware that you may need to handle parsing exceptions in case the model returns incorrect json.
46+
# This could happen, for example, if you exceed the token limit.
47+
try:
48+
data = json.loads(result.text)
49+
print("Parsed JSON:", data)
50+
51+
bad_text = result.text[:5]
52+
json.loads(bad_text)
53+
except json.JSONDecodeError as e:
54+
print("JSON parsing error:", e)
55+
56+
# You could use not only .run, but .run_stream as well as other methods too:
57+
print('Any JSON in streaming:')
58+
for partial_result in model.run_stream([
59+
{'role': 'system', 'text': 'Extract the date and venue information'},
60+
{'role': 'user', 'text': text},
61+
]):
62+
print(f" {partial_result.text}")
63+
64+
# NB: For each example, I am trying to make slightly different format to show a difference at print results.
65+
# We could pass a raw json schema:
66+
model = model.configure(response_format={
67+
"json_schema": {
68+
"properties": {
69+
"DATE": {
70+
"title": "Date",
71+
"type": "string"
72+
},
73+
"PLACE": {
74+
"title": "Place",
75+
"type": "string"
76+
}
77+
},
78+
"required": ["DATE", "PLACE"],
79+
"title": "Venue",
80+
"type": "object"
81+
}
82+
})
83+
result = model.run([
84+
{'role': 'system', 'text': 'Extract the date and venue information'},
85+
{'role': 'user', 'text': text},
86+
])
87+
print('JSONSchema from raw jsonschema:', result[0].text)
88+
89+
# Also we could use pydantic.BaseModel descendant to describe JSONSchema for
90+
# structured output:
91+
model = model.configure(response_format=Venue)
92+
result = model.run([
93+
{'role': 'system', 'text': 'Extract the date and venue information'},
94+
{'role': 'user', 'text': text},
95+
])
96+
print('JSONSchema from Pydantic model:', result[0].text)
97+
98+
# Lastly we could pass pydantic-dataclass:
99+
assert pydantic.__version__ > "2"
100+
model = model.configure(response_format=VenueDataclass)
101+
result = model.run([
102+
{'role': 'system', 'text': 'Extract the date and venue information'},
103+
{'role': 'user', 'text': text},
104+
])
105+
print('JSONSchema from Pydantic dataclass:', result[0].text)
106+
107+
108+
if __name__ == '__main__':
109+
main()

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ classifiers = [
3333
requires-python = ">=3.9"
3434
dynamic = ["version"]
3535
dependencies = [
36-
"yandexcloud>=0.331.0",
37-
"grpcio>=1.62.0",
36+
"yandexcloud>=0.334.0",
37+
"grpcio>=1.70.0",
3838
"get-annotations",
3939
"httpx>=0.27,<1",
4040
"typing-extensions>=4",
@@ -50,6 +50,9 @@ langchain = [
5050
"langchain-core>=0.3",
5151
"pydantic<2.10"
5252
]
53+
pydantic = [
54+
"pydantic>2",
55+
]
5356

5457
[project.urls]
5558
Documentation = "https://yandex.cloud/ru/docs/foundation-models/"

src/yandex_cloud_ml_sdk/_models/completions/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import ReasoningOptions as ProtoReasoningOptions
99

1010
from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig
11+
from yandex_cloud_ml_sdk._types.structured_output import ResponseType
1112
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase
1213

1314
_m = ProtoReasoningOptions.ReasoningMode
@@ -27,3 +28,4 @@ class GPTModelConfig(BaseModelConfig):
2728
temperature: float | None = None
2829
max_tokens: int | None = None
2930
reasoning_mode: ReasoningModeType | None = None
31+
response_format: ResponseType | None = None

src/yandex_cloud_ml_sdk/_models/completions/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ModelAsyncMixin, ModelSyncMixin, ModelSyncStreamMixin, ModelTuneMixin, OperationTypeT
2020
)
2121
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
22+
from yandex_cloud_ml_sdk._types.structured_output import ResponseType, schema_from_response_format
2223
from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType
2324
from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer
2425
from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler
@@ -66,11 +67,13 @@ def configure( # type: ignore[override]
6667
temperature: UndefinedOr[float] = UNDEFINED,
6768
max_tokens: UndefinedOr[int] = UNDEFINED,
6869
reasoning_mode: UndefinedOr[ReasoningModeType] = UNDEFINED,
70+
response_format: UndefinedOr[ResponseType] = UNDEFINED,
6971
) -> Self:
7072
return super().configure(
7173
temperature=temperature,
7274
max_tokens=max_tokens,
7375
reasoning_mode=reasoning_mode,
76+
response_format=response_format,
7477
)
7578

7679
def _make_request(
@@ -80,6 +83,7 @@ def _make_request(
8083
stream: bool | None,
8184
) -> CompletionRequest:
8285
completion_options_kwargs: dict[str, Any] = {}
86+
response_format_kwargs: dict[str, Any] = {}
8387

8488
if stream is not None:
8589
completion_options_kwargs['stream'] = stream
@@ -92,11 +96,19 @@ def _make_request(
9296
reasoning_mode = ReasoningMode._coerce(self._config.reasoning_mode)._to_proto()
9397
reasoning_options = ReasoningOptions(mode=reasoning_mode) # type: ignore[arg-type]
9498
completion_options_kwargs['reasoning_options'] = reasoning_options
99+
if self._config.response_format is not None:
100+
schema = schema_from_response_format(self._config.response_format)
101+
if isinstance(schema, str):
102+
response_format_kwargs['json_object'] = True
103+
else:
104+
assert isinstance(schema, dict)
105+
response_format_kwargs['json_schema'] = {'schema': schema}
95106

96107
return CompletionRequest(
97108
model_uri=self._uri,
98109
completion_options=CompletionOptions(**completion_options_kwargs),
99110
messages=messages_to_proto(messages),
111+
**response_format_kwargs,
100112
)
101113

102114
async def _run_sync_impl(

src/yandex_cloud_ml_sdk/_models/completions/result.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,15 @@ def __getitem__(self, slice_: slice, /) -> tuple[Alternative, ...]:
8686

8787
def __getitem__(self, index, /):
8888
return self.alternatives[index]
89+
90+
@property
91+
def role(self) -> str:
92+
return self[0].role
93+
94+
@property
95+
def text(self) -> str:
96+
return self[0].text
97+
98+
@property
99+
def status(self) -> AlternativeStatus:
100+
return self[0].status
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Literal, TypedDict, Union
4+
5+
from typing_extensions import TypeAlias, TypeGuard
6+
7+
from yandex_cloud_ml_sdk._logging import get_logger
8+
9+
logger = get_logger(__name__)
10+
11+
LITERAL_RESPONSE_FORMATS = ('json', )
12+
13+
StrResponseType = Literal['json']
14+
JsonSchemaType = dict[str, Any]
15+
16+
class JsonSchemaResponseType(TypedDict):
17+
json_schema: JsonSchemaType
18+
19+
ResponseType: TypeAlias = Union[StrResponseType, JsonSchemaResponseType, type]
20+
21+
try:
22+
import pydantic
23+
24+
PYDANTIC = True
25+
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
26+
27+
except ImportError:
28+
PYDANTIC = False
29+
PYDANTIC_V2 = False
30+
31+
32+
def is_pydantic_model_class(response_format: ResponseType) -> TypeGuard[type[pydantic.BaseModel]]:
33+
return (
34+
PYDANTIC and
35+
isinstance(response_format, type) and
36+
issubclass(response_format, pydantic.BaseModel) and
37+
not response_format is pydantic.BaseModel
38+
)
39+
40+
41+
def schema_from_response_format(response_format: ResponseType) -> StrResponseType | JsonSchemaType:
42+
result: StrResponseType | JsonSchemaType
43+
44+
if isinstance(response_format, str):
45+
if not response_format in LITERAL_RESPONSE_FORMATS:
46+
raise ValueError(
47+
f"Literal response type '{response_format}' is not supported, use one of {LITERAL_RESPONSE_FORMATS}")
48+
result = response_format
49+
elif isinstance(response_format, dict):
50+
# TODO: in case we would get jsonschema dependency, it is a good
51+
# idea to add validation here
52+
53+
json_schema = response_format.get('json_schema') or response_format.get('jsonschema')
54+
if json_schema and isinstance(json_schema, dict):
55+
result = dict(json_schema)
56+
else:
57+
raise ValueError(
58+
'json_schema field must be present in response_format field '
59+
'and must be a valid json schema dict'
60+
)
61+
elif (
62+
not PYDANTIC or
63+
not isinstance(response_format, type) or
64+
not is_pydantic_model_class(response_format) and
65+
not pydantic.dataclasses.is_pydantic_dataclass(response_format)
66+
):
67+
raise TypeError(
68+
"Response type could be only str, jsonschema dict or pydantic model class"
69+
)
70+
71+
elif is_pydantic_model_class(response_format):
72+
result = response_format.model_json_schema()
73+
else:
74+
assert pydantic.dataclasses.is_pydantic_dataclass(response_format)
75+
result = pydantic.TypeAdapter(response_format).json_schema()
76+
77+
logger.debug('transform input response_format=%r to json_schema=%r', response_format, result)
78+
return result

0 commit comments

Comments
 (0)