Skip to content

Commit 1414ed4

Browse files
authored
Merge pull request #3 from sensein/original
Updating the speech-to-speech fork with visemes
2 parents 90c38d4 + 6ba4f97 commit 1414ed4

14 files changed

+495
-137
lines changed

Dockerfile.arm64

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
2+
3+
ENV PYTHONUNBUFFERED 1
4+
5+
WORKDIR /usr/src/app
6+
7+
# Install packages
8+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
9+
10+
COPY requirements.txt ./
11+
RUN pip install --no-cache-dir -r requirements.txt
12+
13+
COPY . .

LLM/mlx_language_model.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99

1010
console = Console()
1111

12+
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
13+
"en": "english",
14+
"fr": "french",
15+
"es": "spanish",
16+
"zh": "chinese",
17+
"ja": "japanese",
18+
"ko": "korean",
19+
}
1220

1321
class MLXLanguageModelHandler(BaseHandler):
1422
"""
@@ -44,7 +52,7 @@ def setup(
4452
def warmup(self):
4553
logger.info(f"Warming up {self.__class__.__name__}")
4654

47-
dummy_input_text = "Write me a poem about Machine Learning."
55+
dummy_input_text = "Repeat the word 'home'."
4856
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
4957

5058
n_steps = 2
@@ -61,6 +69,11 @@ def warmup(self):
6169

6270
def process(self, prompt):
6371
logger.debug("infering language model...")
72+
language_code = None
73+
74+
if isinstance(prompt, tuple):
75+
prompt, language_code = prompt
76+
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
6477

6578
self.chat.append({"role": self.user_role, "content": prompt})
6679

@@ -86,9 +99,9 @@ def process(self, prompt):
8699
output += t
87100
curr_output += t
88101
if curr_output.endswith((".", "?", "!", "<|end|>")):
89-
yield curr_output.replace("<|end|>", "")
102+
yield (curr_output.replace("<|end|>", ""), language_code)
90103
curr_output = ""
91104
generated_text = output.replace("<|end|>", "")
92105
torch.mps.empty_cache()
93106

94-
self.chat.append({"role": "assistant", "content": generated_text})
107+
self.chat.append({"role": "assistant", "content": generated_text})

LLM/openai_api_language_model.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from openai import OpenAI
2+
from LLM.chat import Chat
3+
from baseHandler import BaseHandler
4+
from rich.console import Console
5+
import logging
6+
import time
7+
logger = logging.getLogger(__name__)
8+
9+
console = Console()
10+
from nltk import sent_tokenize
11+
12+
class OpenApiModelHandler(BaseHandler):
13+
"""
14+
Handles the language model part.
15+
"""
16+
def setup(
17+
self,
18+
model_name="deepseek-chat",
19+
device="cuda",
20+
gen_kwargs={},
21+
base_url =None,
22+
api_key=None,
23+
stream=False,
24+
user_role="user",
25+
chat_size=1,
26+
init_chat_role="system",
27+
init_chat_prompt="You are a helpful AI assistant.",
28+
):
29+
self.model_name = model_name
30+
self.stream = stream
31+
self.chat = Chat(chat_size)
32+
if init_chat_role:
33+
if not init_chat_prompt:
34+
raise ValueError(
35+
"An initial promt needs to be specified when setting init_chat_role."
36+
)
37+
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
38+
self.user_role = user_role
39+
self.client = OpenAI(api_key=api_key, base_url=base_url)
40+
self.warmup()
41+
42+
def warmup(self):
43+
logger.info(f"Warming up {self.__class__.__name__}")
44+
start = time.time()
45+
response = self.client.chat.completions.create(
46+
model=self.model_name,
47+
messages=[
48+
{"role": "system", "content": "You are a helpful assistant"},
49+
{"role": "user", "content": "Hello"},
50+
],
51+
stream=self.stream
52+
)
53+
end = time.time()
54+
logger.info(
55+
f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
56+
)
57+
def process(self, prompt):
58+
logger.debug("call api language model...")
59+
self.chat.append({"role": self.user_role, "content": prompt})
60+
61+
language_code = None
62+
if isinstance(prompt, tuple):
63+
prompt, language_code = prompt
64+
65+
response = self.client.chat.completions.create(
66+
model=self.model_name,
67+
messages=[
68+
{"role": self.user_role, "content": prompt},
69+
],
70+
stream=self.stream
71+
)
72+
if self.stream:
73+
generated_text, printable_text = "", ""
74+
for chunk in response:
75+
new_text = chunk.choices[0].delta.content or ""
76+
generated_text += new_text
77+
printable_text += new_text
78+
sentences = sent_tokenize(printable_text)
79+
if len(sentences) > 1:
80+
yield sentences[0], language_code
81+
printable_text = new_text
82+
self.chat.append({"role": "assistant", "content": generated_text})
83+
# don't forget last sentence
84+
yield printable_text, language_code
85+
else:
86+
generated_text = response.choices[0].message.content
87+
self.chat.append({"role": "assistant", "content": generated_text})
88+
yield generated_text, language_code
89+

README.md

+70-18
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install
7979

8080
### Server/Client Approach
8181

82-
To run the pipeline on the server:
83-
```bash
84-
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
85-
```
82+
1. Run the pipeline on the server:
83+
```bash
84+
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
85+
```
8686

87-
Then run the client locally to handle sending microphone input and receiving generated audio:
88-
```bash
89-
python listen_and_play.py --host <IP address of your server>
90-
```
87+
2. Run the client locally to handle microphone input and receive generated audio:
88+
```bash
89+
python listen_and_play.py --host <IP address of your server>
90+
```
9191

92-
### Local approach (running on Mac)
93-
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
94-
```bash
95-
python s2s_pipeline.py --local_mac_optimal_settings
96-
```
92+
### Local Approach (Mac)
93+
94+
1. For optimal settings on Mac:
95+
```bash
96+
python s2s_pipeline.py --local_mac_optimal_settings
97+
```
9798

98-
You can also pass `--device mps` to have all the models set to device mps.
99-
The local mac optimal settings set the mode to be local as explained above and change the models to:
100-
- LightningWhisperMLX
101-
- MLX LM
102-
- MeloTTS
99+
This setting:
100+
- Adds `--device mps` to use MPS for all models.
101+
- Sets LightningWhisperMLX for STT
102+
- Sets MLX LM for language model
103+
- Sets MeloTTS for TTS
103104

104105
### Recommended usage with Cuda
105106

@@ -117,6 +118,57 @@ python s2s_pipeline.py \
117118

118119
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
119120

121+
122+
### Multi-language Support
123+
124+
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
125+
126+
#### With the server version:
127+
128+
129+
For automatic language detection:
130+
131+
```bash
132+
python s2s_pipeline.py \
133+
--stt_model_name large-v3 \
134+
--language zh \
135+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
136+
```
137+
138+
Or for one language in particular, chinese in this example
139+
140+
```bash
141+
python s2s_pipeline.py \
142+
--stt_model_name large-v3 \
143+
--language zh \
144+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
145+
```
146+
147+
#### Local Mac Setup
148+
149+
For automatic language detection:
150+
151+
```bash
152+
python s2s_pipeline.py \
153+
--local_mac_optimal_settings \
154+
--device mps \
155+
--stt_model_name large-v3 \
156+
--language zh \
157+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
158+
```
159+
160+
Or for one language in particular, chinese in this example
161+
162+
```bash
163+
python s2s_pipeline.py \
164+
--local_mac_optimal_settings \
165+
--device mps \
166+
--stt_model_name large-v3 \
167+
--language zh \
168+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
169+
```
170+
171+
120172
## Command-line Usage
121173

122174
### Model Parameters

STT/lightning_whisper_mlx_handler.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@
44
from lightning_whisper_mlx import LightningWhisperMLX
55
import numpy as np
66
from rich.console import Console
7+
from copy import copy
78
import torch
89

910
logger = logging.getLogger(__name__)
1011

1112
console = Console()
1213

14+
SUPPORTED_LANGUAGES = [
15+
"en",
16+
"fr",
17+
"es",
18+
"zh",
19+
"ja",
20+
"ko",
21+
]
22+
1323

1424
class LightningWhisperSTTHandler(BaseHandler):
1525
"""
@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler):
1929
def setup(
2030
self,
2131
model_name="distil-large-v3",
22-
device="cuda",
32+
device="mps",
2333
torch_dtype="float16",
2434
compile_mode=None,
2535
language=None,
@@ -29,6 +39,9 @@ def setup(
2939
model_name = model_name.split("/")[-1]
3040
self.device = device
3141
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
42+
self.start_language = language
43+
self.last_language = language
44+
3245
self.warmup()
3346

3447
def warmup(self):
@@ -47,10 +60,26 @@ def process(self, spoken_prompt):
4760
global pipeline_start
4861
pipeline_start = perf_counter()
4962

50-
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
63+
if self.start_language != 'auto':
64+
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
65+
else:
66+
transcription_dict = self.model.transcribe(spoken_prompt)
67+
language_code = transcription_dict["language"]
68+
if language_code not in SUPPORTED_LANGUAGES:
69+
logger.warning(f"Whisper detected unsupported language: {language_code}")
70+
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
71+
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
72+
else:
73+
transcription_dict = {"text": "", "language": "en"}
74+
else:
75+
self.last_language = language_code
76+
77+
pred_text = transcription_dict["text"].strip()
78+
language_code = transcription_dict["language"]
5179
torch.mps.empty_cache()
5280

5381
logger.debug("finished whisper inference")
5482
console.print(f"[yellow]USER: {pred_text}")
83+
logger.debug(f"Language Code Whisper: {language_code}")
5584

56-
yield pred_text
85+
yield (pred_text, language_code)

TTS/melo_handler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
console = Console()
1414

1515
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
16-
"en": "EN_NEWEST",
16+
"en": "EN",
1717
"fr": "FR",
1818
"es": "ES",
1919
"zh": "ZH",
@@ -22,7 +22,7 @@
2222
}
2323

2424
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
25-
"en": "EN-Newest",
25+
"en": "EN-BR",
2626
"fr": "FR",
2727
"es": "ES",
2828
"zh": "ZH",

TTS/parler_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def setup(
7070

7171
if self.compile_mode not in (None, "default"):
7272
logger.warning(
73-
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
73+
"Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
7474
)
7575
self.compile_mode = "default"
7676

VAD/vad_handler.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,7 @@ def process(self, audio_chunk):
8686
)
8787
array = enhanced.numpy().squeeze()
8888
yield array
89+
90+
@property
91+
def min_time_to_debug(self):
92+
return 0.00001

0 commit comments

Comments
 (0)