Skip to content

Commit e479891

Browse files
authored
Dummy demo from non-causal model
1 parent cba3e57 commit e479891

1 file changed

Lines changed: 132 additions & 0 deletions

File tree

scripts/dummy_eben_demo.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
3+
This demo script showcases real-time speech enhancement using the "Cnam-LMSSC/EBEN_throat_microphone" model:
4+
5+
--> https://huggingface.co/Cnam-LMSSC/EBEN_throat_microphone
6+
7+
This a non-causal version of EBEN that we make causal by waiting for enough future context
8+
which corresponds to half of the receptive field of the model (207ms).
9+
10+
"""
11+
12+
import torch
13+
import sounddevice as sd
14+
import numpy as np
15+
import threading
16+
import sys
17+
import termios
18+
import tty
19+
from vibravox.torch_modules.dnn.eben_generator import EBENGenerator
20+
21+
# ======================================================
22+
# KEYBOARD LISTENER (non-blocking)
23+
# ======================================================
24+
eben_enabled = False # default ON
25+
26+
27+
def keyboard_listener():
28+
global eben_enabled
29+
fd = sys.stdin.fileno()
30+
old = termios.tcgetattr(fd)
31+
tty.setcbreak(fd)
32+
try:
33+
while True:
34+
ch = sys.stdin.read(1)
35+
if ch == "e":
36+
eben_enabled = True
37+
print("🔊 EBEN ENABLED: 207ms latency")
38+
elif ch == "d":
39+
eben_enabled = False
40+
print("🔇 EBEN DISABLED (passthrough): 16ms latency")
41+
finally:
42+
termios.tcsetattr(fd, termios.TCSADRAIN, old)
43+
44+
45+
listener_thread = threading.Thread(target=keyboard_listener, daemon=True)
46+
listener_thread.start()
47+
48+
# ======================================================
49+
# EBEN + AUDIO SETUP
50+
# ======================================================
51+
sample_rate = 16_000
52+
hop_size = 256 # (=16ms) minimal number of samples needed to produce a new chunk of output given enough input samples
53+
window_size = 6624 # (=414ms) first valid length greater than 6340 (model receptive field) + 256 (hop_size)
54+
# the model receptive field can be computed from the architecture or simply by feeding a very long signal
55+
# of zeros except at one position and checking the length of the output that is non-zero
56+
57+
# Valid region within the output window
58+
valid_start = (window_size - hop_size) // 2
59+
valid_end = (window_size + hop_size) // 2
60+
61+
device = "cuda" if torch.cuda.is_available() else "cpu"
62+
print(f"Using device: {device}")
63+
model = EBENGenerator.from_pretrained("Cnam-LMSSC/EBEN_throat_microphone")
64+
model = model.eval().to(device)
65+
66+
in_stream = sd.InputStream(
67+
samplerate=sample_rate,
68+
channels=1,
69+
blocksize=hop_size,
70+
dtype="float32",
71+
)
72+
out_stream = sd.OutputStream(
73+
samplerate=sample_rate,
74+
channels=1,
75+
blocksize=hop_size,
76+
dtype="float32",
77+
)
78+
79+
in_stream.start()
80+
out_stream.start()
81+
82+
print("🎤 Live enhancement running... CTRL+C to stop")
83+
print("Press 'e' to ENABLE EBEN, 'd' to DISABLE EBEN")
84+
85+
# Rolling buffers
86+
in_buffer = torch.zeros(1, 1, 0, device=device)
87+
out_buffer = np.zeros(0, dtype=np.float32)
88+
89+
try:
90+
while True:
91+
# (1) Read throat microphone input (for mic ref: https://vibravox.cnam.fr/documentation/hardware/sensors/throat)
92+
in_block, _ = in_stream.read(hop_size)
93+
block_np = in_block.T
94+
block_t = torch.from_numpy(block_np).unsqueeze(0).to(device)
95+
96+
if not eben_enabled:
97+
out_stream.write(in_block) # direct pass-through
98+
continue
99+
100+
# (2) Append to rolling buffer
101+
in_buffer = torch.cat([in_buffer, block_t], dim=-1)
102+
103+
# (3) Process windows if enough samples
104+
while in_buffer.shape[-1] >= window_size:
105+
input_chunk = in_buffer[:, :, :window_size]
106+
107+
with torch.no_grad():
108+
enhanced_chunk, _ = model(input_chunk)
109+
110+
valid = enhanced_chunk[:, :, valid_start:valid_end]
111+
valid_np = valid.squeeze().detach().cpu().numpy().astype(np.float32)
112+
113+
out_buffer = np.concatenate([out_buffer, valid_np])
114+
in_buffer = in_buffer[:, :, hop_size:]
115+
116+
# (4) Play output
117+
if out_buffer.shape[0] >= hop_size:
118+
play_chunk = out_buffer[:hop_size]
119+
out_buffer = out_buffer[hop_size:]
120+
else:
121+
play_chunk = np.zeros(hop_size, dtype=np.float32)
122+
123+
out_stream.write(play_chunk.reshape(-1, 1))
124+
125+
except KeyboardInterrupt:
126+
print("\n🛑 Exiting...")
127+
128+
finally:
129+
in_stream.stop()
130+
in_stream.close()
131+
out_stream.stop()
132+
out_stream.close()

0 commit comments

Comments
 (0)