-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy path_json_schema.py
More file actions
210 lines (167 loc) · 8.29 KB
/
_json_schema.py
File metadata and controls
210 lines (167 loc) · 8.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from __future__ import annotations as _annotations
import re
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from .exceptions import UserError
JsonSchema = dict[str, Any]
@dataclass(init=False)
class JsonSchemaTransformer(ABC):
"""Walks a JSON schema, applying transformations to it at each level.
The transformer is called during a model's prepare_request() step to build the JSON schema
before it is sent to the model provider.
Note: We may eventually want to rework tools to build the JSON schema from the type directly, using a subclass of
pydantic.json_schema.GenerateJsonSchema, rather than making use of this machinery.
"""
def __init__(
self,
schema: JsonSchema,
*,
strict: bool | None = None,
prefer_inlined_defs: bool = False,
simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used
):
self.schema = schema
self.strict = strict
"""The `strict` parameter forces the conversion of the original JSON schema (`self.schema`) of a `ToolDefinition` or `OutputObjectDefinition` to a format supported by the model provider.
The "strict mode" offered by model providers ensures that the model's output adheres closely to the defined schema. However, not all model providers offer it, and their support for various schema features may differ. For example, a model provider's required schema may not support certain validation constraints like `minLength` or `pattern`.
"""
self.is_strict_compatible = True
"""Whether the schema is compatible with strict mode.
This value is used to set `ToolDefinition.strict` or `OutputObjectDefinition.strict` when their values are `None`.
"""
self.prefer_inlined_defs = prefer_inlined_defs
self.simplify_nullable_unions = simplify_nullable_unions
self.defs: dict[str, JsonSchema] = deepcopy(self.schema.get('$defs', {}))
self.refs_stack: list[str] = []
self.recursive_refs = set[str]()
@abstractmethod
def transform(self, schema: JsonSchema) -> JsonSchema:
"""Make changes to the schema."""
return schema
def walk(self) -> JsonSchema:
schema = deepcopy(self.schema)
# First, handle everything but $defs:
schema.pop('$defs', None)
handled = self._handle(schema)
if not self.prefer_inlined_defs and self.defs:
handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()}
elif self.recursive_refs:
# If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure
# We try to use whatever the original root key was, but if it is already in use,
# we modify it to avoid collisions.
defs = {key: self.defs[key] for key in self.recursive_refs}
root_ref = self.schema.get('$ref')
root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref)
if root_key is None: # pragma: no cover
root_key = self.schema.get('title', 'root')
while root_key in defs:
# Modify the root key until it is not already in use
root_key = f'{root_key}_root'
defs[root_key] = handled
return {'$defs': defs, '$ref': f'#/$defs/{root_key}'}
return handled
def _handle(self, schema: JsonSchema) -> JsonSchema:
if isinstance(schema, bool):
return schema
nested_refs = 0
if self.prefer_inlined_defs:
while ref := schema.get('$ref'):
key = re.sub(r'^#/\$defs/', '', ref)
if key in self.recursive_refs:
break
if key in self.refs_stack:
self.recursive_refs.add(key)
break # recursive ref can't be unpacked
self.refs_stack.append(key)
nested_refs += 1
def_schema = self.defs.get(key)
if def_schema is None: # pragma: no cover
raise UserError(f'Could not find $ref definition for {key}')
schema = def_schema
# Handle the schema based on its type / structure
type_ = schema.get('type')
if type_ == 'object':
schema = self._handle_object(schema)
elif type_ == 'array':
schema = self._handle_array(schema)
elif type_ is None:
schema = self._handle_union(schema, 'anyOf')
if not isinstance(schema, bool):
schema = self._handle_union(schema, 'oneOf')
# Guard: _handle_union may return a bool (e.g., from {'oneOf': [True]})
if isinstance(schema, bool):
if nested_refs > 0:
self.refs_stack = self.refs_stack[:-nested_refs]
return schema
# Apply the base transform
schema = self.transform(schema)
if nested_refs > 0:
self.refs_stack = self.refs_stack[:-nested_refs]
return schema
def _handle_object(self, schema: JsonSchema) -> JsonSchema:
if properties := schema.get('properties'):
handled_properties = {}
for key, value in properties.items():
handled_properties[key] = self._handle(value)
schema['properties'] = handled_properties
if (additional_properties := schema.get('additionalProperties')) is not None:
if isinstance(additional_properties, bool):
schema['additionalProperties'] = additional_properties
else:
schema['additionalProperties'] = self._handle(additional_properties)
if (pattern_properties := schema.get('patternProperties')) is not None:
handled_pattern_properties = {}
for key, value in pattern_properties.items():
handled_pattern_properties[key] = self._handle(value)
schema['patternProperties'] = handled_pattern_properties
return schema
def _handle_array(self, schema: JsonSchema) -> JsonSchema:
if prefix_items := schema.get('prefixItems'):
schema['prefixItems'] = [self._handle(item) for item in prefix_items]
if items := schema.get('items'):
schema['items'] = self._handle(items)
return schema
def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf']) -> JsonSchema:
try:
members = schema.pop(union_kind)
except KeyError:
return schema
handled = [self._handle(member) for member in members]
# TODO (v2): Remove this feature, no longer used
if self.simplify_nullable_unions:
handled = self._simplify_nullable_union(handled)
if len(handled) == 1:
# In this case, no need to retain the union
if isinstance(handled[0], bool):
return handled[0]
return handled[0] | schema
# If we have keys besides the union kind (such as title or discriminator), keep them without modifications
schema = schema.copy()
schema[union_kind] = handled
return schema
@staticmethod
def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
# TODO (v2): Remove this method, no longer used
if len(cases) == 2 and {'type': 'null'} in cases:
# Find the non-null schema
non_null_schema = next(
(item for item in cases if item != {'type': 'null'}),
None,
)
if non_null_schema:
# Create a new schema based on the non-null part, mark as nullable
new_schema = deepcopy(non_null_schema)
new_schema['nullable'] = True
return [new_schema]
else: # pragma: no cover
# they are both null, so just return one of them
return [cases[0]]
return cases
class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer):
"""Transforms the JSON Schema to inline $defs."""
def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict, prefer_inlined_defs=True)
def transform(self, schema: JsonSchema) -> JsonSchema:
return schema