-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
268 lines (222 loc) · 9.05 KB
/
utils.py
File metadata and controls
268 lines (222 loc) · 9.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import json
import os
import re
import time
import traceback
from typing import Any
import ffmpeg
from halo import Halo
from mms.align_utils import get_alignments, get_spans, get_uroman_tokens
from mms.text_normalization import text_normalize
from timestamp_types import File, FileTimestamps, Match, Section
mms_languages = json.load(open("mms_languages.json"))
def match_files(
files: list[File],
) -> list[Match]:
"""
Match audio and text files by name (without extension).
"""
audio_extensions = {".wav", ".mp3"}
text_extensions = {".txt", ".usfm"}
# Dictionary to store matched files by name (without extension)
matched_files = {}
for file_name, path in files:
name, ext = file_name.rsplit(".", 1)
ext = f".{ext}"
if ext in audio_extensions:
if name not in matched_files:
matched_files[name] = (None, None)
matched_files[name] = (
(file_name, path),
matched_files[name][1],
) # Store audio file
elif ext in text_extensions:
if name not in matched_files:
matched_files[name] = (None, None)
matched_files[name] = (
matched_files[name][0],
(file_name, path),
) # Store text file
# Filter out pairs where either the audio or text is missing
return [match for match in matched_files.values() if None not in match]
def align_matches(
folder: str,
language: str | None,
separator: str,
matches: list[Match],
model: Any,
dictionary: Any,
):
"""
Align audio and text files and write output to Firestore.
"""
spinner = Halo("Aligning...").start()
file_timestamps: list[FileTimestamps] = []
progress = 0
total_length = 0
for index, match in enumerate(matches):
if match[0] is None or match[1] is None:
continue
try:
audio_path = match[0][1]
audio_type = match[0][0].split(".")[-1]
chapter_id = ".".join(match[0][0].split(".")[0:-1])
wav_output = audio_path.replace(f".{audio_type}", "_output.wav")
spinner.text = f"Converting audio to {wav_output}..."
spinner.start()
total_length += float(ffmpeg.probe(audio_path)["streams"][0]["duration"])
stream = ffmpeg.input(audio_path)
stream = ffmpeg.output(stream, wav_output, acodec="pcm_s16le", ar=16000)
stream = ffmpeg.overwrite_output(stream)
ffmpeg.run(
stream,
overwrite_output=True,
cmd=["ffmpeg", "-loglevel", "error"], # type: ignore
)
spinner.succeed(f"Audio converted to {wav_output}.")
# Identify the session language. This is time
# consuming so we only do it for the first file and assume
# all files are the same language.
if index == 0 and language is None:
# Cut down audio to 10 seconds for language
# identification.
spinner.text = "Identifying language..."
spinner.start()
cut_output = f"{folder}/cut_output.wav"
stream = ffmpeg.input(wav_output)
stream = ffmpeg.output(stream, cut_output, t=30)
stream = ffmpeg.overwrite_output(stream)
ffmpeg.run(
stream,
overwrite_output=True,
cmd=["ffmpeg", "-loglevel", "error"], # type: ignore
)
language = identify_language(cut_output)
# Check if language is valid.
language_match = next(
(item for item in mms_languages if item["iso"] == language), None
)
if language_match is None or not language_match["align"]:
spinner.fail(f"Detected language {language} not supported.")
return
else:
spinner.succeed(f"Valid language identified as {language}.")
# Remove the cut file
os.remove(cut_output)
text_path = match[1][1]
text_extension = match[1][0].split(".")[-1]
text_file = open(text_path, "r", encoding="utf-8")
lines_to_timestamp = []
if text_extension == "json":
verses = json.load(text_file)
for verse in verses:
lines_to_timestamp.append(verse["text"])
elif text_extension == "txt":
# Read the separator from the query parameter and adjust
# it so it can be used in the split function.
if separator == "lineBreak":
separator = "\n"
elif separator == "squareBracket":
separator = "["
elif separator == "downArrow":
separator = "⬇️"
lines_to_timestamp = text_file.read().strip(separator).split(separator)
lines_to_timestamp = [line for line in lines_to_timestamp if line.strip()]
elif text_extension == "usfm":
# Define the tags to ignore
ignore_tags = [
"\\c",
"\\p",
"\\s",
"\\s1",
"\\s2",
"\\f",
"\\ft",
"\\fr",
"\\x",
"\\xt",
"\\xo",
"\\r",
"\\t",
"\\m",
]
# Compile a regex to match tags we want to ignore
ignore_regex = re.compile(
r"|".join(re.escape(tag) for tag in ignore_tags)
)
current_verse = ""
for line in text_file:
if ignore_regex.match(line.strip()):
continue
if line.startswith(r"\v"): # USFM verse marker
if current_verse:
cleaned_verse = re.sub(
r"\\[a-z]+\s?", "", current_verse.strip()
)
lines_to_timestamp.append(cleaned_verse)
current_verse = line.strip() # Start a new verse
else:
current_verse += " " + line.strip()
if current_verse: # Append the last verse after the loop
cleaned_verse = re.sub(r"\\[a-z]+\s?", "", current_verse.strip())
lines_to_timestamp.append(cleaned_verse)
norm_lines_to_timestamp = [
text_normalize(
line.strip(), language if language is not None else "eng"
)
for line in lines_to_timestamp
]
uroman_lines_to_timestamp = get_uroman_tokens(
norm_lines_to_timestamp, language
)
uroman_lines_to_timestamp = ["<star>"] + uroman_lines_to_timestamp
lines_to_timestamp = ["<star>"] + lines_to_timestamp
norm_lines_to_timestamp = ["<star>"] + norm_lines_to_timestamp
spinner.succeed("Text normalized and romanized.")
spinner.text = "Aligning..."
spinner.start()
segments, stride = get_alignments(
wav_output,
uroman_lines_to_timestamp,
model,
dictionary,
)
spans = get_spans(uroman_lines_to_timestamp, segments)
sections = []
for i, t in enumerate(lines_to_timestamp):
if i == 0:
continue
span = spans[i]
seg_start_idx = span[0].start
seg_end_idx = span[-1].end
audio_start_sec = round(seg_start_idx * stride / 1000, 2)
audio_end_sec = round(seg_end_idx * stride / 1000, 2)
section: Section = {
"verse_id": f"{chapter_id}.{i}",
"timings": (audio_start_sec, audio_end_sec),
"timings_str": (
time.strftime("%H:%M:%S", time.gmtime(audio_start_sec)),
time.strftime("%H:%M:%S", time.gmtime(audio_end_sec)),
),
"text": t,
"uroman_tokens": uroman_lines_to_timestamp[i],
}
sections.append(section)
except Exception:
spinner.fail("Failed to align.")
print(traceback.format_exc())
return
spinner.succeed("Alignment done.")
spinner.text = "Cleaning up..."
spinner.start()
os.remove(wav_output)
spinner.succeed("Cleaned up.")
file_timestamps.append(
{
"audio_file": match[0][0],
"text_file": match[1][0],
"sections": sections,
}
)
progress += 1
return file_timestamps