-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathdata_utils.py
More file actions
70 lines (54 loc) · 2.09 KB
/
data_utils.py
File metadata and controls
70 lines (54 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from datasets import load_dataset, Audio
from normalizer import EnglishTextNormalizer, BasicMultilingualTextNormalizer
from .eval_utils import read_manifest, write_manifest
def is_target_text_in_range(ref):
if ref.strip() == "ignore time segment in scoring":
return False
else:
return ref.strip() != ""
def get_text(sample):
if "text" in sample:
return sample["text"]
elif "sentence" in sample:
return sample["sentence"]
elif "normalized_text" in sample:
return sample["normalized_text"]
elif "transcript" in sample:
return sample["transcript"]
elif "transcription" in sample:
return sample["transcription"]
else:
raise ValueError(
f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
".join{sample.keys()}. Ensure a text column name is present in the dataset."
)
normalizer = EnglishTextNormalizer()
ml_normalizer = BasicMultilingualTextNormalizer()
def normalize(batch):
batch["original_text"] = get_text(batch)
batch["norm_text"] = normalizer(batch["original_text"])
return batch
def load_data(args):
dataset = load_dataset(
args.dataset_path,
args.dataset,
split=args.split,
streaming=args.streaming,
token=True,
)
return dataset
def prepare_data(dataset, decode_audio=True):
# Re-sample to 16kHz and normalise transcriptions
if decode_audio:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
# Keep decode=False but ensure sampling_rate is set
current_audio_feature = dataset.features["audio"]
if hasattr(current_audio_feature, 'decode') and not current_audio_feature.decode:
# Already set up with decode=False, don't change it
pass
else:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000, decode=False))
dataset = dataset.map(normalize)
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
return dataset