9
9
10
10
import seqio
11
11
import tensorflow as tf
12
- from absl .testing import parameterized
12
+ from absl .testing import absltest , parameterized
13
13
14
14
from axlearn .audio import input_asr
15
15
from axlearn .common import input_fake , input_tf_data
@@ -28,25 +28,24 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
28
28
max_len = 5 ,
29
29
expected = [
30
30
{
31
- "inputs" : tf .constant ([- 29515.0 , 0 , 0 , 0 , 0 ]),
32
- "paddings" : tf .constant ([0 , 1 , 1 , 1 , 1 ]),
33
- },
34
- {
35
- "inputs" : tf .constant ([14620.0 , - 21206.0 , 0 , 0 , 0 ]),
31
+ "inputs" : tf .constant ([- 29515.0 , - 18256.0 , 0 , 0 , 0 ]),
36
32
"paddings" : tf .constant ([0 , 0 , 1 , 1 , 1 ]),
37
33
},
38
34
{
39
- "inputs" : tf .constant ([- 3954 .0 , - 15555 .0 , 18074 .0 , 0 , 0 ]),
35
+ "inputs" : tf .constant ([14620 .0 , - 21206 .0 , - 4254 .0 , 0 , 0 ]),
40
36
"paddings" : tf .constant ([0 , 0 , 0 , 1 , 1 ]),
41
37
},
38
+ {
39
+ "inputs" : tf .constant ([- 3954.0 , - 15555.0 , 18074.0 , 22466.0 , 0 ]),
40
+ "paddings" : tf .constant ([0 , 0 , 0 , 0 , 1 ]),
41
+ },
42
42
],
43
43
),
44
44
dict (
45
45
# Test a basic case with filtering.
46
46
max_len = 2 ,
47
47
expected = [
48
- {"inputs" : tf .constant ([- 29515.0 , 0 ]), "paddings" : tf .constant ([0 , 1 ])},
49
- {"inputs" : tf .constant ([14620.0 , - 21206.0 ]), "paddings" : tf .constant ([0 , 0 ])},
48
+ {"inputs" : tf .constant ([- 29515.0 , - 18256.0 ]), "paddings" : tf .constant ([0 , 0 ])},
50
49
],
51
50
),
52
51
dict (
@@ -55,8 +54,8 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
55
54
truncate = True ,
56
55
expected = [
57
56
{
58
- "inputs" : tf .constant ([- 29515.0 , 0 ]),
59
- "paddings" : tf .constant ([0 , 1 ]),
57
+ "inputs" : tf .constant ([- 29515.0 , - 18256. 0 ]),
58
+ "paddings" : tf .constant ([0 , 0 ]),
60
59
},
61
60
{
62
61
"inputs" : tf .constant ([14620.0 , - 21206.0 ]),
@@ -74,17 +73,17 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
74
73
scale = 2 ** 15 ,
75
74
expected = [
76
75
{
77
- "inputs" : tf .constant ([- 0.9007263 , 0.0 , 0.0 , 0.0 , 0.0 ]),
78
- "paddings" : tf .constant ([0 , 1 , 1 , 1 , 1 ]),
79
- },
80
- {
81
- "inputs" : tf .constant ([0.446167 , - 0.64715576 , 0.0 , 0.0 , 0.0 ]),
76
+ "inputs" : tf .constant ([- 0.9007263 , - 0.5571289 , 0.0 , 0.0 , 0.0 ]),
82
77
"paddings" : tf .constant ([0 , 0 , 1 , 1 , 1 ]),
83
78
},
84
79
{
85
- "inputs" : tf .constant ([- 0.1206665 , - 0.47470093 , 0.5515747 , 0.0 , 0.0 ]),
80
+ "inputs" : tf .constant ([0.446167 , - 0.64715576 , - 0.12982178 , 0.0 , 0.0 ]),
86
81
"paddings" : tf .constant ([0 , 0 , 0 , 1 , 1 ]),
87
82
},
83
+ {
84
+ "inputs" : tf .constant ([- 0.1206665 , - 0.47470093 , 0.5515747 , 0.6856079 , 0.0 ]),
85
+ "paddings" : tf .constant ([0 , 0 , 0 , 0 , 1 ]),
86
+ },
88
87
],
89
88
),
90
89
dict (
@@ -94,11 +93,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
94
93
input_key = "input_speech" ,
95
94
expected = [
96
95
{
97
- "inputs" : tf .constant ([- 0.9007263 , 0.0 ]),
98
- "paddings" : tf .constant ([0 , 1 ]),
99
- },
100
- {
101
- "inputs" : tf .constant ([0.446167 , - 0.64715576 ]),
96
+ "inputs" : tf .constant ([- 0.9007263 , - 0.5571289 ]),
102
97
"paddings" : tf .constant ([0 , 0 ]),
103
98
},
104
99
],
@@ -108,7 +103,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
108
103
def test_speech_input (
109
104
self ,
110
105
max_len : int ,
111
- expected : dict [str , Any ],
106
+ expected : list [ dict [str , Any ] ],
112
107
truncate : bool = False ,
113
108
input_key : str = "speech" ,
114
109
scale : Optional [float ] = None ,
@@ -122,12 +117,15 @@ def test_speech_input(
122
117
# Use a fake speech source with only speech inputs.
123
118
source = input_tf_data .with_processor (
124
119
config_for_function (input_fake .fake_speech_source ).set (
125
- speech_key = input_key , num_examples = 10
120
+ speech_key = input_key , num_examples = 10 , max_len = 5
126
121
),
127
122
processor = config_for_function (input_tf_data .select_fields ).set (fields = [input_key ]),
128
123
is_training = False ,
129
124
)
130
125
actual = list (processor (source ()).take (3 ))
126
+ expected = [
127
+ dict (inputs = d ["inputs" ], paddings = tf .cast (d ["paddings" ], tf .bool )) for d in expected
128
+ ]
131
129
tf .nest .map_structure (self .assertAllClose , expected , actual )
132
130
133
131
@@ -481,3 +479,7 @@ def test_filter_by_length(
481
479
{k : tf .constant (v , dtype = tf .int32 ) for k , v in expect .items ()} for expect in expected
482
480
]
483
481
tf .nest .map_structure (self .assertAllEqual , expected , actual )
482
+
483
+
484
+ if __name__ == "__main__" :
485
+ absltest .main ()
0 commit comments