Skip to content

Commit fb89404

Browse files
committed
feat+wip: move predict method to FastAPI
This requires that we change all model args/responses from Marshmallow to Pydantic. Most of the code is in this change, we can split it later on two different changes (marshmallow + pydantic and FastAPI for predict).
1 parent e005e71 commit fb89404

File tree

5 files changed

+298
-60
lines changed

5 files changed

+298
-60
lines changed

deepaas/api/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from deepaas.api.v2 import debug as v2_debug
2121
from deepaas.api.v2 import models as v2_model
22-
# from deepaas.api.v2 import predict as v2_predict
22+
from deepaas.api.v2 import predict as v2_predict
2323
# from deepaas.api.v2 import responses
2424
# from deepaas.api.v2 import train as v2_train
2525
from deepaas import log
@@ -41,6 +41,7 @@ def get_app(enable_train=True, enable_predict=True):
4141

4242
APP.include_router(v2_debug.router, tags=["debug"])
4343
APP.include_router(v2_model.get_router(), tags=["models"])
44+
APP.include_router(v2_predict.get_router(), tags=["predict"])
4445

4546
# APP.router.add_get("/", get_version, name="v2", allow_head=False)
4647
# v2_debug.setup_routes(APP)

deepaas/api/v2/predict.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17-
from aiohttp import web
18-
import aiohttp_apispec
19-
from webargs import aiohttpparser
20-
import webargs.core
17+
# from aiohttp import web
18+
# import aiohttp_apispec
19+
# from webargs import aiohttpparser
20+
# import webargs.core
21+
22+
import fastapi
23+
import fastapi.encoders
24+
import fastapi.exceptions
2125

2226
from deepaas.api.v2 import responses
2327
from deepaas.api.v2 import utils
@@ -33,68 +37,80 @@ def _get_model_response(model_name, model_obj):
3337
return responses.Prediction
3438

3539

36-
def _get_handler(model_name, model_obj):
37-
aux = model_obj.get_predict_args()
38-
accept = aux.get("accept", None)
39-
if accept:
40-
accept.validate.choices.append("*/*")
41-
accept.load_default = accept.validate.choices[0]
42-
accept.location = "headers"
40+
router = fastapi.APIRouter(prefix="/models")
41+
42+
43+
def _get_handler_for_model(model_name, model_obj):
44+
"""Auxiliary function to get the handler for a model.
45+
46+
This function returns a handler for a model that can be used to
47+
register the routes in the router.
4348
44-
handler_args = webargs.core.dict2schema(aux)
45-
handler_args.opts.ordered = True
49+
"""
4650

47-
response = _get_model_response(model_name, model_obj)
51+
user_declared_args = model_obj.get_predict_args()
52+
pydantic_schema = utils.get_pydantic_schema_from_marshmallow_fields(
53+
"PydanticSchema",
54+
user_declared_args,
55+
)
4856

4957
class Handler(object):
58+
"""Class to handle the model metadata endpoints."""
59+
5060
model_name = None
5161
model_obj = None
5262

5363
def __init__(self, model_name, model_obj):
5464
self.model_name = model_name
5565
self.model_obj = model_obj
5666

57-
@aiohttp_apispec.docs(
58-
tags=["models"],
59-
summary="Make a prediction given the input data",
60-
produces=accept.validate.choices if accept else None,
61-
)
62-
@aiohttp_apispec.querystring_schema(handler_args)
63-
@aiohttp_apispec.response_schema(response(), 200)
64-
@aiohttp_apispec.response_schema(responses.Failure(), 400)
65-
async def post(self, request):
66-
args = await aiohttpparser.parser.parse(handler_args, request)
67-
task = self.model_obj.predict(**args)
68-
await task
69-
70-
ret = task.result()["output"]
67+
async def predict(self, args: pydantic_schema = fastapi.Depends()):
68+
"""Make a prediction given the input data."""
69+
dict_args = args.model_dump(by_alias=True)
70+
71+
ret = await self.model_obj.predict(**args.model_dump(by_alias=True))
7172

7273
if isinstance(ret, model.v2.wrapper.ReturnedFile):
7374
ret = open(ret.filename, "rb")
7475

75-
accept = args.get("accept", "application/json")
76-
if accept not in ["application/json", "*/*"]:
77-
response = web.Response(
78-
body=ret,
79-
content_type=accept,
80-
)
81-
return response
8276
if self.model_obj.has_schema:
83-
self.model_obj.validate_response(ret)
84-
return web.json_response(ret)
77+
# FIXME(aloga): Validation does not work, as we are converting from
78+
# Marshmallow to Pydantic, check this as son as possible.
79+
# self.model_obj.validate_response(ret)
80+
return fastapi.responses.JSONResponse(ret)
81+
82+
return fastapi.responses.JSONResponse(
83+
content={"status": "OK", "predictions": ret}
84+
)
8585

86-
return web.json_response({"status": "OK", "predictions": ret})
86+
def register_routes(self, router):
87+
"""Register the routes in the router."""
88+
89+
response = _get_model_response(self.model_name, self.model_obj)
90+
91+
router.add_api_route(
92+
f"/{self.model_name}/predict",
93+
self.predict,
94+
methods=["POST"],
95+
response_model=response,
96+
)
8797

8898
return Handler(model_name, model_obj)
8999

90100

91-
def setup_routes(app, enable=True):
92-
# In the next lines we iterate over the loaded models and create the
93-
# different resources for each model. This way we can also load the
94-
# expected parameters if needed (as in the training method).
95-
for model_name, model_obj in model.V2_MODELS.items():
96-
if enable:
97-
hdlr = _get_handler(model_name, model_obj)
98-
else:
99-
hdlr = utils.NotEnabledHandler()
100-
app.router.add_post("/models/%s/predict/" % model_name, hdlr.post)
101+
def get_router():
102+
"""Auxiliary function to get the router.
103+
104+
We use this function to be able to include the router in the main
105+
application and do things before it gets included.
106+
107+
In this case we explicitly include the model precit endpoint.
108+
109+
"""
110+
model_name = model.V2_MODEL_NAME
111+
model_obj = model.V2_MODEL
112+
113+
hdlr = _get_handler_for_model(model_name, model_obj)
114+
hdlr.register_routes(router)
115+
116+
return router

deepaas/api/v2/responses.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@
3333
# versions = fields.List(fields.Nested(Version))
3434

3535

36-
class Failure(marshmallow.Schema):
37-
message = fields.Str(required=True, description="Failure message")
38-
39-
40-
class Prediction(marshmallow.Schema):
41-
status = fields.String(required=True, description="Response status message")
42-
predictions = fields.Str(required=True, description="String containing predictions")
43-
44-
4536
class Training(marshmallow.Schema):
4637
uuid = fields.UUID(required=True, description="Training identifier")
4738
date = fields.DateTime(required=True, description="Training start time")
@@ -110,3 +101,12 @@ class ModelList(pydantic.BaseModel):
110101
...,
111102
description="List of loaded models"
112103
)
104+
105+
106+
class Prediction(pydantic.BaseModel):
107+
status: str = pydantic.Field(description="Response status message")
108+
predictions: str = pydantic.Field(description="String containing predictions")
109+
110+
111+
class Failure(pydantic.BaseModel):
112+
message: str = pydantic.Field(description="Failure message")

0 commit comments

Comments
 (0)