From e13179a07ebb3ddcc56a93e03f549008f869a7f6 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Fri, 23 May 2025 14:07:41 -0700 Subject: [PATCH] Add max_len arg to fake_speech_source(). fake_speech_source() generates extremely short fake audio samples (lengths like 1, 2, 3 and so on), which causes gradients to be zero in gradient unit tests. --- axlearn/audio/input_asr_test.py | 50 ++++++++++--------- axlearn/audio/spectrum_augmenter_test.py | 2 +- axlearn/common/input_fake.py | 5 +- .../audio/conformer/common_test.py | 10 ++-- .../audio/conformer/librispeech_trainer.py | 9 ++-- 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/axlearn/audio/input_asr_test.py b/axlearn/audio/input_asr_test.py index b46b85f52..5e69a8c95 100644 --- a/axlearn/audio/input_asr_test.py +++ b/axlearn/audio/input_asr_test.py @@ -9,7 +9,7 @@ import seqio import tensorflow as tf -from absl.testing import parameterized +from absl.testing import absltest, parameterized from axlearn.audio import input_asr from axlearn.common import input_fake, input_tf_data @@ -28,25 +28,24 @@ class SpeechInputTest(TestCase, tf.test.TestCase): max_len=5, expected=[ { - "inputs": tf.constant([-29515.0, 0, 0, 0, 0]), - "paddings": tf.constant([0, 1, 1, 1, 1]), - }, - { - "inputs": tf.constant([14620.0, -21206.0, 0, 0, 0]), + "inputs": tf.constant([-29515.0, -18256.0, 0, 0, 0]), "paddings": tf.constant([0, 0, 1, 1, 1]), }, { - "inputs": tf.constant([-3954.0, -15555.0, 18074.0, 0, 0]), + "inputs": tf.constant([14620.0, -21206.0, -4254.0, 0, 0]), "paddings": tf.constant([0, 0, 0, 1, 1]), }, + { + "inputs": tf.constant([-3954.0, -15555.0, 18074.0, 22466.0, 0]), + "paddings": tf.constant([0, 0, 0, 0, 1]), + }, ], ), dict( # Test a basic case with filtering. max_len=2, expected=[ - {"inputs": tf.constant([-29515.0, 0]), "paddings": tf.constant([0, 1])}, - {"inputs": tf.constant([14620.0, -21206.0]), "paddings": tf.constant([0, 0])}, + {"inputs": tf.constant([-29515.0, -18256.0]), "paddings": tf.constant([0, 0])}, ], ), dict( @@ -55,8 +54,8 @@ class SpeechInputTest(TestCase, tf.test.TestCase): truncate=True, expected=[ { - "inputs": tf.constant([-29515.0, 0]), - "paddings": tf.constant([0, 1]), + "inputs": tf.constant([-29515.0, -18256.0]), + "paddings": tf.constant([0, 0]), }, { "inputs": tf.constant([14620.0, -21206.0]), @@ -74,17 +73,17 @@ class SpeechInputTest(TestCase, tf.test.TestCase): scale=2**15, expected=[ { - "inputs": tf.constant([-0.9007263, 0.0, 0.0, 0.0, 0.0]), - "paddings": tf.constant([0, 1, 1, 1, 1]), - }, - { - "inputs": tf.constant([0.446167, -0.64715576, 0.0, 0.0, 0.0]), + "inputs": tf.constant([-0.9007263, -0.5571289, 0.0, 0.0, 0.0]), "paddings": tf.constant([0, 0, 1, 1, 1]), }, { - "inputs": tf.constant([-0.1206665, -0.47470093, 0.5515747, 0.0, 0.0]), + "inputs": tf.constant([0.446167, -0.64715576, -0.12982178, 0.0, 0.0]), "paddings": tf.constant([0, 0, 0, 1, 1]), }, + { + "inputs": tf.constant([-0.1206665, -0.47470093, 0.5515747, 0.6856079, 0.0]), + "paddings": tf.constant([0, 0, 0, 0, 1]), + }, ], ), dict( @@ -94,11 +93,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase): input_key="input_speech", expected=[ { - "inputs": tf.constant([-0.9007263, 0.0]), - "paddings": tf.constant([0, 1]), - }, - { - "inputs": tf.constant([0.446167, -0.64715576]), + "inputs": tf.constant([-0.9007263, -0.5571289]), "paddings": tf.constant([0, 0]), }, ], @@ -108,7 +103,7 @@ class SpeechInputTest(TestCase, tf.test.TestCase): def test_speech_input( self, max_len: int, - expected: dict[str, Any], + expected: list[dict[str, Any]], truncate: bool = False, input_key: str = "speech", scale: Optional[float] = None, @@ -122,12 +117,15 @@ def test_speech_input( # Use a fake speech source with only speech inputs. source = input_tf_data.with_processor( config_for_function(input_fake.fake_speech_source).set( - speech_key=input_key, num_examples=10 + speech_key=input_key, num_examples=10, max_len=5 ), processor=config_for_function(input_tf_data.select_fields).set(fields=[input_key]), is_training=False, ) actual = list(processor(source()).take(3)) + expected = [ + dict(inputs=d["inputs"], paddings=tf.cast(d["paddings"], tf.bool)) for d in expected + ] tf.nest.map_structure(self.assertAllClose, expected, actual) @@ -481,3 +479,7 @@ def test_filter_by_length( {k: tf.constant(v, dtype=tf.int32) for k, v in expect.items()} for expect in expected ] tf.nest.map_structure(self.assertAllEqual, expected, actual) + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/audio/spectrum_augmenter_test.py b/axlearn/audio/spectrum_augmenter_test.py index b785468cc..269bc2c3c 100644 --- a/axlearn/audio/spectrum_augmenter_test.py +++ b/axlearn/audio/spectrum_augmenter_test.py @@ -217,7 +217,7 @@ def test_visualize(self, input_shape: Sequence[int], **kwargs): dict(inputs=inputs, paddings=paddings), is_training=True, **kwargs ) # pylint: disable-next=import-outside-toplevel - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt # pytype: disable=import-error _, plots = plt.subplots(outputs.shape[0], 1) for plot, output in zip(plots, outputs): diff --git a/axlearn/common/input_fake.py b/axlearn/common/input_fake.py index 0cb38da53..ca9c151dd 100644 --- a/axlearn/common/input_fake.py +++ b/axlearn/common/input_fake.py @@ -418,6 +418,7 @@ def fake_classification_source_instruct_lm( def fake_speech_source( *, is_training: bool, + max_len: int = 100, num_examples: int = 100, speech_key: str = "speech", shuffle_buffer_size: Optional[int] = None, @@ -426,7 +427,9 @@ def fake_speech_source( Args: is_training: A boolean indicating whether it is in the training mode. + max_len: Maximum sequence length (in samples) for generated speech data. num_examples: Integer of number of examples in the dataset. + speech_key: Key name for the audio field in each example dict. shuffle_buffer_size: Shuffle buffer size used for training. Returns: @@ -441,7 +444,7 @@ def fake_speech_source( jax.random.PRNGKey(ix), minval=-(2**15), maxval=2**15, - shape=[ix % 100 + 1], + shape=[min(max_len // 2 + ix, max_len)], ), } for ix in range(num_examples) diff --git a/axlearn/experiments/audio/conformer/common_test.py b/axlearn/experiments/audio/conformer/common_test.py index 383bc35ee..08e12ed97 100644 --- a/axlearn/experiments/audio/conformer/common_test.py +++ b/axlearn/experiments/audio/conformer/common_test.py @@ -7,7 +7,7 @@ import pytest import seqio import tensorflow as tf -from absl.testing import parameterized +from absl.testing import absltest, parameterized from axlearn.audio.encoder_asr import SpeechFeatureLayer from axlearn.common.config import config_for_class @@ -46,7 +46,7 @@ def visit_fn(_, value): # Dropping all text. dict(max_source_len=5, max_target_len=5, expect_count=0), # Dropping some speech and pad text. - dict(max_source_len=4, max_target_len=8, expect_count=4), + dict(max_source_len=4, max_target_len=8, expect_count=3), ) @pytest.mark.skipif(not os.path.exists(_bpe_vocab_file), reason="Missing testdata.") def test_asr_input(self, max_source_len: int, max_target_len: int, expect_count: int): @@ -63,7 +63,7 @@ def test_asr_input(self, max_source_len: int, max_target_len: int, expect_count: ) def source(): - speech_ds = fake_speech_source(is_training=False, num_examples=5)() + speech_ds = fake_speech_source(is_training=False, num_examples=5, max_len=5)() text_ds = fake_text_source(is_training=False, batch_size=5)() return tf.data.Dataset.zip((speech_ds, text_ds)).map(lambda s, t: {**s, **t}) @@ -98,3 +98,7 @@ def source(): tf.reduce_all(ex["target"]["input_ids"][1:] == ex["target_labels"][:-1]) ) self.assertEqual(expect_count, actual_count) + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/experiments/audio/conformer/librispeech_trainer.py b/axlearn/experiments/audio/conformer/librispeech_trainer.py index 3b475e4b1..8b1c61f87 100644 --- a/axlearn/experiments/audio/conformer/librispeech_trainer.py +++ b/axlearn/experiments/audio/conformer/librispeech_trainer.py @@ -63,12 +63,15 @@ from axlearn.experiments.trainer_config_utils import TrainerConfigFn -def source_config(is_training: bool, split: str) -> InstantiableConfig[BuildDatasetFn]: +def source_config( + is_training: bool, split: str, max_source_len: int +) -> InstantiableConfig[BuildDatasetFn]: """Builds a source dataset config for librispeech. Args: is_training: Whether source is for training. split: Dataset split. + max_source_len: Max speech length. Returns: A config that instantiates to a data source. @@ -78,7 +81,7 @@ def source_config(is_training: bool, split: str) -> InstantiableConfig[BuildData def fake_asr_source(is_training: bool): def fn(): text_ds = fake_text_source(is_training=is_training)() - speech_ds = fake_speech_source(is_training=is_training)() + speech_ds = fake_speech_source(is_training=is_training, max_len=max_source_len)() return tf.data.Dataset.zip((speech_ds, text_ds)).map(lambda s, t: {**s, **t}) return fn @@ -317,7 +320,7 @@ def _input_fn(is_training: bool, split: str, eos_id: Optional[int] = None) -> In global_batch_size = 2048 if is_training else 512 return Input.default_config().set( - source=source_config(is_training=is_training, split=split), + source=source_config(is_training=is_training, split=split, max_source_len=max_source_len), processor=config_for_function(asr_input).set( max_source_len=max_source_len, max_target_len=max_target_len,