Skip to content

Commit 853190d

Browse files
committed
fix: update openai_function_call
1 parent 5f05d46 commit 853190d

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

openai_streaming/fn_dispatcher.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def o_func(func):
2727
"""
2828
if hasattr(func, 'func'):
2929
return o_func(func.func)
30+
if hasattr(func, '__func'):
31+
return o_func(func.__func)
3032
return func
3133

3234

openai_streaming/openai_function.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
#
2424
# Since the original project has taken a huge pivot and provide many unnecessary features - this is a stripped version
2525
# of the openai_function decorator copied from
26-
# https://github.com/jxnl/instructor/blob/0.2.3/openai_function_call/function_calls.py
26+
# https://github.com/jxnl/instructor/blob/0.2.8/instructor/function_calls.py
2727

2828
import json
29+
from docstring_parser import parse
2930
from functools import wraps
3031
from typing import Any, Callable
3132
from pydantic import validate_arguments
@@ -35,7 +36,7 @@ def _remove_a_key(d, remove_key) -> None:
3536
"""Remove a key from a dictionary recursively"""
3637
if isinstance(d, dict):
3738
for key in list(d.keys()):
38-
if key == remove_key:
39+
if key == remove_key and "type" in d.keys():
3940
del d[key]
4041
else:
4142
_remove_a_key(d[key], remove_key)
@@ -70,20 +71,27 @@ def sum(a: int, b: int) -> int:
7071
def __init__(self, func: Callable) -> None:
7172
self.func = func
7273
self.validate_func = validate_arguments(func)
74+
self.docstring = parse(self.func.__doc__ or "")
75+
7376
parameters = self.validate_func.model.model_json_schema()
7477
parameters["properties"] = {
7578
k: v
7679
for k, v in parameters["properties"].items()
7780
if k not in ("v__duplicate_kwargs", "args", "kwargs")
7881
}
82+
for param in self.docstring.params:
83+
if (name := param.arg_name) in parameters["properties"] and (
84+
description := param.description
85+
):
86+
parameters["properties"][name]["description"] = description
7987
parameters["required"] = sorted(
8088
k for k, v in parameters["properties"].items() if not "default" in v
8189
)
8290
_remove_a_key(parameters, "additionalProperties")
8391
_remove_a_key(parameters, "title")
8492
self.openai_schema = {
8593
"name": self.func.__name__,
86-
"description": self.func.__doc__,
94+
"description": self.docstring.short_description,
8795
"parameters": parameters,
8896
}
8997
self.model = self.validate_func.model
@@ -106,7 +114,7 @@ def from_response(self, completion, throw_error=True):
106114
Returns:
107115
result (any): result of the function call
108116
"""
109-
message = completion.choices[0].message
117+
message = completion["choices"][0]["message"]
110118

111119
if throw_error:
112120
assert "function_call" in message, "No function call detected"

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python = "^3.9"
1818
openai = "^0.27.8"
1919
json-streamer = "^0.1.0"
2020
pydantic = "^2.0.2"
21+
docstring-parser = "^0.15"
2122

2223
[dev-dependencies]
2324
pytest = "^6.2"

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
openai==0.27.8
22
json-streamer==0.1.0
33
pydantic==2.0.2
4+
docstring-parser==0.15

0 commit comments

Comments
 (0)