diff --git a/src/fish_audio_sdk/apis.py b/src/fish_audio_sdk/apis.py index fad1098..dfaa180 100644 --- a/src/fish_audio_sdk/apis.py +++ b/src/fish_audio_sdk/apis.py @@ -8,6 +8,7 @@ ASRRequest, ASRResponse, ModelEntity, + Backends, PackageEntity, PaginatedResponse, TTSRequest, @@ -16,11 +17,11 @@ class Session(RemoteCall): @convert_stream - def tts(self, request: TTSRequest) -> GStream: + def tts(self, request: TTSRequest, backend: Backends = "speech-1.5") -> GStream: yield Request( method="POST", url="/v1/tts", - headers={"Content-Type": "application/msgpack"}, + headers={"Content-Type": "application/msgpack", "model": backend}, content=ormsgpack.packb(request.model_dump()), ) diff --git a/src/fish_audio_sdk/schemas.py b/src/fish_audio_sdk/schemas.py index 232223d..0bdbb75 100644 --- a/src/fish_audio_sdk/schemas.py +++ b/src/fish_audio_sdk/schemas.py @@ -4,6 +4,9 @@ from pydantic import BaseModel, Field + +Backends = Literal["speech-1.5", "agent-x0"] + Item = TypeVar("Item") diff --git a/src/fish_audio_sdk/websocket.py b/src/fish_audio_sdk/websocket.py index 2f444b7..41efe33 100644 --- a/src/fish_audio_sdk/websocket.py +++ b/src/fish_audio_sdk/websocket.py @@ -8,7 +8,7 @@ from .exceptions import WebSocketErr -from .schemas import CloseEvent, StartEvent, TTSRequest, TextEvent +from .schemas import Backends, CloseEvent, StartEvent, TTSRequest, TextEvent class WebSocketSession: @@ -37,9 +37,16 @@ def close(self): self._client.close() def tts( - self, request: TTSRequest, text_stream: Iterable[str] + self, + request: TTSRequest, + text_stream: Iterable[str], + backend: Backends = "speech-1.5", ) -> Generator[bytes, None, None]: - with connect_ws("/v1/tts/live", client=self._client) as ws: + with connect_ws( + "/v1/tts/live", + client=self._client, + headers={"model": backend}, + ) as ws: def sender(): ws.send_bytes( @@ -102,9 +109,16 @@ async def close(self): await self._client.aclose() async def tts( - self, request: TTSRequest, text_stream: AsyncIterable[str] + self, + request: TTSRequest, + text_stream: AsyncIterable[str], + backend: Backends = "speech-1.5", ) -> AsyncGenerator[bytes, None]: - async with aconnect_ws("/v1/tts/live", client=self._client) as ws: + async with aconnect_ws( + "/v1/tts/live", + client=self._client, + headers={"model": backend}, + ) as ws: async def sender(): await ws.send_bytes(