|
4 | 4 | from guardian.shortcuts import assign_perm |
5 | 5 | from rest_framework.exceptions import ErrorDetail |
6 | 6 |
|
7 | | -from grandchallenge.algorithms.models import Job |
| 7 | +from grandchallenge.algorithms.models import Endpoint, Job |
8 | 8 | from grandchallenge.algorithms.serializers import ( |
9 | 9 | HyperlinkedJobSerializer, |
| 10 | + InvocationPostSerializer, |
10 | 11 | JobPostSerializer, |
11 | 12 | ) |
12 | 13 | from grandchallenge.cases.models import RawImageUploadSession |
|
16 | 17 | AlgorithmImageFactory, |
17 | 18 | AlgorithmInterfaceFactory, |
18 | 19 | AlgorithmJobFactory, |
| 20 | + EndpointFactory, |
19 | 21 | ) |
20 | 22 | from tests.cases_tests.factories import RawImageUploadSessionFactory |
21 | 23 | from tests.components_tests.factories import ( |
@@ -586,3 +588,90 @@ def test_validate_inputs_on_job_serializer(inputs, interface, rf): |
586 | 588 | in str(serializer.errors) |
587 | 589 | ) |
588 | 590 | assert "algorithm_interface" not in serializer.validated_data |
| 591 | + |
| 592 | + |
| 593 | +@pytest.mark.parametrize( |
| 594 | + "inputs, interface", |
| 595 | + ( |
| 596 | + ([1], 1), # matches interface 1 of algorithm |
| 597 | + ([1, 2], 2), # matches interface 2 of algorithm |
| 598 | + ([3, 4, 5], 3), # matches interface 3 of algorithm |
| 599 | + ([4], None), # matches interface 4, but not configured for algorithm |
| 600 | + ( |
| 601 | + [1, 2, 3], |
| 602 | + None, |
| 603 | + ), # matches interface 5, but not configured for algorithm |
| 604 | + ([2], None), # matches no interface (implements part of interface 2) |
| 605 | + ( |
| 606 | + [1, 3, 4], |
| 607 | + None, |
| 608 | + ), # matches no interface (implements interface 3 and an additional input) |
| 609 | + ), |
| 610 | +) |
| 611 | +@pytest.mark.django_db |
| 612 | +def test_input_validation_on_invocation_serializer(inputs, interface, rf): |
| 613 | + user = UserFactory() |
| 614 | + algorithm = AlgorithmFactory() |
| 615 | + algorithm.add_editor(user) |
| 616 | + ai = AlgorithmImageFactory( |
| 617 | + algorithm=algorithm, |
| 618 | + is_desired_version=True, |
| 619 | + is_manifest_valid=True, |
| 620 | + is_in_registry=True, |
| 621 | + ) |
| 622 | + endpoint = EndpointFactory( |
| 623 | + algorithm_image=ai, creator=user, status=Endpoint.StatusChoices.RUNNING |
| 624 | + ) |
| 625 | + |
| 626 | + io1, io2, io3, io4, io5 = AlgorithmInterfaceFactory.create_batch(5) |
| 627 | + ci1, ci2, ci3, ci4, ci5, ci6 = ComponentInterfaceFactory.create_batch( |
| 628 | + 6, kind=ComponentInterface.Kind.STRING |
| 629 | + ) |
| 630 | + |
| 631 | + interfaces = [io1, io2, io3] |
| 632 | + cis = [ci1, ci2, ci3, ci4, ci5, ci6] |
| 633 | + |
| 634 | + io1.inputs.set([ci1]) |
| 635 | + io2.inputs.set([ci1, ci2]) |
| 636 | + io3.inputs.set([ci3, ci4, ci5]) |
| 637 | + io4.inputs.set([ci1, ci2, ci3]) |
| 638 | + io5.inputs.set([ci4]) |
| 639 | + io1.outputs.set([ci6]) |
| 640 | + io2.outputs.set([ci3]) |
| 641 | + io3.outputs.set([ci1]) |
| 642 | + io4.outputs.set([ci1]) |
| 643 | + io5.outputs.set([ci1]) |
| 644 | + |
| 645 | + algorithm.interfaces.add(io1) |
| 646 | + algorithm.interfaces.add(io2) |
| 647 | + algorithm.interfaces.add(io3) |
| 648 | + |
| 649 | + algorithm_interface = interfaces[interface - 1] if interface else None |
| 650 | + inputs = [cis[i - 1] for i in inputs] |
| 651 | + |
| 652 | + invocation = { |
| 653 | + "endpoint": endpoint.api_url, |
| 654 | + "inputs": [ |
| 655 | + {"interface": int.slug, "value": "dummy"} for int in inputs |
| 656 | + ], |
| 657 | + } |
| 658 | + |
| 659 | + request = rf.get("/foo") |
| 660 | + request.user = user |
| 661 | + serializer = InvocationPostSerializer( |
| 662 | + data=invocation, context={"request": request} |
| 663 | + ) |
| 664 | + |
| 665 | + if interface: |
| 666 | + assert serializer.is_valid() |
| 667 | + assert ( |
| 668 | + serializer.validated_data["algorithm_interface"] |
| 669 | + == algorithm_interface |
| 670 | + ) |
| 671 | + else: |
| 672 | + assert not serializer.is_valid() |
| 673 | + assert ( |
| 674 | + "The set of inputs provided does not match any of the endpoint's algorithm's interfaces." |
| 675 | + in str(serializer.errors) |
| 676 | + ) |
| 677 | + assert "algorithm_interface" not in serializer.validated_data |
0 commit comments