-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwhisperx_handler.py
More file actions
81 lines (66 loc) · 2.75 KB
/
whisperx_handler.py
File metadata and controls
81 lines (66 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json
import io
import os
import torch
import tempfile
import whisperx
import numpy as np
import soundfile as sf
from pydub import AudioSegment
from whisperx_model import whisperXModel
from ts.torch_handler.base_handler import BaseHandler
class CustomASRHandler(BaseHandler):
def initialize(self, context):
properties = context.system_properties
model_dir = properties.get("model_dir")
config_path = os.path.join(model_dir, "config_wx.json")
with open(config_path, "r") as config_file:
config = json.load(config_file)
size = config.get("size", "small")
device = config.get("device", "cpu")
compute = config.get("compute", "int8")
self.model = whisperXModel(size, device, compute)
self.initialized = True
def preprocess(self, data):
if 'data' in data[0]:
temp_audio_path = "/tmp/temp_audio_file.wav"
with open(temp_audio_path, 'wb') as f:
f.write(data[0].get("data"))
return whisperx.load_audio(temp_audio_path)
elif 'audio' in data[0]:
audio_data = data[0].get('audio')
if not audio_data:
raise ValueError("No audio data received")
# Write the audio data directly to a temporary file
temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") # Adjust suffix as needed
try:
temp_audio_file.write(audio_data)
print(f"Temporary file created: {temp_audio_file.name}")
temp_audio_file.flush()
# Load the audio directly from MP4 or Ogg format
loaded_audio = whisperx.load_audio(temp_audio_file.name)
except Exception as load_error:
raise ValueError(f"Failed to load audio file: {load_error}")
finally:
temp_audio_file.close()
os.remove(temp_audio_file.name)
print(f"Temporary file deleted: {temp_audio_file.name}")
return loaded_audio
elif 'body' in data[0]:
audio_file = data[0].get("body").strip()
return whisperx.load_audio(audio_file)
else:
raise ValueError("No valid audio field in data.")
def inference(self, data):
print('TRANSCRIBING NOW')
#print(f'DATA = {data}')
decode, alignment = self.model(data)
result = [{'transcription':decode, 'alignment':alignment}]
return result
def postprocess(self, inference_output):
print("SERVER RESPONSE:", inference_output)
return inference_output
def handle(self, data, context):
wav = self.preprocess(data)
model_output = self.inference(wav)
return self.postprocess(model_output)