Skip to content

Commit 99aa4df

Browse files
committed
Add max_len arg to fake_speech_source().
fake_speech_source() generates extremely short fake audio samples (lengths like 1, 2, 3 and so on), which causes gradients to be zero in gradient unit tests.
1 parent 97f5108 commit 99aa4df

File tree

5 files changed

+44
-31
lines changed

5 files changed

+44
-31
lines changed

axlearn/audio/input_asr_test.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import seqio
1111
import tensorflow as tf
12-
from absl.testing import parameterized
12+
from absl.testing import absltest, parameterized
1313

1414
from axlearn.audio import input_asr
1515
from axlearn.common import input_fake, input_tf_data
@@ -28,25 +28,24 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
2828
max_len=5,
2929
expected=[
3030
{
31-
"inputs": tf.constant([-29515.0, 0, 0, 0, 0]),
32-
"paddings": tf.constant([0, 1, 1, 1, 1]),
33-
},
34-
{
35-
"inputs": tf.constant([14620.0, -21206.0, 0, 0, 0]),
31+
"inputs": tf.constant([-29515.0, -18256.0, 0, 0, 0]),
3632
"paddings": tf.constant([0, 0, 1, 1, 1]),
3733
},
3834
{
39-
"inputs": tf.constant([-3954.0, -15555.0, 18074.0, 0, 0]),
35+
"inputs": tf.constant([14620.0, -21206.0, -4254.0, 0, 0]),
4036
"paddings": tf.constant([0, 0, 0, 1, 1]),
4137
},
38+
{
39+
"inputs": tf.constant([-3954.0, -15555.0, 18074.0, 22466.0, 0]),
40+
"paddings": tf.constant([0, 0, 0, 0, 1]),
41+
},
4242
],
4343
),
4444
dict(
4545
# Test a basic case with filtering.
4646
max_len=2,
4747
expected=[
48-
{"inputs": tf.constant([-29515.0, 0]), "paddings": tf.constant([0, 1])},
49-
{"inputs": tf.constant([14620.0, -21206.0]), "paddings": tf.constant([0, 0])},
48+
{"inputs": tf.constant([-29515.0, -18256.0]), "paddings": tf.constant([0, 0])},
5049
],
5150
),
5251
dict(
@@ -55,8 +54,8 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
5554
truncate=True,
5655
expected=[
5756
{
58-
"inputs": tf.constant([-29515.0, 0]),
59-
"paddings": tf.constant([0, 1]),
57+
"inputs": tf.constant([-29515.0, -18256.0]),
58+
"paddings": tf.constant([0, 0]),
6059
},
6160
{
6261
"inputs": tf.constant([14620.0, -21206.0]),
@@ -74,17 +73,17 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
7473
scale=2**15,
7574
expected=[
7675
{
77-
"inputs": tf.constant([-0.9007263, 0.0, 0.0, 0.0, 0.0]),
78-
"paddings": tf.constant([0, 1, 1, 1, 1]),
79-
},
80-
{
81-
"inputs": tf.constant([0.446167, -0.64715576, 0.0, 0.0, 0.0]),
76+
"inputs": tf.constant([-0.9007263, -0.5571289, 0.0, 0.0, 0.0]),
8277
"paddings": tf.constant([0, 0, 1, 1, 1]),
8378
},
8479
{
85-
"inputs": tf.constant([-0.1206665, -0.47470093, 0.5515747, 0.0, 0.0]),
80+
"inputs": tf.constant([0.446167, -0.64715576, -0.12982178, 0.0, 0.0]),
8681
"paddings": tf.constant([0, 0, 0, 1, 1]),
8782
},
83+
{
84+
"inputs": tf.constant([-0.1206665, -0.47470093, 0.5515747, 0.6856079, 0.0]),
85+
"paddings": tf.constant([0, 0, 0, 0, 1]),
86+
},
8887
],
8988
),
9089
dict(
@@ -94,11 +93,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
9493
input_key="input_speech",
9594
expected=[
9695
{
97-
"inputs": tf.constant([-0.9007263, 0.0]),
98-
"paddings": tf.constant([0, 1]),
99-
},
100-
{
101-
"inputs": tf.constant([0.446167, -0.64715576]),
96+
"inputs": tf.constant([-0.9007263, -0.5571289]),
10297
"paddings": tf.constant([0, 0]),
10398
},
10499
],
@@ -108,7 +103,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase):
108103
def test_speech_input(
109104
self,
110105
max_len: int,
111-
expected: dict[str, Any],
106+
expected: list[dict[str, Any]],
112107
truncate: bool = False,
113108
input_key: str = "speech",
114109
scale: Optional[float] = None,
@@ -122,12 +117,15 @@ def test_speech_input(
122117
# Use a fake speech source with only speech inputs.
123118
source = input_tf_data.with_processor(
124119
config_for_function(input_fake.fake_speech_source).set(
125-
speech_key=input_key, num_examples=10
120+
speech_key=input_key, num_examples=10, max_len=5
126121
),
127122
processor=config_for_function(input_tf_data.select_fields).set(fields=[input_key]),
128123
is_training=False,
129124
)
130125
actual = list(processor(source()).take(3))
126+
expected = [
127+
dict(inputs=d["inputs"], paddings=tf.cast(d["paddings"], tf.bool)) for d in expected
128+
]
131129
tf.nest.map_structure(self.assertAllClose, expected, actual)
132130

133131

@@ -481,3 +479,7 @@ def test_filter_by_length(
481479
{k: tf.constant(v, dtype=tf.int32) for k, v in expect.items()} for expect in expected
482480
]
483481
tf.nest.map_structure(self.assertAllEqual, expected, actual)
482+
483+
484+
if __name__ == "__main__":
485+
absltest.main()

axlearn/audio/spectrum_augmenter_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def test_visualize(self, input_shape: Sequence[int], **kwargs):
216216
outputs = self._generate_masks(
217217
dict(inputs=inputs, paddings=paddings), is_training=True, **kwargs
218218
)
219+
# pytype: disable=import-error
219220
# pylint: disable-next=import-outside-toplevel
220221
import matplotlib.pyplot as plt
221222

axlearn/common/input_fake.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def fake_classification_source_instruct_lm(
418418
def fake_speech_source(
419419
*,
420420
is_training: bool,
421+
max_len: int = 100,
421422
num_examples: int = 100,
422423
speech_key: str = "speech",
423424
shuffle_buffer_size: Optional[int] = None,
@@ -426,7 +427,9 @@ def fake_speech_source(
426427
427428
Args:
428429
is_training: A boolean indicating whether it is in the training mode.
430+
max_len: Maximum sequence length (in samples) for generated speech data.
429431
num_examples: Integer of number of examples in the dataset.
432+
speech_key: Key name for the audio field in each example dict.
430433
shuffle_buffer_size: Shuffle buffer size used for training.
431434
432435
Returns:
@@ -441,7 +444,7 @@ def fake_speech_source(
441444
jax.random.PRNGKey(ix),
442445
minval=-(2**15),
443446
maxval=2**15,
444-
shape=[ix % 100 + 1],
447+
shape=[min(max_len // 2 + ix, max_len)],
445448
),
446449
}
447450
for ix in range(num_examples)

axlearn/experiments/audio/conformer/common_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import seqio
99
import tensorflow as tf
10-
from absl.testing import parameterized
10+
from absl.testing import absltest, parameterized
1111

1212
from axlearn.audio.encoder_asr import SpeechFeatureLayer
1313
from axlearn.common.config import config_for_class
@@ -46,7 +46,7 @@ def visit_fn(_, value):
4646
# Dropping all text.
4747
dict(max_source_len=5, max_target_len=5, expect_count=0),
4848
# Dropping some speech and pad text.
49-
dict(max_source_len=4, max_target_len=8, expect_count=4),
49+
dict(max_source_len=4, max_target_len=8, expect_count=3),
5050
)
5151
@pytest.mark.skipif(not os.path.exists(_bpe_vocab_file), reason="Missing testdata.")
5252
def test_asr_input(self, max_source_len: int, max_target_len: int, expect_count: int):
@@ -63,7 +63,7 @@ def test_asr_input(self, max_source_len: int, max_target_len: int, expect_count:
6363
)
6464

6565
def source():
66-
speech_ds = fake_speech_source(is_training=False, num_examples=5)()
66+
speech_ds = fake_speech_source(is_training=False, num_examples=5, max_len=5)()
6767
text_ds = fake_text_source(is_training=False, batch_size=5)()
6868
return tf.data.Dataset.zip((speech_ds, text_ds)).map(lambda s, t: {**s, **t})
6969

@@ -98,3 +98,7 @@ def source():
9898
tf.reduce_all(ex["target"]["input_ids"][1:] == ex["target_labels"][:-1])
9999
)
100100
self.assertEqual(expect_count, actual_count)
101+
102+
103+
if __name__ == "__main__":
104+
absltest.main()

axlearn/experiments/audio/conformer/librispeech_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,15 @@
6363
from axlearn.experiments.trainer_config_utils import TrainerConfigFn
6464

6565

66-
def source_config(is_training: bool, split: str) -> InstantiableConfig[BuildDatasetFn]:
66+
def source_config(
67+
is_training: bool, split: str, max_source_len: int
68+
) -> InstantiableConfig[BuildDatasetFn]:
6769
"""Builds a source dataset config for librispeech.
6870
6971
Args:
7072
is_training: Whether source is for training.
7173
split: Dataset split.
74+
max_source_len: Max speech length.
7275
7376
Returns:
7477
A config that instantiates to a data source.
@@ -78,7 +81,7 @@ def source_config(is_training: bool, split: str) -> InstantiableConfig[BuildData
7881
def fake_asr_source(is_training: bool):
7982
def fn():
8083
text_ds = fake_text_source(is_training=is_training)()
81-
speech_ds = fake_speech_source(is_training=is_training)()
84+
speech_ds = fake_speech_source(is_training=is_training, max_len=max_source_len)()
8285
return tf.data.Dataset.zip((speech_ds, text_ds)).map(lambda s, t: {**s, **t})
8386

8487
return fn
@@ -317,7 +320,7 @@ def _input_fn(is_training: bool, split: str, eos_id: Optional[int] = None) -> In
317320
global_batch_size = 2048 if is_training else 512
318321

319322
return Input.default_config().set(
320-
source=source_config(is_training=is_training, split=split),
323+
source=source_config(is_training=is_training, split=split, max_source_len=max_source_len),
321324
processor=config_for_function(asr_input).set(
322325
max_source_len=max_source_len,
323326
max_target_len=max_target_len,

0 commit comments

Comments
 (0)