Skip to content

Commit fdeb3da

Browse files
acosferreirade1987
authored andcommitted
add validation to csrf values before response
1 parent 962d831 commit fdeb3da

File tree

4 files changed

+167
-1
lines changed

4 files changed

+167
-1
lines changed

.github/workflows/pip_compile.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
echo "ARCH=$(uname -m)" >> "$GITHUB_ENV"
4747
- name: install
4848
run: |
49+
pip install --upgrade "pip<24.1"
4950
pip install pip-tools
5051
pip-compile --quiet -o requirements-${{ env.ARCH }}.txt requirements.in
5152
pip-compile --quiet -o requirements-dev-${{ env.ARCH }}.txt -c requirements-${{ env.ARCH }}.txt requirements-dev.in

ansible_ai_connect/ai/api/tests/test_chat_view.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,3 +860,154 @@ def test(self):
860860
self.api_version_reverse("streaming_chat"), TestChatView.VALID_PAYLOAD, format="json"
861861
)
862862
self.assertEqual(response.status_code, 401)
863+
864+
865+
class TestStreamingChatViewCSRF(APIVersionTestCaseBase, WisdomServiceAPITestCaseBase):
866+
"""Test CSRF validation for the streaming_chat endpoint."""
867+
868+
api_version = "v1"
869+
870+
def setUp(self):
871+
super().setUp()
872+
(org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False)
873+
self.user.organization = org
874+
self.user.rh_internal = True
875+
self.user.save()
876+
877+
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
878+
@mock.patch(
879+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines."
880+
"HttpStreamingChatBotPipeline.get_streaming_http_response",
881+
)
882+
def test_csrf_validation_fails_without_token(self, mock_response):
883+
"""Test that CSRF validation fails when no CSRF token is provided."""
884+
885+
mock_response.return_value = TestStreamingChatView.mocked_response(
886+
json=TestChatView.VALID_PAYLOAD
887+
)
888+
889+
with patch.object(
890+
apps.get_app_config("ai"),
891+
"get_model_pipeline",
892+
Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))),
893+
):
894+
from rest_framework.test import APIClient
895+
896+
csrf_client = APIClient(enforce_csrf_checks=True)
897+
csrf_client.force_authenticate(user=self.user)
898+
899+
response = csrf_client.post(
900+
self.api_version_reverse("streaming_chat"),
901+
TestChatView.VALID_PAYLOAD,
902+
format="json",
903+
)
904+
905+
self.assertEqual(response.status_code, 403)
906+
self.assertIn("CSRF", str(response.data))
907+
908+
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
909+
@mock.patch(
910+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines."
911+
"HttpStreamingChatBotPipeline.get_streaming_http_response",
912+
)
913+
def test_csrf_validation_succeeds_with_valid_token(self, mock_response):
914+
"""Test that CSRF validation succeeds when a valid CSRF token is provided."""
915+
mock_response.return_value = TestStreamingChatView.mocked_response(
916+
json=TestChatView.VALID_PAYLOAD
917+
)
918+
919+
with patch.object(
920+
apps.get_app_config("ai"),
921+
"get_model_pipeline",
922+
Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))),
923+
):
924+
from django.middleware.csrf import _get_new_csrf_string
925+
from rest_framework.test import APIClient
926+
927+
csrf_client = APIClient(enforce_csrf_checks=True)
928+
csrf_client.force_authenticate(user=self.user)
929+
930+
token = _get_new_csrf_string()
931+
csrf_client.cookies["csrftoken"] = token
932+
933+
response = csrf_client.post(
934+
self.api_version_reverse("streaming_chat"),
935+
TestChatView.VALID_PAYLOAD,
936+
format="json",
937+
HTTP_X_CSRFTOKEN=token,
938+
)
939+
940+
self.assertNotEqual(response.status_code, 403)
941+
942+
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
943+
@mock.patch(
944+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines."
945+
"HttpStreamingChatBotPipeline.get_streaming_http_response",
946+
)
947+
def test_csrf_validation_fails_with_invalid_token(self, mock_response):
948+
"""Test that CSRF validation fails when an invalid CSRF token is provided."""
949+
mock_response.return_value = TestStreamingChatView.mocked_response(
950+
json=TestChatView.VALID_PAYLOAD
951+
)
952+
953+
with patch.object(
954+
apps.get_app_config("ai"),
955+
"get_model_pipeline",
956+
Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))),
957+
):
958+
from rest_framework.test import APIClient
959+
960+
csrf_client = APIClient(enforce_csrf_checks=True)
961+
csrf_client.force_authenticate(user=self.user)
962+
963+
response = csrf_client.post(
964+
self.api_version_reverse("streaming_chat"),
965+
TestChatView.VALID_PAYLOAD,
966+
format="json",
967+
HTTP_X_CSRFTOKEN="invalid_token_12345",
968+
)
969+
970+
self.assertEqual(response.status_code, 403)
971+
self.assertIn("CSRF", str(response.data))
972+
973+
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
974+
@mock.patch(
975+
"ansible_ai_connect.ai.api.model_pipelines.http.pipelines."
976+
"HttpStreamingChatBotPipeline.get_streaming_http_response",
977+
)
978+
def test_csrf_validation_fails_with_valid_session_and_invalid_token(self, mock_response):
979+
"""Test that CSRF validation fails when a valid session present but the token is invalid.
980+
981+
This reproduces the real-world scenario described in the PR: sessionid is valid (user
982+
has a session) but the X-CSRFToken value does not match the csrftoken cookie.
983+
"""
984+
mock_response.return_value = TestStreamingChatView.mocked_response(
985+
json=TestChatView.VALID_PAYLOAD
986+
)
987+
988+
with patch.object(
989+
apps.get_app_config("ai"),
990+
"get_model_pipeline",
991+
Mock(return_value=HttpStreamingChatBotPipeline(mock_pipeline_config("http"))),
992+
):
993+
from rest_framework.test import APIClient
994+
995+
csrf_client = APIClient(enforce_csrf_checks=True)
996+
csrf_client.force_authenticate(user=self.user)
997+
998+
# Simulate a valid session cookie being present
999+
csrf_client.cookies["sessionid"] = "validsession123"
1000+
1001+
# Simulate a csrftoken cookie that does NOT match the header token
1002+
csrf_client.cookies["csrftoken"] = "expected_token_abc"
1003+
1004+
# Send a mismatched/invalid header token
1005+
response = csrf_client.post(
1006+
self.api_version_reverse("streaming_chat"),
1007+
TestChatView.VALID_PAYLOAD,
1008+
format="json",
1009+
HTTP_X_CSRFTOKEN="invalid_token_12345",
1010+
)
1011+
1012+
self.assertEqual(response.status_code, 403)
1013+
self.assertIn("CSRF", str(response.data))

ansible_ai_connect/ai/api/views.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from ansible_ai_connect.users.models import User
103103

104+
from ...main.middleware import check_csrf
104105
from ...main.permissions import IsAAPUser, IsRHInternalUser, IsTestUser
105106
from ...users.throttling import EndpointRateThrottle
106107
from ..feature_flags import FeatureFlags
@@ -1202,6 +1203,7 @@ def __init__(self):
12021203
summary="Streaming chat request",
12031204
)
12041205
def post(self, request) -> StreamingHttpResponse:
1206+
check_csrf(request)
12051207
if not self.chatbot_enabled:
12061208
raise ChatbotNotEnabledException()
12071209

ansible_ai_connect/main/middleware.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from ansible_anonymizer import anonymizer
2121
from django.conf import settings
22-
from rest_framework.exceptions import ErrorDetail
22+
from django.middleware.csrf import CsrfViewMiddleware
23+
from rest_framework.exceptions import ErrorDetail, PermissionDenied
2324
from segment import analytics
2425
from social_django.middleware import SocialAuthExceptionMiddleware
2526

@@ -40,6 +41,17 @@
4041
version_info = VersionInfo()
4142

4243

44+
def check_csrf(request):
45+
django_request = getattr(request, "_request", request)
46+
47+
reason = CsrfViewMiddleware(get_response=lambda r: None).process_view(
48+
django_request, None, (), {}
49+
)
50+
51+
if reason:
52+
raise PermissionDenied(detail="CSRF validation failed")
53+
54+
4355
def on_segment_error(error, _):
4456
logger.error(f"An error occurred in sending data to Segment: {error}")
4557

0 commit comments

Comments
 (0)