14
14
# License for the specific language governing permissions and limitations
15
15
# under the License.
16
16
17
- from aiohttp import web
18
- import aiohttp_apispec
19
- from webargs import aiohttpparser
20
- import webargs .core
17
+ # from aiohttp import web
18
+ # import aiohttp_apispec
19
+ # from webargs import aiohttpparser
20
+ # import webargs.core
21
+
22
+ import fastapi
23
+ import fastapi .encoders
24
+ import fastapi .exceptions
21
25
22
26
from deepaas .api .v2 import responses
23
27
from deepaas .api .v2 import utils
@@ -33,68 +37,80 @@ def _get_model_response(model_name, model_obj):
33
37
return responses .Prediction
34
38
35
39
36
- def _get_handler (model_name , model_obj ):
37
- aux = model_obj .get_predict_args ()
38
- accept = aux .get ("accept" , None )
39
- if accept :
40
- accept .validate .choices .append ("*/*" )
41
- accept .load_default = accept .validate .choices [0 ]
42
- accept .location = "headers"
40
+ router = fastapi .APIRouter (prefix = "/models" )
41
+
42
+
43
+ def _get_handler_for_model (model_name , model_obj ):
44
+ """Auxiliary function to get the handler for a model.
45
+
46
+ This function returns a handler for a model that can be used to
47
+ register the routes in the router.
43
48
44
- handler_args = webargs .core .dict2schema (aux )
45
- handler_args .opts .ordered = True
49
+ """
46
50
47
- response = _get_model_response (model_name , model_obj )
51
+ user_declared_args = model_obj .get_predict_args ()
52
+ pydantic_schema = utils .get_pydantic_schema_from_marshmallow_fields (
53
+ "PydanticSchema" ,
54
+ user_declared_args ,
55
+ )
48
56
49
57
class Handler (object ):
58
+ """Class to handle the model metadata endpoints."""
59
+
50
60
model_name = None
51
61
model_obj = None
52
62
53
63
def __init__ (self , model_name , model_obj ):
54
64
self .model_name = model_name
55
65
self .model_obj = model_obj
56
66
57
- @aiohttp_apispec .docs (
58
- tags = ["models" ],
59
- summary = "Make a prediction given the input data" ,
60
- produces = accept .validate .choices if accept else None ,
61
- )
62
- @aiohttp_apispec .querystring_schema (handler_args )
63
- @aiohttp_apispec .response_schema (response (), 200 )
64
- @aiohttp_apispec .response_schema (responses .Failure (), 400 )
65
- async def post (self , request ):
66
- args = await aiohttpparser .parser .parse (handler_args , request )
67
- task = self .model_obj .predict (** args )
68
- await task
69
-
70
- ret = task .result ()["output" ]
67
+ async def predict (self , args : pydantic_schema = fastapi .Depends ()):
68
+ """Make a prediction given the input data."""
69
+ dict_args = args .model_dump (by_alias = True )
70
+
71
+ ret = await self .model_obj .predict (** args .model_dump (by_alias = True ))
71
72
72
73
if isinstance (ret , model .v2 .wrapper .ReturnedFile ):
73
74
ret = open (ret .filename , "rb" )
74
75
75
- accept = args .get ("accept" , "application/json" )
76
- if accept not in ["application/json" , "*/*" ]:
77
- response = web .Response (
78
- body = ret ,
79
- content_type = accept ,
80
- )
81
- return response
82
76
if self .model_obj .has_schema :
83
- self .model_obj .validate_response (ret )
84
- return web .json_response (ret )
77
+ # FIXME(aloga): Validation does not work, as we are converting from
78
+ # Marshmallow to Pydantic, check this as son as possible.
79
+ # self.model_obj.validate_response(ret)
80
+ return fastapi .responses .JSONResponse (ret )
81
+
82
+ return fastapi .responses .JSONResponse (
83
+ content = {"status" : "OK" , "predictions" : ret }
84
+ )
85
85
86
- return web .json_response ({"status" : "OK" , "predictions" : ret })
86
+ def register_routes (self , router ):
87
+ """Register the routes in the router."""
88
+
89
+ response = _get_model_response (self .model_name , self .model_obj )
90
+
91
+ router .add_api_route (
92
+ f"/{ self .model_name } /predict" ,
93
+ self .predict ,
94
+ methods = ["POST" ],
95
+ response_model = response ,
96
+ )
87
97
88
98
return Handler (model_name , model_obj )
89
99
90
100
91
- def setup_routes (app , enable = True ):
92
- # In the next lines we iterate over the loaded models and create the
93
- # different resources for each model. This way we can also load the
94
- # expected parameters if needed (as in the training method).
95
- for model_name , model_obj in model .V2_MODELS .items ():
96
- if enable :
97
- hdlr = _get_handler (model_name , model_obj )
98
- else :
99
- hdlr = utils .NotEnabledHandler ()
100
- app .router .add_post ("/models/%s/predict/" % model_name , hdlr .post )
101
+ def get_router ():
102
+ """Auxiliary function to get the router.
103
+
104
+ We use this function to be able to include the router in the main
105
+ application and do things before it gets included.
106
+
107
+ In this case we explicitly include the model precit endpoint.
108
+
109
+ """
110
+ model_name = model .V2_MODEL_NAME
111
+ model_obj = model .V2_MODEL
112
+
113
+ hdlr = _get_handler_for_model (model_name , model_obj )
114
+ hdlr .register_routes (router )
115
+
116
+ return router
0 commit comments