Skip to content

Commit 75b364b

Browse files
committed
dont let the thread die
1 parent 66533a2 commit 75b364b

File tree

4 files changed

+78
-30
lines changed

4 files changed

+78
-30
lines changed

arguments_classes/module_arguments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ModuleArguments:
1111
mode: Optional[str] = field(
1212
default="socket",
1313
metadata={
14-
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
14+
"help": "The mode to run the pipeline in. Either 'local', 'socket', or 'none'. Default is 'socket'."
1515
},
1616
)
1717
local_mac_optimal_settings: bool = field(

audio_streaming_client.py

+73-27
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,30 @@
22
from queue import Queue
33
import sounddevice as sd
44
import numpy as np
5-
import requests
6-
import base64
75
import time
86
from dataclasses import dataclass, field
97
import websocket
10-
import threading
118
import ssl
129

10+
1311
@dataclass
1412
class AudioStreamingClientArguments:
15-
sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
16-
chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 512."})
17-
api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
18-
auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})
13+
sample_rate: int = field(
14+
default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."}
15+
)
16+
chunk_size: int = field(
17+
default=512,
18+
metadata={"help": "The size of audio chunks in samples. Default is 512."},
19+
)
20+
api_url: str = field(
21+
default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud",
22+
metadata={"help": "The URL of the API endpoint."},
23+
)
24+
auth_token: str = field(
25+
default="your_auth_token",
26+
metadata={"help": "Authentication token for the API."},
27+
)
28+
1929

2030
class AudioStreamingClient:
2131
def __init__(self, args: AudioStreamingClientArguments):
@@ -27,9 +37,11 @@ def __init__(self, args: AudioStreamingClientArguments):
2737
self.headers = {
2838
"Accept": "application/json",
2939
"Authorization": f"Bearer {self.args.auth_token}",
30-
"Content-Type": "application/json"
40+
"Content-Type": "application/json",
3141
}
32-
self.session_state = "idle" # Possible states: idle, sending, processing, waiting
42+
self.session_state = (
43+
"idle" # Possible states: idle, sending, processing, waiting
44+
)
3345
self.ws_ready = threading.Event()
3446

3547
def start(self):
@@ -43,12 +55,14 @@ def start(self):
4355
on_open=self.on_open,
4456
on_message=self.on_message,
4557
on_error=self.on_error,
46-
on_close=self.on_close
58+
on_close=self.on_close,
4759
)
4860

49-
ws_thread = threading.Thread(target=self.ws.run_forever, kwargs={'sslopt': {"cert_reqs": ssl.CERT_NONE}})
50-
ws_thread.start()
51-
61+
self.ws_thread = threading.Thread(
62+
target=self.ws.run_forever, kwargs={"sslopt": {"cert_reqs": ssl.CERT_NONE}}
63+
)
64+
self.ws_thread.start()
65+
5266
# Wait for the WebSocket to be ready
5367
self.ws_ready.wait()
5468
self.start_audio_streaming()
@@ -57,17 +71,25 @@ def start_audio_streaming(self):
5771
self.send_thread = threading.Thread(target=self.send_audio)
5872
self.play_thread = threading.Thread(target=self.play_audio)
5973

60-
with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_input_callback, blocksize=self.args.chunk_size):
74+
with sd.InputStream(
75+
samplerate=self.args.sample_rate,
76+
channels=1,
77+
dtype="int16",
78+
callback=self.audio_input_callback,
79+
blocksize=self.args.chunk_size,
80+
):
6181
self.send_thread.start()
6282
self.play_thread.start()
83+
input("Press Enter to stop streaming...")
84+
self.on_shutdown()
6385

6486
def on_open(self, ws):
6587
print("WebSocket connection opened.")
6688
self.ws_ready.set() # Signal that the WebSocket is ready
6789

6890
def on_message(self, ws, message):
6991
# message is bytes
70-
if message == b'DONE':
92+
if message == b"DONE":
7193
print("listen")
7294
self.session_state = "listen"
7395
else:
@@ -97,7 +119,7 @@ def send_audio(self):
97119
if self.session_state != "processing":
98120
self.ws.send(chunk.tobytes(), opcode=websocket.ABNF.OPCODE_BINARY)
99121
else:
100-
self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
122+
self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
101123
time.sleep(0.01)
102124

103125
def audio_input_callback(self, indata, frames, time, status):
@@ -106,33 +128,57 @@ def audio_input_callback(self, indata, frames, time, status):
106128
def audio_out_callback(self, outdata, frames, time, status):
107129
if not self.recv_queue.empty():
108130
chunk = self.recv_queue.get()
109-
131+
110132
# Ensure chunk is int16 and clip to valid range
111133
chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)
112-
134+
113135
if len(chunk_int16) < len(outdata):
114-
outdata[:len(chunk_int16), 0] = chunk_int16
115-
outdata[len(chunk_int16):] = 0
136+
outdata[: len(chunk_int16), 0] = chunk_int16
137+
outdata[len(chunk_int16) :] = 0
116138
else:
117-
outdata[:, 0] = chunk_int16[:len(outdata)]
139+
outdata[:, 0] = chunk_int16[: len(outdata)]
118140
else:
119141
outdata[:] = 0
120142

121143
def play_audio(self):
122-
with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_out_callback, blocksize=self.args.chunk_size):
144+
with sd.OutputStream(
145+
samplerate=self.args.sample_rate,
146+
channels=1,
147+
dtype="int16",
148+
callback=self.audio_out_callback,
149+
blocksize=self.args.chunk_size,
150+
):
123151
while not self.stop_event.is_set():
124152
time.sleep(0.1)
125153

154+
126155
if __name__ == "__main__":
127156
import argparse
128157

129158
parser = argparse.ArgumentParser(description="Audio Streaming Client")
130-
parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.")
131-
parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.")
132-
parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.")
133-
parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.")
159+
parser.add_argument(
160+
"--sample_rate",
161+
type=int,
162+
default=16000,
163+
help="Audio sample rate in Hz. Default is 16000.",
164+
)
165+
parser.add_argument(
166+
"--chunk_size",
167+
type=int,
168+
default=1024,
169+
help="The size of audio chunks in samples. Default is 1024.",
170+
)
171+
parser.add_argument(
172+
"--api_url", type=str, required=True, help="The URL of the API endpoint."
173+
)
174+
parser.add_argument(
175+
"--auth_token",
176+
type=str,
177+
required=True,
178+
help="Authentication token for the API.",
179+
)
134180

135181
args = parser.parse_args()
136182
client_args = AudioStreamingClientArguments(**vars(args))
137183
client = AudioStreamingClient(client_args)
138-
client.start()
184+
client.start()

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ funasr>=1.1.6
88
faster-whisper>=1.0.3
99
modelscope>=1.17.1
1010
deepfilternet>=0.5.6
11-
openai>=1.40.1
11+
openai>=1.40.1
12+
websocket-client>=1.8.0

requirements_mac.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ funasr>=1.1.6
1010
faster-whisper>=1.0.3
1111
modelscope>=1.17.1
1212
deepfilternet>=0.5.6
13-
openai>=1.40.1
13+
openai>=1.40.1
14+
websocket-client>=1.8.0

0 commit comments

Comments
 (0)