2
2
from queue import Queue
3
3
import sounddevice as sd
4
4
import numpy as np
5
- import requests
6
- import base64
7
5
import time
8
6
from dataclasses import dataclass , field
9
7
import websocket
10
- import threading
11
8
import ssl
12
9
10
+
13
11
@dataclass
14
12
class AudioStreamingClientArguments :
15
- sample_rate : int = field (default = 16000 , metadata = {"help" : "Audio sample rate in Hz. Default is 16000." })
16
- chunk_size : int = field (default = 512 , metadata = {"help" : "The size of audio chunks in samples. Default is 512." })
17
- api_url : str = field (default = "https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud" , metadata = {"help" : "The URL of the API endpoint." })
18
- auth_token : str = field (default = "your_auth_token" , metadata = {"help" : "Authentication token for the API." })
13
+ sample_rate : int = field (
14
+ default = 16000 , metadata = {"help" : "Audio sample rate in Hz. Default is 16000." }
15
+ )
16
+ chunk_size : int = field (
17
+ default = 512 ,
18
+ metadata = {"help" : "The size of audio chunks in samples. Default is 512." },
19
+ )
20
+ api_url : str = field (
21
+ default = "https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud" ,
22
+ metadata = {"help" : "The URL of the API endpoint." },
23
+ )
24
+ auth_token : str = field (
25
+ default = "your_auth_token" ,
26
+ metadata = {"help" : "Authentication token for the API." },
27
+ )
28
+
19
29
20
30
class AudioStreamingClient :
21
31
def __init__ (self , args : AudioStreamingClientArguments ):
@@ -27,9 +37,11 @@ def __init__(self, args: AudioStreamingClientArguments):
27
37
self .headers = {
28
38
"Accept" : "application/json" ,
29
39
"Authorization" : f"Bearer { self .args .auth_token } " ,
30
- "Content-Type" : "application/json"
40
+ "Content-Type" : "application/json" ,
31
41
}
32
- self .session_state = "idle" # Possible states: idle, sending, processing, waiting
42
+ self .session_state = (
43
+ "idle" # Possible states: idle, sending, processing, waiting
44
+ )
33
45
self .ws_ready = threading .Event ()
34
46
35
47
def start (self ):
@@ -43,12 +55,14 @@ def start(self):
43
55
on_open = self .on_open ,
44
56
on_message = self .on_message ,
45
57
on_error = self .on_error ,
46
- on_close = self .on_close
58
+ on_close = self .on_close ,
47
59
)
48
60
49
- ws_thread = threading .Thread (target = self .ws .run_forever , kwargs = {'sslopt' : {"cert_reqs" : ssl .CERT_NONE }})
50
- ws_thread .start ()
51
-
61
+ self .ws_thread = threading .Thread (
62
+ target = self .ws .run_forever , kwargs = {"sslopt" : {"cert_reqs" : ssl .CERT_NONE }}
63
+ )
64
+ self .ws_thread .start ()
65
+
52
66
# Wait for the WebSocket to be ready
53
67
self .ws_ready .wait ()
54
68
self .start_audio_streaming ()
@@ -57,17 +71,25 @@ def start_audio_streaming(self):
57
71
self .send_thread = threading .Thread (target = self .send_audio )
58
72
self .play_thread = threading .Thread (target = self .play_audio )
59
73
60
- with sd .InputStream (samplerate = self .args .sample_rate , channels = 1 , dtype = 'int16' , callback = self .audio_input_callback , blocksize = self .args .chunk_size ):
74
+ with sd .InputStream (
75
+ samplerate = self .args .sample_rate ,
76
+ channels = 1 ,
77
+ dtype = "int16" ,
78
+ callback = self .audio_input_callback ,
79
+ blocksize = self .args .chunk_size ,
80
+ ):
61
81
self .send_thread .start ()
62
82
self .play_thread .start ()
83
+ input ("Press Enter to stop streaming..." )
84
+ self .on_shutdown ()
63
85
64
86
def on_open (self , ws ):
65
87
print ("WebSocket connection opened." )
66
88
self .ws_ready .set () # Signal that the WebSocket is ready
67
89
68
90
def on_message (self , ws , message ):
69
91
# message is bytes
70
- if message == b' DONE' :
92
+ if message == b" DONE" :
71
93
print ("listen" )
72
94
self .session_state = "listen"
73
95
else :
@@ -97,7 +119,7 @@ def send_audio(self):
97
119
if self .session_state != "processing" :
98
120
self .ws .send (chunk .tobytes (), opcode = websocket .ABNF .OPCODE_BINARY )
99
121
else :
100
- self .ws .send ([], opcode = websocket .ABNF .OPCODE_BINARY ) # handshake
122
+ self .ws .send ([], opcode = websocket .ABNF .OPCODE_BINARY ) # handshake
101
123
time .sleep (0.01 )
102
124
103
125
def audio_input_callback (self , indata , frames , time , status ):
@@ -106,33 +128,57 @@ def audio_input_callback(self, indata, frames, time, status):
106
128
def audio_out_callback (self , outdata , frames , time , status ):
107
129
if not self .recv_queue .empty ():
108
130
chunk = self .recv_queue .get ()
109
-
131
+
110
132
# Ensure chunk is int16 and clip to valid range
111
133
chunk_int16 = np .clip (chunk , - 32768 , 32767 ).astype (np .int16 )
112
-
134
+
113
135
if len (chunk_int16 ) < len (outdata ):
114
- outdata [:len (chunk_int16 ), 0 ] = chunk_int16
115
- outdata [len (chunk_int16 ):] = 0
136
+ outdata [: len (chunk_int16 ), 0 ] = chunk_int16
137
+ outdata [len (chunk_int16 ) :] = 0
116
138
else :
117
- outdata [:, 0 ] = chunk_int16 [:len (outdata )]
139
+ outdata [:, 0 ] = chunk_int16 [: len (outdata )]
118
140
else :
119
141
outdata [:] = 0
120
142
121
143
def play_audio (self ):
122
- with sd .OutputStream (samplerate = self .args .sample_rate , channels = 1 , dtype = 'int16' , callback = self .audio_out_callback , blocksize = self .args .chunk_size ):
144
+ with sd .OutputStream (
145
+ samplerate = self .args .sample_rate ,
146
+ channels = 1 ,
147
+ dtype = "int16" ,
148
+ callback = self .audio_out_callback ,
149
+ blocksize = self .args .chunk_size ,
150
+ ):
123
151
while not self .stop_event .is_set ():
124
152
time .sleep (0.1 )
125
153
154
+
126
155
if __name__ == "__main__" :
127
156
import argparse
128
157
129
158
parser = argparse .ArgumentParser (description = "Audio Streaming Client" )
130
- parser .add_argument ("--sample_rate" , type = int , default = 16000 , help = "Audio sample rate in Hz. Default is 16000." )
131
- parser .add_argument ("--chunk_size" , type = int , default = 1024 , help = "The size of audio chunks in samples. Default is 1024." )
132
- parser .add_argument ("--api_url" , type = str , required = True , help = "The URL of the API endpoint." )
133
- parser .add_argument ("--auth_token" , type = str , required = True , help = "Authentication token for the API." )
159
+ parser .add_argument (
160
+ "--sample_rate" ,
161
+ type = int ,
162
+ default = 16000 ,
163
+ help = "Audio sample rate in Hz. Default is 16000." ,
164
+ )
165
+ parser .add_argument (
166
+ "--chunk_size" ,
167
+ type = int ,
168
+ default = 1024 ,
169
+ help = "The size of audio chunks in samples. Default is 1024." ,
170
+ )
171
+ parser .add_argument (
172
+ "--api_url" , type = str , required = True , help = "The URL of the API endpoint."
173
+ )
174
+ parser .add_argument (
175
+ "--auth_token" ,
176
+ type = str ,
177
+ required = True ,
178
+ help = "Authentication token for the API." ,
179
+ )
134
180
135
181
args = parser .parse_args ()
136
182
client_args = AudioStreamingClientArguments (** vars (args ))
137
183
client = AudioStreamingClient (client_args )
138
- client .start ()
184
+ client .start ()
0 commit comments