-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathwhisper.py
More file actions
35 lines (26 loc) · 1.39 KB
/
whisper.py
File metadata and controls
35 lines (26 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import soundfile as sf
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizerFast
# load model and processor
tokenizer = WhisperTokenizerFast.from_pretrained('openai/whisper-base')
processor = WhisperProcessor.from_pretrained('openai/whisper-base', tokenizer=tokenizer)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device: ', device)
model.to(device)
model.config.forced_decoder_ids = None
def is_cuda():
print('Using device: ', device)
return device.type == 'cuda'
def transcribe(file_name: str) -> str:
audio, sample_rate = sf.read(file_name)
input_features = processor(audio, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features, max_length=1000)
transcription: str = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
def set_param_size(param_size: str = 'base'):
global model, processor, tokenizer
tokenizer = WhisperTokenizerFast.from_pretrained(f'openai/whisper-{param_size}')
processor = WhisperProcessor.from_pretrained(f'openai/whisper-{param_size}', tokenizer=tokenizer)
model = WhisperForConditionalGeneration.from_pretrained(f'openai/whisper-{param_size}')
model.to(device)