Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions instructor/distil.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from openai import OpenAI
from .processing.function_calls import openai_schema
from .core.patch import patch


P = ParamSpec("P")
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
self.finetune_format = finetune_format
self.indent = indent
self.include_code_body = include_code_body
self.client = openai_client or OpenAI()
self.client = patch(openai_client or OpenAI())

self.logger = logging.getLogger(self.name)
for handler in log_handlers or []:
Expand Down Expand Up @@ -184,7 +185,7 @@ def _dispatch(*args: P.args, **kwargs: P.kwargs) -> ChatCompletion:
return self.client.chat.completions.create(
**openai_kwargs,
model=model,
response_model=return_base_model, # type: ignore - TODO figure out why `response_model` is not recognized
response_model=return_base_model,
)

@functools.wraps(fn)
Expand Down
104 changes: 104 additions & 0 deletions tests/test_distil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@

import unittest
from typing import Generator
import sys
from pydantic import BaseModel
from instructor.distil import (
get_signature_from_fn,
format_function,
is_return_type_base_model_or_instance,
Instructions,
)
from unittest.mock import MagicMock, patch

class User(BaseModel):
name: str
age: int

def sample_function(a: int, b: int) -> User:
"""Sample docstring."""
return User(name="Test", age=a + b)

class TestDistil(unittest.TestCase):
def setUp(self):
# Patch OpenAI to avoid missing API key errors
self.openai_patcher = patch("instructor.distil.OpenAI")
self.mock_openai = self.openai_patcher.start()

# Patch instructor.distil.patch to avoid actual patching logic
self.instructor_patch_patcher = patch("instructor.distil.patch")
self.mock_instructor_patch = self.instructor_patch_patcher.start()
self.mock_instructor_patch.side_effect = lambda c, mode=None: c

def tearDown(self):
self.openai_patcher.stop()
self.instructor_patch_patcher.stop()

def test_get_signature_from_fn(self):
sig = get_signature_from_fn(sample_function)
self.assertIn("def sample_function(a: int, b: int) -> ", sig)
self.assertIn("Sample docstring", sig)

def test_format_function(self):
formatted = format_function(sample_function)
self.assertIn("def sample_function", formatted)
self.assertIn("return User(name=\"Test\", age=a + b)", formatted)
self.assertIn("Sample docstring", formatted)

def test_is_return_type_base_model_or_instance(self):
self.assertTrue(is_return_type_base_model_or_instance(sample_function))

def invalid_function():
pass

def int_function() -> int:
return 1

with self.assertRaises(AssertionError):
is_return_type_base_model_or_instance(invalid_function)

self.assertFalse(is_return_type_base_model_or_instance(int_function))

def test_instructions_init_patches_client(self):
# Test that client is patched
Instructions()
self.mock_instructor_patch.assert_called()

def test_instructions_distil_decorator(self):
instructions = Instructions()
# Mock the client
instructions.client = MagicMock()
instructions.track = MagicMock()

@instructions.distil
def tracked_function(x: int) -> User:
return User(name="Tracked", age=x)

result = tracked_function(10)
self.assertEqual(result.name, "Tracked")
self.assertEqual(result.age, 10)
instructions.track.assert_called_once()

def test_instructions_dispatch_mode(self):
instructions = Instructions()
# Mock the client
instructions.client = MagicMock()
instructions.client.chat.completions.create.return_value = User(name="Dispatched", age=20)

@instructions.distil(mode="dispatch")
def dispatched_function(x: int) -> User:
return User(name="Original", age=x)

result = dispatched_function(20)

# In dispatch mode, it should call client.create and return its result
self.assertEqual(result.name, "Dispatched")
instructions.client.chat.completions.create.assert_called_once()

# Verify arguments passed to create
call_kwargs = instructions.client.chat.completions.create.call_args[1]
self.assertEqual(call_kwargs["response_model"], User)
self.assertIn("messages", call_kwargs)

if __name__ == "__main__":
unittest.main()