Skip to content

Cuda oom handler [continuation] #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clearml_serving/serving/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# print configuration
echo CLEARML_SERVING_TASK_ID="$CLEARML_SERVING_TASK_ID"
echo CLEARML_INFERENCE_TASK_ID="$CLEARML_INFERENCE_TASK_ID"
echo CLEARML_SERVING_PORT="$CLEARML_SERVING_PORT"
echo CLEARML_USE_GUNICORN="$CLEARML_USE_GUNICORN"
echo CLEARML_EXTRA_PYTHON_PACKAGES="$CLEARML_EXTRA_PYTHON_PACKAGES"
Expand Down
2 changes: 2 additions & 0 deletions clearml_serving/serving/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

def setup_task(force_threaded_logging=None):
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
inference_service_task_id = os.environ.get("CLEARML_INFERENCE_TASK_ID", False) # according Task.init() docs

# always use background thread, it requires less memory
if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"):
Expand All @@ -24,6 +25,7 @@ def setup_task(force_threaded_logging=None):
project_name=serving_task.get_project_name(),
task_name="{} - serve instance".format(serving_task.name),
task_type="inference", # noqa
continue_last_task=inference_service_task_id,
)
instance_task.set_system_tags(["service"])
# make sure we start logging thread/process
Expand Down
26 changes: 25 additions & 1 deletion clearml_serving/serving/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
import traceback
import gzip
import asyncio

from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
from fastapi.routing import APIRoute
from fastapi.responses import PlainTextResponse

from starlette.background import BackgroundTask

from typing import Optional, Dict, Any, Callable, Union

Expand Down Expand Up @@ -48,6 +52,9 @@ async def custom_route_handler(request: Request) -> Response:
except (ValueError, TypeError):
pass

class CUDAException(Exception):
def __init__(self, exception: str):
self.exception = exception

# start FastAPI app
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
Expand All @@ -70,6 +77,20 @@ async def startup_event():
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)


@app.on_event('shutdown')
def shutdown_event():
print('RESTARTING INFERENCE SERVICE!')

async def exit_app():
loop = asyncio.get_running_loop()
loop.stop()

@app.exception_handler(CUDAException)
async def cuda_exception_handler(request, exc):
task = BackgroundTask(exit_app)
return PlainTextResponse("CUDA out of memory. Restarting service", status_code=500, background=task)


router = APIRouter(
prefix="/serve",
tags=["models"],
Expand Down Expand Up @@ -102,7 +123,10 @@ async def serve_model(model_id: str, version: Optional[str] = None, request: Uni
except ValueError as ex:
session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())))
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
if "CUDA out of memory. " in str(ex) or "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(ex):
raise CUDAException(exception=ex)
else:
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
except Exception as ex:
session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())))
Expand Down
7 changes: 7 additions & 0 deletions clearml_serving/serving/model_request_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import gc
import torch
from collections import deque
from pathlib import Path
from random import random
Expand Down Expand Up @@ -915,7 +917,12 @@ def _sync_daemon(self, poll_frequency_sec: float = 300) -> None:
for k in list(self._engine_processor_lookup.keys()):
if k not in self._endpoints:
# atomic
self._engine_processor_lookup[k]._model = None
self._engine_processor_lookup[k]._preprocess = None
del self._engine_processor_lookup[k]
self._engine_processor_lookup.pop(k, None)
gc.collect()
torch.cuda.empty_cache()
cleanup = False
model_monitor_update = False
except Exception as ex:
Expand Down
1 change: 1 addition & 0 deletions clearml_serving/serving/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ lightgbm>=3.3.2,<3.4
requests>=2.31.0
kafka-python>=2.0.2,<2.1
lz4>=4.0.0,<5
torch>=2.1.2
3 changes: 2 additions & 1 deletion docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ services:
CLEARML_DEFAULT_KAFKA_SERVE_URL: ${CLEARML_DEFAULT_KAFKA_SERVE_URL:-clearml-serving-kafka:9092}
CLEARML_DEFAULT_TRITON_GRPC_ADDR: ${CLEARML_DEFAULT_TRITON_GRPC_ADDR:-}
CLEARML_USE_GUNICORN: ${CLEARML_USE_GUNICORN:-}
CLEARML_SERVING_NUM_PROCESS: ${CLEARML_SERVING_NUM_PROCESS:-}
# CLEARML_SERVING_NUM_PROCESS has to be 1 to activate CUDA OOM handler
CLEARML_SERVING_NUM_PROCESS: "1"
CLEARML_EXTRA_PYTHON_PACKAGES: ${CLEARML_EXTRA_PYTHON_PACKAGES:-}
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
Expand Down