Skip to content

Commit cacdcf4

Browse files
committed
refactor: unify ModelForwarders to make tflite_serving responsible of prediction types
1 parent 1854dd0 commit cacdcf4

16 files changed

Lines changed: 113 additions & 487 deletions

edge_model_serving/tflite_serving/src/tflite_serving/api_routes.py

Lines changed: 0 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
import io
22
import logging
33
from typing import Any, Dict, List, Optional, cast
4-
from typing import Any, Dict, List, Optional, cast
54

65
import numpy as np
76
from fastapi import APIRouter, HTTPException, Request
87
from PIL import Image
98

10-
from tflite_serving.schemas import (
11-
ClassificationPrediction,
12-
DetectedObject,
13-
DetectionPrediction,
14-
ModelMetadataResponse,
15-
PredictionResponse,
16-
)
17-
from PIL import Image
18-
199
from tflite_serving.schemas import (
2010
ClassificationPrediction,
2111
DetectedObject,
@@ -142,159 +132,17 @@ def _get_state(request: Any) -> tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]
142132
# ---------------------------------------------------------------------------
143133

144134

145-
# ---------------------------------------------------------------------------
146-
# Helpers
147-
# ---------------------------------------------------------------------------
148-
149-
150-
def _preprocess(
151-
image: Image.Image,
152-
input_shape: List[int],
153-
input_dtype: Any,
154-
normalization: str,
155-
) -> np.ndarray:
156-
"""Resize, reformat, and normalize an image to match the model's expected input."""
157-
_, height, width, channels = input_shape
158-
image = image.convert("L" if channels == 1 else "RGB")
159-
image = image.resize((width, height), Image.Resampling.LANCZOS)
160-
arr = np.array(image, dtype=np.float32)
161-
if normalization == "mobilenet":
162-
arr = (arr / 127.0) - 1.0
163-
elif normalization == "yolo":
164-
arr = arr / 255.0
165-
# "uint8" → no normalization, cast to target dtype below
166-
arr = np.expand_dims(arr, axis=0).astype(input_dtype)
167-
return arr
168-
169-
170-
def _postprocess_classification(
171-
outputs: List[np.ndarray], class_names: List[str]
172-
) -> ClassificationPrediction:
173-
scores = outputs[0][0]
174-
best_idx = int(np.argmax(scores))
175-
return ClassificationPrediction(
176-
prediction_type="class",
177-
label=class_names[best_idx],
178-
probability=round(float(scores[best_idx]), 5),
179-
)
180-
181-
182-
def _postprocess_object_detection(
183-
outputs: List[np.ndarray], class_names: List[str]
184-
) -> DetectionPrediction:
185-
boxes = outputs[0]
186-
classes = outputs[1].astype(int)
187-
scores = outputs[2]
188-
189-
detected_objects: Dict[str, DetectedObject] = {}
190-
for i, (box, cls, score) in enumerate(zip(boxes[0], classes[0], scores[0])):
191-
label = class_names[cls] if cls < len(class_names) else str(cls)
192-
detected_objects[f"object_{i + 1}"] = DetectedObject(
193-
location=[round(c, 4) for c in box.tolist()],
194-
objectness=round(float(score), 5),
195-
label=label,
196-
)
197-
return DetectionPrediction(
198-
prediction_type="objects", detected_objects=detected_objects
199-
)
200-
201-
202-
def _postprocess_yolo(
203-
outputs: List[np.ndarray], class_names: List[str], input_array: np.ndarray
204-
) -> DetectionPrediction:
205-
raw = outputs[0][0]
206-
# Rotate the YOLO output tensor
207-
rotated = []
208-
for i in range(len(raw[0]), 0, -1):
209-
rotated.append([x[i - 1] for x in raw])
210-
rotated = np.array(rotated)
211-
212-
boxes, scores, class_ids = yolo_extract_boxes_information(rotated)
213-
boxes, scores, class_ids = non_max_suppression(boxes, scores, class_ids)
214-
severities = compute_severities(input_array[0], boxes)
215-
216-
detected_objects: Dict[str, DetectedObject] = {}
217-
for i, (box, score, cls_id, severity) in enumerate(
218-
zip(boxes, scores, class_ids, severities)
219-
):
220-
label = (
221-
class_names[int(cls_id)]
222-
if int(cls_id) < len(class_names)
223-
else str(int(cls_id))
224-
)
225-
detected_objects[f"object_{i + 1}"] = DetectedObject(
226-
location=[round(c, 4) for c in box],
227-
objectness=round(float(score), 5),
228-
label=label,
229-
severity=severity,
230-
)
231-
return DetectionPrediction(
232-
prediction_type="objects", detected_objects=detected_objects
233-
)
234-
235-
236-
def _get_state(request: Any) -> tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
237-
"""Extract typed model registries from app state.
238-
239-
``request`` is typed ``Any`` because ``Request.app`` is ``ASGIApp`` in
240-
Starlette's stubs and does not expose ``.state`` — this is a known
241-
Starlette limitation with no clean typing solution at call-site level.
242-
"""
243-
interpreters = cast(Dict[str, Any], request.app.state.model_interpreters)
244-
metadata_registry = cast(
245-
Dict[str, Dict[str, Any]], request.app.state.model_metadata
246-
)
247-
return interpreters, metadata_registry
248-
249-
250-
# ---------------------------------------------------------------------------
251-
# Routes
252-
# ---------------------------------------------------------------------------
253-
254-
255135
@api_router.get("/")
256136
async def info() -> str:
257137
return "tflite-server docs at ip:port/docs"
258-
async def info() -> str:
259-
return "tflite-server docs at ip:port/docs"
260138

261139

262140
@api_router.get("/models")
263141
async def get_models(request: Request) -> List[str]:
264142
interpreters, _ = _get_state(request)
265143
return list(interpreters.keys())
266-
async def get_models(request: Request) -> List[str]:
267-
interpreters, _ = _get_state(request)
268-
return list(interpreters.keys())
269-
270-
271-
@api_router.get(
272-
"/models/{model_name}/metadata",
273-
response_model=ModelMetadataResponse,
274-
)
275-
async def get_model_metadata(
276-
model_name: str, request: Request
277-
) -> ModelMetadataResponse:
278-
interpreters, metadata_registry = _get_state(request)
279-
if model_name not in interpreters:
280-
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
281-
interpreter = interpreters[model_name]
282-
input_details: List[Dict[str, Any]] = interpreter.get_input_details()
283-
metadata: Dict[str, Any] = metadata_registry.get(model_name, {})
284-
return ModelMetadataResponse(
285-
input_shape=input_details[0]["shape"].tolist(),
286-
input_dtype=np.dtype(input_details[0]["dtype"]).name,
287-
output_type=metadata.get("output_type"),
288-
class_names=metadata.get("class_names"),
289-
normalization=metadata.get("normalization"),
290-
)
291144

292145

293-
@api_router.post(
294-
"/models/{model_name}/versions/{model_version}:predict",
295-
response_model=PredictionResponse,
296-
response_model_exclude_none=True,
297-
)
298146
@api_router.get(
299147
"/models/{model_name}/metadata",
300148
response_model=ModelMetadataResponse,
@@ -349,35 +197,6 @@ async def predict(
349197
input_shape: List[int] = input_details[0]["shape"].tolist()
350198
input_dtype: Any = input_details[0]["dtype"]
351199

352-
logging.info(
353-
f"Predicting with '{model_name}' | shape={input_shape} | output_type={output_type}"
354-
)
355-
model_name: str, model_version: str, request: Request # noqa: ARG001
356-
) -> PredictionResponse:
357-
interpreters, metadata_registry = _get_state(request)
358-
if model_name not in interpreters:
359-
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
360-
361-
interpreter = interpreters[model_name]
362-
metadata: Dict[str, Any] = metadata_registry.get(model_name, {})
363-
output_type: Optional[str] = metadata.get("output_type")
364-
class_names: List[str] = metadata.get("class_names", [])
365-
normalization: str = metadata.get("normalization", "uint8")
366-
367-
if output_type is None:
368-
raise HTTPException(
369-
status_code=422,
370-
detail=(
371-
f"No metadata.json found for model '{model_name}'. "
372-
"Add a metadata.json alongside the .tflite file to enable inference."
373-
),
374-
)
375-
376-
input_details: List[Dict[str, Any]] = interpreter.get_input_details()
377-
output_details: List[Dict[str, Any]] = interpreter.get_output_details()
378-
input_shape: List[int] = input_details[0]["shape"].tolist()
379-
input_dtype: Any = input_details[0]["dtype"]
380-
381200
logging.info(
382201
f"Predicting with '{model_name}' | shape={input_shape} | output_type={output_type}"
383202
)
@@ -389,14 +208,6 @@ async def predict(
389208
status_code=400, detail="Empty request body — expected raw image bytes"
390209
)
391210

392-
image = Image.open(io.BytesIO(body))
393-
input_array = _preprocess(image, input_shape, input_dtype, normalization)
394-
body: bytes = await request.body()
395-
if not body:
396-
raise HTTPException(
397-
status_code=400, detail="Empty request body — expected raw image bytes"
398-
)
399-
400211
image = Image.open(io.BytesIO(body))
401212
input_array = _preprocess(image, input_shape, input_dtype, normalization)
402213

@@ -418,24 +229,6 @@ async def predict(
418229
detail=f"Unknown output_type '{output_type}' in metadata.json for model '{model_name}'.",
419230
)
420231

421-
except HTTPException:
422-
raise
423-
outputs: List[np.ndarray] = [
424-
interpreter.get_tensor(d["index"]) for d in output_details
425-
]
426-
427-
if output_type == "classification":
428-
return _postprocess_classification(outputs, class_names)
429-
elif output_type == "object_detection":
430-
return _postprocess_object_detection(outputs, class_names)
431-
elif output_type == "yolo":
432-
return _postprocess_yolo(outputs, class_names, input_array)
433-
else:
434-
raise HTTPException(
435-
status_code=422,
436-
detail=f"Unknown output_type '{output_type}' in metadata.json for model '{model_name}'.",
437-
)
438-
439232
except HTTPException:
440233
raise
441234
except Exception as e:

edge_orchestrator/config/marker_classif_with_1_raspberry_cam.json

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,8 @@
77
"position": "front",
88
"model_forwarder_config": {
99
"model_name": "marker_quality_control",
10-
"model_type": "classification",
1110
"model_serving_url": "http://edge_model_serving:8501/",
12-
"expected_image_resolution": {
13-
"width": 224,
14-
"height": 224
15-
},
16-
"model_version": "1",
17-
"class_names": [
18-
"OK",
19-
"KO"
20-
]
11+
"model_version": "1"
2112
},
2213
"camera_rule_config": {
2314
"camera_rule_type": "expected_label_rule",
@@ -38,4 +29,4 @@
3829
"expected_decision": "OK",
3930
"threshold": 1
4031
}
41-
}
32+
}

edge_orchestrator/config/marker_classif_with_1_usb_cam.json

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,8 @@
88
"position": "front",
99
"model_forwarder_config": {
1010
"model_name": "marker_quality_control",
11-
"model_type": "classification",
1211
"model_serving_url": "http://edge_model_serving:8501/",
13-
"expected_image_resolution": {
14-
"width": 224,
15-
"height": 224
16-
},
17-
"model_version": "1",
18-
"class_names": [
19-
"OK",
20-
"KO"
21-
]
12+
"model_version": "1"
2213
},
2314
"camera_rule_config": {
2415
"camera_rule_type": "expected_label_rule",
@@ -39,4 +30,4 @@
3930
"expected_decision": "OK",
4031
"threshold": 1
4132
}
42-
}
33+
}

edge_orchestrator/config/marker_classif_with_2_fake_cam.json

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,8 @@
88
"position": "front",
99
"model_forwarder_config": {
1010
"model_name": "marker_quality_control",
11-
"model_type": "classification",
1211
"model_serving_url": "http://edge_model_serving:8501/",
13-
"expected_image_resolution": {
14-
"width": 224,
15-
"height": 224
16-
},
17-
"model_version": "1",
18-
"class_names": [
19-
"OK",
20-
"KO"
21-
]
12+
"model_version": "1"
2213
},
2314
"camera_rule_config": {
2415
"camera_rule_type": "expected_label_rule",
@@ -32,17 +23,8 @@
3223
"position": "back",
3324
"model_forwarder_config": {
3425
"model_name": "marker_quality_control",
35-
"model_type": "classification",
3626
"model_serving_url": "http://edge_model_serving:8501/",
37-
"expected_image_resolution": {
38-
"width": 224,
39-
"height": 224
40-
},
41-
"model_version": "1",
42-
"class_names": [
43-
"OK",
44-
"KO"
45-
]
27+
"model_version": "1"
4628
},
4729
"camera_rule_config": {
4830
"camera_rule_type": "expected_label_rule",
@@ -63,4 +45,4 @@
6345
"expected_decision": "OK",
6446
"threshold": 1
6547
}
66-
}
48+
}

edge_orchestrator/config/mobilenet_ssd_v2_coco_with_2_usb_cam.json

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,8 @@
88
"position": "front",
99
"model_forwarder_config": {
1010
"model_name": "yolo_coco_nano",
11-
"model_type": "object_detection",
1211
"model_serving_url": "http://edge_model_serving:8501/",
13-
"expected_image_resolution": {
14-
"width": 320,
15-
"height": 320
16-
},
17-
"model_version": "1",
18-
"class_names_filepath": "model_labels/coco_labels.txt"
12+
"model_version": "1"
1913
},
2014
"camera_rule_config": {
2115
"camera_rule_type": "min_nb_objects_rule",
@@ -30,14 +24,8 @@
3024
"position": "back",
3125
"model_forwarder_config": {
3226
"model_name": "yolo_coco_nano",
33-
"model_type": "object_detection",
3427
"model_serving_url": "http://edge_model_serving:8501/",
35-
"expected_image_resolution": {
36-
"width": 320,
37-
"height": 320
38-
},
39-
"model_version": "1",
40-
"class_names_filepath": "model_labels/coco_labels.txt"
28+
"model_version": "1"
4129
},
4230
"camera_rule_config": {
4331
"camera_rule_type": "min_nb_objects_rule",
@@ -59,4 +47,4 @@
5947
"expected_decision": "OK",
6048
"threshold": 1
6149
}
62-
}
50+
}

0 commit comments

Comments
 (0)