Skip to content

Commit b25cf0e

Browse files
committed
Sketch out getting model from mlflow for inference
1 parent 5e89f3b commit b25cf0e

2 files changed

Lines changed: 58 additions & 22 deletions

File tree

bats_ai/core/management/commands/registeronnxmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def command():
2222
mlflow.onnx.log_model(
2323
onnx_model=onnx_model,
2424
artifact_path='onnx_model',
25+
# save_as_external_data=True,
2526
)
2627
model_uri = f'runs:/{run_id}/onnx_model'
2728
result = mlflow.register_model(model_uri=model_uri, name='prototype')

bats_ai/tasks/tasks.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import math
3+
import os
34
import tempfile
45

56
from PIL import Image
@@ -431,36 +432,57 @@ def predict(compressed_spectrogram_id: int):
431432
return label, score, confs
432433

433434

434-
def predict_compressed(image_file):
435+
def _fully_local_inference(image_file, use_mlflow_model):
435436
import json
436-
import os
437437

438438
import onnx
439439
import onnxruntime as ort
440440
import tqdm
441441

442442
img = Image.open(image_file)
443443

444-
relative = ('..',) * 3
445-
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))
446-
447-
onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
448-
assert os.path.exists(onnx_filename)
449-
450-
session = ort.InferenceSession(
451-
onnx_filename,
452-
providers=[
453-
(
454-
'CUDAExecutionProvider',
455-
{
456-
'cudnn_conv_use_max_workspace': '1',
457-
'device_id': 0,
458-
'cudnn_conv_algo_search': 'HEURISTIC',
459-
},
460-
),
461-
'CPUExecutionProvider',
462-
],
463-
)
444+
if not use_mlflow_model:
445+
relative = ('..',) * 3
446+
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))
447+
448+
onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
449+
assert os.path.exists(onnx_filename)
450+
451+
session = ort.InferenceSession(
452+
onnx_filename,
453+
providers=[
454+
(
455+
'CUDAExecutionProvider',
456+
{
457+
'cudnn_conv_use_max_workspace': '1',
458+
'device_id': 0,
459+
'cudnn_conv_algo_search': 'HEURISTIC',
460+
},
461+
),
462+
'CPUExecutionProvider',
463+
],
464+
)
465+
else:
466+
import mlflow
467+
import mlflow.onnx
468+
469+
MODEL_URI = 'models:/prototype/1'
470+
mlflow.set_tracking_uri(settings.MLFLOW_ENDPOINT)
471+
model = mlflow.onnx.load_model(model_uri=MODEL_URI)
472+
session = ort.InferenceSession(
473+
model.SerializeToString(),
474+
providers=[
475+
(
476+
'CUDAExecutionProvider',
477+
{
478+
'cudnn_conv_use_max_workspace': '1',
479+
'device_id': 0,
480+
'cudnn_conv_algo_search': 'HEURISTIC',
481+
},
482+
),
483+
'CPUExecutionProvider',
484+
],
485+
)
464486

465487
img = np.array(img)
466488

@@ -507,6 +529,19 @@ def predict_compressed(image_file):
507529
return label, score, confs
508530

509531

532+
def predict_compressed(image_file):
533+
# 0: use the local file and do inference with that
534+
# 1: get the file from mlflow and do inference locally
535+
# 2: do inference from deployed mlflow model
536+
inference_mode = int(os.getenv('INFERENCE_MODE', 0))
537+
if inference_mode == 1:
538+
pass
539+
elif inference_mode == 2:
540+
pass
541+
else:
542+
return _fully_local_inference(image_file, False)
543+
544+
510545
def train_body(experiment_name: str):
511546
import mlflow
512547
from mlflow.models import infer_signature

0 commit comments

Comments
 (0)