|
| 1 | +import dataclasses |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import tempfile |
| 5 | +from absl.testing import absltest |
| 6 | +from absl.testing import parameterized |
| 7 | +import numpy as np |
| 8 | +from PIL import Image |
| 9 | +from tunix.processors import image_processor |
| 10 | + |
| 11 | + |
| 12 | +@dataclasses.dataclass(slots=True, kw_only=True) |
| 13 | +class DummyConfig: |
| 14 | + |
| 15 | + image_height: int = 32 |
| 16 | + image_width: int = 32 |
| 17 | + image_channels: int = 3 |
| 18 | + image_mean: tuple[float, ...] = (127.5, 127.5, 127.5) |
| 19 | + image_std: tuple[float, ...] = (127.5, 127.5, 127.5) |
| 20 | + |
| 21 | + |
| 22 | +class ImageProcessorTest(parameterized.TestCase): |
| 23 | + |
| 24 | + def setUp(self): |
| 25 | + super().setUp() |
| 26 | + self.height = 32 |
| 27 | + self.width = 32 |
| 28 | + self.channels = 3 |
| 29 | + config = DummyConfig( |
| 30 | + image_height=self.height, |
| 31 | + image_width=self.width, |
| 32 | + image_channels=self.channels, |
| 33 | + ) |
| 34 | + self.processor = image_processor.ImageProcessor(config) |
| 35 | + |
| 36 | + def _create_dummy_image_file(self, filename='test_image.png'): |
| 37 | + img_array = np.zeros((100, 100, 3), dtype=np.uint8) |
| 38 | + img = Image.fromarray(img_array) |
| 39 | + |
| 40 | + temp_dir = tempfile.mkdtemp() |
| 41 | + self.addCleanup(lambda: shutil.rmtree(temp_dir)) |
| 42 | + |
| 43 | + temp_file = os.path.join(temp_dir, filename) |
| 44 | + img.save(temp_file) |
| 45 | + return temp_file |
| 46 | + |
| 47 | + def test_process_none_image(self): |
| 48 | + processed_image = self.processor.preprocess_image(None) |
| 49 | + self.assertEqual( |
| 50 | + processed_image.shape, (self.height, self.width, self.channels) |
| 51 | + ) |
| 52 | + np.testing.assert_array_equal(processed_image, np.zeros((32, 32, 3))) |
| 53 | + |
| 54 | + def test_path_input(self): |
| 55 | + img_path = self._create_dummy_image_file() |
| 56 | + processed_image = self.processor.preprocess_image(img_path) |
| 57 | + self.assertEqual( |
| 58 | + processed_image.shape, (self.height, self.width, self.channels) |
| 59 | + ) |
| 60 | + np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3))) |
| 61 | + |
| 62 | + def test_array_input(self): |
| 63 | + img_array = np.zeros((100, 100, 3), dtype=np.uint8) |
| 64 | + processed_image = self.processor.preprocess_image(img_array) |
| 65 | + self.assertEqual( |
| 66 | + processed_image.shape, (self.height, self.width, self.channels) |
| 67 | + ) |
| 68 | + np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3))) |
| 69 | + |
| 70 | + @parameterized.named_parameters( |
| 71 | + dict(testcase_name='array', input_type='array'), |
| 72 | + dict(testcase_name='path', input_type='path'), |
| 73 | + ) |
| 74 | + def test_call_one_image(self, input_type): |
| 75 | + if input_type == 'array': |
| 76 | + images = [np.zeros((100, 100, 3), dtype=np.uint8)] |
| 77 | + elif input_type == 'path': |
| 78 | + images = [self._create_dummy_image_file()] |
| 79 | + |
| 80 | + processed_images = self.processor(images=images) # pylint: disable=undefined-variable |
| 81 | + self.assertLen(processed_images, 1) |
| 82 | + self.assertLen(processed_images[0], 1) |
| 83 | + self.assertEqual( |
| 84 | + processed_images[0][0].shape, (self.height, self.width, self.channels) # pytype: disable=attribute-error |
| 85 | + ) |
| 86 | + np.testing.assert_allclose( |
| 87 | + processed_images[0][0], -1.0 * np.ones((32, 32, 3)) |
| 88 | + ) |
| 89 | + |
| 90 | + def test_padding(self): |
| 91 | + img1 = np.zeros((100, 100, 3), dtype=np.uint8) |
| 92 | + img2 = np.zeros((50, 50, 3), dtype=np.uint8) |
| 93 | + images = [[img1], [img1, img2]] |
| 94 | + processed_images = self.processor(images=images) |
| 95 | + self.assertLen(processed_images, 2) |
| 96 | + self.assertLen(processed_images[0], 2) # Padded to 2 |
| 97 | + self.assertLen(processed_images[1], 2) |
| 98 | + np.testing.assert_allclose( |
| 99 | + processed_images[0][0], -1.0 * np.ones((32, 32, 3)) |
| 100 | + ) |
| 101 | + # Padded image should be zeros |
| 102 | + np.testing.assert_allclose(processed_images[0][1], np.zeros((32, 32, 3))) |
| 103 | + np.testing.assert_allclose( |
| 104 | + processed_images[1][0], -1.0 * np.ones((32, 32, 3)) |
| 105 | + ) |
| 106 | + np.testing.assert_allclose( |
| 107 | + processed_images[1][1], -1.0 * np.ones((32, 32, 3)) |
| 108 | + ) |
| 109 | + |
| 110 | + def test_call_with_none_in_batch(self): |
| 111 | + images = [None, [np.zeros((100, 100, 3), dtype=np.uint8)]] |
| 112 | + processed_images = self.processor(images=images) |
| 113 | + self.assertLen(processed_images, 2) |
| 114 | + self.assertLen(processed_images[0], 1) |
| 115 | + self.assertLen(processed_images[1], 1) |
| 116 | + np.testing.assert_allclose(processed_images[0][0], np.zeros((32, 32, 3))) |
| 117 | + np.testing.assert_allclose( |
| 118 | + processed_images[1][0], -1.0 * np.ones((32, 32, 3)) |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == '__main__': |
| 123 | + absltest.main() |
0 commit comments