Skip to content

Commit bc375d8

Browse files
authored
Fix Cohere streaming and xAI tools validation (#1983)
1 parent 7ca2dcc commit bc375d8

File tree

6 files changed

+621
-44
lines changed

6 files changed

+621
-44
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ repos:
3737
- id: ty-check
3838
name: Run Type Check (ty)
3939
entry: uv
40-
args: [run, ty, check]
40+
args: [run, ty, check, --ignore, unresolved-import]
4141
language: system
4242
files: ^instructor/
4343
pass_filenames: false

instructor/dsl/iterable.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,142 @@ def extract_cls_task_type(
148148
def extract_json(
149149
completion: Iterable[Any], mode: Mode
150150
) -> Generator[str, None, None]:
151+
json_started = False
151152
for chunk in completion:
152153
try:
154+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
155+
event_type = getattr(chunk, "event_type", None)
156+
if event_type == "text-generation":
157+
if text := getattr(chunk, "text", None):
158+
if not json_started:
159+
json_start = min(
160+
(
161+
pos
162+
for pos in (text.find("{"), text.find("["))
163+
if pos != -1
164+
),
165+
default=-1,
166+
)
167+
if json_start == -1:
168+
continue
169+
json_started = True
170+
text = text[json_start:]
171+
yield text
172+
elif event_type == "tool-calls-chunk":
173+
delta = getattr(chunk, "tool_call_delta", None)
174+
args = getattr(delta, "parameters", None) or getattr(
175+
delta, "text", None
176+
)
177+
if args:
178+
if not json_started:
179+
json_start = min(
180+
(
181+
pos
182+
for pos in (args.find("{"), args.find("["))
183+
if pos != -1
184+
),
185+
default=-1,
186+
)
187+
if json_start == -1:
188+
continue
189+
json_started = True
190+
args = args[json_start:]
191+
yield args
192+
elif text := getattr(chunk, "text", None):
193+
if not json_started:
194+
json_start = min(
195+
(
196+
pos
197+
for pos in (text.find("{"), text.find("["))
198+
if pos != -1
199+
),
200+
default=-1,
201+
)
202+
if json_start == -1:
203+
continue
204+
json_started = True
205+
text = text[json_start:]
206+
yield text
207+
elif event_type == "tool-calls-generation":
208+
tool_calls = getattr(chunk, "tool_calls", None)
209+
if tool_calls:
210+
args = json.dumps(tool_calls[0].parameters)
211+
if not json_started:
212+
json_start = min(
213+
(
214+
pos
215+
for pos in (args.find("{"), args.find("["))
216+
if pos != -1
217+
),
218+
default=-1,
219+
)
220+
if json_start == -1:
221+
continue
222+
json_started = True
223+
args = args[json_start:]
224+
yield args
225+
elif text := getattr(chunk, "text", None):
226+
if not json_started:
227+
json_start = min(
228+
(
229+
pos
230+
for pos in (text.find("{"), text.find("["))
231+
if pos != -1
232+
),
233+
default=-1,
234+
)
235+
if json_start == -1:
236+
continue
237+
json_started = True
238+
text = text[json_start:]
239+
yield text
240+
else:
241+
chunk_type = getattr(chunk, "type", None)
242+
if chunk_type == "content-delta":
243+
delta = getattr(chunk, "delta", None)
244+
message = getattr(delta, "message", None)
245+
content = getattr(message, "content", None)
246+
if text := getattr(content, "text", None):
247+
if not json_started:
248+
json_start = min(
249+
(
250+
pos
251+
for pos in (
252+
text.find("{"),
253+
text.find("["),
254+
)
255+
if pos != -1
256+
),
257+
default=-1,
258+
)
259+
if json_start == -1:
260+
continue
261+
json_started = True
262+
text = text[json_start:]
263+
yield text
264+
elif chunk_type == "tool-call-delta":
265+
delta = getattr(chunk, "delta", None)
266+
message = getattr(delta, "message", None)
267+
tool_calls = getattr(message, "tool_calls", None)
268+
function = getattr(tool_calls, "function", None)
269+
if args := getattr(function, "arguments", None):
270+
if not json_started:
271+
json_start = min(
272+
(
273+
pos
274+
for pos in (
275+
args.find("{"),
276+
args.find("["),
277+
)
278+
if pos != -1
279+
),
280+
default=-1,
281+
)
282+
if json_start == -1:
283+
continue
284+
json_started = True
285+
args = args[json_start:]
286+
yield args
153287
if mode == Mode.ANTHROPIC_JSON:
154288
if json_chunk := chunk.delta.text:
155289
yield json_chunk
@@ -230,8 +364,142 @@ def extract_json(
230364
async def extract_json_async(
231365
completion: AsyncGenerator[Any, None], mode: Mode
232366
) -> AsyncGenerator[str, None]:
367+
json_started = False
233368
async for chunk in completion:
234369
try:
370+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
371+
event_type = getattr(chunk, "event_type", None)
372+
if event_type == "text-generation":
373+
if text := getattr(chunk, "text", None):
374+
if not json_started:
375+
json_start = min(
376+
(
377+
pos
378+
for pos in (text.find("{"), text.find("["))
379+
if pos != -1
380+
),
381+
default=-1,
382+
)
383+
if json_start == -1:
384+
continue
385+
json_started = True
386+
text = text[json_start:]
387+
yield text
388+
elif event_type == "tool-calls-chunk":
389+
delta = getattr(chunk, "tool_call_delta", None)
390+
args = getattr(delta, "parameters", None) or getattr(
391+
delta, "text", None
392+
)
393+
if args:
394+
if not json_started:
395+
json_start = min(
396+
(
397+
pos
398+
for pos in (args.find("{"), args.find("["))
399+
if pos != -1
400+
),
401+
default=-1,
402+
)
403+
if json_start == -1:
404+
continue
405+
json_started = True
406+
args = args[json_start:]
407+
yield args
408+
elif text := getattr(chunk, "text", None):
409+
if not json_started:
410+
json_start = min(
411+
(
412+
pos
413+
for pos in (text.find("{"), text.find("["))
414+
if pos != -1
415+
),
416+
default=-1,
417+
)
418+
if json_start == -1:
419+
continue
420+
json_started = True
421+
text = text[json_start:]
422+
yield text
423+
elif event_type == "tool-calls-generation":
424+
tool_calls = getattr(chunk, "tool_calls", None)
425+
if tool_calls:
426+
args = json.dumps(tool_calls[0].parameters)
427+
if not json_started:
428+
json_start = min(
429+
(
430+
pos
431+
for pos in (args.find("{"), args.find("["))
432+
if pos != -1
433+
),
434+
default=-1,
435+
)
436+
if json_start == -1:
437+
continue
438+
json_started = True
439+
args = args[json_start:]
440+
yield args
441+
elif text := getattr(chunk, "text", None):
442+
if not json_started:
443+
json_start = min(
444+
(
445+
pos
446+
for pos in (text.find("{"), text.find("["))
447+
if pos != -1
448+
),
449+
default=-1,
450+
)
451+
if json_start == -1:
452+
continue
453+
json_started = True
454+
text = text[json_start:]
455+
yield text
456+
else:
457+
chunk_type = getattr(chunk, "type", None)
458+
if chunk_type == "content-delta":
459+
delta = getattr(chunk, "delta", None)
460+
message = getattr(delta, "message", None)
461+
content = getattr(message, "content", None)
462+
if text := getattr(content, "text", None):
463+
if not json_started:
464+
json_start = min(
465+
(
466+
pos
467+
for pos in (
468+
text.find("{"),
469+
text.find("["),
470+
)
471+
if pos != -1
472+
),
473+
default=-1,
474+
)
475+
if json_start == -1:
476+
continue
477+
json_started = True
478+
text = text[json_start:]
479+
yield text
480+
elif chunk_type == "tool-call-delta":
481+
delta = getattr(chunk, "delta", None)
482+
message = getattr(delta, "message", None)
483+
tool_calls = getattr(message, "tool_calls", None)
484+
function = getattr(tool_calls, "function", None)
485+
if args := getattr(function, "arguments", None):
486+
if not json_started:
487+
json_start = min(
488+
(
489+
pos
490+
for pos in (
491+
args.find("{"),
492+
args.find("["),
493+
)
494+
if pos != -1
495+
),
496+
default=-1,
497+
)
498+
if json_start == -1:
499+
continue
500+
json_started = True
501+
args = args[json_start:]
502+
yield args
235503
if mode == Mode.ANTHROPIC_JSON:
236504
if json_chunk := chunk.delta.text:
237505
yield json_chunk

0 commit comments

Comments
 (0)