-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathvit_det_backbone_test.py
More file actions
48 lines (42 loc) · 1.46 KB
/
vit_det_backbone_test.py
File metadata and controls
48 lines (42 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import pytest
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.tests.test_case import TestCase
class ViTDetBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"image_shape": (16, 16, 3),
"patch_size": 2,
"hidden_size": 4,
"num_layers": 2,
"global_attention_layer_indices": [2, 5, 8, 11],
"intermediate_dim": 4 * 4,
"num_heads": 2,
"num_output_channels": 2,
"window_size": 2,
}
self.input_data = np.ones((1, 16, 16, 3), dtype="float32")
def test_backbone_basics(self):
self.run_backbone_test(
cls=ViTDetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(1, 8, 8, 2),
run_mixed_precision_check=False,
run_quantization_check=False,
)
@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ViTDetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
def test_litert_export(self):
self.run_litert_export_test(
cls=ViTDetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
comparison_mode="statistical",
output_thresholds={"*": {"max": 1e-3, "mean": 1e-4}},
)