-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
84 lines (68 loc) · 2.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import inspect
from dataclasses import is_dataclass
from typing import Any, Callable, Optional
from pydantic import BaseModel
from unstructured_platform_plugins.schema.json_schema import (
parameters_to_json_schema,
response_to_json_schema,
)
def get_func(instance: Any, method_name: Optional[str] = None) -> Callable:
method_name = method_name or "__call__"
if inspect.isfunction(instance):
return instance
elif inspect.isclass(instance):
i = instance()
return getattr(i, method_name)
elif isinstance(instance, object) and hasattr(instance, method_name):
func = getattr(instance, method_name)
if inspect.ismethod(func):
return func
raise ValueError(f"type of instance not recognized: {type(instance)}")
def get_plugin_id(instance: Any, method_name: Optional[str] = None) -> str:
method_name = method_name or "__call__"
ref_id = None
if inspect.isfunction(instance):
ref_id = instance()
elif inspect.isclass(instance):
i = instance()
method_name = method_name or "__call__"
fn = getattr(i, method_name)
ref_id = fn()
elif isinstance(instance, object) and hasattr(instance, method_name):
func = getattr(instance, method_name)
if inspect.ismethod(func):
ref_id = func()
else:
ref_id = instance
if not ref_id:
raise ValueError(f"id could not be parsed from instance {instance}")
ref_id = str(ref_id)
if not ref_id.isidentifier():
raise ValueError(f"'{ref_id}' is not a valid identifier")
return ref_id
def get_input_schema(func: Callable) -> dict:
sig = inspect.signature(func)
parameters = list(sig.parameters.values())
return parameters_to_json_schema(parameters)
def get_output_sig(func: Callable) -> Optional[Any]:
sig = inspect.signature(func)
outputs = (
sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
)
return outputs
def get_output_schema(func: Callable) -> dict:
return response_to_json_schema(get_output_sig(func))
def get_schema_dict(func) -> dict:
return {
"inputs": get_input_schema(func),
"outputs": get_output_schema(func),
}
def map_inputs(func: Callable, raw_inputs: dict[str, Any]) -> dict[str, Any]:
input_params = {p.name: p for p in inspect.signature(func).parameters.values()}
for k, v in input_params.items():
annotation = v.annotation
if is_dataclass(annotation) and k in raw_inputs and isinstance(raw_inputs[k], dict):
raw_inputs[k] = annotation(**raw_inputs[k])
elif inspect.isclass(annotation) and issubclass(annotation, BaseModel):
raw_inputs[k] = annotation.parse_obj(raw_inputs[k])
return raw_inputs