Skip to content

Commit 1bc8186

Browse files
authored
Merge pull request #5 from sensein/flow
Adding speech to visemes as a child of BaseHandler
2 parents 7176a1b + c7b85e1 commit 1bc8186

14 files changed

+401
-231
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
__pycache__
22
tmp
33
cache
4-
mlx_models/
4+
mlx_models/
5+
asset/
6+
config/

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ This repository implements a speech-to-speech cascaded pipeline consisting of th
2828
2. **Speech to Text (STT)**
2929
3. **Language Model (LM)**
3030
4. **Text to Speech (TTS)**
31+
5. **Speech to Visemes (STV)**
3132

3233
### Modularity
3334
The pipeline provides a fully open and modular approach, with a focus on leveraging models available through the Transformers library on the Hugging Face hub. The code is designed for easy modification, and we already support device-specific and external library implementations:
@@ -50,6 +51,9 @@ The pipeline provides a fully open and modular approach, with a focus on leverag
5051
- [MeloTTS](https://github.com/myshell-ai/MeloTTS)
5152
- [ChatTTS](https://github.com/2noise/ChatTTS?tab=readme-ov-file)
5253

54+
**STV**
55+
- [Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/en/model_doc/wav2vec2_phoneme) + [Phoneme to viseme mapping](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis-viseme?tabs=visemeid&pivots=programming-language-python#map-phonemes-to-visemes)
56+
5357
## Setup
5458

5559
Clone the repository:
@@ -216,6 +220,13 @@ For example:
216220
--lm_model_name google/gemma-2b-it
217221
```
218222

223+
224+
### STV parameters
225+
See [Wav2Vec2STVHandlerArguments](arguments_classes/w2v_stv_arguments.py) class. Notably:
226+
- `stv_model_name` is by default `bookbot/wav2vec2-ljspeech-gruut` and has been chosen because accurate and fast enough
227+
- `stv_skip`, flag it to `True` if you don't need visemes
228+
229+
219230
### Generation parameters
220231

221232
Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed.

STT/paraformer_handler.py

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def setup(
2828
device="cuda",
2929
gen_kwargs={},
3030
):
31-
print(model_name)
3231
if len(model_name.split("/")) > 1:
3332
model_name = model_name.split("/")[-1]
3433
self.device = device
File renamed without changes.

STV/w2v_stv_handler.py

+253
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import json
2+
import logging
3+
import time
4+
from typing import Any, Dict, Generator, List
5+
6+
import numpy as np
7+
from rich.console import Console
8+
from transformers import pipeline
9+
10+
from baseHandler import BaseHandler
11+
12+
logger = logging.getLogger(__name__)
13+
console = Console()
14+
15+
16+
class Wav2Vec2STVHandler(BaseHandler):
17+
"""
18+
Handles the Speech-To-Viseme generation using a Wav2Vec2 model for automatic
19+
speech recognition (ASR) and phoneme mapping to visemes.
20+
21+
Attributes:
22+
MIN_AUDIO_LENGTH (float): Minimum length of audio (in seconds) required
23+
for phoneme extraction.
24+
"""
25+
26+
MIN_AUDIO_LENGTH = 0.5 # Minimum audio length in seconds for phoneme extraction
27+
28+
def setup(
29+
self,
30+
should_listen: bool,
31+
model_name: str = "bookbot/wav2vec2-ljspeech-gruut",
32+
blocksize: int = 512,
33+
device: str = "cuda",
34+
skip: bool = False,
35+
gen_kwargs: Dict[str, Any] = {}, # Not used
36+
) -> None:
37+
"""
38+
Initializes the handler by loading the ASR model and phoneme-to-viseme map.
39+
40+
Args:
41+
should_listen (bool): Flag indicating whether the speech-to-speech pipeline should start
42+
listening to the user or not.
43+
model_name (str): Name of the ASR model to use.
44+
Defaults to "bookbot/wav2vec2-ljspeech-gruut".
45+
blocksize (int): Size of each audio block when processing audio.
46+
Defaults to 512.
47+
device (str): Device to run the model on ("cuda", "mps", or "cpu").
48+
Defaults to "cuda".
49+
skip (bool): If True, the speech-to-viseme process is skipped.
50+
Defaults to False.
51+
gen_kwargs (dict): Additional parameters for speech generation.
52+
53+
Returns:
54+
None
55+
"""
56+
self.device = device
57+
self.gen_kwargs = gen_kwargs
58+
self.blocksize = blocksize
59+
self.should_listen = should_listen
60+
self.skip = skip
61+
62+
# Load phoneme-to-viseme map from the JSON file
63+
# inspired by https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-ssml-phonetic-sets
64+
phoneme_viseme_map_file = "STV/phoneme_viseme_map.json"
65+
with open(phoneme_viseme_map_file, "r") as f:
66+
self.phoneme_viseme_map = json.load(f)
67+
68+
# Initialize the ASR pipeline using the specified model and device
69+
self.asr_pipeline = pipeline(
70+
"automatic-speech-recognition",
71+
model=model_name,
72+
device=device,
73+
torch_dtype="auto",
74+
)
75+
self.expected_sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
76+
77+
# Initialize an empty dictionary to store audio batch data
78+
self.audio_batch = {
79+
"waveform": np.array([]),
80+
"sampling_rate": self.expected_sampling_rate,
81+
}
82+
self.text_batch = None
83+
self.should_listen_flag = False
84+
85+
self.warmup() # Perform model warmup
86+
87+
def warmup(self) -> None:
88+
"""Warms up the model with dummy input to prepare it for inference.
89+
90+
Returns:
91+
None
92+
"""
93+
logger.info(f"Warming up {self.__class__.__name__}")
94+
start_time = time.time()
95+
96+
# Create dummy input for warmup inference
97+
dummy_input = np.random.randn(self.blocksize).astype(np.int16)
98+
_ = self.speech_to_visemes(dummy_input)
99+
100+
warmup_time = time.time() - start_time
101+
logger.info(
102+
f"{self.__class__.__name__}: warmed up in {warmup_time:.4f} seconds!"
103+
)
104+
105+
def speech_to_visemes(self, audio: Any) -> List[Dict[str, Any]]:
106+
"""
107+
Converts speech audio to visemes by performing Automatic Speech Recognition (ASR)
108+
and mapping phonemes to visemes.
109+
110+
Args:
111+
audio (Any): The input audio data.
112+
113+
Returns:
114+
List[Dict[str, Any]]: A list of dictionaries containing mapped visemes
115+
and their corresponding timestamps.
116+
117+
Note:
118+
Heuristically, the input audio should be at least 0.5 seconds long for proper phoneme extraction.
119+
"""
120+
121+
def _map_phonemes_to_visemes(
122+
data: Dict[str, Any],
123+
) -> List[Dict[str, Any]]:
124+
"""
125+
Maps extracted phonemes to their corresponding visemes based on a predefined map.
126+
127+
Args:
128+
data (Dict[str, Any]): Dictionary containing phoneme data where data['chunks']
129+
holds a list of phonemes and their timestamps.
130+
131+
Returns:
132+
List[Dict[str, Any]]: A list of dictionaries with viseme IDs and their corresponding timestamps.
133+
"""
134+
viseme_list = []
135+
chunks = data.get("chunks", [])
136+
137+
# Map each phoneme to corresponding visemes
138+
for chunk in chunks:
139+
phoneme = chunk.get("text", None)
140+
timestamp = chunk.get("timestamp", None)
141+
visemes = self.phoneme_viseme_map.get(phoneme, [])
142+
143+
for viseme in visemes:
144+
viseme_list.append({"viseme": viseme, "timestamp": timestamp})
145+
146+
return viseme_list
147+
148+
# Perform ASR to extract phoneme data, including timestamps
149+
try:
150+
asr_result = self.asr_pipeline(audio, return_timestamps="char")
151+
except Exception as e:
152+
logger.error(f"ASR error: {e}")
153+
return []
154+
# Map the phonemes obtained from ASR to visemes
155+
return _map_phonemes_to_visemes(asr_result)
156+
157+
def process(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
158+
"""
159+
Processes an audio file to generate visemes and output blocks of audio data
160+
along with corresponding viseme data.
161+
162+
Args:
163+
data (Dict[str, Any]): Dictionary containing audio, text, and potentially additional information.
164+
165+
Yields:
166+
Dict: A dictionary containing audio waveform, and optionally viseme data, text, and potentially additional information.
167+
"""
168+
169+
if "sentence_end" in data and data["sentence_end"]:
170+
self.should_listen_flag = True
171+
if self.skip: # Skip viseme extraction if the flag is set
172+
yield {
173+
"audio": {
174+
"waveform": data["audio"]["waveform"],
175+
"sampling_rate": data["audio"]["sampling_rate"],
176+
},
177+
"text": data["text"] if "text" in data else None,
178+
}
179+
else:
180+
# Check if text data is present and save it for later
181+
if "text" in data and data["text"] is not None:
182+
self.text_batch = data["text"]
183+
# Concatenate new audio data into the buffer if available and valid
184+
if "audio" in data and data["audio"] is not None:
185+
audio_data = data["audio"]
186+
# Check if the sampling rate is valid and matches the expected one
187+
if audio_data.get("sampling_rate", None) != self.expected_sampling_rate:
188+
logger.error(
189+
f"Expected sampling rate {self.expected_sampling_rate}, "
190+
f"but got {audio_data['sampling_rate']}."
191+
)
192+
return
193+
# Append the waveform to the audio buffer
194+
self.audio_batch["waveform"] = np.concatenate(
195+
(self.audio_batch["waveform"], audio_data["waveform"]), axis=0
196+
)
197+
198+
# Ensure the total audio length is sufficient for phoneme extraction
199+
if (
200+
len(self.audio_batch["waveform"]) / self.audio_batch["sampling_rate"]
201+
< self.MIN_AUDIO_LENGTH
202+
):
203+
return
204+
else:
205+
logger.debug("Starting viseme inference...")
206+
207+
# Perform viseme inference using the accumulated audio batch
208+
viseme_data = self.speech_to_visemes(self.audio_batch["waveform"])
209+
logger.debug("Viseme inference completed.")
210+
211+
# Print the visemes and timestamps to the console
212+
for viseme in viseme_data:
213+
console.print(
214+
f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}"
215+
)
216+
217+
# Process the audio in chunks of the defined blocksize
218+
self.audio_batch["waveform"] = self.audio_batch["waveform"].astype(
219+
np.int16
220+
)
221+
for i in range(0, len(self.audio_batch["waveform"]), self.blocksize):
222+
chunk_waveform = self.audio_batch["waveform"][
223+
i : i + self.blocksize
224+
]
225+
padded_waveform = np.pad(
226+
chunk_waveform, (0, self.blocksize - len(chunk_waveform))
227+
)
228+
229+
chunk_data = {
230+
"audio": {
231+
"waveform": padded_waveform,
232+
"sample_rate": self.audio_batch["sampling_rate"],
233+
}
234+
}
235+
236+
# Add text and viseme data only in the first chunk
237+
if i == 0:
238+
if self.text_batch:
239+
chunk_data["text"] = self.text_batch
240+
if viseme_data and len(viseme_data) > 0:
241+
chunk_data["visemes"] = viseme_data
242+
yield chunk_data
243+
244+
# Reset the audio and text buffer after processing
245+
self.audio_batch = {
246+
"waveform": np.array([]),
247+
"sampling_rate": self.expected_sampling_rate,
248+
}
249+
self.text_batch = ""
250+
251+
if self.should_listen_flag:
252+
self.should_listen.set()
253+
self.should_listen_flag = False

0 commit comments

Comments
 (0)