Skip to content

Commit 9911e6b

Browse files
committed
Add handle endpoint status event task
1 parent c6767cd commit 9911e6b

3 files changed

Lines changed: 169 additions & 2 deletions

File tree

app/grandchallenge/components/backends/amazon_sagemaker_endpoint.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24-
class InvocationParams(NamedTuple):
24+
class ObjectParams(NamedTuple):
2525
app_label: str
2626
model_name: str
2727
pk: UUID
@@ -248,6 +248,44 @@ def attempt(method):
248248

249249
self.deprovision_auxiliary_data()
250250

251+
@staticmethod
252+
def get_endpoint_name(*, event):
253+
return event["EndpointName"]
254+
255+
@staticmethod
256+
def _get_endpoint_status(*, event):
257+
return event["EndpointStatus"]
258+
259+
@staticmethod
260+
def get_endpoint_params(*, endpoint_name):
261+
prefix_regex = re.escape(settings.COMPONENTS_REGISTRY_PREFIX)
262+
pattern = rf"^{prefix_regex}\-AE\-(?P<pk>{UUID4_REGEX})$"
263+
264+
result = re.match(pattern, endpoint_name)
265+
266+
if result is None:
267+
raise ValueError("Invalid endpoint name")
268+
else:
269+
pk = result.group("pk")
270+
return ObjectParams(
271+
app_label="algorithms",
272+
model_name="endpoint",
273+
pk=pk,
274+
)
275+
276+
def handle_status_event(self, *, event):
277+
endpoint_status = self._get_endpoint_status(event=event)
278+
279+
if endpoint_status == "IN_SERVICE":
280+
return
281+
elif endpoint_status == "FAILED":
282+
# Requires investigation
283+
task_logger.info(event)
284+
task_logger.error("Starting endpoint failed")
285+
raise ComponentException(SystemErrorMessages.UNEXPECTED_ERROR)
286+
else:
287+
raise ValueError("Invalid endpoint status")
288+
251289
def provision_invocation_input_data(self, *, input_civs):
252290
self._executor.provision(input_civs=input_civs, input_prefixes={})
253291

@@ -279,7 +317,7 @@ def get_invocation_params(*, inference_id):
279317
raise ValueError("Invalid inference id")
280318
else:
281319
pk = result.group("pk")
282-
return InvocationParams(
320+
return ObjectParams(
283321
app_label="algorithms",
284322
model_name="invocation",
285323
pk=pk,

app/grandchallenge/components/tasks.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,6 +1681,52 @@ def start_endpoint(*, pk: uuid.UUID, app_label: str, model_name: str):
16811681
error_message=SystemErrorMessages.UNEXPECTED_ERROR,
16821682
)
16831683

1684+
else:
1685+
endpoint.update_status(status=endpoint.StatusChoices.STARTING)
1686+
1687+
1688+
@lambda_task(retry_on=(LockNotAcquiredException,))
1689+
def handle_endpoint_status_event(*, event: dict):
1690+
from grandchallenge.components.backends.amazon_sagemaker_endpoint import (
1691+
EndpointOrchestrator,
1692+
)
1693+
1694+
endpoint_name = EndpointOrchestrator.get_endpoint_name(event=event)
1695+
params = EndpointOrchestrator.get_endpoint_params(
1696+
endpoint_name=endpoint_name
1697+
)
1698+
1699+
model = apps.get_model(
1700+
app_label=params.app_label,
1701+
model_name=params.model_name,
1702+
)
1703+
1704+
with check_lock_acquired():
1705+
endpoint = model.objects.select_for_update(nowait=True).get(
1706+
pk=params.pk
1707+
)
1708+
1709+
if endpoint.status != endpoint.StatusChoices.STARTING:
1710+
# Nothing to do
1711+
return
1712+
1713+
orchestrator = endpoint.orchestrator
1714+
1715+
try:
1716+
orchestrator.handle_status_event(event=event)
1717+
except ComponentException as error:
1718+
orchestrator.deprovision()
1719+
endpoint.update_status(
1720+
status=endpoint.StatusChoices.FAILED,
1721+
error_message=str(error),
1722+
)
1723+
except Exception:
1724+
logger.error("Could not start endpoint", exc_info=True)
1725+
orchestrator.deprovision()
1726+
endpoint.update_status(
1727+
status=endpoint.StatusChoices.FAILED,
1728+
error_message=SystemErrorMessages.UNEXPECTED_ERROR,
1729+
)
16841730
else:
16851731
endpoint.update_status(status=endpoint.StatusChoices.RUNNING)
16861732

app/tests/components_tests/test_tasks.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
encode_b64j,
5252
execute_job,
5353
handle_endpoint_invocation_event,
54+
handle_endpoint_status_event,
5455
parse_endpoint_invocation_outputs,
5556
preload_interactive_algorithms,
5657
remove_container_image_from_registry,
@@ -1661,6 +1662,88 @@ def test_stop_expired_endpoints(
16611662
assert endpoint_to_stop.status == EndpointStatusChoices.STOPPED
16621663

16631664

1665+
@pytest.mark.django_db
1666+
def test_handle_endpoint_status_in_service_event(settings):
1667+
endpoint = EndpointFactory(
1668+
status=EndpointStatusChoices.STARTING,
1669+
)
1670+
event = {
1671+
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
1672+
"EndpointStatus": "IN_SERVICE",
1673+
}
1674+
1675+
handle_endpoint_status_event(event=event)
1676+
endpoint.refresh_from_db()
1677+
1678+
assert endpoint.status == EndpointStatusChoices.RUNNING
1679+
1680+
1681+
@pytest.mark.django_db
1682+
def test_handle_endpoint_status_failed_events(settings, mocker):
1683+
endpoint = EndpointFactory(
1684+
status=EndpointStatusChoices.STARTING,
1685+
)
1686+
event = {
1687+
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
1688+
"EndpointStatus": "FAILED",
1689+
}
1690+
mock_deprovision = mocker.patch.object(
1691+
EndpointOrchestrator,
1692+
"deprovision",
1693+
)
1694+
1695+
handle_endpoint_status_event(event=event)
1696+
endpoint.refresh_from_db()
1697+
1698+
mock_deprovision.assert_called_once()
1699+
assert endpoint.status == EndpointStatusChoices.FAILED
1700+
assert endpoint.error_message == SystemErrorMessages.UNEXPECTED_ERROR
1701+
1702+
1703+
@pytest.mark.django_db
1704+
def test_handle_endpoint_status_invalid_events(settings, mocker):
1705+
endpoint = EndpointFactory(
1706+
status=EndpointStatusChoices.STARTING,
1707+
)
1708+
event = {
1709+
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
1710+
"EndpointStatus": "some invalid status",
1711+
}
1712+
mock_deprovision = mocker.patch.object(
1713+
EndpointOrchestrator,
1714+
"deprovision",
1715+
)
1716+
1717+
handle_endpoint_status_event(event=event)
1718+
endpoint.refresh_from_db()
1719+
1720+
mock_deprovision.assert_called_once()
1721+
assert endpoint.status == EndpointStatusChoices.FAILED
1722+
assert endpoint.error_message == SystemErrorMessages.UNEXPECTED_ERROR
1723+
1724+
1725+
@pytest.mark.parametrize(
1726+
"status",
1727+
set(EndpointStatusChoices).difference([EndpointStatusChoices.STARTING]),
1728+
)
1729+
@pytest.mark.django_db
1730+
def test_handle_endpoint_status_wrong_state_ignored(mocker, settings, status):
1731+
endpoint = EndpointFactory(status=status)
1732+
event = {
1733+
"EndpointName": f"{settings.COMPONENTS_REGISTRY_PREFIX}-AE-{endpoint.pk}",
1734+
}
1735+
mock_handle_status_event = mocker.patch.object(
1736+
EndpointOrchestrator,
1737+
"handle_status_event",
1738+
)
1739+
1740+
handle_endpoint_status_event(event=event)
1741+
endpoint.refresh_from_db()
1742+
1743+
mock_handle_status_event.assert_not_called()
1744+
assert endpoint.status == status
1745+
1746+
16641747
@pytest.mark.django_db
16651748
def test_handle_endpoint_invocation_completed_event(settings):
16661749
invocation = InvocationFactory(

0 commit comments

Comments
 (0)