|
| 1 | +import os |
| 2 | +import pydantic |
| 3 | +import time |
1 | 4 | import typing |
2 | 5 |
|
| 6 | +from magic_hour.helpers.download import download_files_async, download_files_sync |
| 7 | +from magic_hour.helpers.logger import get_sdk_logger |
3 | 8 | from magic_hour.types import models |
4 | 9 | from make_api_request import ( |
5 | 10 | AsyncBaseClient, |
|
9 | 14 | ) |
10 | 15 |
|
11 | 16 |
|
| 17 | +logger = get_sdk_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class V1AudioProjectsGetResponseWithDownloads(models.V1AudioProjectsGetResponse): |
| 21 | + downloaded_paths: typing.Optional[typing.List[str]] = pydantic.Field( |
| 22 | + default=None, alias="downloaded_paths" |
| 23 | + ) |
| 24 | + """ |
| 25 | + The paths to the downloaded files. |
| 26 | +
|
| 27 | + This field is only populated if `download_outputs` is True and the audio project is complete. |
| 28 | + """ |
| 29 | + |
| 30 | + |
12 | 31 | class AudioProjectsClient: |
13 | 32 | def __init__(self, *, base_client: SyncBaseClient): |
14 | 33 | self._base_client = base_client |
15 | 34 |
|
| 35 | + def check_result( |
| 36 | + self, |
| 37 | + id: str, |
| 38 | + wait_for_completion: bool, |
| 39 | + download_outputs: bool, |
| 40 | + download_directory: typing.Optional[str] = None, |
| 41 | + ) -> V1AudioProjectsGetResponseWithDownloads: |
| 42 | + """ |
| 43 | + Check the result of an audio project with optional waiting and downloading. |
| 44 | +
|
| 45 | + This method retrieves the status of an audio project and optionally waits for completion |
| 46 | + and downloads the output files. |
| 47 | +
|
| 48 | + Args: |
| 49 | + id: Unique ID of the audio project |
| 50 | + wait_for_completion: Whether to wait for the audio project to complete |
| 51 | + download_outputs: Whether to download the outputs |
| 52 | + download_directory: The directory to download the outputs to. If not provided, |
| 53 | + the outputs will be downloaded to the current working directory |
| 54 | +
|
| 55 | + Returns: |
| 56 | + V1AudioProjectsGetResponseWithDownloads: The audio project response with optional |
| 57 | + downloaded file paths included |
| 58 | + """ |
| 59 | + api_response = self.get(id=id) |
| 60 | + if not wait_for_completion: |
| 61 | + response = V1AudioProjectsGetResponseWithDownloads( |
| 62 | + **api_response.model_dump() |
| 63 | + ) |
| 64 | + return response |
| 65 | + |
| 66 | + poll_interval = float(os.getenv("MAGIC_HOUR_POLL_INTERVAL", "0.5")) |
| 67 | + |
| 68 | + status = api_response.status |
| 69 | + |
| 70 | + while status not in ["complete", "error", "canceled"]: |
| 71 | + api_response = self.get(id=id) |
| 72 | + status = api_response.status |
| 73 | + time.sleep(poll_interval) |
| 74 | + |
| 75 | + if api_response.status != "complete": |
| 76 | + log = logger.error if api_response.status == "error" else logger.info |
| 77 | + log( |
| 78 | + f"Audio project {id} has status {api_response.status}: {api_response.error}" |
| 79 | + ) |
| 80 | + return V1AudioProjectsGetResponseWithDownloads(**api_response.model_dump()) |
| 81 | + |
| 82 | + if not download_outputs: |
| 83 | + return V1AudioProjectsGetResponseWithDownloads(**api_response.model_dump()) |
| 84 | + |
| 85 | + downloaded_paths = download_files_sync( |
| 86 | + downloads=api_response.downloads, |
| 87 | + download_directory=download_directory, |
| 88 | + ) |
| 89 | + |
| 90 | + return V1AudioProjectsGetResponseWithDownloads( |
| 91 | + **api_response.model_dump(), downloaded_paths=downloaded_paths |
| 92 | + ) |
| 93 | + |
16 | 94 | def delete( |
17 | 95 | self, *, id: str, request_options: typing.Optional[RequestOptions] = None |
18 | 96 | ) -> None: |
@@ -95,6 +173,65 @@ class AsyncAudioProjectsClient: |
95 | 173 | def __init__(self, *, base_client: AsyncBaseClient): |
96 | 174 | self._base_client = base_client |
97 | 175 |
|
| 176 | + async def check_result( |
| 177 | + self, |
| 178 | + id: str, |
| 179 | + wait_for_completion: bool, |
| 180 | + download_outputs: bool, |
| 181 | + download_directory: typing.Optional[str] = None, |
| 182 | + ) -> V1AudioProjectsGetResponseWithDownloads: |
| 183 | + """ |
| 184 | + Check the result of an audio project with optional waiting and downloading. |
| 185 | +
|
| 186 | + This method retrieves the status of an audio project and optionally waits for completion |
| 187 | + and downloads the output files. |
| 188 | +
|
| 189 | + Args: |
| 190 | + id: Unique ID of the audio project |
| 191 | + wait_for_completion: Whether to wait for the audio project to complete |
| 192 | + download_outputs: Whether to download the outputs |
| 193 | + download_directory: The directory to download the outputs to. If not provided, |
| 194 | + the outputs will be downloaded to the current working directory |
| 195 | +
|
| 196 | + Returns: |
| 197 | + V1AudioProjectsGetResponseWithDownloads: The audio project response with optional |
| 198 | + downloaded file paths included |
| 199 | + """ |
| 200 | + api_response = await self.get(id=id) |
| 201 | + if not wait_for_completion: |
| 202 | + response = V1AudioProjectsGetResponseWithDownloads( |
| 203 | + **api_response.model_dump() |
| 204 | + ) |
| 205 | + return response |
| 206 | + |
| 207 | + poll_interval = float(os.getenv("MAGIC_HOUR_POLL_INTERVAL", "0.5")) |
| 208 | + |
| 209 | + status = api_response.status |
| 210 | + |
| 211 | + while status not in ["complete", "error", "canceled"]: |
| 212 | + api_response = await self.get(id=id) |
| 213 | + status = api_response.status |
| 214 | + time.sleep(poll_interval) |
| 215 | + |
| 216 | + if api_response.status != "complete": |
| 217 | + log = logger.error if api_response.status == "error" else logger.info |
| 218 | + log( |
| 219 | + f"Audio project {id} has status {api_response.status}: {api_response.error}" |
| 220 | + ) |
| 221 | + return V1AudioProjectsGetResponseWithDownloads(**api_response.model_dump()) |
| 222 | + |
| 223 | + if not download_outputs: |
| 224 | + return V1AudioProjectsGetResponseWithDownloads(**api_response.model_dump()) |
| 225 | + |
| 226 | + downloaded_paths = await download_files_async( |
| 227 | + downloads=api_response.downloads, |
| 228 | + download_directory=download_directory, |
| 229 | + ) |
| 230 | + |
| 231 | + return V1AudioProjectsGetResponseWithDownloads( |
| 232 | + **api_response.model_dump(), downloaded_paths=downloaded_paths |
| 233 | + ) |
| 234 | + |
98 | 235 | async def delete( |
99 | 236 | self, *, id: str, request_options: typing.Optional[RequestOptions] = None |
100 | 237 | ) -> None: |
|
0 commit comments