Skip to content

Commit a72ca59

Browse files
committed
FastAPI model changes
Signed-off-by: shubh chaurasia <[email protected]>
1 parent 2dc2f70 commit a72ca59

File tree

2 files changed

+56
-39
lines changed

2 files changed

+56
-39
lines changed

mlflow/pyfunc/scoring_server/__init__.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
import pandas as pd
2121
import sys
2222
import traceback
23+
from pydantic import BaseModel
24+
from fastapi import FastAPI, APIRouter, Request, HTTPException, Response, Header, status
25+
from typing import List, Optional, Dict
26+
import uvicorn
27+
import asyncio
28+
import json
2329

2430
# NB: We need to be careful what we import form mlflow here. Scoring server is used from within
2531
# model's conda environment. The version of mlflow doing the serving (outside) and the version of
@@ -65,13 +71,24 @@
6571

6672
CONTENT_TYPE_FORMAT_RECORDS_ORIENTED = "pandas-records"
6773
CONTENT_TYPE_FORMAT_SPLIT_ORIENTED = "pandas-split"
74+
CONTENT_TYPE_RAW_JSON = "raw-json"
6875

69-
FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED]
76+
FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED, CONTENT_TYPE_RAW_JSON]
7077

7178
PREDICTIONS_WRAPPER_ATTR_NAME_ENV_KEY = "PREDICTIONS_WRAPPER_ATTR_NAME"
7279

7380
_logger = logging.getLogger(__name__)
7481

82+
class RequestData(BaseModel):
83+
columns: List[str] = []
84+
data: list = []
85+
86+
def is_valid(self):
87+
return True
88+
89+
def get_dataframe(self):
90+
df = pd.DataFrame(data = self.data, columns = self.columns)
91+
return df
7592

7693
def infer_and_parse_json_input(json_input, schema: Schema = None):
7794
"""
@@ -205,38 +222,38 @@ def _handle_serving_error(error_message, error_code, include_traceback=True):
205222
e = MlflowException(message=error_message, error_code=error_code)
206223
reraise(MlflowException, e)
207224

208-
209225
def init(model: PyFuncModel):
210226

211227
"""
212228
Initialize the server. Loads pyfunc model from the path.
213229
"""
214-
app = flask.Flask(__name__)
230+
fast_app = FastAPI(title= __name__, version= "v1")
231+
fast_app.include_router(APIRouter())
215232
input_schema = model.metadata.get_input_schema()
216233

217-
@app.route("/ping", methods=["GET"])
234+
@fast_app.get("/ping")
218235
def ping(): # pylint: disable=unused-variable
219236
"""
220237
Determine if the container is working and healthy.
221238
We declare it healthy if we can load the model successfully.
222239
"""
223-
health = model is not None
224-
status = 200 if health else 404
225-
return flask.Response(response="\n", status=status, mimetype="application/json")
240+
if model is None:
241+
raise HTTPException(status_code=404, detail="Model not loaded properly")
242+
return {"message": "OK"}
226243

227-
@app.route("/invocations", methods=["POST"])
228-
@catch_mlflow_exception
229-
def transformation(): # pylint: disable=unused-variable
244+
@fast_app.post("/invocations")
245+
def transformation(request_data: RequestData, content_type: Optional[str] = Header(None)): # pylint: disable=unused-variable
230246
"""
231247
Do an inference on a single batch of data. In this sample server,
232248
we take data as CSV or json, convert it to a Pandas DataFrame or Numpy,
233249
generate predictions and convert them back to json.
234250
"""
251+
# data = _dataframe_from_json(request_data.json())
235252

236253
# Content-Type can include other attributes like CHARSET
237254
# Content-type RFC: https://datatracker.ietf.org/doc/html/rfc2045#section-5.1
238255
# TODO: Suport ";" in quoted parameter values
239-
type_parts = flask.request.content_type.split(";")
256+
type_parts = content_type.split(";")
240257
type_parts = list(map(str.strip, type_parts))
241258
mime_type = type_parts[0]
242259
parameter_value_pairs = type_parts[1:]
@@ -247,27 +264,31 @@ def transformation(): # pylint: disable=unused-variable
247264

248265
charset = parameter_values.get("charset", "utf-8").lower()
249266
if charset != "utf-8":
250-
return flask.Response(
251-
response="The scoring server only supports UTF-8",
252-
status=415,
253-
mimetype="text/plain",
267+
return Response(
268+
content="The scoring server only supports UTF-8",
269+
status_code=415,
270+
media_type="text/plain"
254271
)
255272

256273
content_format = parameter_values.get("format")
257274

258275
# Convert from CSV to pandas
259276
if mime_type == CONTENT_TYPE_CSV and not content_format:
260-
data = flask.request.data.decode("utf-8")
277+
data = request_data.json()
261278
csv_input = StringIO(data)
262279
data = parse_csv_input(csv_input=csv_input)
280+
elif mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_RAW_JSON:
281+
if len(request_data.data) != 0:
282+
data = dict(zip(request_data.columns, request_data.data[0]))
283+
else:
284+
data = {}
263285
elif mime_type == CONTENT_TYPE_JSON and not content_format:
264-
json_str = flask.request.data.decode("utf-8")
265-
data = infer_and_parse_json_input(json_str, input_schema)
286+
data = infer_and_parse_json_input(request_data.json(), input_schema)
266287
elif (
267288
mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_FORMAT_SPLIT_ORIENTED
268289
):
269290
data = parse_json_input(
270-
json_input=StringIO(flask.request.data.decode("utf-8")),
291+
json_input=StringIO(request_data.json()),
271292
orient="split",
272293
schema=input_schema,
273294
)
@@ -276,29 +297,25 @@ def transformation(): # pylint: disable=unused-variable
276297
and content_format == CONTENT_TYPE_FORMAT_RECORDS_ORIENTED
277298
):
278299
data = parse_json_input(
279-
json_input=StringIO(flask.request.data.decode("utf-8")),
300+
json_input=StringIO(request_data.json()),
280301
orient="records",
281302
schema=input_schema,
282303
)
283304
elif mime_type == CONTENT_TYPE_JSON_SPLIT_NUMPY and not content_format:
284-
data = parse_split_oriented_json_input_to_numpy(flask.request.data.decode("utf-8"))
305+
data = parse_split_oriented_json_input_to_numpy(request_data.json())
285306
else:
286-
return flask.Response(
287-
response=(
288-
"This predictor only supports the following content types and formats:"
307+
return Response(
308+
content="This predictor only supports the following content types and formats:"
289309
" Types: {supported_content_types}; Formats: {formats}."
290310
" Got '{received_content_type}'.".format(
291311
supported_content_types=CONTENT_TYPES,
292312
formats=FORMATS,
293-
received_content_type=flask.request.content_type,
294-
)
295-
),
296-
status=415,
297-
mimetype="text/plain",
313+
received_content_type=content_type,
314+
),
315+
status_code=415,
316+
media_type="text/plain"
298317
)
299318

300-
# Do the prediction
301-
302319
try:
303320
raw_predictions = model.predict(data)
304321
except MlflowException as e:
@@ -314,11 +331,10 @@ def transformation(): # pylint: disable=unused-variable
314331
),
315332
error_code=BAD_REQUEST,
316333
)
317-
result = StringIO()
318-
predictions_to_json(raw_predictions, result)
319-
return flask.Response(response=result.getvalue(), status=200, mimetype="application/json")
334+
predictions = _get_jsonable_obj(raw_predictions, pandas_orient="records")
335+
return predictions
320336

321-
return app
337+
return fast_app
322338

323339

324340
def _predict(model_uri, input_path, output_path, content_type, json_format):
@@ -342,8 +358,8 @@ def _predict(model_uri, input_path, output_path, content_type, json_format):
342358

343359
def _serve(model_uri, port, host):
344360
pyfunc_model = load_model(model_uri)
345-
init(pyfunc_model).run(port=port, host=host)
346-
361+
fast_app = init(pyfunc_model)
362+
uvicorn.run(fast_app, host=host, port=port, log_level="info")
347363

348364
def get_cmd(
349365
model_uri: str, port: int = None, host: int = None, nworkers: int = None
@@ -362,8 +378,7 @@ def get_cmd(
362378
args.append(f"-w {nworkers}")
363379

364380
command = (
365-
f"gunicorn {' '.join(args)} ${{GUNICORN_CMD_ARGS}}"
366-
" -- mlflow.pyfunc.scoring_server.wsgi:app"
381+
"gunicorn mlflow.pyfunc.scoring_server.wsgi:app --worker-class uvicorn.workers.UvicornWorker"
367382
)
368383
else:
369384
args = []

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def package_files(directory):
6868
"alembic<=1.4.1",
6969
# Required
7070
"docker>=4.0.0",
71+
"fastapi",
72+
"uvicorn",
7173
"Flask",
7274
"gunicorn; platform_system != 'Windows'",
7375
"numpy",

0 commit comments

Comments
 (0)