|
3 | 3 | # |
4 | 4 | # SPDX-License-Identifier: Apache-2.0 |
5 | 5 |
|
| 6 | +import base64 |
6 | 7 | from django.core.files.uploadedfile import SimpleUploadedFile |
7 | 8 | from django.test import TestCase |
8 | 9 | import json |
| 10 | +import io |
9 | 11 | import pickle |
10 | 12 | import torch |
| 13 | +import torch.nn |
| 14 | +from torchvision.transforms.functional import to_pil_image |
11 | 15 | from uuid import uuid4 |
12 | 16 |
|
13 | 17 | from fl_server_core.tests import BASE_URL, Dummy |
@@ -168,3 +172,103 @@ def _inference_result(self, torch_model: torch.nn.Module): |
168 | 172 | self.assertIsNotNone(inference) |
169 | 173 | inference_tensor = torch.as_tensor(inference) |
170 | 174 | self.assertTrue(torch.all(torch.tensor([2, 0, 0]) == inference_tensor)) |
| 175 | + |
| 176 | + def test_inference_input_shape_positive(self): |
| 177 | + inp = from_torch_tensor(torch.zeros(3, 3)) |
| 178 | + model = Dummy.create_model(input_shape=[None, 3]) |
| 179 | + training = Dummy.create_training(actor=self.user, model=model) |
| 180 | + input_file = SimpleUploadedFile( |
| 181 | + "input.pt", |
| 182 | + inp, |
| 183 | + content_type="application/octet-stream" |
| 184 | + ) |
| 185 | + response = self.client.post( |
| 186 | + f"{BASE_URL}/inference/", |
| 187 | + {"model_id": str(training.model.id), "model_input": input_file} |
| 188 | + ) |
| 189 | + self.assertEqual(response.status_code, 200) |
| 190 | + |
| 191 | + def test_inference_input_shape_negative(self): |
| 192 | + inp = from_torch_tensor(torch.zeros(3, 3)) |
| 193 | + model = Dummy.create_model(input_shape=[None, 5]) |
| 194 | + training = Dummy.create_training(actor=self.user, model=model) |
| 195 | + input_file = SimpleUploadedFile( |
| 196 | + "input.pt", |
| 197 | + inp, |
| 198 | + content_type="application/octet-stream" |
| 199 | + ) |
| 200 | + with self.assertLogs("root", level="WARNING") as cm: |
| 201 | + response = self.client.post( |
| 202 | + f"{BASE_URL}/inference/", |
| 203 | + {"model_id": str(training.model.id), "model_input": input_file} |
| 204 | + ) |
| 205 | + self.assertEqual(cm.output, [ |
| 206 | + "WARNING:django.request:Bad Request: /api/inference/", |
| 207 | + ]) |
| 208 | + self.assertEqual(response.status_code, 400) |
| 209 | + self.assertEqual(response.json()[0], "Input shape does not match model input shape.") |
| 210 | + |
| 211 | + def test_inference_input_pil_image(self): |
| 212 | + img = to_pil_image(torch.zeros(1, 5, 5)) |
| 213 | + img_byte_arr = io.BytesIO() |
| 214 | + img.save(img_byte_arr, format="jpeg") |
| 215 | + img_byte_arr = img_byte_arr.getvalue() |
| 216 | + |
| 217 | + torch.manual_seed(42) |
| 218 | + torch_model = torch.jit.script(torch.nn.Sequential( |
| 219 | + torch.nn.Conv2d(1, 2, 3), |
| 220 | + torch.nn.Flatten(), |
| 221 | + torch.nn.Linear(3*3, 2) |
| 222 | + )) |
| 223 | + model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model)) |
| 224 | + training = Dummy.create_training(actor=self.user, model=model) |
| 225 | + input_file = SimpleUploadedFile( |
| 226 | + "input.pt", |
| 227 | + img_byte_arr, |
| 228 | + content_type="application/octet-stream" |
| 229 | + ) |
| 230 | + response = self.client.post( |
| 231 | + f"{BASE_URL}/inference/", |
| 232 | + {"model_id": str(training.model.id), "model_input": input_file} |
| 233 | + ) |
| 234 | + self.assertEqual(response.status_code, 200) |
| 235 | + |
| 236 | + results = pickle.loads(response.content) |
| 237 | + self.assertEqual({}, results["uncertainty"]) |
| 238 | + inference = results["inference"] |
| 239 | + self.assertIsNotNone(inference) |
| 240 | + inference_tensor = torch.as_tensor(inference) |
| 241 | + self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor)) |
| 242 | + |
| 243 | + def test_inference_input_pil_image_base64(self): |
| 244 | + img = to_pil_image(torch.zeros(1, 5, 5)) |
| 245 | + img_byte_arr = io.BytesIO() |
| 246 | + img.save(img_byte_arr, format="jpeg") |
| 247 | + img_byte_arr = img_byte_arr.getvalue() |
| 248 | + inp = base64.b64encode(img_byte_arr) |
| 249 | + |
| 250 | + torch.manual_seed(42) |
| 251 | + torch_model = torch.jit.script(torch.nn.Sequential( |
| 252 | + torch.nn.Conv2d(1, 2, 3), |
| 253 | + torch.nn.Flatten(), |
| 254 | + torch.nn.Linear(3*3, 2) |
| 255 | + )) |
| 256 | + model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model)) |
| 257 | + training = Dummy.create_training(actor=self.user, model=model) |
| 258 | + input_file = SimpleUploadedFile( |
| 259 | + "input.pt", |
| 260 | + inp, |
| 261 | + content_type="application/octet-stream" |
| 262 | + ) |
| 263 | + response = self.client.post( |
| 264 | + f"{BASE_URL}/inference/", |
| 265 | + {"model_id": str(training.model.id), "model_input": input_file} |
| 266 | + ) |
| 267 | + self.assertEqual(response.status_code, 200) |
| 268 | + |
| 269 | + results = pickle.loads(response.content) |
| 270 | + self.assertEqual({}, results["uncertainty"]) |
| 271 | + inference = results["inference"] |
| 272 | + self.assertIsNotNone(inference) |
| 273 | + inference_tensor = torch.as_tensor(inference) |
| 274 | + self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor)) |
0 commit comments