Skip to content

Commit ae051bd

Browse files
authored
Merge pull request #12 from SocAIty/dev
Dev
2 parents dd7d25a + 7a55fb4 commit ae051bd

16 files changed

Lines changed: 429 additions & 192 deletions

apipod/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from apipod.api import APIPod
2+
from apipod.engine.jobs.base_job import BaseJob, LocalJob
23
from apipod.engine.jobs.job_progress import JobProgress
3-
from apipod.engine.jobs.job_result import FileModel, JobResult
4+
from apipod.engine.jobs.job_result import FileModel, JobLinks, JobMetrics, JobResult
45
from media_toolkit import MediaFile, ImageFile, AudioFile, VideoFile, MediaList, MediaDict
56
from apipod.common import constants
67

@@ -15,4 +16,20 @@
1516
except Exception:
1617
__version__ = "0.0.0"
1718

18-
__all__ = ["APIPod", "JobProgress", "FileModel", "JobResult", "MediaFile", "ImageFile", "AudioFile", "VideoFile", "MediaList", "MediaDict", "constants"]
19+
__all__ = [
20+
"APIPod",
21+
"BaseJob",
22+
"LocalJob",
23+
"JobProgress",
24+
"FileModel",
25+
"JobLinks",
26+
"JobMetrics",
27+
"JobResult",
28+
"MediaFile",
29+
"ImageFile",
30+
"AudioFile",
31+
"VideoFile",
32+
"MediaList",
33+
"MediaDict",
34+
"constants",
35+
]

apipod/api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def APIPod(
2323
|------------- |----------- |---------- |--------------------------- |
2424
| socaity | dedicated | auto | FastAPI |
2525
| socaity | dedicated | localhost | FastAPI + job queue (test) |
26+
| socaity | dedicated | socaity | FastAPI + redis (prod) |
2627
| socaity | dedicated | runpod | Celery (planned) |
2728
| socaity | dedicated | scaleway | Celery (planned) |
2829
| socaity | dedicated | azure | Celery (planned) |
@@ -40,15 +41,20 @@ def APIPod(
4041
Args:
4142
orchestrator: "socaity" or "local" (default from env / local).
4243
compute: "dedicated" or "serverless" (default from env / dedicated).
43-
provider: "auto", "localhost", "runpod", "scaleway", "azure" (default from env / localhost).
44+
provider: "auto", "localhost", "socaity", "runpod", "scaleway", "azure" (default from env / localhost).
4445
"""
4546
orchestrator = _resolve_enum(orchestrator, constants.ORCHESTRATOR, APIPOD_ORCHESTRATOR, constants.ORCHESTRATOR.LOCAL)
4647
compute = _resolve_enum(compute, constants.COMPUTE, APIPOD_COMPUTE, constants.COMPUTE.DEDICATED)
4748
provider = _resolve_enum(provider, constants.PROVIDER, APIPOD_PROVIDER, constants.PROVIDER.LOCALHOST)
4849

4950
backend_class, use_job_queue = _resolve_backend(orchestrator, compute, provider)
5051

51-
job_queue = _create_job_queue() if use_job_queue else None
52+
custom_job_queue = kwargs.pop("job_queue", None)
53+
if custom_job_queue:
54+
use_job_queue = True
55+
job_queue = custom_job_queue
56+
else:
57+
job_queue = _create_job_queue() if use_job_queue else None
5258

5359
if backend_class == SocaityFastAPIRouter:
5460
return backend_class(job_queue=job_queue, *args, **kwargs)
@@ -101,6 +107,9 @@ def _raise_if_unsupported(compute: constants.COMPUTE, provider: constants.PROVID
101107

102108
def _resolve_socaity(compute: constants.COMPUTE, provider: constants.PROVIDER) -> tuple:
103109
if compute == constants.COMPUTE.DEDICATED:
110+
if provider == constants.PROVIDER.SOCAITY:
111+
return SocaityFastAPIRouter, True
112+
104113
if provider in (constants.PROVIDER.RUNPOD, constants.PROVIDER.SCALEWAY, constants.PROVIDER.AZURE):
105114
raise NotImplementedError(
106115
f"Celery backend for socaity + dedicated + {provider.value} is planned but not yet available."
@@ -135,4 +144,5 @@ def _resolve_local(compute: constants.COMPUTE, provider: constants.PROVIDER) ->
135144

136145
def _create_job_queue() -> JobQueueInterface:
137146
from apipod.engine.queue.job_queue import JobQueue
147+
138148
return JobQueue()

apipod/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class COMPUTE(Enum):
1414
class PROVIDER(Enum):
1515
AUTO = "auto"
1616
LOCALHOST = "localhost"
17+
SOCAITY = "socaity"
1718
RUNPOD = "runpod"
1819
SCALEWAY = "scaleway"
1920
AZURE = "azure"

apipod/engine/backend/fastapi/file_handling_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from apipod.engine.signatures.upload import is_param_media_toolkit_file
66
from apipod.engine.jobs.job_result import FileModel, ImageFileModel, AudioFileModel, VideoFileModel
77
from apipod.engine.signatures.policies import FastAPISignaturePolicies
8-
from apipod.engine.files.base_mixin import _BaseFileHandlingMixin
8+
from apipod.engine.files.base_file_mixin import _BaseFileHandlingMixin
99
from apipod.engine.utils import replace_func_signature
1010
from media_toolkit import MediaList, MediaDict, ImageFile, AudioFile, VideoFile, MediaFile
1111
import functools

apipod/engine/backend/fastapi/llm_mixin.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from fastapi.responses import StreamingResponse
22

3-
from apipod.engine.llm.base_mixin import _BaseLLMMixin
4-
from apipod.common.settings import SERVER_DOMAIN
3+
from apipod.engine.llm.base_llm_mixin import _BaseLLMMixin
54

65

76
class _FastApiLlmMixin(_BaseLLMMixin):
@@ -24,7 +23,6 @@ async def handle_llm_request(self, func, openai_req, should_use_queue, res_model
2423
job_params={"payload": openai_req.dict()}
2524
)
2625
ret_job = JobResultFactory.from_base_job(job)
27-
ret_job.refresh_job_url = f"{SERVER_DOMAIN}/status?job_id={ret_job.id}"
2826
return ret_job
2927

3028
raw_res = await self._execute_func(func, payload=openai_req, **kwargs)

apipod/engine/backend/fastapi/router.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import logging
55
from contextlib import asynccontextmanager
66
from typing import Union, Callable, get_type_hints, Generator, AsyncGenerator, Iterator, AsyncIterator
7-
from fastapi import APIRouter, FastAPI, Response
8-
from fastapi.responses import JSONResponse
7+
from fastapi import APIRouter, FastAPI, Request, Response, status
8+
from fastapi.exceptions import HTTPException
9+
from fastapi.responses import JSONResponse, StreamingResponse
910

10-
from apipod.common.settings import APIPOD_PORT, APIPOD_HOST, SERVER_DOMAIN
11+
from apipod.common.settings import APIPOD_PORT, APIPOD_HOST
1112
from apipod.common.constants import SERVER_HEALTH
1213
from apipod.engine.jobs.job_result import JobResultFactory, JobResult
1314
from apipod.engine.endpoint_config import FastApiEndpointConfigurator, EndpointExecutionPlan
@@ -50,10 +51,14 @@ def __init__(
5051
job_queue: Optional custom JobQueue implementation
5152
lifespan: Optional async context manager for custom startup/shutdown logic
5253
args: Additional arguments
53-
kwargs: Additional keyword arguments
54+
kwargs: May include ``stream_store`` (SSE backend for GET /stream/{job_id}) and
55+
``gateway_stream_url_prefix`` for absolute stream URLs in JobResult, plus
56+
additional keyword arguments for parent classes.
5457
"""
5558
# Extract user-provided lifespan (explicit param or kwarg) before parent init
5659
user_lifespan = lifespan or kwargs.pop('lifespan', None)
60+
stream_store = kwargs.pop("stream_store", None)
61+
gateway_stream_url_prefix = kwargs.pop("gateway_stream_url_prefix", "")
5762

5863
# Initialize parent classes
5964
api_router_params = inspect.signature(APIRouter.__init__).parameters
@@ -92,6 +97,8 @@ def __init__(
9297

9398
self.app: FastAPI = app
9499
self.prefix = prefix
100+
self.stream_store = stream_store
101+
self.gateway_stream_url_prefix = gateway_stream_url_prefix
95102
self.add_standard_routes()
96103

97104
self._endpoint_configurator = FastApiEndpointConfigurator(self)
@@ -173,7 +180,11 @@ def _stop_background_worker(self):
173180
def add_standard_routes(self):
174181
"""Add standard API routes for status and health checks."""
175182
if self.job_queue is not None:
176-
self.api_route(path="/status", methods=["POST"])(self.get_job)
183+
self.api_route(path="/status/{job_id}", methods=["GET"], response_model_exclude_none=True)(self.get_job)
184+
self.api_route(path="/status", methods=["POST"], response_model_exclude_none=True)(self.get_job)
185+
self.api_route(path="/cancel/{job_id}", methods=["POST"])(self.post_cancel_job)
186+
if self.stream_store is not None:
187+
self.api_route(path="/stream/{job_id}", methods=["GET"])(self.stream_job_sse)
177188
self.api_route(path="/health", methods=["GET"])(self.get_health)
178189

179190
def get_health(self) -> Response:
@@ -216,19 +227,76 @@ def get_job(self, job_id: str, return_format: str = 'json') -> JobResult:
216227
if self.job_queue is None:
217228
return JobResultFactory.job_not_found(job_id)
218229

219-
base_job = self.job_queue.get_job(job_id)
220-
if base_job is None:
230+
ret_job = self.job_queue.get_job_result(job_id)
231+
if ret_job is None:
221232
return JobResultFactory.job_not_found(job_id)
222233

223-
ret_job = JobResultFactory.from_base_job(base_job)
224-
ret_job.refresh_job_url = f"{SERVER_DOMAIN}/status?job_id={ret_job.id}"
225-
ret_job.cancel_job_url = f"{SERVER_DOMAIN}/cancel?job_id={ret_job.id}"
226-
227234
if return_format != 'json':
228235
ret_job = JobResultFactory.gzip_job_result(ret_job)
229236

230237
return ret_job
231238

239+
def post_cancel_job(self, job_id: str) -> dict:
240+
"""Cancel a background job (gateway / orchestrator integration)."""
241+
job_id = job_id.strip().strip('"').strip("'").strip("?").strip("#")
242+
if self.job_queue is None:
243+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Job queue not configured.")
244+
245+
cancel_fn = getattr(self.job_queue, "cancel_gateway_job", None)
246+
if callable(cancel_fn):
247+
return cancel_fn(job_id)
248+
249+
try:
250+
self.job_queue.cancel_job(job_id)
251+
except NotImplementedError:
252+
raise HTTPException(
253+
status_code=status.HTTP_501_NOT_IMPLEMENTED,
254+
detail="Cancellation is not supported for this job queue.",
255+
) from None
256+
return {"id": job_id, "status": "cancelled", "message": "Job cancelled."}
257+
258+
async def stream_job_sse(self, job_id: str, request: Request):
259+
"""Server-Sent Events for streaming job output (requires stream_store)."""
260+
job_id = job_id.strip().strip('"').strip("'").strip("?").strip("#")
261+
if self.stream_store is None:
262+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Streaming not configured.")
263+
if self.job_queue is None:
264+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Job queue not configured.")
265+
266+
jq = self.job_queue
267+
job_data = jq.get_job_status(job_id) if hasattr(jq, "get_job_status") else None
268+
269+
if job_data is None:
270+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job '{job_id}' not found.")
271+
272+
st = (job_data.get("status") or "").lower()
273+
if st != "streaming" and not self.stream_store.stream_exists(job_id):
274+
raise HTTPException(
275+
status_code=status.HTTP_409_CONFLICT,
276+
detail=f"Job '{job_id}' is not streaming (status: {job_data.get('status')}).",
277+
)
278+
279+
async def _event_generator():
280+
try:
281+
async for chunk in self.stream_store.read_chunks(job_id):
282+
if await request.is_disconnected():
283+
break
284+
yield chunk
285+
except Exception:
286+
self._logger.exception("Error during stream delivery | job_id=%s", job_id)
287+
yield 'data: {"error": "Internal stream error"}\n\n'
288+
289+
return StreamingResponse(
290+
_event_generator(),
291+
media_type="text/event-stream",
292+
headers={
293+
"Cache-Control": "no-cache",
294+
"Connection": "keep-alive",
295+
"X-Accel-Buffering": "no",
296+
},
297+
)
298+
299+
232300
def endpoint(self, path: str, methods: list[str] | None = None, max_upload_file_size_mb: int = None, queue_size: int = 500, use_queue: bool = None, *args, **kwargs):
233301
"""
234302
Unified endpoint decorator.
@@ -356,12 +424,16 @@ async def _unified_worker(*w_args, **w_kwargs):
356424
plan.max_upload_file_size_mb
357425
)
358426

427+
route_kwargs = dict(plan.route_kwargs)
428+
if plan.should_use_queue:
429+
route_kwargs["response_model_exclude_none"] = True
430+
359431
self.api_route(
360432
path=plan.path,
361433
methods=plan.active_methods,
362434
response_model=JobResult if plan.should_use_queue else res_model,
363435
*plan.route_args,
364-
**plan.route_kwargs
436+
**route_kwargs
365437
)(final_handler)
366438

367439
return final_handler
@@ -472,12 +544,14 @@ def _determine_generator_fun(self, func: Callable) -> bool:
472544
def _create_task_endpoint_decorator(self, path: str, methods: list[str] | None, max_upload_file_size_mb: int, queue_size: int, args, kwargs):
473545
"""Create a decorator for task endpoints (background job execution)."""
474546
# FastAPI route decorator (returning JobResult)
547+
task_kwargs = dict(kwargs)
548+
task_kwargs["response_model_exclude_none"] = True
475549
fastapi_route_decorator = self.api_route(
476550
path=path,
477551
methods=["POST"] if methods is None else methods,
478552
response_model=JobResult,
479553
*args,
480-
**kwargs
554+
**task_kwargs
481555
)
482556

483557
# Queue decorator

apipod/engine/backend/runpod/llm_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from apipod.engine.llm.base_mixin import _BaseLLMMixin
1+
from apipod.engine.llm.base_llm_mixin import _BaseLLMMixin
22

33

44
class _RunPodLLMMixin(_BaseLLMMixin):

apipod/engine/backend/runpod/router.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from apipod.engine.jobs.job_progress import JobProgressRunpod, JobProgress
1111
from apipod.engine.jobs.job_result import JobResultFactory, JobResult
1212
from apipod.engine.base_backend import _BaseBackend
13-
from apipod.engine.files.base_mixin import _BaseFileHandlingMixin
13+
from apipod.engine.files.base_file_mixin import _BaseFileHandlingMixin
1414
from apipod.engine.backend.runpod.llm_mixin import _RunPodLLMMixin
1515

1616
from apipod.engine.utils import normalize_name
@@ -195,7 +195,10 @@ def _router(self, path, job, **kwargs):
195195

196196
# Prepare result tracking
197197
start_time = datetime.now(timezone.utc)
198-
result = JobResult(id=job['id'], execution_started_at=start_time.strftime(DEFAULT_DATE_TIME_FORMAT))
198+
result = JobResult(
199+
job_id=job["id"],
200+
created_at=start_time.strftime(DEFAULT_DATE_TIME_FORMAT),
201+
)
199202

200203
try:
201204
# Execute the function (Sync or Async Handling)
@@ -218,7 +221,7 @@ def _router(self, path, job, **kwargs):
218221
print(f"Job {job['id']} failed: {str(e)}")
219222
traceback.print_exc()
220223
finally:
221-
result.execution_finished_at = datetime.now(timezone.utc).strftime(DEFAULT_DATE_TIME_FORMAT)
224+
result.updated_at = datetime.now(timezone.utc).strftime(DEFAULT_DATE_TIME_FORMAT)
222225

223226
result = result.model_dump_json()
224227
return result

0 commit comments

Comments
 (0)