-
-
Notifications
You must be signed in to change notification settings - Fork 170
Expand file tree
/
Copy pathwhisperx.py
More file actions
192 lines (158 loc) · 7.78 KB
/
whisperx.py
File metadata and controls
192 lines (158 loc) · 7.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import numpy as np
from .backend import Backend, Transcription, Segment
import os, math
from tqdm import tqdm # type: ignore
import uuid
from faster_whisper import WhisperModel, download_model
import whisperx
class WhisperxBackend(Backend):
device: str = "cuda" # cpu, cuda
quantization: str = "float16" # int8, float16
batch_size: int = 25
model: WhisperModel | None = None
def __init__(self, model_size, device: str = "cuda"):
self.model_size = model_size
self.device = device
self.__post_init__()
def model_path(self) -> str:
local_model_path = os.path.join(
os.environ["WHISPER_MODELS_DIR"], f"faster-whisper-{self.model_size}"
)
if os.path.exists(local_model_path):
return local_model_path
else:
raise RuntimeError(f"model not found in {local_model_path}")
def load(self) -> None:
print(f"Loading model: {self.model_path()}, {self.device}, {self.quantization}")
self.model = whisperx.load_model(
self.model_path(), device=self.device, compute_type=self.quantization
)
def get_model(self) -> None:
print(f"Downloading model {self.model_size}...")
local_model_path = os.path.join(os.environ["WHISPER_MODELS_DIR"], f"faster-whisper-{self.model_size}")
local_model_cache = os.path.join(os.environ["WHISPER_MODELS_DIR"], f"faster-whisper-{self.model_size}", "cache")
# Check if directory exists
if not os.path.exists(local_model_path):
os.makedirs(local_model_path)
try:
download_model(self.model_size, output_dir=local_model_path, local_files_only=True, cache_dir=local_model_cache)
print("Model already cached...")
except:
download_model(self.model_size, output_dir=local_model_path, local_files_only=False, cache_dir=local_model_cache)
# Splits a line based on commas or word gaps
def _split_lineIfNeeded(words, max_splits=12):
# If there are no words or we have no more splits left, return an empty list
if not words or max_splits <= 0:
return [{'words': words}]
# If the length of the words is less or equal to n, return the words as they are.
if len(words) <= 12:
return [{'words': words}]
# Find the index of the comma closest to the middle of the line
middle = len(words) // 2
comma_indices = [i for i, word in enumerate(words[:-1]) if ',' in word['word']]
closest_comma_index = min(comma_indices, key=lambda idx: abs(middle - idx), default=None)
# If there's no comma, find the largest gap among the 20% of words around the center
if closest_comma_index is None:
window_start = max(0, middle - len(words) // 5)
window_end = min(len(words), middle + len(words) // 5)
max_gap_size = 0
for i in range(window_start, window_end - 1):
gap_size = words[i + 1]['start'] - words[i]['end']
if gap_size > max_gap_size:
max_gap_size = gap_size
closest_comma_index = i
# If there's still no suitable split point (no comma and no gap found), split at the middle
if closest_comma_index is None:
closest_comma_index = middle
# Splitting the line at the closest comma or the largest gap
left_part = words[:closest_comma_index + 1]
right_part = words[closest_comma_index + 1:]
# Recursively check if the split parts need further splitting
split_left_part = WhisperxBackend._split_lineIfNeeded(words=left_part, max_splits=max_splits-1)
split_right_part = WhisperxBackend._split_lineIfNeeded(words=right_part, max_splits=max_splits-1)
return split_left_part + split_right_part
def transcribe(
self, input: np.ndarray, silent: bool = False, language: str = None
) -> Transcription:
"""
Return word level transcription data.
World level probabities are calculated by ctranslate2.models.Whisper.align
"""
print("Transcribing with WhisperX...")
assert self.model is not None
result = self.model.transcribe(
input,
language=language,
)
language_code = result["language"]
print(f"Language code: {language_code}")
model_align, metadata = whisperx.load_align_model(language_code=language_code, device="cuda")
result = whisperx.align(result['segments'], model_align, metadata, input, "cuda", return_char_alignments=False)
all_file_words = []
#write result_segments to file
with open("/var/log/whishper/segments.json", "w") as f:
f.write(str(result))
for segment in result['segments']:
for word in segment['words']:
all_file_words.append(word)
srt_output = []
line_buffer = []
for i, word in enumerate(all_file_words):
if word.get('start') and word.get('end') is not None:
word['start'] = round(word['start'], 3)
word['end'] = round(word['end'], 3)
else:
#work backwards to find the start time. this isn't really accurate.
word['start'] = round(all_file_words[i - 1]['end'] + 0.01, 3) if i > 0 else 0
word['end'] = round(all_file_words[i - 1]['end'] + 0.01, 3) if i > 0 else 0
# Post-process to adjust 'end' properties
for i in range(len(all_file_words) - 1):
word = all_file_words[i]
next_word = all_file_words[i + 1]
# Adjust 'end' considering the 'start' of the next word
if not ('end' in word):
word['end'] = round(next_word['start'] - 0.001, 3)
word['end'] = max(word['end'], round(word['start'] + 0.5, 3)) # Ensure minimum duration
word['end'] = min(word['end'], round(next_word['start'] - 0.001, 3)) # Ensure not overlapping with next word
for word in all_file_words:
line_buffer.append(word)
if word['word'].endswith(('.', '?', '!')): # Check for sentence-ending punctuation
if len(line_buffer) > 12:
srt_output.extend(WhisperxBackend._split_lineIfNeeded(words=line_buffer, max_splits=12))
else:
srt_output.append({'words': line_buffer})
line_buffer = []
# If there are words left in the buffer after the loop, treat as a line
if line_buffer:
srt_output.extend(WhisperxBackend._split_lineIfNeeded(words=line_buffer, max_splits=3))
result_segments = []
# Store the segments
for index, line in enumerate(srt_output):
if len(line['words']) == 0:
continue
id = uuid.uuid4().hex
start = line['words'][0]['start']
end = line['words'][-1]['end']
text = " ".join([word['word'] for word in line['words']])
#score = sum([word['score'] for word in line['words']]) / len(line['words'])
score = 0
segment_extract: Segment = {
"id": id,
"text": text,
"start": start,
"end": end,
"score": score,
"words": line['words'],
}
result_segments.append(segment_extract)
text = " ".join([segment["text"] for segment in result_segments])
text = ' '.join(text.strip().split())
#get duration from last segment
duration = result_segments[-1]["end"]
transcription: Transcription = {
"text": text,
"language": language_code,
"duration": duration,
"segments": result_segments,
}
return transcription