-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathwebsocket.py
157 lines (134 loc) · 4.67 KB
/
websocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, AsyncIterable, Generator, Iterable
import httpx
import ormsgpack
from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws
from .exceptions import WebSocketErr
from .schemas import Backends, CloseEvent, StartEvent, TTSRequest, TextEvent
class WebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
max_workers: int = 10,
):
self._apikey = apikey
self._base_url = base_url
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def close(self):
self._client.close()
def tts(
self,
request: TTSRequest,
text_stream: Iterable[str],
backend: Backends = "speech-1.5",
) -> Generator[bytes, None, None]:
with connect_ws(
"/v1/tts/live",
client=self._client,
headers={"model": backend},
) as ws:
def sender():
ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request).model_dump(),
)
)
for text in text_stream:
ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text).model_dump(),
)
)
ws.send_bytes(
ormsgpack.packb(
CloseEvent().model_dump(),
)
)
sender_future = self._executor.submit(sender)
while True:
try:
message = ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr
sender_future.result()
class AsyncWebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
):
self._apikey = apikey
self._base_url = base_url
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
async def close(self):
await self._client.aclose()
async def tts(
self,
request: TTSRequest,
text_stream: AsyncIterable[str],
backend: Backends = "speech-1.5",
) -> AsyncGenerator[bytes, None]:
async with aconnect_ws(
"/v1/tts/live",
client=self._client,
headers={"model": backend},
) as ws:
async def sender():
await ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request).model_dump(),
)
)
async for text in text_stream:
await ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text).model_dump(),
)
)
await ws.send_bytes(
ormsgpack.packb(
CloseEvent().model_dump(),
)
)
sender_future = asyncio.get_running_loop().create_task(sender())
while True:
try:
message = await ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr
await sender_future