-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathinference_api.py
More file actions
108 lines (89 loc) · 3.97 KB
/
inference_api.py
File metadata and controls
108 lines (89 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import httpx
from httpx import Response
import logging
log = logging.getLogger(__name__)
class InferenceAPI:
def __init__(self,
address: str,
error_callback: Optional[Callable] = None) -> None:
self.address = address
if not error_callback:
error_callback = self.default_error_callback
self.client = httpx.Client(timeout=1000,
event_hooks={"response": [error_callback]})
@staticmethod
def default_error_callback(response: Response) -> None:
if response.status_code != 200:
log.info(f"Warn - status code: {response.status_code},{response}")
# get available models
# determine type of model (NOT currently possible, we don't know which endpoint maps to what type)
# auto generate python client based on model spec? (use library)
# or define explicitly like below (not sustainable)
def image_classify(endpoint, image):
# put image in request body
#
# res send(address+endpoint,req)
# return res
pass
### Model
def get_predictions(self,
input_,
model_name: str,
version: Optional[str] = None):
"""Get the prediction of a model according to the provided input.
CURL equivalence:
curl http://localhost:8080/predictions/resnet-18/2.0 -T kitten_small.jpg
or
curl http://localhost:8080/predictions/resnet-18/2.0 -F "data=@kitten_small.jpg"
Args:
input_ ([type]): Buffer or Tensor to send to the endpoint as payload
model_name (str): name of the model to use
version (Optional[str]): Version number of the model. Defaults to None.
Returns:
httpx.Response: The response from the Torch server.
"""
req_url = self.address + '/predictions/' + model_name
if version:
req_url += '/' + version
res: Response = self.client.post(req_url, files={'data': input_})
return res
def get_explanations(self,
input_,
model_name: str,
version: Optional[str] = None):
"""Get Explanations from the model.
(sets is_explain to True in the model handler, which leads to calling its the ``explain_handle`` method).
CURL equivalence:
curl http://localhost:8080/explanations/resnet-18/2.0 -T kitten_small.jpg
or
curl http://localhost:8080/explanations/resnet-18/2.0 -F "data=@kitten_small.jpg"
Args:
input_ ([type]): Buffer or Tensor to send to the endpoint as payload
model_name (str): name of the model to use
version (Optional[str]): Version number of the model. Defaults to None.
Returns:
httpx.Response: The response from the Torch server.
"""
req_url = self.address + '/explanations/' + model_name
if version:
req_url += '/' + version
res: Response = self.client.post(req_url, files={'data': input_})
return res
### Workflow
def get_workflow_predictions(self, input_: str, workflow_name: str):
"""Get the prediction of a model according to the provided input.
CURL equivalence:
curl http://localhost:8080/wfpredict/myworkflow -T kitten_small.jpg
or
curl http://localhost:8080/wfpredict/myworkflow -F "data=@kitten_small.jpg"
Args:
input_ ([type]): Buffer or Tensor to send to the endpoint as payload
model_name (str): name of the model to use
Returns:
httpx.Response: The response from the Torch server.
"""
req_url = self.address + '/wfpredict/' + workflow_name
res: Response = self.client.post(req_url, files={'data': input_})
return res