-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathinference.py
More file actions
73 lines (55 loc) · 2.29 KB
/
inference.py
File metadata and controls
73 lines (55 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import onnxruntime as ort
from transformers import WhisperFeatureExtractor
from audio_utils import truncate_audio_to_last_n_seconds
ONNX_MODEL_PATH = "smart-turn-v3.1.onnx"
def build_session(onnx_path):
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
return ort.InferenceSession(onnx_path, sess_options=so)
feature_extractor = WhisperFeatureExtractor(chunk_length=8)
session = build_session(ONNX_MODEL_PATH)
def predict_endpoint(audio_array):
"""
Predict whether an audio segment is complete (turn ended) or incomplete.
Args:
audio_array: Numpy array containing audio samples at 16kHz
Returns:
Dictionary containing prediction results:
- prediction: 1 for complete, 0 for incomplete
- probability: Probability of completion (sigmoid output)
"""
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
# Process audio using Whisper's feature extractor
inputs = feature_extractor(
audio_array,
sampling_rate=16000,
return_tensors="np",
padding="max_length",
max_length=8 * 16000,
truncation=True,
do_normalize=True,
)
# Extract features and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
# Run ONNX inference
outputs = session.run(None, {"input_features": input_features})
# Extract probability (ONNX model returns sigmoid probabilities)
probability = outputs[0][0].item()
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if probability > 0.5 else 0
return {
"prediction": prediction,
"probability": probability,
}
# Example usage
if __name__ == "__main__":
# Create a dummy audio array for testing (1 second of random audio)
dummy_audio = np.random.randn(16000).astype(np.float32)
result = predict_endpoint(dummy_audio)
print(f"Prediction: {result['prediction']}")
print(f"Probability: {result['probability']:.4f}")