11import io
22import logging
33from typing import Any , Dict , List , Optional , cast
4- from typing import Any , Dict , List , Optional , cast
54
65import numpy as np
76from fastapi import APIRouter , HTTPException , Request
87from PIL import Image
98
10- from tflite_serving .schemas import (
11- ClassificationPrediction ,
12- DetectedObject ,
13- DetectionPrediction ,
14- ModelMetadataResponse ,
15- PredictionResponse ,
16- )
17- from PIL import Image
18-
199from tflite_serving .schemas import (
2010 ClassificationPrediction ,
2111 DetectedObject ,
@@ -142,159 +132,17 @@ def _get_state(request: Any) -> tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]
142132# ---------------------------------------------------------------------------
143133
144134
145- # ---------------------------------------------------------------------------
146- # Helpers
147- # ---------------------------------------------------------------------------
148-
149-
150- def _preprocess (
151- image : Image .Image ,
152- input_shape : List [int ],
153- input_dtype : Any ,
154- normalization : str ,
155- ) -> np .ndarray :
156- """Resize, reformat, and normalize an image to match the model's expected input."""
157- _ , height , width , channels = input_shape
158- image = image .convert ("L" if channels == 1 else "RGB" )
159- image = image .resize ((width , height ), Image .Resampling .LANCZOS )
160- arr = np .array (image , dtype = np .float32 )
161- if normalization == "mobilenet" :
162- arr = (arr / 127.0 ) - 1.0
163- elif normalization == "yolo" :
164- arr = arr / 255.0
165- # "uint8" → no normalization, cast to target dtype below
166- arr = np .expand_dims (arr , axis = 0 ).astype (input_dtype )
167- return arr
168-
169-
170- def _postprocess_classification (
171- outputs : List [np .ndarray ], class_names : List [str ]
172- ) -> ClassificationPrediction :
173- scores = outputs [0 ][0 ]
174- best_idx = int (np .argmax (scores ))
175- return ClassificationPrediction (
176- prediction_type = "class" ,
177- label = class_names [best_idx ],
178- probability = round (float (scores [best_idx ]), 5 ),
179- )
180-
181-
182- def _postprocess_object_detection (
183- outputs : List [np .ndarray ], class_names : List [str ]
184- ) -> DetectionPrediction :
185- boxes = outputs [0 ]
186- classes = outputs [1 ].astype (int )
187- scores = outputs [2 ]
188-
189- detected_objects : Dict [str , DetectedObject ] = {}
190- for i , (box , cls , score ) in enumerate (zip (boxes [0 ], classes [0 ], scores [0 ])):
191- label = class_names [cls ] if cls < len (class_names ) else str (cls )
192- detected_objects [f"object_{ i + 1 } " ] = DetectedObject (
193- location = [round (c , 4 ) for c in box .tolist ()],
194- objectness = round (float (score ), 5 ),
195- label = label ,
196- )
197- return DetectionPrediction (
198- prediction_type = "objects" , detected_objects = detected_objects
199- )
200-
201-
202- def _postprocess_yolo (
203- outputs : List [np .ndarray ], class_names : List [str ], input_array : np .ndarray
204- ) -> DetectionPrediction :
205- raw = outputs [0 ][0 ]
206- # Rotate the YOLO output tensor
207- rotated = []
208- for i in range (len (raw [0 ]), 0 , - 1 ):
209- rotated .append ([x [i - 1 ] for x in raw ])
210- rotated = np .array (rotated )
211-
212- boxes , scores , class_ids = yolo_extract_boxes_information (rotated )
213- boxes , scores , class_ids = non_max_suppression (boxes , scores , class_ids )
214- severities = compute_severities (input_array [0 ], boxes )
215-
216- detected_objects : Dict [str , DetectedObject ] = {}
217- for i , (box , score , cls_id , severity ) in enumerate (
218- zip (boxes , scores , class_ids , severities )
219- ):
220- label = (
221- class_names [int (cls_id )]
222- if int (cls_id ) < len (class_names )
223- else str (int (cls_id ))
224- )
225- detected_objects [f"object_{ i + 1 } " ] = DetectedObject (
226- location = [round (c , 4 ) for c in box ],
227- objectness = round (float (score ), 5 ),
228- label = label ,
229- severity = severity ,
230- )
231- return DetectionPrediction (
232- prediction_type = "objects" , detected_objects = detected_objects
233- )
234-
235-
236- def _get_state (request : Any ) -> tuple [Dict [str , Any ], Dict [str , Dict [str , Any ]]]:
237- """Extract typed model registries from app state.
238-
239- ``request`` is typed ``Any`` because ``Request.app`` is ``ASGIApp`` in
240- Starlette's stubs and does not expose ``.state`` — this is a known
241- Starlette limitation with no clean typing solution at call-site level.
242- """
243- interpreters = cast (Dict [str , Any ], request .app .state .model_interpreters )
244- metadata_registry = cast (
245- Dict [str , Dict [str , Any ]], request .app .state .model_metadata
246- )
247- return interpreters , metadata_registry
248-
249-
250- # ---------------------------------------------------------------------------
251- # Routes
252- # ---------------------------------------------------------------------------
253-
254-
255135@api_router .get ("/" )
256136async def info () -> str :
257137 return "tflite-server docs at ip:port/docs"
258- async def info () -> str :
259- return "tflite-server docs at ip:port/docs"
260138
261139
262140@api_router .get ("/models" )
263141async def get_models (request : Request ) -> List [str ]:
264142 interpreters , _ = _get_state (request )
265143 return list (interpreters .keys ())
266- async def get_models (request : Request ) -> List [str ]:
267- interpreters , _ = _get_state (request )
268- return list (interpreters .keys ())
269-
270-
271- @api_router .get (
272- "/models/{model_name}/metadata" ,
273- response_model = ModelMetadataResponse ,
274- )
275- async def get_model_metadata (
276- model_name : str , request : Request
277- ) -> ModelMetadataResponse :
278- interpreters , metadata_registry = _get_state (request )
279- if model_name not in interpreters :
280- raise HTTPException (status_code = 404 , detail = f"Model '{ model_name } ' not found" )
281- interpreter = interpreters [model_name ]
282- input_details : List [Dict [str , Any ]] = interpreter .get_input_details ()
283- metadata : Dict [str , Any ] = metadata_registry .get (model_name , {})
284- return ModelMetadataResponse (
285- input_shape = input_details [0 ]["shape" ].tolist (),
286- input_dtype = np .dtype (input_details [0 ]["dtype" ]).name ,
287- output_type = metadata .get ("output_type" ),
288- class_names = metadata .get ("class_names" ),
289- normalization = metadata .get ("normalization" ),
290- )
291144
292145
293- @api_router .post (
294- "/models/{model_name}/versions/{model_version}:predict" ,
295- response_model = PredictionResponse ,
296- response_model_exclude_none = True ,
297- )
298146@api_router .get (
299147 "/models/{model_name}/metadata" ,
300148 response_model = ModelMetadataResponse ,
@@ -349,35 +197,6 @@ async def predict(
349197 input_shape : List [int ] = input_details [0 ]["shape" ].tolist ()
350198 input_dtype : Any = input_details [0 ]["dtype" ]
351199
352- logging .info (
353- f"Predicting with '{ model_name } ' | shape={ input_shape } | output_type={ output_type } "
354- )
355- model_name : str , model_version : str , request : Request # noqa: ARG001
356- ) -> PredictionResponse :
357- interpreters , metadata_registry = _get_state (request )
358- if model_name not in interpreters :
359- raise HTTPException (status_code = 404 , detail = f"Model '{ model_name } ' not found" )
360-
361- interpreter = interpreters [model_name ]
362- metadata : Dict [str , Any ] = metadata_registry .get (model_name , {})
363- output_type : Optional [str ] = metadata .get ("output_type" )
364- class_names : List [str ] = metadata .get ("class_names" , [])
365- normalization : str = metadata .get ("normalization" , "uint8" )
366-
367- if output_type is None :
368- raise HTTPException (
369- status_code = 422 ,
370- detail = (
371- f"No metadata.json found for model '{ model_name } '. "
372- "Add a metadata.json alongside the .tflite file to enable inference."
373- ),
374- )
375-
376- input_details : List [Dict [str , Any ]] = interpreter .get_input_details ()
377- output_details : List [Dict [str , Any ]] = interpreter .get_output_details ()
378- input_shape : List [int ] = input_details [0 ]["shape" ].tolist ()
379- input_dtype : Any = input_details [0 ]["dtype" ]
380-
381200 logging .info (
382201 f"Predicting with '{ model_name } ' | shape={ input_shape } | output_type={ output_type } "
383202 )
@@ -389,14 +208,6 @@ async def predict(
389208 status_code = 400 , detail = "Empty request body — expected raw image bytes"
390209 )
391210
392- image = Image .open (io .BytesIO (body ))
393- input_array = _preprocess (image , input_shape , input_dtype , normalization )
394- body : bytes = await request .body ()
395- if not body :
396- raise HTTPException (
397- status_code = 400 , detail = "Empty request body — expected raw image bytes"
398- )
399-
400211 image = Image .open (io .BytesIO (body ))
401212 input_array = _preprocess (image , input_shape , input_dtype , normalization )
402213
@@ -418,24 +229,6 @@ async def predict(
418229 detail = f"Unknown output_type '{ output_type } ' in metadata.json for model '{ model_name } '." ,
419230 )
420231
421- except HTTPException :
422- raise
423- outputs : List [np .ndarray ] = [
424- interpreter .get_tensor (d ["index" ]) for d in output_details
425- ]
426-
427- if output_type == "classification" :
428- return _postprocess_classification (outputs , class_names )
429- elif output_type == "object_detection" :
430- return _postprocess_object_detection (outputs , class_names )
431- elif output_type == "yolo" :
432- return _postprocess_yolo (outputs , class_names , input_array )
433- else :
434- raise HTTPException (
435- status_code = 422 ,
436- detail = f"Unknown output_type '{ output_type } ' in metadata.json for model '{ model_name } '." ,
437- )
438-
439232 except HTTPException :
440233 raise
441234 except Exception as e :
0 commit comments