Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Generated by Django 5.2.14 on 2026-05-29 15:23

from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
(
"algorithms",
"0098_endpoint_runtime_metrics_endpoint_stderr_and_more",
),
("auth", "0012_alter_user_first_name_max_length"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]

operations = [
migrations.RemoveConstraint(
model_name="endpoint",
name="endpoint_status_valid",
),
migrations.AlterField(
model_name="endpoint",
name="status",
field=models.CharField(
choices=[
("QUEUED", "Queued"),
("STARTED", "Started"),
("RUNNING", "Running"),
("FAILED", "Failed"),
("STOPPED", "Stopped"),
],
default="QUEUED",
max_length=17,
),
),
migrations.AddConstraint(
model_name="endpoint",
constraint=models.CheckConstraint(
condition=models.Q(
(
"status__in",
["QUEUED", "STARTED", "RUNNING", "FAILED", "STOPPED"],
)
),
name="endpoint_status_valid",
),
),
]
4 changes: 3 additions & 1 deletion app/grandchallenge/algorithms/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,14 +1544,16 @@ class Meta:

class EndpointStatusChoices(TextChoices):
QUEUED = "QUEUED", _("Queued")
STARTED = "STARTED", _("Started")
RUNNING = "RUNNING", _("Running")
STOPPED = "STOPPED", _("Stopped")
FAILED = "FAILED", _("Failed")
STOPPED = "STOPPED", _("Stopped")

@classmethod
def get_active_choices(cls):
return {
cls.QUEUED,
cls.STARTED,
cls.RUNNING,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


class InvocationParams(NamedTuple):
class ObjectParams(NamedTuple):
app_label: str
model_name: str
pk: UUID
Expand Down Expand Up @@ -248,6 +248,44 @@ def attempt(method):

self.deprovision_auxiliary_data()

@staticmethod
def get_endpoint_name(*, event):
return event["EndpointName"]

@staticmethod
def _get_endpoint_status(*, event):
return event["EndpointStatus"]

@staticmethod
def get_endpoint_params(*, endpoint_name):
prefix_regex = re.escape(settings.COMPONENTS_REGISTRY_PREFIX)
pattern = rf"^{prefix_regex}\-AE\-(?P<pk>{UUID4_REGEX})$"

result = re.match(pattern, endpoint_name)

if result is None:
raise ValueError("Invalid endpoint name")
else:
pk = result.group("pk")
return ObjectParams(
app_label="algorithms",
model_name="endpoint",
pk=pk,
)

def handle_status_event(self, *, event):
endpoint_status = self._get_endpoint_status(event=event)

if endpoint_status == "IN_SERVICE":
return
elif endpoint_status == "FAILED":
# Requires investigation
task_logger.info(event)
task_logger.error("Starting endpoint failed")
raise ComponentException(SystemErrorMessages.UNEXPECTED_ERROR)
else:
raise ValueError("Invalid endpoint status")

def provision_invocation_input_data(self, *, input_civs):
self._executor.provision(input_civs=input_civs, input_prefixes={})

Expand Down Expand Up @@ -279,7 +317,7 @@ def get_invocation_params(*, inference_id):
raise ValueError("Invalid inference id")
else:
pk = result.group("pk")
return InvocationParams(
return ObjectParams(
app_label="algorithms",
model_name="invocation",
pk=pk,
Expand Down
46 changes: 46 additions & 0 deletions app/grandchallenge/components/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,52 @@ def start_endpoint(*, pk: uuid.UUID, app_label: str, model_name: str):
error_message=SystemErrorMessages.UNEXPECTED_ERROR,
)

else:
endpoint.update_status(status=endpoint.StatusChoices.STARTED)


@lambda_task(retry_on=(LockNotAcquiredException,))
def handle_endpoint_status_event(*, event: dict):
from grandchallenge.components.backends.amazon_sagemaker_endpoint import (
EndpointOrchestrator,
)

endpoint_name = EndpointOrchestrator.get_endpoint_name(event=event)
params = EndpointOrchestrator.get_endpoint_params(
endpoint_name=endpoint_name
)

model = apps.get_model(
app_label=params.app_label,
model_name=params.model_name,
)

with check_lock_acquired():
endpoint = model.objects.select_for_update(nowait=True).get(
pk=params.pk
)

if endpoint.status != endpoint.StatusChoices.STARTED:
# Nothing to do
return

orchestrator = endpoint.orchestrator

try:
orchestrator.handle_status_event(event=event)
except ComponentException as error:
orchestrator.deprovision()
endpoint.update_status(
status=endpoint.StatusChoices.FAILED,
error_message=str(error),
)
except Exception:
logger.error("Could not start endpoint", exc_info=True)
orchestrator.deprovision()
endpoint.update_status(
status=endpoint.StatusChoices.FAILED,
error_message=SystemErrorMessages.UNEXPECTED_ERROR,
)
else:
endpoint.update_status(status=endpoint.StatusChoices.RUNNING)

Expand Down
85 changes: 84 additions & 1 deletion app/tests/components_tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
encode_b64j,
execute_job,
handle_endpoint_invocation_event,
handle_endpoint_status_event,
parse_endpoint_invocation_outputs,
preload_interactive_algorithms,
remove_container_image_from_registry,
Expand Down Expand Up @@ -1499,7 +1500,7 @@ def test_start_endpoint(mocker):

for mock_method in mock_start_methods:
mock_method.assert_called_once()
assert endpoint.status == endpoint.StatusChoices.RUNNING
assert endpoint.status == endpoint.StatusChoices.STARTED


@pytest.mark.django_db
Expand Down Expand Up @@ -1661,6 +1662,88 @@ def test_stop_expired_endpoints(
assert endpoint_to_stop.status == EndpointStatusChoices.STOPPED


@pytest.mark.django_db
def test_handle_endpoint_status_in_service_event(settings):
endpoint = EndpointFactory(
status=EndpointStatusChoices.STARTED,
)
event = {
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
"EndpointStatus": "IN_SERVICE",
}

handle_endpoint_status_event(event=event)
endpoint.refresh_from_db()

assert endpoint.status == EndpointStatusChoices.RUNNING


@pytest.mark.django_db
def test_handle_endpoint_status_failed_events(settings, mocker):
endpoint = EndpointFactory(
status=EndpointStatusChoices.STARTED,
)
event = {
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
"EndpointStatus": "FAILED",
}
mock_deprovision = mocker.patch.object(
EndpointOrchestrator,
"deprovision",
)

handle_endpoint_status_event(event=event)
endpoint.refresh_from_db()

mock_deprovision.assert_called_once()
assert endpoint.status == EndpointStatusChoices.FAILED
assert endpoint.error_message == SystemErrorMessages.UNEXPECTED_ERROR


@pytest.mark.django_db
def test_handle_endpoint_status_invalid_events(settings, mocker):
endpoint = EndpointFactory(
status=EndpointStatusChoices.STARTED,
)
event = {
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
"EndpointStatus": "some invalid status",
}
mock_deprovision = mocker.patch.object(
EndpointOrchestrator,
"deprovision",
)

handle_endpoint_status_event(event=event)
endpoint.refresh_from_db()

mock_deprovision.assert_called_once()
assert endpoint.status == EndpointStatusChoices.FAILED
assert endpoint.error_message == SystemErrorMessages.UNEXPECTED_ERROR


@pytest.mark.parametrize(
"status",
set(EndpointStatusChoices).difference([EndpointStatusChoices.STARTED]),
)
@pytest.mark.django_db
def test_handle_endpoint_status_wrong_state_ignored(mocker, settings, status):
endpoint = EndpointFactory(status=status)
event = {
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
}
mock_handle_status_event = mocker.patch.object(
EndpointOrchestrator,
"handle_status_event",
)

handle_endpoint_status_event(event=event)
endpoint.refresh_from_db()

mock_handle_status_event.assert_not_called()
assert endpoint.status == status


@pytest.mark.django_db
def test_handle_endpoint_invocation_completed_event(settings):
invocation = InvocationFactory(
Expand Down
10 changes: 8 additions & 2 deletions app/tests/core_tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,13 @@ def test_get_metrics():
"Unit": "Count",
},
{
"MetricName": "EndpointsRunning",
"MetricName": "EndpointsStarted",
"Dimensions": [{"Name": "Model", "Value": "Endpoint"}],
"Value": 0,
"Unit": "Count",
},
{
"MetricName": "EndpointsStopped",
"MetricName": "EndpointsRunning",
"Dimensions": [{"Name": "Model", "Value": "Endpoint"}],
"Value": 0,
"Unit": "Count",
Expand All @@ -336,6 +336,12 @@ def test_get_metrics():
"Value": 0,
"Unit": "Count",
},
{
"MetricName": "EndpointsStopped",
"Dimensions": [{"Name": "Model", "Value": "Endpoint"}],
"Value": 0,
"Unit": "Count",
},
{
"MetricName": "OldestActiveAlgorithmImage",
"Value": 0,
Expand Down
Loading