23
23
#
24
24
# Since the original project has taken a huge pivot and provide many unnecessary features - this is a stripped version
25
25
# of the openai_function decorator copied from
26
- # https://github.com/jxnl/instructor/blob/0.2.3/openai_function_call /function_calls.py
26
+ # https://github.com/jxnl/instructor/blob/0.2.8/instructor /function_calls.py
27
27
28
28
import json
29
+ from docstring_parser import parse
29
30
from functools import wraps
30
31
from typing import Any , Callable
31
32
from pydantic import validate_arguments
@@ -35,7 +36,7 @@ def _remove_a_key(d, remove_key) -> None:
35
36
"""Remove a key from a dictionary recursively"""
36
37
if isinstance (d , dict ):
37
38
for key in list (d .keys ()):
38
- if key == remove_key :
39
+ if key == remove_key and "type" in d . keys () :
39
40
del d [key ]
40
41
else :
41
42
_remove_a_key (d [key ], remove_key )
@@ -70,20 +71,27 @@ def sum(a: int, b: int) -> int:
70
71
def __init__ (self , func : Callable ) -> None :
71
72
self .func = func
72
73
self .validate_func = validate_arguments (func )
74
+ self .docstring = parse (self .func .__doc__ or "" )
75
+
73
76
parameters = self .validate_func .model .model_json_schema ()
74
77
parameters ["properties" ] = {
75
78
k : v
76
79
for k , v in parameters ["properties" ].items ()
77
80
if k not in ("v__duplicate_kwargs" , "args" , "kwargs" )
78
81
}
82
+ for param in self .docstring .params :
83
+ if (name := param .arg_name ) in parameters ["properties" ] and (
84
+ description := param .description
85
+ ):
86
+ parameters ["properties" ][name ]["description" ] = description
79
87
parameters ["required" ] = sorted (
80
88
k for k , v in parameters ["properties" ].items () if not "default" in v
81
89
)
82
90
_remove_a_key (parameters , "additionalProperties" )
83
91
_remove_a_key (parameters , "title" )
84
92
self .openai_schema = {
85
93
"name" : self .func .__name__ ,
86
- "description" : self .func . __doc__ ,
94
+ "description" : self .docstring . short_description ,
87
95
"parameters" : parameters ,
88
96
}
89
97
self .model = self .validate_func .model
@@ -106,7 +114,7 @@ def from_response(self, completion, throw_error=True):
106
114
Returns:
107
115
result (any): result of the function call
108
116
"""
109
- message = completion . choices [0 ]. message
117
+ message = completion [ " choices" ] [0 ][ " message" ]
110
118
111
119
if throw_error :
112
120
assert "function_call" in message , "No function call detected"
0 commit comments