Conversation
Optional batched VAD feature.
|
It's possible to get a similar speedup but with mathematical equivalence if you thread pool the encoder and decode as they finish. Also the model still wastes flops on the stft which is a free 5% |
I don't know how to do that because model is stateful. Can you share/make PR such func producing mathematical equivalence? |
class SileroVADModelFast:
def __init__(self, encoder_path, decoder_path):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.enable_cpu_mem_arena = False
opts.log_severity_level = 4
self.encoder_session = onnxruntime.InferenceSession(
encoder_path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
self.decoder_session = onnxruntime.InferenceSession(
decoder_path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def __call__(
self,
audio: np.ndarray,
threads: int = 1,
window_size_samples: int = 512,
context_size_samples: int = 64,
) -> np.ndarray:
assert (
audio.ndim == 2
), "Input should be a 2D array with size (batch_size, num_samples)"
batch_size, num_samples = audio.shape
rhs_padding = window_size_samples - num_samples % window_size_samples
audio = np.pad(audio, ((0, 0), (context_size_samples, rhs_padding)))
num_samples = audio.shape[1]
encoder_batch_size = 256
batch_samples = encoder_batch_size // batch_size * window_size_samples
input_size = window_size_samples + context_size_samples
h = np.zeros((1, batch_size, 128), dtype=np.float32)
c = np.zeros((1, batch_size, 128), dtype=np.float32)
def encode(i):
batch = audio[:, i : i + batch_samples + context_size_samples]
shape = (batch_size, batch.shape[1] // window_size_samples, input_size)
strides = (
batch.strides[0],
batch.strides[1] * window_size_samples,
batch.strides[1],
)
batch = np.lib.stride_tricks.as_strided(batch, shape, strides)
return self.encoder_session.run(None, {"input": batch})[0]
outputs = []
with ThreadPoolExecutor(threads) as executor:
futures = [executor.submit(encode, i) for i in range(0, num_samples, batch_samples)]
for future in futures:
batch = future.result()
output, h, c = self.decoder_session.run(None, {"input": batch, "h": h, "c": c})
outputs.append(output)
out = np.concatenate(outputs, axis=0).T
return out
model = SileroVADModelFast('enc.onnx', 'dec.onnx')
model(audio[None], 8)
6.6hrs
2 15523.9138 12746.2859
3 10647.9426 8864.6105
4 8231.9762 6742.6077
5 6804.4266 5604.9741
6 5940.452 4912.6163
7 5332.413 4346.9303
8 4776.8938 4432.3617
9 4498.9325 4454.5463
10 4311.0307 4468.7211
11 4176.293 4430.4164
12 4139.2589 plateaus
13 3995.533
14 3730.9739
15 3854.1674
16 3723.9984Needed onnx files attached, faster when <10 threads but doesn't scale. Probs was identical with main which took like 30s. |
|
@sssshhhhhh Thanks for the func. As I understand, you benchmarked on a CPU with 8 logical processors, right? |
|
16, it's limited by decoder speed. Can always chunk like you do if more speed is needed. But at 1s/hr already idk if it'll make a big impact on latency. |
Could you benchmark |
|
14 3663.7819 pretty sure it's bandwidth bound, power draw isn't that high for all core |
|
Tested your func, on my CPU it's not that slower with 16 threads like in your benchmark, just 5% slower. |
Optional batched VAD feature.
Prompted by #1388
[User wanted GPU option because it's ~5x faster in slow CPU/fast GPU env, this PR should give similar speed increase on CPU]
Enabled when VAD attribute
vad_batch_size> 1 .Optimal
vad_batch_sizevalue is probably somewhere between 1.5x-2x CPU threads.RAM usage increase [check on 2h audio]:
Speed tests:
Up to 4.3 times faster with
vad_batch_size=16on CPU with 8 logical processors.Up to 8.9 times faster with
vad_batch_size=24on CPU with 16 logical processors.Probs difference [inspected 2h audio with vad_batch_size=8`]:
VAD timestamps are a bit different, not worse - not better.
5% timestamps were different from 1233 total, 92% of those diffs were insignificant.
Just some stats for nerds
Depends on #1406 [but its influence on the probs is not significant]