-
Notifications
You must be signed in to change notification settings - Fork 4.1k
online mix noise audio data in training step #2622
base: master
Are you sure you want to change the base?
Changes from 6 commits
681f470
421243d
d08efad
b0a14b5
ba1a587
aebd08d
d255c3f
ec25136
484134e
4f24f08
1f57ece
66cc7c4
b7eb0f4
ccae7cc
8cc95f9
9e2648a
2269514
0b8147c
42bc45b
289722d
9334e79
25736e0
40b431b
f7d1279
7792226
c4c3ced
c151b1d
c089b7f
491a4b0
735cbbb
2fa91e8
6b820bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import tensorflow as tf | ||
import tensorflow.compat.v1 as tfv1 | ||
from tensorflow.python.ops import gen_audio_ops as contrib_audio | ||
import os | ||
|
||
def collect_noise_filenames(walk_dirs): | ||
assert isinstance(walk_dirs, list) | ||
|
||
for d in walk_dirs: | ||
for dirpath, _, filenames in os.walk(d): | ||
for filename in filenames: | ||
if filename.endswith('.wav'): | ||
yield os.path.join(dirpath, filename) | ||
|
||
def noise_file_to_audio(noise_file): | ||
samples = tf.io.read_file(noise_file) | ||
decoded = contrib_audio.decode_wav(samples, desired_channels=1) | ||
return decoded.audio | ||
|
||
def augment_noise(audio, | ||
noise_audio, | ||
change_audio_db_max=0, | ||
change_audio_db_min=-10, | ||
change_noise_db_max=-15, | ||
change_noise_db_min=-25): | ||
|
||
decoded_audio_len = tf.shape(audio)[0] | ||
noise_decoded_audio_len = tf.shape(noise_audio)[0] | ||
|
||
multiply = tf.math.floordiv(decoded_audio_len, noise_decoded_audio_len) + 1 | ||
noise_audio_tile = tf.tile(noise_audio, [multiply, 1]) | ||
|
||
# now noise_decoded_len must > decoded_len | ||
noise_decoded_audio_len = tf.shape(noise_audio_tile)[0] | ||
|
||
mix_decoded_start_end_points = tfv1.random_uniform( | ||
[2], minval=0, maxval=decoded_audio_len-1, dtype=tf.int32) | ||
mix_decoded_start_point = tf.math.reduce_min(mix_decoded_start_end_points) | ||
mix_decoded_end_point = tf.math.reduce_max( | ||
mix_decoded_start_end_points) + 1 | ||
mix_decoded_width = mix_decoded_end_point - mix_decoded_start_point | ||
|
||
left_zeros = tf.zeros(shape=[mix_decoded_start_point, 1]) | ||
|
||
mix_noise_decoded_start_point = tfv1.random_uniform( | ||
[], minval=0, maxval=noise_decoded_audio_len - mix_decoded_width, dtype=tf.int32) | ||
mix_noise_decoded_end_point = mix_noise_decoded_start_point + mix_decoded_width | ||
extract_noise_decoded = noise_audio_tile[mix_noise_decoded_start_point:mix_noise_decoded_end_point, :] | ||
|
||
right_zeros = tf.zeros( | ||
shape=[decoded_audio_len - mix_decoded_end_point, 1]) | ||
|
||
mixed_noise = tf.concat( | ||
[left_zeros, extract_noise_decoded, right_zeros], axis=0) | ||
|
||
choosen_audio_db = tfv1.random_uniform( | ||
[], minval=change_audio_db_min, maxval=change_audio_db_max) | ||
audio_ratio = tf.math.pow(10.0, choosen_audio_db / 10) | ||
|
||
choosen_noise_db = tfv1.random_uniform( | ||
[], minval=change_noise_db_min, maxval=change_noise_db_max) | ||
noise_ratio = tf.math.pow(10.0, choosen_noise_db / 10) | ||
return tf.multiply(audio, audio_ratio) + tf.multiply(mixed_noise, noise_ratio) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from util.flags import FLAGS | ||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up | ||
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT | ||
from util.audio_augmentation import augment_noise, noise_file_to_audio, collect_noise_filenames | ||
|
||
|
||
def read_csvs(csv_files): | ||
|
@@ -64,11 +65,26 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): | |
return mfccs, tf.shape(input=mfccs)[0] | ||
|
||
|
||
def audiofile_to_features(wav_filename, train_phase=False): | ||
def audiofile_to_features(wav_filename, train_phase=False, noise_iterator=None): | ||
samples = tf.io.read_file(wav_filename) | ||
decoded = contrib_audio.decode_wav(samples, desired_channels=1) | ||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase) | ||
audio = decoded.audio | ||
|
||
# augment audio | ||
if train_phase and noise_iterator: | ||
audio = augment_noise( | ||
audio, | ||
noise_iterator.get_next(), | ||
change_audio_db_max=FLAGS.audio_aug_mix_noise_max_audio_db, | ||
change_audio_db_min=FLAGS.audio_aug_mix_noise_min_audio_db, | ||
change_noise_db_max=FLAGS.audio_aug_mix_noise_max_noise_db, | ||
change_noise_db_min=FLAGS.audio_aug_mix_noise_min_noise_db, | ||
) | ||
|
||
|
||
features, features_len = samples_to_mfccs(audio, decoded.sample_rate, train_phase=train_phase) | ||
|
||
# augment features | ||
if train_phase: | ||
if FLAGS.data_aug_features_multiplicative > 0: | ||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) | ||
|
@@ -79,9 +95,9 @@ def audiofile_to_features(wav_filename, train_phase=False): | |
return features, features_len | ||
|
||
|
||
def entry_to_features(wav_filename, transcript, train_phase): | ||
def entry_to_features(wav_filename, transcript, train_phase, noise_iterator=None): | ||
# https://bugs.python.org/issue32117 | ||
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) | ||
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase, noise_iterator=noise_iterator) | ||
return wav_filename, features, features_len, tf.SparseTensor(*transcript) | ||
|
||
|
||
|
@@ -120,7 +136,22 @@ def batch_fn(wav_filenames, features, features_len, transcripts): | |
return tf.data.Dataset.zip((wav_filenames, features, transcripts)) | ||
|
||
num_gpus = len(Config.available_devices) | ||
process_fn = partial(entry_to_features, train_phase=train_phase) | ||
|
||
if train_phase and FLAGS.audio_aug_mix_noise_walk_dirs: | ||
# because we have to determine the shuffle size, so we could not use generator | ||
noise_filenames = tf.convert_to_tensor( | ||
list(collect_noise_filenames(FLAGS.audio_aug_mix_noise_walk_dirs.split(','))), | ||
dtype=tf.string) | ||
print(">>> Collect {} noise files for mixing audio".format(noise_filenames.shape[0])) | ||
noise_dataset = (tf.data.Dataset.from_tensor_slices(noise_filenames) | ||
.map(noise_file_to_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
.shuffle(noise_filenames.shape[0]) | ||
.cache(FLAGS.audio_aug_mix_noise_cache) | ||
.repeat()) | ||
noise_iterator = tf.compat.v1.data.make_one_shot_iterator(noise_dataset) | ||
else: | ||
noise_iterator = None | ||
process_fn = partial(entry_to_features, train_phase=train_phase, noise_iterator=noise_iterator) | ||
|
||
|
||
dataset = (tf.data.Dataset.from_generator(generate_values, | ||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) | ||
|
Uh oh!
There was an error while loading. Please reload this page.