Skip to content

Commit 8e050ee

Browse files
committed
Add serializer validation
1 parent f99d218 commit 8e050ee

2 files changed

Lines changed: 127 additions & 1 deletion

File tree

app/grandchallenge/algorithms/serializers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,40 @@ def __init__(self, *args, **kwargs):
374374
user=user,
375375
codename="invoke_endpoint",
376376
)
377+
378+
def validate(self, data):
379+
self._endpoint = data.pop("endpoint")
380+
inputs = data.pop("inputs")
381+
data["algorithm_interface"] = (
382+
self.validate_inputs_and_return_matching_interface(inputs=inputs)
383+
)
384+
data["civ_data_objects"] = convert_deserialized_civ_data(
385+
deserialized_civ_data=inputs
386+
)
387+
388+
return data
389+
390+
def validate_inputs_and_return_matching_interface(self, *, inputs):
391+
"""
392+
Validates that the provided inputs match one of the endpoint's algorithm's configured interfaces
393+
"""
394+
provided_inputs = {i["interface"] for i in inputs}
395+
allowed_algorithm_interfaces = (
396+
self._endpoint.algorithm_image.algorithm.interfaces.all()
397+
)
398+
annotated_qs = annotate_input_output_counts(
399+
allowed_algorithm_interfaces, inputs=provided_inputs
400+
)
401+
try:
402+
interface = annotated_qs.get(
403+
relevant_input_count=len(provided_inputs),
404+
input_count=len(provided_inputs),
405+
)
406+
return interface
407+
except ObjectDoesNotExist:
408+
raise serializers.ValidationError(
409+
f"The set of inputs provided does not match "
410+
f"any of the endpoint's algorithm's interfaces. This algorithm supports the "
411+
f"following input combinations: "
412+
f"{oxford_comma([f'Interface {n}: {oxford_comma(interface.inputs.all())}' for n, interface in enumerate(allowed_algorithm_interfaces, start=1)])}"
413+
)

app/tests/algorithms_tests/test_serializers.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from guardian.shortcuts import assign_perm
55
from rest_framework.exceptions import ErrorDetail
66

7-
from grandchallenge.algorithms.models import Job
7+
from grandchallenge.algorithms.models import Endpoint, Job
88
from grandchallenge.algorithms.serializers import (
99
HyperlinkedJobSerializer,
10+
InvocationPostSerializer,
1011
JobPostSerializer,
1112
)
1213
from grandchallenge.cases.models import RawImageUploadSession
@@ -16,6 +17,7 @@
1617
AlgorithmImageFactory,
1718
AlgorithmInterfaceFactory,
1819
AlgorithmJobFactory,
20+
EndpointFactory,
1921
)
2022
from tests.cases_tests.factories import RawImageUploadSessionFactory
2123
from tests.components_tests.factories import (
@@ -586,3 +588,90 @@ def test_validate_inputs_on_job_serializer(inputs, interface, rf):
586588
in str(serializer.errors)
587589
)
588590
assert "algorithm_interface" not in serializer.validated_data
591+
592+
593+
@pytest.mark.parametrize(
594+
"inputs, interface",
595+
(
596+
([1], 1), # matches interface 1 of algorithm
597+
([1, 2], 2), # matches interface 2 of algorithm
598+
([3, 4, 5], 3), # matches interface 3 of algorithm
599+
([4], None), # matches interface 4, but not configured for algorithm
600+
(
601+
[1, 2, 3],
602+
None,
603+
), # matches interface 5, but not configured for algorithm
604+
([2], None), # matches no interface (implements part of interface 2)
605+
(
606+
[1, 3, 4],
607+
None,
608+
), # matches no interface (implements interface 3 and an additional input)
609+
),
610+
)
611+
@pytest.mark.django_db
612+
def test_input_validation_on_invocation_serializer(inputs, interface, rf):
613+
user = UserFactory()
614+
algorithm = AlgorithmFactory()
615+
algorithm.add_editor(user)
616+
ai = AlgorithmImageFactory(
617+
algorithm=algorithm,
618+
is_desired_version=True,
619+
is_manifest_valid=True,
620+
is_in_registry=True,
621+
)
622+
endpoint = EndpointFactory(
623+
algorithm_image=ai, creator=user, status=Endpoint.StatusChoices.RUNNING
624+
)
625+
626+
io1, io2, io3, io4, io5 = AlgorithmInterfaceFactory.create_batch(5)
627+
ci1, ci2, ci3, ci4, ci5, ci6 = ComponentInterfaceFactory.create_batch(
628+
6, kind=ComponentInterface.Kind.STRING
629+
)
630+
631+
interfaces = [io1, io2, io3]
632+
cis = [ci1, ci2, ci3, ci4, ci5, ci6]
633+
634+
io1.inputs.set([ci1])
635+
io2.inputs.set([ci1, ci2])
636+
io3.inputs.set([ci3, ci4, ci5])
637+
io4.inputs.set([ci1, ci2, ci3])
638+
io5.inputs.set([ci4])
639+
io1.outputs.set([ci6])
640+
io2.outputs.set([ci3])
641+
io3.outputs.set([ci1])
642+
io4.outputs.set([ci1])
643+
io5.outputs.set([ci1])
644+
645+
algorithm.interfaces.add(io1)
646+
algorithm.interfaces.add(io2)
647+
algorithm.interfaces.add(io3)
648+
649+
algorithm_interface = interfaces[interface - 1] if interface else None
650+
inputs = [cis[i - 1] for i in inputs]
651+
652+
invocation = {
653+
"endpoint": endpoint.api_url,
654+
"inputs": [
655+
{"interface": int.slug, "value": "dummy"} for int in inputs
656+
],
657+
}
658+
659+
request = rf.get("/foo")
660+
request.user = user
661+
serializer = InvocationPostSerializer(
662+
data=invocation, context={"request": request}
663+
)
664+
665+
if interface:
666+
assert serializer.is_valid()
667+
assert (
668+
serializer.validated_data["algorithm_interface"]
669+
== algorithm_interface
670+
)
671+
else:
672+
assert not serializer.is_valid()
673+
assert (
674+
"The set of inputs provided does not match any of the endpoint's algorithm's interfaces."
675+
in str(serializer.errors)
676+
)
677+
assert "algorithm_interface" not in serializer.validated_data

0 commit comments

Comments
 (0)