Skip to content

Commit 67e768d

Browse files
committed
Add create method on serializer
1 parent 9a230f7 commit 67e768d

6 files changed

Lines changed: 174 additions & 10 deletions

File tree

app/grandchallenge/algorithms/models.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1762,7 +1762,7 @@ class InvocationStatusChoices(TextChoices):
17621762
CANCELLED = "CANCELLED", _("Cancelled")
17631763

17641764

1765-
class Invocation(UUIDModel):
1765+
class Invocation(CIVForObjectMixin, UUIDModel):
17661766
StatusChoices = InvocationStatusChoices
17671767

17681768
endpoint = models.ForeignKey(Endpoint, on_delete=models.PROTECT)
@@ -1861,6 +1861,13 @@ def orchestrator_kwargs(self):
18611861
def orchestrator(self):
18621862
return EndpointOrchestrator(**self.orchestrator_kwargs)
18631863

1864+
@property
1865+
def is_editable(self):
1866+
if self.status == InvocationStatusChoices.VALIDATING_INPUTS:
1867+
return True
1868+
else:
1869+
return False
1870+
18641871
@cached_property
18651872
def inputs_complete(self):
18661873
# check if all inputs are present and if they all have a value
@@ -1894,3 +1901,31 @@ def update_status(
18941901
self.invoke_duration = invoke_duration
18951902

18961903
self.save()
1904+
1905+
def add_civ(self, *, civ):
1906+
super().add_civ(civ=civ)
1907+
return self.inputs.add(civ)
1908+
1909+
def remove_civ(self, *, civ):
1910+
super().remove_civ(civ=civ)
1911+
return self.inputs.remove(civ)
1912+
1913+
def get_civ_for_interface(self, interface):
1914+
return self.inputs.get(interface=interface)
1915+
1916+
def process_civ_data_objects_and_execute_linked_task(
1917+
self, *, civ_data_objects, user, linked_task=None
1918+
):
1919+
from grandchallenge.algorithms.tasks import (
1920+
execute_invocation_for_inputs,
1921+
)
1922+
1923+
linked_task = execute_invocation_for_inputs.signature(
1924+
kwargs={"invocation_pk": str(self.pk)}, immutable=True
1925+
)
1926+
1927+
return super().process_civ_data_objects_and_execute_linked_task(
1928+
civ_data_objects=civ_data_objects,
1929+
user=user,
1930+
linked_task=linked_task,
1931+
)

app/grandchallenge/algorithms/serializers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
AlgorithmModel,
2222
Endpoint,
2323
Invocation,
24+
InvocationStatusChoices,
2425
Job,
2526
annotate_input_output_counts,
2627
)
@@ -399,3 +400,31 @@ def validate(self, data):
399400
)
400401

401402
return data
403+
404+
def create(self, validated_data):
405+
civ_data_objects = validated_data.pop("civ_data_objects", [])
406+
407+
invocation = Invocation.objects.create(
408+
**validated_data,
409+
status=InvocationStatusChoices.VALIDATING_INPUTS,
410+
)
411+
412+
try:
413+
invocation.process_civ_data_objects_and_execute_linked_task(
414+
civ_data_objects=civ_data_objects,
415+
user=self.context["request"].user,
416+
)
417+
except CIVNotEditableException as e:
418+
invocation.refresh_from_db()
419+
if invocation.status == invocation.StatusChoices.CANCELLED:
420+
# this can happen for jobs with multiple inputs
421+
# if one of them fails validation
422+
pass
423+
else:
424+
error_handler = invocation.get_error_handler()
425+
error_handler.handle_error(
426+
error_message=SystemErrorMessages.UNEXPECTED_ERROR,
427+
)
428+
logger.error(e, exc_info=True)
429+
430+
return invocation

app/grandchallenge/algorithms/tasks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from grandchallenge.algorithms.exceptions import TooManyJobsScheduled
1313
from grandchallenge.components.schemas import GPUTypeChoices
1414
from grandchallenge.components.tasks import (
15+
provision_invocation_input_data,
1516
remove_container_image_from_registry,
1617
)
1718
from grandchallenge.core.celery import (
@@ -325,3 +326,31 @@ def deactivate_old_algorithm_images():
325326
}
326327
).apply_async
327328
)
329+
330+
331+
@acks_late_micro_short_task(retry_on=(LockNotAcquiredException,))
332+
@transaction.atomic
333+
def execute_invocation_for_inputs(*, invocation_pk):
334+
from grandchallenge.algorithms.models import Invocation
335+
336+
with check_lock_acquired():
337+
invocation = Invocation.objects.select_for_update(nowait=True).get(
338+
pk=invocation_pk
339+
)
340+
341+
if not invocation.inputs_complete:
342+
# Nothing to do
343+
return
344+
345+
if invocation.status != Invocation.StatusChoices.VALIDATING_INPUTS:
346+
# this task can be called multiple times with complete inputs,
347+
# and might have been queued for execution already, so ignore
348+
return
349+
350+
invocation.update_status(status=Invocation.StatusChoices.QUEUED)
351+
352+
on_commit(
353+
provision_invocation_input_data.signature(
354+
kwargs=invocation.task_kwargs
355+
).apply_async
356+
)

app/grandchallenge/components/models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
DICOMImageSetUploadErrorHandler,
7474
EvaluationCIVErrorHandler,
7575
FallbackCIVValidationErrorHandler,
76+
InvocationCIVErrorHandler,
7677
JobCIVErrorHandler,
7778
RawImageUploadSessionErrorHandler,
7879
UserUploadCIVErrorHandler,
@@ -2841,7 +2842,7 @@ def get_current_value_for_interface(self, *, interface, user):
28412842

28422843
def get_error_handler(self, *, linked_object=None):
28432844
# local imports to prevent circular dependency
2844-
from grandchallenge.algorithms.models import Job
2845+
from grandchallenge.algorithms.models import Invocation, Job
28452846
from grandchallenge.archives.models import ArchiveItem
28462847
from grandchallenge.evaluation.models import Evaluation
28472848
from grandchallenge.reader_studies.models import DisplaySet
@@ -2860,6 +2861,8 @@ def get_error_handler(self, *, linked_object=None):
28602861
return JobCIVErrorHandler(job=self)
28612862
elif isinstance(self, Evaluation):
28622863
return EvaluationCIVErrorHandler(job=self)
2864+
elif isinstance(self, Invocation):
2865+
return InvocationCIVErrorHandler(invocation=self)
28632866
elif linked_object and isinstance(linked_object, UserUpload):
28642867
return UserUploadCIVErrorHandler(
28652868
user_upload=linked_object,

app/grandchallenge/core/error_handlers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,41 @@ def handle_error(self, *, error_message, interface=None, user=None):
8686
)
8787

8888

89+
class InvocationCIVErrorHandler(ErrorHandler):
90+
"""
91+
Error handler for CIV validation errors on invocation creation.
92+
Handle_error() updates an algorithm endpoint invocation.
93+
"""
94+
95+
def __init__(self, *args, invocation, **kwargs):
96+
from grandchallenge.algorithms.models import Invocation
97+
98+
if not invocation or not isinstance(invocation, Invocation):
99+
raise RuntimeError(
100+
"You need to provide an Invocation instance to this error handler."
101+
)
102+
103+
super().__init__(*args, **kwargs)
104+
self._invocation = invocation
105+
106+
def handle_error(self, *, error_message, interface=None, user=None):
107+
if interface:
108+
detailed_error_message = copy.deepcopy(
109+
self._invocation.detailed_error_message
110+
)
111+
detailed_error_message[interface.title] = error_message
112+
self._invocation.update_status(
113+
status=self._invocation.StatusChoices.CANCELLED,
114+
error_message="One or more of the inputs failed validation.",
115+
detailed_error_message=detailed_error_message,
116+
)
117+
else:
118+
self._invocation.update_status(
119+
status=self._invocation.StatusChoices.CANCELLED,
120+
error_message=error_message,
121+
)
122+
123+
89124
class RawImageUploadSessionErrorHandler(ErrorHandler):
90125
"""
91126
Error handler for image imports and image validation.

app/tests/algorithms_tests/test_serializers.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -758,24 +758,57 @@ def test_time_limit_validation_on_invocation_post_serializer(settings):
758758

759759

760760
@pytest.mark.django_db
761-
def test_invocation_post_serializer_create(request):
762-
endpoint = EndpointFactory(status=Endpoint.StatusChoices.RUNNING)
763-
request.user = endpoint.creator
764-
socket = ComponentInterfaceFactory(kind=ComponentInterface.Kind.STRING)
765-
interface = AlgorithmInterfaceFactory(inputs=[socket])
761+
def test_invocation_post_serializer_create(
762+
request, settings, django_capture_on_commit_callbacks
763+
):
764+
settings.CELERY_TASK_ALWAYS_EAGER = True
765+
settings.CELERY_TASK_EAGER_PROPAGATES = True
766+
767+
user = UserFactory()
768+
request.user = user
769+
endpoint = EndpointFactory.create(
770+
creator=user,
771+
status=Endpoint.StatusChoices.RUNNING,
772+
)
773+
ci_string = ComponentInterfaceFactory.create(
774+
kind=ComponentInterface.Kind.STRING
775+
)
776+
ci_img1, ci_img2 = ComponentInterfaceFactory.create_batch(
777+
2, kind=ComponentInterface.Kind.PANIMG_IMAGE
778+
)
779+
interface = AlgorithmInterfaceFactory(inputs=[ci_string, ci_img2, ci_img1])
766780
endpoint.algorithm_image.algorithm.interfaces.add(interface)
781+
upload = RawImageUploadSessionFactory(creator=user)
782+
image1, image2 = ImageFactory.create_batch(2)
783+
for im in [image1, image2]:
784+
assign_perm("view_image", user, im)
785+
upload.image_set.set([image1])
767786

768787
serializer = InvocationPostSerializer(
769788
data={
770789
"endpoint": endpoint.api_url,
771-
"inputs": [{"interface": socket.slug, "value": "dummy"}],
790+
"inputs": [
791+
{"interface": ci_string.slug, "value": "foo"},
792+
{"interface": ci_img1.slug, "upload_session": upload.api_url},
793+
{"interface": ci_img2.slug, "image": image2.api_url},
794+
],
772795
},
773796
context={"request": request},
774797
)
775798

776-
assert serializer.is_valid()
777-
serializer.create(serializer.validated_data)
799+
assert serializer.is_valid(), serializer.errors
800+
801+
# fake successful upload
802+
upload.status = RawImageUploadSession.SUCCESS
803+
upload.save()
804+
805+
with django_capture_on_commit_callbacks(execute=True):
806+
serializer.create(serializer.validated_data)
807+
808+
assert Invocation.objects.count() == 1
778809

779810
invocation = Invocation.objects.get()
780811

781812
assert invocation.endpoint == endpoint
813+
assert invocation.algorithm_interface == interface
814+
assert invocation.inputs.count() == 3

0 commit comments

Comments
 (0)