2
2
from queue import Queue
3
3
import sounddevice as sd
4
4
import numpy as np
5
- import requests
6
- import base64
7
5
import time
8
6
from dataclasses import dataclass , field
9
7
import websocket
10
- import threading
11
8
import ssl
12
9
10
+
13
11
@dataclass
14
12
class AudioStreamingClientArguments :
15
- sample_rate : int = field (default = 16000 , metadata = {"help" : "Audio sample rate in Hz. Default is 16000." })
16
- chunk_size : int = field (default = 512 , metadata = {"help" : "The size of audio chunks in samples. Default is 512." })
17
- api_url : str = field (default = "https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud" , metadata = {"help" : "The URL of the API endpoint." })
18
- auth_token : str = field (default = "your_auth_token" , metadata = {"help" : "Authentication token for the API." })
13
+ sample_rate : int = field (
14
+ default = 16000 , metadata = {"help" : "Audio sample rate in Hz. Default is 16000." }
15
+ )
16
+ chunk_size : int = field (
17
+ default = 512 ,
18
+ metadata = {"help" : "The size of audio chunks in samples. Default is 512." },
19
+ )
20
+ api_url : str = field (
21
+ default = "https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud" ,
22
+ metadata = {"help" : "The URL of the API endpoint." },
23
+ )
24
+ auth_token : str = field (
25
+ default = "your_auth_token" ,
26
+ metadata = {"help" : "Authentication token for the API." },
27
+ )
28
+
19
29
20
30
class AudioStreamingClient :
21
31
def __init__ (self , args : AudioStreamingClientArguments ):
@@ -27,9 +37,11 @@ def __init__(self, args: AudioStreamingClientArguments):
27
37
self .headers = {
28
38
"Accept" : "application/json" ,
29
39
"Authorization" : f"Bearer { self .args .auth_token } " ,
30
- "Content-Type" : "application/json"
40
+ "Content-Type" : "application/json" ,
31
41
}
32
- self .session_state = "idle" # Possible states: idle, sending, processing, waiting
42
+ self .session_state = (
43
+ "idle" # Possible states: idle, sending, processing, waiting
44
+ )
33
45
self .ws_ready = threading .Event ()
34
46
35
47
def start (self ):
@@ -43,12 +55,14 @@ def start(self):
43
55
on_open = self .on_open ,
44
56
on_message = self .on_message ,
45
57
on_error = self .on_error ,
46
- on_close = self .on_close
58
+ on_close = self .on_close ,
47
59
)
48
60
49
- self .ws_thread = threading .Thread (target = self .ws .run_forever , kwargs = {'sslopt' : {"cert_reqs" : ssl .CERT_NONE }})
61
+ self .ws_thread = threading .Thread (
62
+ target = self .ws .run_forever , kwargs = {"sslopt" : {"cert_reqs" : ssl .CERT_NONE }}
63
+ )
50
64
self .ws_thread .start ()
51
-
65
+
52
66
# Wait for the WebSocket to be ready
53
67
self .ws_ready .wait ()
54
68
self .start_audio_streaming ()
@@ -57,7 +71,13 @@ def start_audio_streaming(self):
57
71
self .send_thread = threading .Thread (target = self .send_audio )
58
72
self .play_thread = threading .Thread (target = self .play_audio )
59
73
60
- with sd .InputStream (samplerate = self .args .sample_rate , channels = 1 , dtype = 'int16' , callback = self .audio_input_callback , blocksize = self .args .chunk_size ):
74
+ with sd .InputStream (
75
+ samplerate = self .args .sample_rate ,
76
+ channels = 1 ,
77
+ dtype = "int16" ,
78
+ callback = self .audio_input_callback ,
79
+ blocksize = self .args .chunk_size ,
80
+ ):
61
81
self .send_thread .start ()
62
82
self .play_thread .start ()
63
83
input ("Press Enter to stop streaming..." )
@@ -69,7 +89,7 @@ def on_open(self, ws):
69
89
70
90
def on_message (self , ws , message ):
71
91
# message is bytes
72
- if message == b' DONE' :
92
+ if message == b" DONE" :
73
93
print ("listen" )
74
94
self .session_state = "listen"
75
95
else :
@@ -99,7 +119,7 @@ def send_audio(self):
99
119
if self .session_state != "processing" :
100
120
self .ws .send (chunk .tobytes (), opcode = websocket .ABNF .OPCODE_BINARY )
101
121
else :
102
- self .ws .send ([], opcode = websocket .ABNF .OPCODE_BINARY ) # handshake
122
+ self .ws .send ([], opcode = websocket .ABNF .OPCODE_BINARY ) # handshake
103
123
time .sleep (0.01 )
104
124
105
125
def audio_input_callback (self , indata , frames , time , status ):
@@ -108,33 +128,57 @@ def audio_input_callback(self, indata, frames, time, status):
108
128
def audio_out_callback (self , outdata , frames , time , status ):
109
129
if not self .recv_queue .empty ():
110
130
chunk = self .recv_queue .get ()
111
-
131
+
112
132
# Ensure chunk is int16 and clip to valid range
113
133
chunk_int16 = np .clip (chunk , - 32768 , 32767 ).astype (np .int16 )
114
-
134
+
115
135
if len (chunk_int16 ) < len (outdata ):
116
- outdata [:len (chunk_int16 ), 0 ] = chunk_int16
117
- outdata [len (chunk_int16 ):] = 0
136
+ outdata [: len (chunk_int16 ), 0 ] = chunk_int16
137
+ outdata [len (chunk_int16 ) :] = 0
118
138
else :
119
- outdata [:, 0 ] = chunk_int16 [:len (outdata )]
139
+ outdata [:, 0 ] = chunk_int16 [: len (outdata )]
120
140
else :
121
141
outdata [:] = 0
122
142
123
143
def play_audio (self ):
124
- with sd .OutputStream (samplerate = self .args .sample_rate , channels = 1 , dtype = 'int16' , callback = self .audio_out_callback , blocksize = self .args .chunk_size ):
144
+ with sd .OutputStream (
145
+ samplerate = self .args .sample_rate ,
146
+ channels = 1 ,
147
+ dtype = "int16" ,
148
+ callback = self .audio_out_callback ,
149
+ blocksize = self .args .chunk_size ,
150
+ ):
125
151
while not self .stop_event .is_set ():
126
152
time .sleep (0.1 )
127
153
154
+
128
155
if __name__ == "__main__" :
129
156
import argparse
130
157
131
158
parser = argparse .ArgumentParser (description = "Audio Streaming Client" )
132
- parser .add_argument ("--sample_rate" , type = int , default = 16000 , help = "Audio sample rate in Hz. Default is 16000." )
133
- parser .add_argument ("--chunk_size" , type = int , default = 1024 , help = "The size of audio chunks in samples. Default is 1024." )
134
- parser .add_argument ("--api_url" , type = str , required = True , help = "The URL of the API endpoint." )
135
- parser .add_argument ("--auth_token" , type = str , required = True , help = "Authentication token for the API." )
159
+ parser .add_argument (
160
+ "--sample_rate" ,
161
+ type = int ,
162
+ default = 16000 ,
163
+ help = "Audio sample rate in Hz. Default is 16000." ,
164
+ )
165
+ parser .add_argument (
166
+ "--chunk_size" ,
167
+ type = int ,
168
+ default = 1024 ,
169
+ help = "The size of audio chunks in samples. Default is 1024." ,
170
+ )
171
+ parser .add_argument (
172
+ "--api_url" , type = str , required = True , help = "The URL of the API endpoint."
173
+ )
174
+ parser .add_argument (
175
+ "--auth_token" ,
176
+ type = str ,
177
+ required = True ,
178
+ help = "Authentication token for the API." ,
179
+ )
136
180
137
181
args = parser .parse_args ()
138
182
client_args = AudioStreamingClientArguments (** vars (args ))
139
183
client = AudioStreamingClient (client_args )
140
- client .start ()
184
+ client .start ()
0 commit comments