Skip to content

Commit 83dd71c

Browse files
committed
Fix xAI tools validation and Cohere streaming
1 parent 7ca2dcc commit 83dd71c

File tree

5 files changed

+224
-43
lines changed

5 files changed

+224
-43
lines changed

instructor/dsl/iterable.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,41 @@ def extract_json(
150150
) -> Generator[str, None, None]:
151151
for chunk in completion:
152152
try:
153+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
154+
event_type = getattr(chunk, "event_type", None)
155+
if event_type == "text-generation":
156+
if text := getattr(chunk, "text", None):
157+
yield text
158+
elif event_type == "tool-calls-chunk":
159+
delta = getattr(chunk, "tool_call_delta", None)
160+
args = getattr(delta, "parameters", None) or getattr(
161+
delta, "text", None
162+
)
163+
if args:
164+
yield args
165+
elif text := getattr(chunk, "text", None):
166+
yield text
167+
elif event_type == "tool-calls-generation":
168+
tool_calls = getattr(chunk, "tool_calls", None)
169+
if tool_calls:
170+
yield json.dumps(tool_calls[0].parameters)
171+
elif text := getattr(chunk, "text", None):
172+
yield text
173+
else:
174+
chunk_type = getattr(chunk, "type", None)
175+
if chunk_type == "content-delta":
176+
delta = getattr(chunk, "delta", None)
177+
message = getattr(delta, "message", None)
178+
content = getattr(message, "content", None)
179+
if text := getattr(content, "text", None):
180+
yield text
181+
elif chunk_type == "tool-call-delta":
182+
delta = getattr(chunk, "delta", None)
183+
message = getattr(delta, "message", None)
184+
tool_calls = getattr(message, "tool_calls", None)
185+
function = getattr(tool_calls, "function", None)
186+
if args := getattr(function, "arguments", None):
187+
yield args
153188
if mode == Mode.ANTHROPIC_JSON:
154189
if json_chunk := chunk.delta.text:
155190
yield json_chunk
@@ -232,6 +267,41 @@ async def extract_json_async(
232267
) -> AsyncGenerator[str, None]:
233268
async for chunk in completion:
234269
try:
270+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
271+
event_type = getattr(chunk, "event_type", None)
272+
if event_type == "text-generation":
273+
if text := getattr(chunk, "text", None):
274+
yield text
275+
elif event_type == "tool-calls-chunk":
276+
delta = getattr(chunk, "tool_call_delta", None)
277+
args = getattr(delta, "parameters", None) or getattr(
278+
delta, "text", None
279+
)
280+
if args:
281+
yield args
282+
elif text := getattr(chunk, "text", None):
283+
yield text
284+
elif event_type == "tool-calls-generation":
285+
tool_calls = getattr(chunk, "tool_calls", None)
286+
if tool_calls:
287+
yield json.dumps(tool_calls[0].parameters)
288+
elif text := getattr(chunk, "text", None):
289+
yield text
290+
else:
291+
chunk_type = getattr(chunk, "type", None)
292+
if chunk_type == "content-delta":
293+
delta = getattr(chunk, "delta", None)
294+
message = getattr(delta, "message", None)
295+
content = getattr(message, "content", None)
296+
if text := getattr(content, "text", None):
297+
yield text
298+
elif chunk_type == "tool-call-delta":
299+
delta = getattr(chunk, "delta", None)
300+
message = getattr(delta, "message", None)
301+
tool_calls = getattr(message, "tool_calls", None)
302+
function = getattr(tool_calls, "function", None)
303+
if args := getattr(function, "arguments", None):
304+
yield args
235305
if mode == Mode.ANTHROPIC_JSON:
236306
if json_chunk := chunk.delta.text:
237307
yield json_chunk

instructor/dsl/partial.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,41 @@ def extract_json(
287287
specific handling to extract the relevant JSON data."""
288288
for chunk in completion:
289289
try:
290+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
291+
event_type = getattr(chunk, "event_type", None)
292+
if event_type == "text-generation":
293+
if text := getattr(chunk, "text", None):
294+
yield text
295+
elif event_type == "tool-calls-chunk":
296+
delta = getattr(chunk, "tool_call_delta", None)
297+
args = getattr(delta, "parameters", None) or getattr(
298+
delta, "text", None
299+
)
300+
if args:
301+
yield args
302+
elif text := getattr(chunk, "text", None):
303+
yield text
304+
elif event_type == "tool-calls-generation":
305+
tool_calls = getattr(chunk, "tool_calls", None)
306+
if tool_calls:
307+
yield json.dumps(tool_calls[0].parameters)
308+
elif text := getattr(chunk, "text", None):
309+
yield text
310+
else:
311+
chunk_type = getattr(chunk, "type", None)
312+
if chunk_type == "content-delta":
313+
delta = getattr(chunk, "delta", None)
314+
message = getattr(delta, "message", None)
315+
content = getattr(message, "content", None)
316+
if text := getattr(content, "text", None):
317+
yield text
318+
elif chunk_type == "tool-call-delta":
319+
delta = getattr(chunk, "delta", None)
320+
message = getattr(delta, "message", None)
321+
tool_calls = getattr(message, "tool_calls", None)
322+
function = getattr(tool_calls, "function", None)
323+
if args := getattr(function, "arguments", None):
324+
yield args
290325
if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
291326
yield chunk.data.choices[0].delta.content
292327
if mode == Mode.MISTRAL_TOOLS:
@@ -378,6 +413,41 @@ async def extract_json_async(
378413
) -> AsyncGenerator[str, None]:
379414
async for chunk in completion:
380415
try:
416+
if mode in {Mode.COHERE_TOOLS, Mode.COHERE_JSON_SCHEMA}:
417+
event_type = getattr(chunk, "event_type", None)
418+
if event_type == "text-generation":
419+
if text := getattr(chunk, "text", None):
420+
yield text
421+
elif event_type == "tool-calls-chunk":
422+
delta = getattr(chunk, "tool_call_delta", None)
423+
args = getattr(delta, "parameters", None) or getattr(
424+
delta, "text", None
425+
)
426+
if args:
427+
yield args
428+
elif text := getattr(chunk, "text", None):
429+
yield text
430+
elif event_type == "tool-calls-generation":
431+
tool_calls = getattr(chunk, "tool_calls", None)
432+
if tool_calls:
433+
yield json.dumps(tool_calls[0].parameters)
434+
elif text := getattr(chunk, "text", None):
435+
yield text
436+
else:
437+
chunk_type = getattr(chunk, "type", None)
438+
if chunk_type == "content-delta":
439+
delta = getattr(chunk, "delta", None)
440+
message = getattr(delta, "message", None)
441+
content = getattr(message, "content", None)
442+
if text := getattr(content, "text", None):
443+
yield text
444+
elif chunk_type == "tool-call-delta":
445+
delta = getattr(chunk, "delta", None)
446+
message = getattr(delta, "message", None)
447+
tool_calls = getattr(message, "tool_calls", None)
448+
function = getattr(tool_calls, "function", None)
449+
if args := getattr(function, "arguments", None):
450+
yield args
381451
if mode == Mode.ANTHROPIC_JSON:
382452
if json_chunk := chunk.delta.text:
383453
yield json_chunk

instructor/processing/function_calls.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,29 @@ def _extract_text_content(completion: Any) -> str:
7979

8080

8181
def _validate_model_from_json(
82-
cls: type[Model],
82+
cls: type[Any],
8383
json_str: str,
8484
validation_context: Optional[dict[str, Any]] = None,
8585
strict: Optional[bool] = None,
86-
) -> Model:
86+
) -> Any:
8787
"""Validate model from JSON string with appropriate error handling."""
8888
try:
89-
if strict:
90-
return cls.model_validate_json(
91-
json_str, context=validation_context, strict=True
92-
)
93-
else:
89+
if hasattr(cls, "model_validate_json"):
90+
if strict:
91+
return cls.model_validate_json(
92+
json_str, context=validation_context, strict=True
93+
)
9494
# Allow control characters
9595
parsed = json.loads(json_str, strict=False)
9696
return cls.model_validate(parsed, context=validation_context, strict=False)
97+
98+
adapter = TypeAdapter(cls)
99+
if strict:
100+
return adapter.validate_json(
101+
json_str, context=validation_context, strict=True
102+
)
103+
parsed = json.loads(json_str, strict=False)
104+
return adapter.validate_python(parsed, context=validation_context, strict=False)
97105
except json.JSONDecodeError as e:
98106
logger.debug(f"JSON decode error: {e}")
99107
raise

instructor/providers/cohere/client.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
22

3+
import inspect
4+
from collections.abc import Awaitable
5+
from typing import Any, TypeVar, cast, overload
6+
37
import cohere
48
import instructor
5-
from typing import (
6-
TypeVar,
7-
overload,
8-
)
9-
from typing import Any
10-
from typing_extensions import ParamSpec
119
from pydantic import BaseModel
10+
from typing_extensions import ParamSpec
1211

1312

1413
T_Model = TypeVar("T_Model", bound=BaseModel)
@@ -81,17 +80,32 @@ def from_cohere(
8180
kwargs["_cohere_client_version"] = client_version
8281

8382
if is_async:
83+
84+
async def async_wrapper(*args: Any, **call_kwargs: Any):
85+
if call_kwargs.pop("stream", False):
86+
return client.chat_stream(*args, **call_kwargs)
87+
result = client.chat(*args, **call_kwargs)
88+
if inspect.isawaitable(result):
89+
return await cast(Awaitable[Any], result)
90+
return result
91+
8492
return instructor.AsyncInstructor(
8593
client=client,
86-
create=instructor.patch(create=client.chat, mode=mode),
94+
create=instructor.patch(create=async_wrapper, mode=mode),
8795
provider=instructor.Provider.COHERE,
8896
mode=mode,
8997
**kwargs,
9098
)
9199
else:
100+
101+
def sync_wrapper(*args: Any, **call_kwargs: Any):
102+
if call_kwargs.pop("stream", False):
103+
return client.chat_stream(*args, **call_kwargs)
104+
return client.chat(*args, **call_kwargs)
105+
92106
return instructor.Instructor(
93107
client=client,
94-
create=instructor.patch(create=client.chat, mode=mode),
108+
create=instructor.patch(create=sync_wrapper, mode=mode),
95109
provider=instructor.Provider.COHERE,
96110
mode=mode,
97111
**kwargs,

0 commit comments

Comments
 (0)