Skip to content

Commit ff97493

Browse files
authored
Merge pull request #2 from DLR-KI/enhancement/inference
refactoring inference and add proprocessing download endpoint
2 parents d6bce32 + 9b9a062 commit ff97493

File tree

9 files changed

+514
-72
lines changed

9 files changed

+514
-72
lines changed

fl_server_api/openapi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def create_error_response(
6767
"""Generic OpenAPI 403 response."""
6868

6969

70+
error_response_404 = create_error_response(
71+
"Not found",
72+
"Not found",
73+
"The server cannot find the requested resource.",
74+
"Provide valid request data."
75+
)
76+
"""Generic OpenAPI 404 response."""
77+
78+
7079
def custom_preprocessing_hook(endpoints: List[Tuple[str, str, str, Callable]]):
7180
"""
7281
Hide the "/api/dummy/" endpoint from the OpenAPI schema.

fl_server_api/serializers/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def to_representation(self, instance):
9393
del data["weights"]
9494
if self.context.get("with-stats", False):
9595
data["stats"] = self.analyze_torch_model(instance)
96+
if isinstance(instance, GlobalModel):
97+
data["has_preprocessing"] = bool(instance.preprocessing)
9698
return data
9799

98100
def analyze_torch_model(self, instance: Model):
@@ -175,6 +177,7 @@ class ModelSerializerNoWeights(ModelSerializer):
175177
class Meta:
176178
model = Model
177179
exclude = ["polymorphic_ctype", "weights"]
180+
include = ["has_preprocessing"]
178181

179182

180183
class ModelSerializerNoWeightsWithStats(ModelSerializerNoWeights):
@@ -186,6 +189,7 @@ class Meta:
186189
model = Model
187190
exclude = ["polymorphic_ctype", "weights"]
188191
include = ["stats"]
192+
include = ["has_preprocessing", "stats"]
189193

190194

191195
#######################################################################################################################

fl_server_api/tests/test_inference.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
#
44
# SPDX-License-Identifier: Apache-2.0
55

6+
import base64
67
from django.core.files.uploadedfile import SimpleUploadedFile
78
from django.test import TestCase
89
import json
10+
import io
911
import pickle
1012
import torch
13+
import torch.nn
14+
from torchvision.transforms.functional import to_pil_image
1115
from uuid import uuid4
1216

1317
from fl_server_core.tests import BASE_URL, Dummy
@@ -168,3 +172,103 @@ def _inference_result(self, torch_model: torch.nn.Module):
168172
self.assertIsNotNone(inference)
169173
inference_tensor = torch.as_tensor(inference)
170174
self.assertTrue(torch.all(torch.tensor([2, 0, 0]) == inference_tensor))
175+
176+
def test_inference_input_shape_positive(self):
177+
inp = from_torch_tensor(torch.zeros(3, 3))
178+
model = Dummy.create_model(input_shape=[None, 3])
179+
training = Dummy.create_training(actor=self.user, model=model)
180+
input_file = SimpleUploadedFile(
181+
"input.pt",
182+
inp,
183+
content_type="application/octet-stream"
184+
)
185+
response = self.client.post(
186+
f"{BASE_URL}/inference/",
187+
{"model_id": str(training.model.id), "model_input": input_file}
188+
)
189+
self.assertEqual(response.status_code, 200)
190+
191+
def test_inference_input_shape_negative(self):
192+
inp = from_torch_tensor(torch.zeros(3, 3))
193+
model = Dummy.create_model(input_shape=[None, 5])
194+
training = Dummy.create_training(actor=self.user, model=model)
195+
input_file = SimpleUploadedFile(
196+
"input.pt",
197+
inp,
198+
content_type="application/octet-stream"
199+
)
200+
with self.assertLogs("root", level="WARNING") as cm:
201+
response = self.client.post(
202+
f"{BASE_URL}/inference/",
203+
{"model_id": str(training.model.id), "model_input": input_file}
204+
)
205+
self.assertEqual(cm.output, [
206+
"WARNING:django.request:Bad Request: /api/inference/",
207+
])
208+
self.assertEqual(response.status_code, 400)
209+
self.assertEqual(response.json()[0], "Input shape does not match model input shape.")
210+
211+
def test_inference_input_pil_image(self):
212+
img = to_pil_image(torch.zeros(1, 5, 5))
213+
img_byte_arr = io.BytesIO()
214+
img.save(img_byte_arr, format="jpeg")
215+
img_byte_arr = img_byte_arr.getvalue()
216+
217+
torch.manual_seed(42)
218+
torch_model = torch.jit.script(torch.nn.Sequential(
219+
torch.nn.Conv2d(1, 2, 3),
220+
torch.nn.Flatten(),
221+
torch.nn.Linear(3*3, 2)
222+
))
223+
model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model))
224+
training = Dummy.create_training(actor=self.user, model=model)
225+
input_file = SimpleUploadedFile(
226+
"input.pt",
227+
img_byte_arr,
228+
content_type="application/octet-stream"
229+
)
230+
response = self.client.post(
231+
f"{BASE_URL}/inference/",
232+
{"model_id": str(training.model.id), "model_input": input_file}
233+
)
234+
self.assertEqual(response.status_code, 200)
235+
236+
results = pickle.loads(response.content)
237+
self.assertEqual({}, results["uncertainty"])
238+
inference = results["inference"]
239+
self.assertIsNotNone(inference)
240+
inference_tensor = torch.as_tensor(inference)
241+
self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor))
242+
243+
def test_inference_input_pil_image_base64(self):
244+
img = to_pil_image(torch.zeros(1, 5, 5))
245+
img_byte_arr = io.BytesIO()
246+
img.save(img_byte_arr, format="jpeg")
247+
img_byte_arr = img_byte_arr.getvalue()
248+
inp = base64.b64encode(img_byte_arr)
249+
250+
torch.manual_seed(42)
251+
torch_model = torch.jit.script(torch.nn.Sequential(
252+
torch.nn.Conv2d(1, 2, 3),
253+
torch.nn.Flatten(),
254+
torch.nn.Linear(3*3, 2)
255+
))
256+
model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model))
257+
training = Dummy.create_training(actor=self.user, model=model)
258+
input_file = SimpleUploadedFile(
259+
"input.pt",
260+
inp,
261+
content_type="application/octet-stream"
262+
)
263+
response = self.client.post(
264+
f"{BASE_URL}/inference/",
265+
{"model_id": str(training.model.id), "model_input": input_file}
266+
)
267+
self.assertEqual(response.status_code, 200)
268+
269+
results = pickle.loads(response.content)
270+
self.assertEqual({}, results["uncertainty"])
271+
inference = results["inference"]
272+
self.assertIsNotNone(inference)
273+
inference_tensor = torch.as_tensor(inference)
274+
self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor))

fl_server_api/tests/test_model.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_get_model_metadata(self):
149149
self.assertEqual(str(model.name), response_json["name"])
150150
self.assertEqual(str(model.description), response_json["description"])
151151
self.assertEqual(model.input_shape, response_json["input_shape"])
152+
self.assertFalse(response_json["has_preprocessing"])
152153
# check stats
153154
stats = response_json["stats"]
154155
self.assertIsNotNone(stats)
@@ -232,6 +233,28 @@ def test_get_model_metadata(self):
232233
self.assertIsNotNone(layer4["output_bytes"])
233234
self.assertIsNotNone(layer4["macs"])
234235

236+
def test_get_model_metadata_with_preprocessing(self):
237+
model_bytes = from_torch_module(torch.nn.Sequential(
238+
torch.nn.Linear(3, 64),
239+
torch.nn.ELU(),
240+
torch.nn.Linear(64, 1),
241+
))
242+
torch_model_preprocessing = from_torch_module(transforms.Compose([
243+
transforms.ToImage(),
244+
transforms.ToDtype(torch.float32, scale=True),
245+
transforms.Normalize(mean=(0.,), std=(1.,)),
246+
]))
247+
model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3])
248+
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
249+
self.assertEqual(200, response.status_code)
250+
self.assertEqual("application/json", response["content-type"])
251+
response_json = response.json()
252+
self.assertEqual(str(model.id), response_json["id"])
253+
self.assertEqual(str(model.name), response_json["name"])
254+
self.assertEqual(str(model.description), response_json["description"])
255+
self.assertEqual(model.input_shape, response_json["input_shape"])
256+
self.assertTrue(response_json["has_preprocessing"])
257+
235258
def test_get_model_metadata_torchscript_model(self):
236259
torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential(
237260
torch.nn.Linear(3, 64),
@@ -552,6 +575,30 @@ def test_upload_model_preprocessing_v2_Compose_good(self):
552575
self.assertIsNotNone(model.preprocessing)
553576
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))
554577

578+
def test_download_model_preprocessing(self):
579+
torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential(
580+
transforms.Normalize(mean=(0.,), std=(1.,)),
581+
)))
582+
model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing)
583+
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
584+
self.assertEqual(200, response.status_code)
585+
self.assertEqual("application/octet-stream", response["content-type"])
586+
torch_model = torch.jit.load(io.BytesIO(response.content))
587+
self.assertIsNotNone(torch_model)
588+
self.assertTrue(isinstance(torch_model, torch.nn.Module))
589+
590+
def test_download_model_preprocessing_with_undefined_preprocessing(self):
591+
model = Dummy.create_model(owner=self.user, preprocessing=None)
592+
with self.assertLogs("django.request", level="WARNING") as cm:
593+
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
594+
self.assertEqual(cm.output, [
595+
f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/",
596+
])
597+
self.assertEqual(404, response.status_code)
598+
response_json = response.json()
599+
self.assertIsNotNone(response_json)
600+
self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"])
601+
555602
@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
556603
def test_upload_update(self, apply_async: MagicMock):
557604
model = Dummy.create_model(owner=self.user, round=0)

fl_server_api/urls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
{"get": "get_model_metrics", "post": "create_model_metrics"}
3939
), name="model-metrics"),
4040
path("models/<str:id>/preprocessing/", view=Model.as_view(
41-
{"post": "upload_model_preprocessing"}
41+
{"get": "get_model_proprecessing", "post": "upload_model_preprocessing"}
4242
), name="model-preprocessing"),
4343
path("models/<str:id>/swag/", view=Model.as_view({"post": "create_swag_stats"}), name="model-swag"),
4444
# trainings

0 commit comments

Comments
 (0)