Skip to content

Commit 9a67d87

Browse files
authored
Merge pull request #170 from TensorSpeech/dev/transducer
Update Transducer, Positional Encoding, Dataset
2 parents 283d1f2 + 53f1a57 commit 9a67d87

19 files changed

+979
-70
lines changed

examples/conformer/config.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ speech_config:
2525

2626
decoder_config:
2727
vocabulary: null
28-
target_vocab_size: 4096
29-
max_subword_length: 4
28+
target_vocab_size: 1000
29+
max_subword_length: 10
3030
blank_at_zero: True
3131
beam_width: 5
3232
norm_score: True
@@ -40,7 +40,7 @@ model_config:
4040
filters: 144
4141
kernel_size: 3
4242
strides: 2
43-
encoder_positional_encoding: sinusoid_concat
43+
encoder_positional_encoding: sinusoid_concat_v2
4444
encoder_dmodel: 144
4545
encoder_num_blocks: 16
4646
encoder_head_size: 36
@@ -55,11 +55,12 @@ model_config:
5555
prediction_rnn_units: 320
5656
prediction_rnn_type: lstm
5757
prediction_rnn_implementation: 2
58-
prediction_layer_norm: False
59-
prediction_projection_units: 144
60-
joint_dim: 640
61-
prejoint_linear: False
58+
prediction_layer_norm: True
59+
prediction_projection_units: 0
60+
joint_dim: 320
61+
prejoint_linear: True
6262
joint_activation: tanh
63+
joint_mode: add
6364

6465
learning_config:
6566
train_dataset_config:
@@ -78,6 +79,7 @@ learning_config:
7879
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
7980
shuffle: True
8081
cache: True
82+
cache_percent: 0.2
8183
buffer_size: 100
8284
drop_remainder: True
8385
stage: train

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848

4949
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
5050

51+
parser.add_argument("--validation", default=False, action="store_true", help="Enable validation dataset")
52+
5153
args = parser.parse_args()
5254

5355
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -83,35 +85,42 @@
8385
**vars(config.learning_config.train_dataset_config),
8486
indefinite=True
8587
)
86-
eval_dataset = ASRTFRecordDatasetKeras(
87-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
88-
**vars(config.learning_config.eval_dataset_config),
89-
indefinite=True
90-
)
88+
89+
if args.validation:
90+
eval_dataset = ASRTFRecordDatasetKeras(
91+
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
92+
**vars(config.learning_config.eval_dataset_config),
93+
indefinite=True
94+
)
9195

9296
if args.compute_lengths:
9397
train_dataset.update_lengths(args.metadata_prefix)
94-
eval_dataset.update_lengths(args.metadata_prefix)
98+
if args.validation:
99+
eval_dataset.update_lengths(args.metadata_prefix)
95100

96101
# Update metadata calculated from both train and eval datasets
97102
train_dataset.load_metadata(args.metadata_prefix)
98-
eval_dataset.load_metadata(args.metadata_prefix)
103+
if args.validation:
104+
eval_dataset.load_metadata(args.metadata_prefix)
99105

100106
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
101107
global_batch_size = batch_size
102108
global_batch_size *= strategy.num_replicas_in_sync
103109

104110
train_data_loader = train_dataset.create(global_batch_size)
105-
eval_data_loader = eval_dataset.create(global_batch_size)
111+
eval_data_loader = eval_dataset.create(global_batch_size) if args.validation else None
112+
validation_steps = eval_dataset.total_steps if args.validation else None
106113

107114
with strategy.scope():
108115
# build model
109116
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
110117
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
111-
conformer.summary(line_length=120)
112118

113119
if args.saved:
114120
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
121+
print('Load pretrained weights successfully')
122+
123+
conformer.summary(line_length=120)
115124

116125
optimizer = tf.keras.optimizers.Adam(
117126
TransformerSchedule(
@@ -140,5 +149,5 @@
140149
conformer.fit(
141150
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
142151
validation_data=eval_data_loader, callbacks=callbacks,
143-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
152+
steps_per_epoch=train_dataset.total_steps, validation_steps=validation_steps
144153
)

examples/contextnet/train_tpu_keras_subword_contextnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from tensorflow_asr.configs.config import Config
5858
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras
5959
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
60-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
60+
from tensorflow_asr.featurizers.text_featurizers import TFSubwordFeaturizer, SentencePieceFeaturizer
6161
from tensorflow_asr.models.keras.contextnet import ContextNet
6262
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6363

@@ -69,10 +69,10 @@
6969
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
7070
elif args.subwords and os.path.exists(args.subwords):
7171
print("Loading subwords ...")
72-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
72+
text_featurizer = TFSubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
7373
else:
7474
print("Generating subwords ...")
75-
text_featurizer = SubwordFeaturizer.build_from_corpus(
75+
text_featurizer = TFSubwordFeaturizer.build_from_corpus(
7676
config.decoder_config,
7777
corpus_files=args.subwords_corpus
7878
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
setuptools.setup(
2424
name="TensorFlowASR",
25-
version="0.8.1",
25+
version="0.8.2",
2626
author="Huy Le Nguyen",
2727
author_email="[email protected]",
2828
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/augmentations/augments.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,29 @@
4040

4141

4242
class TFAugmentationExecutor:
43-
def __init__(self, augmentations: list):
43+
def __init__(self, augmentations: list, prob: float = 0.5):
4444
self.augmentations = augmentations
45+
self.prob = prob
4546

4647
@tf.function
4748
def augment(self, inputs):
4849
outputs = inputs
4950
for au in self.augmentations:
50-
outputs = au.augment(outputs)
51+
p = tf.random.uniform([])
52+
outputs = tf.where(tf.less(p, self.prob), au.augment(outputs), outputs)
5153
return outputs
5254

5355

5456
class Augmentation:
5557
def __init__(self, config: dict = None, use_tf: bool = False):
5658
if not config: config = {}
57-
if use_tf:
58-
self.before = self.tf_parse(config.pop("before", {}))
59-
self.after = self.tf_parse(config.pop("after", {}))
60-
else:
61-
self.before = self.parse(config.pop("before", {}))
62-
self.after = self.parse(config.pop("after", {}))
59+
prob = float(config.pop("prob", 0.5))
60+
parser = self.tf_parse if use_tf else self.parse
61+
self.before = parser(config.pop("before", {}), prob=prob)
62+
self.after = parser(config.pop("after", {}), prob=prob)
6363

6464
@staticmethod
65-
def parse(config: dict) -> list:
65+
def parse(config: dict, prob: float = 0.5) -> naf.Sometimes:
6666
augmentations = []
6767
for key, value in config.items():
6868
au = AUGMENTATIONS.get(key, None)
@@ -71,10 +71,10 @@ def parse(config: dict) -> list:
7171
f"Available augmentations: {AUGMENTATIONS.keys()}")
7272
aug = au(**value) if value is not None else au()
7373
augmentations.append(aug)
74-
return naf.Sometimes(augmentations)
74+
return naf.Sometimes(augmentations, pipeline_p=prob)
7575

7676
@staticmethod
77-
def tf_parse(config: dict) -> list:
77+
def tf_parse(config: dict, prob: float = 0.5) -> TFAugmentationExecutor:
7878
augmentations = []
7979
for key, value in config.items():
8080
au = TFAUGMENTATIONS.get(key, None)
@@ -83,4 +83,4 @@ def tf_parse(config: dict) -> list:
8383
f"Available tf augmentations: {TFAUGMENTATIONS.keys()}")
8484
aug = au(**value) if value is not None else au()
8585
augmentations.append(aug)
86-
return TFAugmentationExecutor(augmentations)
86+
return TFAugmentationExecutor(augmentations, prob=prob)

tensorflow_asr/datasets/base_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ def __init__(self,
3737
stage: str = "train",
3838
**kwargs):
3939
self.data_paths = data_paths or []
40+
if not isinstance(self.data_paths, list):
41+
raise ValueError('data_paths must be a list of string paths')
4042
self.augmentations = augmentations # apply augmentation
41-
self.cache = cache # whether to cache WHOLE transformed dataset to memory
43+
self.cache = cache # whether to cache transformed dataset to memory
4244
self.shuffle = shuffle # whether to shuffle tf.data.Dataset
43-
if buffer_size <= 0 and shuffle: raise ValueError("buffer_size must be positive when shuffle is on")
45+
if buffer_size <= 0 and shuffle:
46+
raise ValueError("buffer_size must be positive when shuffle is on")
4447
self.buffer_size = buffer_size # shuffle buffer size
4548
self.stage = stage # for defining tfrecords files
4649
self.use_tf = use_tf

tensorflow_asr/featurizers/text_featurizers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,10 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
295295
def cond(batch, total, _): return tf.less(batch, total)
296296

297297
def body(batch, total, transcripts):
298-
upoints = self.indices2upoints(indices[batch])
299-
transcripts = transcripts.write(batch, tf.strings.unicode_encode(upoints, "UTF-8"))
298+
norm_indices = self.normalize_indices(indices[batch])
299+
norm_indices = tf.gather_nd(norm_indices, tf.where(tf.not_equal(norm_indices, 0)))
300+
decoded = tf.numpy_function(self.subwords.decode, inp=[norm_indices], Tout=tf.string)
301+
transcripts = transcripts.write(batch, decoded)
300302
return batch + 1, total, transcripts
301303

302304
_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])

tensorflow_asr/models/conformer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,17 @@ def __init__(self,
307307

308308
if positional_encoding == "sinusoid":
309309
self.pe = PositionalEncoding(name=f"{name}_pe")
310+
elif positional_encoding == "sinusoid_v2":
311+
self.pe = PositionalEncoding(alpha=2, beta=0, name=f"{name}_pe")
310312
elif positional_encoding == "sinusoid_concat":
311313
self.pe = PositionalEncodingConcat(name=f"{name}_pe")
314+
elif positional_encoding == "sinusoid_concat_v2":
315+
self.pe = PositionalEncodingConcat(alpha=2, beta=-1, name=f"{name}_pe")
312316
elif positional_encoding == "subsampling":
313317
self.pe = tf.keras.layers.Activation("linear", name=f"{name}_pe")
314318
else:
315-
raise ValueError("positional_encoding must be either 'sinusoid' or 'subsampling'")
319+
raise ValueError("positional_encoding must be either 'sinusoid', \
320+
'sinusoid_concat', 'sinusoid_v2', 'sinusoid_concat_v2' or 'subsampling'")
316321

317322
self.linear = tf.keras.layers.Dense(
318323
dmodel, name=f"{name}_linear",
@@ -373,6 +378,7 @@ def __init__(self,
373378
encoder_depth_multiplier: int = 1,
374379
encoder_fc_factor: float = 0.5,
375380
encoder_dropout: float = 0,
381+
encoder_trainable: bool = True,
376382
prediction_embed_dim: int = 512,
377383
prediction_embed_dropout: int = 0,
378384
prediction_num_rnns: int = 1,
@@ -381,12 +387,16 @@ def __init__(self,
381387
prediction_rnn_implementation: int = 2,
382388
prediction_layer_norm: bool = True,
383389
prediction_projection_units: int = 0,
390+
prediction_trainable: bool = True,
384391
joint_dim: int = 1024,
385392
joint_activation: str = "tanh",
386393
prejoint_linear: bool = True,
394+
postjoint_linear: bool = False,
395+
joint_mode: str = "add",
396+
joint_trainable: bool = True,
387397
kernel_regularizer=L2,
388398
bias_regularizer=L2,
389-
name: str = "conformer_transducer",
399+
name: str = "conformer",
390400
**kwargs):
391401
super(Conformer, self).__init__(
392402
encoder=ConformerEncoder(
@@ -402,7 +412,9 @@ def __init__(self,
402412
fc_factor=encoder_fc_factor,
403413
dropout=encoder_dropout,
404414
kernel_regularizer=kernel_regularizer,
405-
bias_regularizer=bias_regularizer
415+
bias_regularizer=bias_regularizer,
416+
trainable=encoder_trainable,
417+
name=f"{name}_encoder"
406418
),
407419
vocabulary_size=vocabulary_size,
408420
embed_dim=prediction_embed_dim,
@@ -413,12 +425,17 @@ def __init__(self,
413425
rnn_implementation=prediction_rnn_implementation,
414426
layer_norm=prediction_layer_norm,
415427
projection_units=prediction_projection_units,
428+
prediction_trainable=prediction_trainable,
416429
joint_dim=joint_dim,
417430
joint_activation=joint_activation,
418431
prejoint_linear=prejoint_linear,
432+
postjoint_linear=postjoint_linear,
433+
joint_mode=joint_mode,
434+
joint_trainable=joint_trainable,
419435
kernel_regularizer=kernel_regularizer,
420436
bias_regularizer=bias_regularizer,
421-
name=name, **kwargs
437+
name=name,
438+
**kwargs
422439
)
423440
self.dmodel = encoder_dmodel
424441
self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor

tensorflow_asr/models/contextnet.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(self,
197197
vocabulary_size: int,
198198
encoder_blocks: List[dict],
199199
encoder_alpha: float = 0.5,
200+
encoder_trainable: bool = True,
200201
prediction_embed_dim: int = 512,
201202
prediction_embed_dropout: int = 0,
202203
prediction_num_rnns: int = 1,
@@ -205,9 +206,13 @@ def __init__(self,
205206
prediction_rnn_implementation: int = 2,
206207
prediction_layer_norm: bool = True,
207208
prediction_projection_units: int = 0,
209+
prediction_trainable: bool = True,
208210
joint_dim: int = 1024,
209211
joint_activation: str = "tanh",
210212
prejoint_linear: bool = True,
213+
postjoint_linear: bool = False,
214+
joint_mode: str = "add",
215+
joint_trainable: bool = True,
211216
kernel_regularizer=L2,
212217
bias_regularizer=L2,
213218
name: str = "contextnet",
@@ -218,6 +223,7 @@ def __init__(self,
218223
alpha=encoder_alpha,
219224
kernel_regularizer=kernel_regularizer,
220225
bias_regularizer=bias_regularizer,
226+
trainable=encoder_trainable,
221227
name=f"{name}_encoder"
222228
),
223229
vocabulary_size=vocabulary_size,
@@ -228,13 +234,18 @@ def __init__(self,
228234
rnn_type=prediction_rnn_type,
229235
rnn_implementation=prediction_rnn_implementation,
230236
layer_norm=prediction_layer_norm,
237+
prediction_trainable=prediction_trainable,
231238
projection_units=prediction_projection_units,
232239
joint_dim=joint_dim,
233240
joint_activation=joint_activation,
234241
prejoint_linear=prejoint_linear,
242+
postjoint_linear=postjoint_linear,
243+
joint_mode=joint_mode,
244+
joint_trainable=joint_trainable,
235245
kernel_regularizer=kernel_regularizer,
236246
bias_regularizer=bias_regularizer,
237-
name=name, **kwargs
247+
name=name,
248+
**kwargs
238249
)
239250
self.dmodel = self.encoder.blocks[-1].dmodel
240251
self.time_reduction_factor = 1

0 commit comments

Comments
 (0)