-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathvae_backbone_test.py
More file actions
52 lines (46 loc) · 1.68 KB
/
vae_backbone_test.py
File metadata and controls
52 lines (46 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pytest
from keras import ops
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
from keras_hub.src.tests.test_case import TestCase
class VAEBackboneTest(TestCase):
def setUp(self):
self.height, self.width = 64, 64
self.init_kwargs = {
"encoder_num_filters": [32, 32, 32, 32],
"encoder_num_blocks": [1, 1, 1, 1],
"decoder_num_filters": [32, 32, 32, 32],
"decoder_num_blocks": [1, 1, 1, 1],
# Use `mode` generate a deterministic output.
"sampler_method": "mode",
}
self.input_data = ops.ones((2, self.height, self.width, 3))
def test_backbone_basics(self):
self.run_backbone_test(
cls=VAEBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, self.height, self.width, 3),
)
@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=VAEBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
@pytest.mark.xfail(
strict=False,
reason=(
"Upstream litert-torch limitation: VAE uses pow ops which fail "
"TFLite legalization ('failed to legalize operation tfl.pow'). "
"Will pass once TFLite built-ins cover tfl.pow."
),
)
def test_litert_export(self):
self.run_litert_export_test(
cls=VAEBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
comparison_mode="statistical",
output_thresholds={"*": {"max": 3e-3, "mean": 3e-4}},
)