From 99638dadba6b2aac444feb42cf3c1a0b66922e0d Mon Sep 17 00:00:00 2001 From: Michael Abashian Date: Tue, 11 Feb 2025 12:03:46 -0500 Subject: [PATCH] More test fixes --- ansible_ai_connect/ai/api/serializers.py | 15 +++++---------- .../ai/api/tests/test_role_explanation_view.py | 13 +++++++++++-- .../v1/ai/tests/test_role_explanation_view.py | 2 +- ansible_ai_connect/ai/api/views.py | 8 ++++---- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/ansible_ai_connect/ai/api/serializers.py b/ansible_ai_connect/ai/api/serializers.py index cca627fd9..329e4386b 100644 --- a/ansible_ai_connect/ai/api/serializers.py +++ b/ansible_ai_connect/ai/api/serializers.py @@ -654,22 +654,17 @@ class ExplanationRoleRequestSerializer(Metadata): label="Files", help_text="A list of role files to be explained.", ) - role_name = serializers.CharField( + roleName = serializers.CharField( required=True, label="Role name", help_text="The name of the role.", ) - model_id = serializers.CharField(required=False, allow_blank=True, default="") - focus_on_file = serializers.CharField(required=False, allow_blank=True, default="") + model = serializers.CharField(required=False, allow_blank=True, default="") + focusOnFile = serializers.CharField(required=False, allow_blank=True, default="") def validate(self, attrs): - attrs = super().validate(attrs) - - if "model_id" in attrs and not attrs["model_id"].strip(): - del attrs["model_id"] - if "focus_on_file" in attrs and not attrs["focus_on_file"].strip(): - del attrs["focus_on_file"] - return attrs + data = super().validate(attrs) + return data class ContentMatchRequestSerializer(Metadata): diff --git a/ansible_ai_connect/ai/api/tests/test_role_explanation_view.py b/ansible_ai_connect/ai/api/tests/test_role_explanation_view.py index 1361a1c0e..fbb3298f1 100644 --- a/ansible_ai_connect/ai/api/tests/test_role_explanation_view.py +++ b/ansible_ai_connect/ai/api/tests/test_role_explanation_view.py @@ -32,12 +32,21 @@ class TestRoleExplanationView( APIVersionTestCaseBase, WisdomAppsBackendMocking, WisdomServiceAPITestCaseBase ): def test_ok(self): - payload = {} + payload = { + "files": [ + { + "path": "dummy_path", + "content": "dummy_content", + "file_type": "dummy_file_type", + } + ], + "roleName": "dummy_role", + } self.client.force_authenticate(user=self.user) r = self.client.post(self.api_version_reverse("explanations/role"), payload, format="json") self.assertEqual(r.status_code, HTTPStatus.OK) self.assertIsNotNone(r.data) - self.assertEqual(r.data, {}) + self.assertEqual(r.data["format"], "markdown") def test_unauthorized(self): payload = {} diff --git a/ansible_ai_connect/ai/api/versions/v1/ai/tests/test_role_explanation_view.py b/ansible_ai_connect/ai/api/versions/v1/ai/tests/test_role_explanation_view.py index b8d5d0ec4..b3a35cb43 100644 --- a/ansible_ai_connect/ai/api/versions/v1/ai/tests/test_role_explanation_view.py +++ b/ansible_ai_connect/ai/api/versions/v1/ai/tests/test_role_explanation_view.py @@ -18,7 +18,7 @@ from ansible_ai_connect.ai.api.versions.v1.test_base import API_VERSION -class TestRoleGenerationViewVersion1(TestRoleExplanationView): +class TestRoleExplanationViewVersion1(TestRoleExplanationView): api_version = API_VERSION def test_explanation_role_version_url(self): diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index d15c7b5ea..02bcc18e2 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -861,15 +861,16 @@ class ExplanationRole(AACSAPIView): summary="Inline code suggestions", ) def post(self, request) -> Response: - self.event.explanationId = self.validated_data["explanationId"] llm: ModelPipelineRoleExplanation = apps.get_app_config("ai").get_model_pipeline( ModelPipelineRoleExplanation ) explanation = llm.invoke( RoleExplanationParameters.init( request=request, - content=self.validated_data["content"], - explanation_id=self.validated_data["explanationId"], + files=self.validated_data["files"], + role_name=self.validated_data["roleName"], + model_id=self.validated_data["model"], + focus_on_file=self.validated_data["focusOnFile"], ) ) @@ -882,7 +883,6 @@ def post(self, request) -> Response: answer = { "content": anonymized_explanation, "format": "markdown", - "explanationId": self.validated_data["explanationId"], } return Response(