Skip to content

Commit 429df53

Browse files
fix: patch OpenAI client in distil module to support dispatch mode
1 parent 848f1af commit 429df53

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

instructor/distil.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from openai import OpenAI
2323
from .processing.function_calls import openai_schema
24+
from .core.patch import patch
2425

2526

2627
P = ParamSpec("P")
@@ -123,7 +124,7 @@ def __init__(
123124
self.finetune_format = finetune_format
124125
self.indent = indent
125126
self.include_code_body = include_code_body
126-
self.client = openai_client or OpenAI()
127+
self.client = patch(openai_client or OpenAI())
127128

128129
self.logger = logging.getLogger(self.name)
129130
for handler in log_handlers or []:
@@ -184,7 +185,7 @@ def _dispatch(*args: P.args, **kwargs: P.kwargs) -> ChatCompletion:
184185
return self.client.chat.completions.create(
185186
**openai_kwargs,
186187
model=model,
187-
response_model=return_base_model, # type: ignore - TODO figure out why `response_model` is not recognized
188+
response_model=return_base_model,
188189
)
189190

190191
@functools.wraps(fn)

tests/test_distil.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
2+
import unittest
3+
from typing import Generator
4+
import sys
5+
from pydantic import BaseModel
6+
from instructor.distil import (
7+
get_signature_from_fn,
8+
format_function,
9+
is_return_type_base_model_or_instance,
10+
Instructions,
11+
)
12+
from unittest.mock import MagicMock, patch
13+
14+
class User(BaseModel):
15+
name: str
16+
age: int
17+
18+
def sample_function(a: int, b: int) -> User:
19+
"""Sample docstring."""
20+
return User(name="Test", age=a + b)
21+
22+
class TestDistil(unittest.TestCase):
23+
def setUp(self):
24+
# Patch OpenAI to avoid missing API key errors
25+
self.openai_patcher = patch("instructor.distil.OpenAI")
26+
self.mock_openai = self.openai_patcher.start()
27+
28+
# Patch instructor.distil.patch to avoid actual patching logic
29+
self.instructor_patch_patcher = patch("instructor.distil.patch")
30+
self.mock_instructor_patch = self.instructor_patch_patcher.start()
31+
self.mock_instructor_patch.side_effect = lambda c, mode=None: c
32+
33+
def tearDown(self):
34+
self.openai_patcher.stop()
35+
self.instructor_patch_patcher.stop()
36+
37+
def test_get_signature_from_fn(self):
38+
sig = get_signature_from_fn(sample_function)
39+
self.assertIn("def sample_function(a: int, b: int) -> ", sig)
40+
self.assertIn("Sample docstring", sig)
41+
42+
def test_format_function(self):
43+
formatted = format_function(sample_function)
44+
self.assertIn("def sample_function", formatted)
45+
self.assertIn("return User(name=\"Test\", age=a + b)", formatted)
46+
self.assertIn("Sample docstring", formatted)
47+
48+
def test_is_return_type_base_model_or_instance(self):
49+
self.assertTrue(is_return_type_base_model_or_instance(sample_function))
50+
51+
def invalid_function():
52+
pass
53+
54+
def int_function() -> int:
55+
return 1
56+
57+
with self.assertRaises(AssertionError):
58+
is_return_type_base_model_or_instance(invalid_function)
59+
60+
self.assertFalse(is_return_type_base_model_or_instance(int_function))
61+
62+
def test_instructions_init_patches_client(self):
63+
# Test that client is patched
64+
Instructions()
65+
self.mock_instructor_patch.assert_called()
66+
67+
def test_instructions_distil_decorator(self):
68+
instructions = Instructions()
69+
# Mock the client
70+
instructions.client = MagicMock()
71+
instructions.track = MagicMock()
72+
73+
@instructions.distil
74+
def tracked_function(x: int) -> User:
75+
return User(name="Tracked", age=x)
76+
77+
result = tracked_function(10)
78+
self.assertEqual(result.name, "Tracked")
79+
self.assertEqual(result.age, 10)
80+
instructions.track.assert_called_once()
81+
82+
def test_instructions_dispatch_mode(self):
83+
instructions = Instructions()
84+
# Mock the client
85+
instructions.client = MagicMock()
86+
instructions.client.chat.completions.create.return_value = User(name="Dispatched", age=20)
87+
88+
@instructions.distil(mode="dispatch")
89+
def dispatched_function(x: int) -> User:
90+
return User(name="Original", age=x)
91+
92+
result = dispatched_function(20)
93+
94+
# In dispatch mode, it should call client.create and return its result
95+
self.assertEqual(result.name, "Dispatched")
96+
instructions.client.chat.completions.create.assert_called_once()
97+
98+
# Verify arguments passed to create
99+
call_kwargs = instructions.client.chat.completions.create.call_args[1]
100+
self.assertEqual(call_kwargs["response_model"], User)
101+
self.assertIn("messages", call_kwargs)
102+
103+
if __name__ == "__main__":
104+
unittest.main()

0 commit comments

Comments
 (0)