Skip to content

Commit 28d60a5

Browse files
committed
feat+wip: move models endpoint to FastAPI
1 parent 89167f4 commit 28d60a5

File tree

3 files changed

+133
-77
lines changed

3 files changed

+133
-77
lines changed

deepaas/api/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from oslo_config import cfg
1919

2020
from deepaas.api.v2 import debug as v2_debug
21-
# from deepaas.api.v2 import models as v2_model
21+
from deepaas.api.v2 import models as v2_model
2222
# from deepaas.api.v2 import predict as v2_predict
2323
# from deepaas.api.v2 import responses
2424
# from deepaas.api.v2 import train as v2_train
@@ -40,6 +40,7 @@ def get_app(enable_train=True, enable_predict=True):
4040
v2_debug.setup_debug()
4141

4242
APP.include_router(v2_debug.router, tags=["debug"])
43+
APP.include_router(v2_model.get_router(), tags=["models"])
4344

4445
# APP.router.add_get("/", get_version, name="v2", allow_head=False)
4546
# v2_debug.setup_routes(APP)

deepaas/api/v2/models.py

Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,89 +14,104 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17-
import urllib.parse
18-
19-
from aiohttp import web
20-
import aiohttp_apispec
17+
import fastapi
2118

2219
from deepaas.api.v2 import responses
2320
from deepaas import model
2421

2522

26-
@aiohttp_apispec.docs(
27-
tags=["models"],
23+
router = fastapi.APIRouter(prefix="/models")
24+
25+
26+
@router.get(
27+
"/",
2828
summary="Return loaded models and its information",
29-
description="DEEPaaS can load several models and server them on the same "
30-
"endpoint, making a call to the root of the models namespace "
31-
"will return the loaded models, as long as their basic "
32-
"metadata.",
29+
description="Return list of DEEPaaS loaded models. In previous versions, DEEPaaS "
30+
"could load several models and serve them on the same endpoint.",
31+
tags=["models"],
32+
response_model=responses.ModelList,
3333
)
34-
@aiohttp_apispec.response_schema(responses.ModelMeta(), 200)
35-
async def index(request):
36-
"""Return loaded models and its information.
34+
async def index_models(
35+
request: fastapi.Request,
36+
):
37+
"""Return loaded models and its information."""
38+
39+
name = model.V2_MODEL_NAME
40+
model_obj = model.V2_MODEL
41+
m = {
42+
"id": name,
43+
"name": name,
44+
"links": [
45+
{
46+
"rel": "self",
47+
"href": str(request.url_for("get_model/" + name)),
48+
}
49+
],
50+
}
51+
meta = model_obj.get_metadata()
52+
m.update(meta)
53+
return {"models": [m]}
3754

38-
DEEPaaS can load several models and server them on the same endpoint,
39-
making a call to the root of the models namespace will return the
40-
loaded models, as long as their basic metadata.
41-
"""
4255

43-
models = []
44-
for name, obj in model.V2_MODELS.items():
45-
m = {
46-
"id": name,
47-
"name": name,
48-
"links": [
49-
{
50-
"rel": "self",
51-
"href": urllib.parse.urljoin("%s/" % request.path, name),
52-
}
53-
],
54-
}
55-
meta = obj.get_metadata()
56-
m.update(meta)
57-
models.append(m)
58-
return web.json_response({"models": models})
59-
60-
61-
def _get_handler(model_name, model_obj):
56+
def _get_handler_for_model(model_name, model_obj):
57+
"""Auxiliary function to get the handler for a model.
58+
59+
This function returns a handler for a model that can be used to
60+
register the routes in the router.
61+
"""
6262
class Handler(object):
63+
"""Class to handle the model metadata endpoints."""
6364
model_name = None
6465
model_obj = None
6566

6667
def __init__(self, model_name, model_obj):
6768
self.model_name = model_name
6869
self.model_obj = model_obj
6970

70-
@aiohttp_apispec.docs(
71-
tags=["models"],
72-
summary="Return model's metadata",
73-
)
74-
@aiohttp_apispec.response_schema(responses.ModelMeta(), 200)
75-
async def get(self, request):
71+
async def get(self, request: fastapi.Request):
72+
"""Return model's metadata."""
7673
m = {
7774
"id": self.model_name,
7875
"name": self.model_name,
7976
"links": [
8077
{
8178
"rel": "self",
82-
"href": request.path.rstrip("/"),
79+
"href": str(request.url),
8380
}
8481
],
8582
}
8683
meta = self.model_obj.get_metadata()
8784
m.update(meta)
8885

89-
return web.json_response(m)
86+
return m
87+
88+
def register_routes(self, router):
89+
"""Register routes for the model in the router."""
90+
router.add_api_route(
91+
f"/{self.model_name}",
92+
self.get,
93+
name="get_model/" + self.model_name,
94+
summary="Return model's metadata",
95+
tags=["models"],
96+
response_model=responses.ModelMeta,
97+
)
9098

9199
return Handler(model_name, model_obj)
92100

93101

94-
def setup_routes(app):
95-
app.router.add_get("/models/", index, allow_head=False)
102+
def get_router() -> fastapi.APIRouter:
103+
"""Auxiliary function to get the router.
104+
105+
We use this function to be able to include the router in the main
106+
application and do things before it gets included.
107+
108+
In this case we explicitly include the model's endpoints.
109+
110+
"""
111+
model_name = model.V2_MODEL_NAME
112+
model_obj = model.V2_MODEL
113+
114+
hdlr = _get_handler_for_model(model_name, model_obj)
115+
hdlr.register_routes(router)
96116

97-
# In the next lines we iterate over the loaded models and create the
98-
# different resources for each model. This way we can also load the
99-
# expected parameters if needed (as in the training method).
100-
for model_name, model_obj in model.V2_MODELS.items():
101-
hdlr = _get_handler(model_name, model_obj)
102-
app.router.add_get("/models/%s/" % model_name, hdlr.get, allow_head=False)
117+
return router

deepaas/api/v2/responses.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,23 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17+
import typing
18+
1719
import marshmallow
1820
from marshmallow import fields
1921
from marshmallow import validate
22+
import pydantic
2023

2124

22-
class Location(marshmallow.Schema):
23-
rel = fields.Str(required=True)
24-
href = fields.Url(required=True)
25-
type = fields.Str(required=True)
26-
27-
28-
class Version(marshmallow.Schema):
29-
version = fields.Str(required="True")
30-
id = fields.Str(required="True")
31-
links = fields.Nested(Location)
32-
type = fields.Str()
25+
# class Version(marshmallow.Schema):
26+
# version = fields.Str(required="True")
27+
# id = fields.Str(required="True")
28+
# # links = fields.Nested(Location)
29+
# type = fields.Str()
3330

3431

35-
class Versions(marshmallow.Schema):
36-
versions = fields.List(fields.Nested(Version))
32+
# class Versions(marshmallow.Schema):
33+
# versions = fields.List(fields.Nested(Version))
3734

3835

3936
class Failure(marshmallow.Schema):
@@ -45,17 +42,6 @@ class Prediction(marshmallow.Schema):
4542
predictions = fields.Str(required=True, description="String containing predictions")
4643

4744

48-
class ModelMeta(marshmallow.Schema):
49-
id = fields.Str(required=True, description="Model identifier") # noqa
50-
name = fields.Str(required=True, description="Model name")
51-
description = fields.Str(required=True, description="Model description")
52-
license = fields.Str(required=False, description="Model license")
53-
author = fields.Str(required=False, description="Model author")
54-
version = fields.Str(required=False, description="Model version")
55-
url = fields.Str(required=False, description="Model url")
56-
links = fields.List(fields.Nested(Location))
57-
58-
5945
class Training(marshmallow.Schema):
6046
uuid = fields.UUID(required=True, description="Training identifier")
6147
date = fields.DateTime(required=True, description="Training start time")
@@ -70,3 +56,57 @@ class Training(marshmallow.Schema):
7056

7157
class TrainingList(marshmallow.Schema):
7258
trainings = fields.List(fields.Nested(Training))
59+
60+
61+
# Pydantic models for the API
62+
63+
64+
class Location(pydantic.BaseModel):
65+
rel: str
66+
href: pydantic.AnyHttpUrl
67+
type: str = "application/json"
68+
69+
70+
class ModelMeta(pydantic.BaseModel):
71+
""""V2 model metadata.
72+
73+
This class is used to represent the metadata of a model in the V2 API, as we were
74+
doing in previous versions.
75+
"""
76+
id: str = pydantic.Field(..., description="Model identifier") # noqa
77+
name: str = pydantic.Field(..., description="Model name")
78+
description: typing.Optional[str] = pydantic.Field(
79+
description="Model description",
80+
default=None
81+
)
82+
summary: typing.Optional[str] = pydantic.Field(
83+
description="Model summary",
84+
default=None
85+
)
86+
license: typing.Optional[str] = pydantic.Field(
87+
description="Model license",
88+
default=None
89+
)
90+
author: typing.Optional[str] = pydantic.Field(
91+
description="Model author",
92+
default=None
93+
)
94+
version: typing.Optional[str] = pydantic.Field(
95+
description="Model version",
96+
default=None
97+
)
98+
url: typing.Optional[str] = pydantic.Field(
99+
description="Model url",
100+
default=None
101+
)
102+
# Links can be alist of Locations, or an empty list
103+
links: typing.List[Location] = pydantic.Field(
104+
description="Model links",
105+
)
106+
107+
108+
class ModelList(pydantic.BaseModel):
109+
models: typing.List[ModelMeta] = pydantic.Field(
110+
...,
111+
description="List of loaded models"
112+
)

0 commit comments

Comments
 (0)