Skip to content

Commit 90c38d4

Browse files
authored
Merge pull request #1 from sensein/speech_to_visemes
adding speech_to_visemes
2 parents 86c0808 + ba4368b commit 90c38d4

8 files changed

+520
-33
lines changed

TTS/STV/speech_to_visemes.py

+341
Large diffs are not rendered by default.

TTS/chatTTS_handler.py

+57-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from rich.console import Console
77
import torch
8+
from .STV.speech_to_visemes import SpeechToVisemes
89

910
logging.basicConfig(
1011
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -22,6 +23,7 @@ def setup(
2223
gen_kwargs={}, # Unused
2324
stream=True,
2425
chunk_size=512,
26+
viseme_flag = True
2527
):
2628
self.should_listen = should_listen
2729
self.device = device
@@ -33,6 +35,9 @@ def setup(
3335
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
3436
spk_emb=rnd_spk_emb,
3537
)
38+
self.viseme_flag = viseme_flag
39+
if self.viseme_flag:
40+
self.speech_to_visemes = SpeechToVisemes()
3641
self.warmup()
3742

3843
def warmup(self):
@@ -61,22 +66,65 @@ def process(self, llm_sentence):
6166
if gen[0] is None or len(gen[0]) == 0:
6267
self.should_listen.set()
6368
return
69+
70+
# Resample the audio to 16000 Hz
6471
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
65-
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
66-
while len(audio_chunk) > self.chunk_size:
67-
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
68-
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
69-
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
72+
# Ensure the audio is converted to mono (single channel)
73+
if len(audio_chunk.shape) > 1:
74+
audio_chunk = librosa.to_mono(audio_chunk)
75+
audio_chunk = (audio_chunk * 32768).astype(np.int16)
76+
77+
# Process visemes if viseme_flag is set
78+
if self.viseme_flag:
79+
visemes = self.speech_to_visemes.process(audio_chunk)
80+
for viseme in visemes:
81+
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
82+
else:
83+
visemes = None
84+
85+
# Loop through audio chunks, yielding dict for each chunk
86+
for i in range(0, len(audio_chunk), self.chunk_size):
87+
chunk_data = {
88+
"audio": np.pad(
89+
audio_chunk[i : i + self.chunk_size],
90+
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
91+
)
92+
}
93+
# Include text and visemes for the first chunk
94+
if i == 0:
95+
chunk_data["text"] = llm_sentence # Assuming llm_sentence is defined elsewhere
96+
chunk_data["visemes"] = visemes
97+
98+
yield chunk_data
7099
else:
71100
wavs = wavs_gen
72101
if len(wavs[0]) == 0:
73102
self.should_listen.set()
74103
return
75104
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
105+
# Ensure the audio is converted to mono (single channel)
106+
if len(audio_chunk.shape) > 1:
107+
audio_chunk = librosa.to_mono(audio_chunk)
76108
audio_chunk = (audio_chunk * 32768).astype(np.int16)
109+
110+
if self.viseme_flag:
111+
visemes = self.speech_to_visemes.process(audio_chunk)
112+
for viseme in visemes:
113+
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
114+
else:
115+
visemes = None
116+
77117
for i in range(0, len(audio_chunk), self.chunk_size):
78-
yield np.pad(
79-
audio_chunk[i : i + self.chunk_size],
80-
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
81-
)
118+
chunk_data = {
119+
"audio": np.pad(
120+
audio_chunk[i : i + self.chunk_size],
121+
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
122+
)
123+
}
124+
# For the first chunk, include text and visemes
125+
if i == 0:
126+
chunk_data["text"] = llm_sentence
127+
chunk_data["visemes"] = visemes
128+
yield chunk_data
129+
82130
self.should_listen.set()

TTS/melo_handler.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from rich.console import Console
77
import torch
88

9+
from .STV.speech_to_visemes import SpeechToVisemes
10+
911
logger = logging.getLogger(__name__)
1012

1113
console = Console()
@@ -28,7 +30,6 @@
2830
"ko": "KR",
2931
}
3032

31-
3233
class MeloTTSHandler(BaseHandler):
3334
def setup(
3435
self,
@@ -38,6 +39,7 @@ def setup(
3839
speaker_to_id="en",
3940
gen_kwargs={}, # Unused
4041
blocksize=512,
42+
viseme_flag = True # To obtain timestamped visemes
4143
):
4244
self.should_listen = should_listen
4345
self.device = device
@@ -49,6 +51,11 @@ def setup(
4951
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
5052
]
5153
self.blocksize = blocksize
54+
55+
self.viseme_flag = viseme_flag
56+
if self.viseme_flag:
57+
self.speech_to_visemes = SpeechToVisemes()
58+
5259
self.warmup()
5360

5461
def warmup(self):
@@ -100,10 +107,25 @@ def process(self, llm_sentence):
100107
return
101108
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
102109
audio_chunk = (audio_chunk * 32768).astype(np.int16)
110+
111+
if self.viseme_flag:
112+
visemes = self.speech_to_visemes.process(audio_chunk)
113+
for viseme in visemes:
114+
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
115+
else:
116+
visemes = None
117+
103118
for i in range(0, len(audio_chunk), self.blocksize):
104-
yield np.pad(
105-
audio_chunk[i : i + self.blocksize],
106-
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
107-
)
119+
chunk_data = {
120+
"audio": np.pad(
121+
audio_chunk[i : i + self.blocksize],
122+
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
123+
)
124+
}
125+
# For the first chunk, include text and visemes
126+
if i == 0:
127+
chunk_data["text"] = llm_sentence
128+
chunk_data["visemes"] = visemes
129+
yield chunk_data
108130

109131
self.should_listen.set()

TTS/parler_handler.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from transformers.utils.import_utils import (
1515
is_flash_attn_2_available,
1616
)
17+
from .STV.speech_to_visemes import SpeechToVisemes
1718

1819
torch._inductor.config.fx_graph_cache = True
1920
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
@@ -47,6 +48,7 @@ def setup(
4748
),
4849
play_steps_s=1,
4950
blocksize=512,
51+
viseme_flag = True
5052
):
5153
self.should_listen = should_listen
5254
self.device = device
@@ -78,6 +80,10 @@ def setup(
7880
self.model.forward, mode=self.compile_mode, fullgraph=True
7981
)
8082

83+
self.viseme_flag = viseme_flag
84+
if self.viseme_flag:
85+
self.speech_to_visemes = SpeechToVisemes()
86+
8187
self.warmup()
8288

8389
def prepare_model_inputs(
@@ -182,10 +188,25 @@ def process(self, llm_sentence):
182188
)
183189
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
184190
audio_chunk = (audio_chunk * 32768).astype(np.int16)
191+
192+
if self.viseme_flag:
193+
visemes = self.speech_to_visemes.process(audio_chunk)
194+
for viseme in visemes:
195+
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
196+
else:
197+
visemes = None
198+
185199
for i in range(0, len(audio_chunk), self.blocksize):
186-
yield np.pad(
187-
audio_chunk[i : i + self.blocksize],
188-
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
189-
)
200+
chunk_data = {
201+
"audio": np.pad(
202+
audio_chunk[i : i + self.blocksize],
203+
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
204+
)
205+
}
206+
# For the first chunk, include text and visemes
207+
if i == 0:
208+
chunk_data["text"] = llm_sentence
209+
chunk_data["visemes"] = visemes
210+
yield chunk_data
190211

191212
self.should_listen.set()

arguments_classes/parler_tts_arguments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ParlerTTSHandlerArguments:
3636
tts_gen_max_new_tokens: int = field(
3737
default=512,
3838
metadata={
39-
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
39+
"help": "Maximum number of new tokens to generate in a single completion. Default is 512, which corresponds to ~6 secs"
4040
},
4141
)
4242
description: str = field(

connections/local_audio_streamer.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,18 @@ def callback(indata, outdata, frames, time, status):
2727
self.input_queue.put(indata.copy())
2828
outdata[:] = 0 * outdata
2929
else:
30-
outdata[:] = self.output_queue.get()[:, np.newaxis]
30+
data = self.output_queue.get()
31+
"""
32+
# Check if text data is present and log it
33+
if data.get('text') is not None:
34+
text = data['text']
35+
logger.info(f"Text: {text}")
36+
# Check if viseme data is present and log it
37+
if data.get('visemes') is not None:
38+
visemes = data['visemes']
39+
logger.info(f"Visemes: {visemes}")
40+
"""
41+
outdata[:] = data['audio'][:, np.newaxis]
3142

3243
logger.debug("Available devices:")
3344
logger.debug(sd.query_devices())

connections/socket_sender.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import socket
22
from rich.console import Console
33
import logging
4+
import pickle
5+
import struct
46

57
logger = logging.getLogger(__name__)
68

@@ -11,7 +13,6 @@ class SocketSender:
1113
"""
1214
Handles sending generated audio packets to the clients.
1315
"""
14-
1516
def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
1617
self.stop_event = stop_event
1718
self.queue_in = queue_in
@@ -28,9 +29,31 @@ def run(self):
2829
logger.info("sender connected")
2930

3031
while not self.stop_event.is_set():
31-
audio_chunk = self.queue_in.get()
32-
self.conn.sendall(audio_chunk)
33-
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
34-
break
32+
data = self.queue_in.get()
33+
packet = {}
34+
if 'audio' in data and data['audio'] is not None:
35+
audio_chunk = data['audio']
36+
packet['audio'] = data['audio']
37+
if 'text' in data and data['text'] is not None:
38+
packet['text'] = data['text']
39+
if 'visemes' in data and data['visemes'] is not None:
40+
packet['visemes'] = data['visemes']
41+
42+
# Serialize the packet using pickle
43+
serialized_packet = pickle.dumps(packet)
44+
45+
# Compute the length of the serialized packet
46+
packet_length = len(serialized_packet)
47+
48+
# Send the packet length as a 4-byte integer using struct
49+
self.conn.sendall(struct.pack('!I', packet_length))
50+
51+
# Send the serialized packet
52+
self.conn.sendall(serialized_packet)
53+
54+
if 'audio' in data and data['audio'] is not None:
55+
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
56+
break
57+
3558
self.conn.close()
3659
logger.info("Sender closed")

listen_and_play.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
from dataclasses import dataclass, field
55
import sounddevice as sd
66
from transformers import HfArgumentParser
7-
7+
import struct
8+
import pickle
89

910
@dataclass
1011
class ListenAndPlayArguments:
1112
send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
1213
recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
1314
list_play_chunk_size: int = field(
14-
default=1024,
15-
metadata={"help": "The size of data chunks (in bytes). Default is 1024."},
15+
default=512,
16+
metadata={"help": "The size of data chunks (in bytes). Default is 512."},
1617
)
1718
host: str = field(
1819
default="localhost",
@@ -33,7 +34,7 @@ class ListenAndPlayArguments:
3334
def listen_and_play(
3435
send_rate=16000,
3536
recv_rate=44100,
36-
list_play_chunk_size=1024,
37+
list_play_chunk_size=512,
3738
host="localhost",
3839
send_port=12345,
3940
recv_port=12346,
@@ -79,9 +80,29 @@ def receive_full_chunk(conn, chunk_size):
7980
return data
8081

8182
while not stop_event.is_set():
82-
data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
83-
if data:
84-
recv_queue.put(data)
83+
# Step 1: Receive the first 4 bytes to get the packet length
84+
length_data = receive_full_chunk(recv_socket, 4)
85+
if not length_data:
86+
continue # Handle disconnection or data not available
87+
88+
# Step 2: Unpack the length (4 bytes)
89+
packet_length = struct.unpack('!I', length_data)[0]
90+
91+
# Step 3: Receive the full packet based on the length
92+
serialized_packet = receive_full_chunk(recv_socket, packet_length)
93+
if serialized_packet:
94+
# Step 4: Deserialize the packet using pickle
95+
packet = pickle.loads(serialized_packet)
96+
# Step 5: Extract the packet contents
97+
if 'text' in packet:
98+
pass
99+
# print(packet['text'])
100+
if 'visemes' in packet:
101+
pass
102+
# print(packet['visemes'])
103+
104+
# Step 6: Put the packet audio data into the queue for sending
105+
recv_queue.put(packet['audio'].tobytes())
85106

86107
try:
87108
send_stream = sd.RawInputStream(
@@ -123,4 +144,4 @@ def receive_full_chunk(conn, chunk_size):
123144
if __name__ == "__main__":
124145
parser = HfArgumentParser((ListenAndPlayArguments,))
125146
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
126-
listen_and_play(**vars(listen_and_play_kwargs))
147+
listen_and_play(**vars(listen_and_play_kwargs))

0 commit comments

Comments
 (0)