20
20
import pandas as pd
21
21
import sys
22
22
import traceback
23
+ from pydantic import BaseModel
24
+ from fastapi import FastAPI , APIRouter , Request , HTTPException , Response , Header , status
25
+ from typing import List , Optional , Dict
26
+ import uvicorn
27
+ import asyncio
28
+ import json
23
29
24
30
# NB: We need to be careful what we import form mlflow here. Scoring server is used from within
25
31
# model's conda environment. The version of mlflow doing the serving (outside) and the version of
65
71
66
72
CONTENT_TYPE_FORMAT_RECORDS_ORIENTED = "pandas-records"
67
73
CONTENT_TYPE_FORMAT_SPLIT_ORIENTED = "pandas-split"
74
+ CONTENT_TYPE_RAW_JSON = "raw-json"
68
75
69
- FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED , CONTENT_TYPE_FORMAT_SPLIT_ORIENTED ]
76
+ FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED , CONTENT_TYPE_FORMAT_SPLIT_ORIENTED , CONTENT_TYPE_RAW_JSON ]
70
77
71
78
PREDICTIONS_WRAPPER_ATTR_NAME_ENV_KEY = "PREDICTIONS_WRAPPER_ATTR_NAME"
72
79
73
80
_logger = logging .getLogger (__name__ )
74
81
82
+ class RequestData (BaseModel ):
83
+ columns : List [str ] = []
84
+ data : list = []
85
+
86
+ def is_valid (self ):
87
+ return True
88
+
89
+ def get_dataframe (self ):
90
+ df = pd .DataFrame (data = self .data , columns = self .columns )
91
+ return df
75
92
76
93
def infer_and_parse_json_input (json_input , schema : Schema = None ):
77
94
"""
@@ -205,38 +222,38 @@ def _handle_serving_error(error_message, error_code, include_traceback=True):
205
222
e = MlflowException (message = error_message , error_code = error_code )
206
223
reraise (MlflowException , e )
207
224
208
-
209
225
def init (model : PyFuncModel ):
210
226
211
227
"""
212
228
Initialize the server. Loads pyfunc model from the path.
213
229
"""
214
- app = flask .Flask (__name__ )
230
+ fast_app = FastAPI (title = __name__ , version = "v1" )
231
+ fast_app .include_router (APIRouter ())
215
232
input_schema = model .metadata .get_input_schema ()
216
233
217
- @app . route ("/ping" , methods = [ "GET" ] )
234
+ @fast_app . get ("/ping" )
218
235
def ping (): # pylint: disable=unused-variable
219
236
"""
220
237
Determine if the container is working and healthy.
221
238
We declare it healthy if we can load the model successfully.
222
239
"""
223
- health = model is not None
224
- status = 200 if health else 404
225
- return flask . Response ( response = " \n " , status = status , mimetype = "application/json" )
240
+ if model is None :
241
+ raise HTTPException ( status_code = 404 , detail = "Model not loaded properly" )
242
+ return { "message" : "OK" }
226
243
227
- @app .route ("/invocations" , methods = ["POST" ])
228
- @catch_mlflow_exception
229
- def transformation (): # pylint: disable=unused-variable
244
+ @fast_app .post ("/invocations" )
245
+ def transformation (request_data : RequestData , content_type : Optional [str ] = Header (None )): # pylint: disable=unused-variable
230
246
"""
231
247
Do an inference on a single batch of data. In this sample server,
232
248
we take data as CSV or json, convert it to a Pandas DataFrame or Numpy,
233
249
generate predictions and convert them back to json.
234
250
"""
251
+ # data = _dataframe_from_json(request_data.json())
235
252
236
253
# Content-Type can include other attributes like CHARSET
237
254
# Content-type RFC: https://datatracker.ietf.org/doc/html/rfc2045#section-5.1
238
255
# TODO: Suport ";" in quoted parameter values
239
- type_parts = flask . request . content_type .split (";" )
256
+ type_parts = content_type .split (";" )
240
257
type_parts = list (map (str .strip , type_parts ))
241
258
mime_type = type_parts [0 ]
242
259
parameter_value_pairs = type_parts [1 :]
@@ -247,27 +264,31 @@ def transformation(): # pylint: disable=unused-variable
247
264
248
265
charset = parameter_values .get ("charset" , "utf-8" ).lower ()
249
266
if charset != "utf-8" :
250
- return flask . Response (
251
- response = "The scoring server only supports UTF-8" ,
252
- status = 415 ,
253
- mimetype = "text/plain" ,
267
+ return Response (
268
+ content = "The scoring server only supports UTF-8" ,
269
+ status_code = 415 ,
270
+ media_type = "text/plain"
254
271
)
255
272
256
273
content_format = parameter_values .get ("format" )
257
274
258
275
# Convert from CSV to pandas
259
276
if mime_type == CONTENT_TYPE_CSV and not content_format :
260
- data = flask . request . data . decode ( "utf-8" )
277
+ data = request_data . json ( )
261
278
csv_input = StringIO (data )
262
279
data = parse_csv_input (csv_input = csv_input )
280
+ elif mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_RAW_JSON :
281
+ if len (request_data .data ) != 0 :
282
+ data = dict (zip (request_data .columns , request_data .data [0 ]))
283
+ else :
284
+ data = {}
263
285
elif mime_type == CONTENT_TYPE_JSON and not content_format :
264
- json_str = flask .request .data .decode ("utf-8" )
265
- data = infer_and_parse_json_input (json_str , input_schema )
286
+ data = infer_and_parse_json_input (request_data .json (), input_schema )
266
287
elif (
267
288
mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_FORMAT_SPLIT_ORIENTED
268
289
):
269
290
data = parse_json_input (
270
- json_input = StringIO (flask . request . data . decode ( "utf-8" )),
291
+ json_input = StringIO (request_data . json ( )),
271
292
orient = "split" ,
272
293
schema = input_schema ,
273
294
)
@@ -276,29 +297,25 @@ def transformation(): # pylint: disable=unused-variable
276
297
and content_format == CONTENT_TYPE_FORMAT_RECORDS_ORIENTED
277
298
):
278
299
data = parse_json_input (
279
- json_input = StringIO (flask . request . data . decode ( "utf-8" )),
300
+ json_input = StringIO (request_data . json ( )),
280
301
orient = "records" ,
281
302
schema = input_schema ,
282
303
)
283
304
elif mime_type == CONTENT_TYPE_JSON_SPLIT_NUMPY and not content_format :
284
- data = parse_split_oriented_json_input_to_numpy (flask . request . data . decode ( "utf-8" ))
305
+ data = parse_split_oriented_json_input_to_numpy (request_data . json ( ))
285
306
else :
286
- return flask .Response (
287
- response = (
288
- "This predictor only supports the following content types and formats:"
307
+ return Response (
308
+ content = "This predictor only supports the following content types and formats:"
289
309
" Types: {supported_content_types}; Formats: {formats}."
290
310
" Got '{received_content_type}'." .format (
291
311
supported_content_types = CONTENT_TYPES ,
292
312
formats = FORMATS ,
293
- received_content_type = flask .request .content_type ,
294
- )
295
- ),
296
- status = 415 ,
297
- mimetype = "text/plain" ,
313
+ received_content_type = content_type ,
314
+ ),
315
+ status_code = 415 ,
316
+ media_type = "text/plain"
298
317
)
299
318
300
- # Do the prediction
301
-
302
319
try :
303
320
raw_predictions = model .predict (data )
304
321
except MlflowException as e :
@@ -314,11 +331,10 @@ def transformation(): # pylint: disable=unused-variable
314
331
),
315
332
error_code = BAD_REQUEST ,
316
333
)
317
- result = StringIO ()
318
- predictions_to_json (raw_predictions , result )
319
- return flask .Response (response = result .getvalue (), status = 200 , mimetype = "application/json" )
334
+ predictions = _get_jsonable_obj (raw_predictions , pandas_orient = "records" )
335
+ return predictions
320
336
321
- return app
337
+ return fast_app
322
338
323
339
324
340
def _predict (model_uri , input_path , output_path , content_type , json_format ):
@@ -342,8 +358,8 @@ def _predict(model_uri, input_path, output_path, content_type, json_format):
342
358
343
359
def _serve (model_uri , port , host ):
344
360
pyfunc_model = load_model (model_uri )
345
- init (pyfunc_model ). run ( port = port , host = host )
346
-
361
+ fast_app = init (pyfunc_model )
362
+ uvicorn . run ( fast_app , host = host , port = port , log_level = "info" )
347
363
348
364
def get_cmd (
349
365
model_uri : str , port : int = None , host : int = None , nworkers : int = None
@@ -362,8 +378,7 @@ def get_cmd(
362
378
args .append (f"-w { nworkers } " )
363
379
364
380
command = (
365
- f"gunicorn { ' ' .join (args )} ${{GUNICORN_CMD_ARGS}}"
366
- " -- mlflow.pyfunc.scoring_server.wsgi:app"
381
+ "gunicorn mlflow.pyfunc.scoring_server.wsgi:app --worker-class uvicorn.workers.UvicornWorker"
367
382
)
368
383
else :
369
384
args = []
0 commit comments