|
| 1 | +""" |
| 2 | +
|
| 3 | +This demo script showcases real-time speech enhancement using the "Cnam-LMSSC/EBEN_throat_microphone" model: |
| 4 | +
|
| 5 | +--> https://huggingface.co/Cnam-LMSSC/EBEN_throat_microphone |
| 6 | +
|
| 7 | +This a non-causal version of EBEN that we make causal by waiting for enough future context |
| 8 | +which corresponds to half of the receptive field of the model (207ms). |
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | +import torch |
| 13 | +import sounddevice as sd |
| 14 | +import numpy as np |
| 15 | +import threading |
| 16 | +import sys |
| 17 | +import termios |
| 18 | +import tty |
| 19 | +from vibravox.torch_modules.dnn.eben_generator import EBENGenerator |
| 20 | + |
| 21 | +# ====================================================== |
| 22 | +# KEYBOARD LISTENER (non-blocking) |
| 23 | +# ====================================================== |
| 24 | +eben_enabled = False # default ON |
| 25 | + |
| 26 | + |
| 27 | +def keyboard_listener(): |
| 28 | + global eben_enabled |
| 29 | + fd = sys.stdin.fileno() |
| 30 | + old = termios.tcgetattr(fd) |
| 31 | + tty.setcbreak(fd) |
| 32 | + try: |
| 33 | + while True: |
| 34 | + ch = sys.stdin.read(1) |
| 35 | + if ch == "e": |
| 36 | + eben_enabled = True |
| 37 | + print("🔊 EBEN ENABLED: 207ms latency") |
| 38 | + elif ch == "d": |
| 39 | + eben_enabled = False |
| 40 | + print("🔇 EBEN DISABLED (passthrough): 16ms latency") |
| 41 | + finally: |
| 42 | + termios.tcsetattr(fd, termios.TCSADRAIN, old) |
| 43 | + |
| 44 | + |
| 45 | +listener_thread = threading.Thread(target=keyboard_listener, daemon=True) |
| 46 | +listener_thread.start() |
| 47 | + |
| 48 | +# ====================================================== |
| 49 | +# EBEN + AUDIO SETUP |
| 50 | +# ====================================================== |
| 51 | +sample_rate = 16_000 |
| 52 | +hop_size = 256 # (=16ms) minimal number of samples needed to produce a new chunk of output given enough input samples |
| 53 | +window_size = 6624 # (=414ms) first valid length greater than 6340 (model receptive field) + 256 (hop_size) |
| 54 | +# the model receptive field can be computed from the architecture or simply by feeding a very long signal |
| 55 | +# of zeros except at one position and checking the length of the output that is non-zero |
| 56 | + |
| 57 | +# Valid region within the output window |
| 58 | +valid_start = (window_size - hop_size) // 2 |
| 59 | +valid_end = (window_size + hop_size) // 2 |
| 60 | + |
| 61 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 62 | +print(f"Using device: {device}") |
| 63 | +model = EBENGenerator.from_pretrained("Cnam-LMSSC/EBEN_throat_microphone") |
| 64 | +model = model.eval().to(device) |
| 65 | + |
| 66 | +in_stream = sd.InputStream( |
| 67 | + samplerate=sample_rate, |
| 68 | + channels=1, |
| 69 | + blocksize=hop_size, |
| 70 | + dtype="float32", |
| 71 | +) |
| 72 | +out_stream = sd.OutputStream( |
| 73 | + samplerate=sample_rate, |
| 74 | + channels=1, |
| 75 | + blocksize=hop_size, |
| 76 | + dtype="float32", |
| 77 | +) |
| 78 | + |
| 79 | +in_stream.start() |
| 80 | +out_stream.start() |
| 81 | + |
| 82 | +print("🎤 Live enhancement running... CTRL+C to stop") |
| 83 | +print("Press 'e' to ENABLE EBEN, 'd' to DISABLE EBEN") |
| 84 | + |
| 85 | +# Rolling buffers |
| 86 | +in_buffer = torch.zeros(1, 1, 0, device=device) |
| 87 | +out_buffer = np.zeros(0, dtype=np.float32) |
| 88 | + |
| 89 | +try: |
| 90 | + while True: |
| 91 | + # (1) Read throat microphone input (for mic ref: https://vibravox.cnam.fr/documentation/hardware/sensors/throat) |
| 92 | + in_block, _ = in_stream.read(hop_size) |
| 93 | + block_np = in_block.T |
| 94 | + block_t = torch.from_numpy(block_np).unsqueeze(0).to(device) |
| 95 | + |
| 96 | + if not eben_enabled: |
| 97 | + out_stream.write(in_block) # direct pass-through |
| 98 | + continue |
| 99 | + |
| 100 | + # (2) Append to rolling buffer |
| 101 | + in_buffer = torch.cat([in_buffer, block_t], dim=-1) |
| 102 | + |
| 103 | + # (3) Process windows if enough samples |
| 104 | + while in_buffer.shape[-1] >= window_size: |
| 105 | + input_chunk = in_buffer[:, :, :window_size] |
| 106 | + |
| 107 | + with torch.no_grad(): |
| 108 | + enhanced_chunk, _ = model(input_chunk) |
| 109 | + |
| 110 | + valid = enhanced_chunk[:, :, valid_start:valid_end] |
| 111 | + valid_np = valid.squeeze().detach().cpu().numpy().astype(np.float32) |
| 112 | + |
| 113 | + out_buffer = np.concatenate([out_buffer, valid_np]) |
| 114 | + in_buffer = in_buffer[:, :, hop_size:] |
| 115 | + |
| 116 | + # (4) Play output |
| 117 | + if out_buffer.shape[0] >= hop_size: |
| 118 | + play_chunk = out_buffer[:hop_size] |
| 119 | + out_buffer = out_buffer[hop_size:] |
| 120 | + else: |
| 121 | + play_chunk = np.zeros(hop_size, dtype=np.float32) |
| 122 | + |
| 123 | + out_stream.write(play_chunk.reshape(-1, 1)) |
| 124 | + |
| 125 | +except KeyboardInterrupt: |
| 126 | + print("\n🛑 Exiting...") |
| 127 | + |
| 128 | +finally: |
| 129 | + in_stream.stop() |
| 130 | + in_stream.close() |
| 131 | + out_stream.stop() |
| 132 | + out_stream.close() |
0 commit comments