Skip to content

Commit 5399e7d

Browse files
authored
397 cross project data leak (#427)
* Refactored filtering logic for retrieving chunks for a given course and conversation * adding fixes to retrieve chunks for a particular conversation id and course prevent cross project data leak * adding the missed must condition
1 parent 05e2c5d commit 5399e7d

File tree

2 files changed

+78
-49
lines changed

2 files changed

+78
-49
lines changed

ai_ta_backend/database/vector.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,14 @@ def process_clinicaltrials_results(results):
379379
def _create_search_filter(self, course_name: str, doc_groups: List[str], admin_disabled_doc_groups: List[str],
380380
public_doc_groups: List[dict]) -> models.Filter:
381381
"""
382-
Create search conditions for the vector search.
382+
Create search conditions for regular searches (no conversation filtering).
383+
Excludes chunks with any conversation_id.
384+
385+
Args:
386+
course_name: The course/project name to filter by
387+
doc_groups: List of document groups to include
388+
admin_disabled_doc_groups: List of document groups to exclude
389+
public_doc_groups: List of public document groups that can be accessed
383390
"""
384391

385392
must_conditions = []
@@ -390,6 +397,12 @@ def _create_search_filter(self, course_name: str, doc_groups: List[str], admin_d
390397
if admin_disabled_doc_groups:
391398
must_not_conditions.append(FieldCondition(key='doc_groups', match=MatchAny(any=admin_disabled_doc_groups)))
392399

400+
# For regular searches, only include chunks that have NO conversation_id field
401+
# This ensures we only get regular course chunks and prevents cross-conversation leaks
402+
must_conditions.append(models.IsEmptyCondition(
403+
is_empty={"key": "conversation_id"} # Only include chunks where conversation_id field is empty/missing
404+
))
405+
393406
# Handle public_doc_groups
394407
if public_doc_groups:
395408
for public_doc_group in public_doc_groups:
@@ -411,12 +424,31 @@ def _create_search_filter(self, course_name: str, doc_groups: List[str], admin_d
411424
# Add the own_course_condition to should_conditions
412425
should_conditions.append(own_course_condition)
413426

414-
# Construct the final filter
415-
vector_search_filter = models.Filter(should=should_conditions, must_not=must_not_conditions)
427+
# Construct the final filter (apply must to enforce no conversation_id)
428+
vector_search_filter = models.Filter(must=must_conditions, should=should_conditions, must_not=must_not_conditions)
416429

417430
print(f"Vector search filter: {vector_search_filter}")
418431
return vector_search_filter
419432

433+
def _create_conversation_search_filter(self, conversation_id: str) -> models.Filter:
434+
"""
435+
Create search conditions for conversation-specific chunks.
436+
Only includes chunks with the specified conversation_id.
437+
438+
Args:
439+
conversation_id: The specific conversation ID to filter by
440+
"""
441+
442+
must_conditions = []
443+
444+
# Conversation ID filter - this is sufficient since conversation_id is unique
445+
must_conditions.append(FieldCondition(
446+
key='conversation_id',
447+
match=MatchValue(value=conversation_id)
448+
))
449+
450+
return models.Filter(must=must_conditions)
451+
420452
def delete_data(self, collection_name: str, key: str, value: str):
421453
"""
422454
Delete data from the vector database.
@@ -460,14 +492,35 @@ def _create_conversation_filter(self, conversation_id: str) -> models.Filter:
460492
]
461493
)
462494

463-
def _combine_filters(self, search_filter: models.Filter, conversation_filter: models.Filter) -> models.Filter:
495+
def _combine_filters(self, search_filter: models.Filter, conversation_filter: models.Filter = None) -> models.Filter:
464496
"""
465-
Combine search filter with conversation filter using OR logic.
466-
This allows searching both regular course documents AND conversation-specific documents.
497+
Combine search filter with conversation filter using AND logic.
498+
499+
Args:
500+
search_filter: The main search filter (course_name, doc_groups, etc.)
501+
conversation_filter: The conversation-specific filter (optional)
502+
503+
Returns:
504+
Combined filter using AND logic for security
467505
"""
468-
return models.Filter(
469-
should=[search_filter, conversation_filter]
470-
)
506+
combined_conditions = []
507+
508+
# Add conditions from search filter
509+
if search_filter.must:
510+
combined_conditions.extend(search_filter.must)
511+
512+
# Add conditions from conversation filter if provided
513+
if conversation_filter and conversation_filter.must:
514+
combined_conditions.extend(conversation_filter.must)
515+
516+
# Combine must_not conditions
517+
combined_must_not = []
518+
if search_filter.must_not:
519+
combined_must_not.extend(search_filter.must_not)
520+
if conversation_filter and conversation_filter.must_not:
521+
combined_must_not.extend(conversation_filter.must_not)
522+
523+
return models.Filter(must=combined_conditions, must_not=combined_must_not)
471524

472525
def vector_search_with_filter(self, search_query, course_name, doc_groups: List[str],
473526
user_query_embedding, top_n, disabled_doc_groups: List[str],

ai_ta_backend/service/retrieval_service.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# from ai_ta_backend.service.nomic_service import NomicService
3434
from ai_ta_backend.service.posthog_service import PosthogService
3535
from ai_ta_backend.service.sentry_service import SentryService
36-
36+
from qdrant_client.http import models
3737

3838
# Qwen query instruction for Illinois Chat retrieval.
3939
# Docs are embedded without instruction during ingest; only queries get this prefix.
@@ -543,10 +543,19 @@ def vector_search(self,
543543
else:
544544
# Handle conversation filtering for normal courses
545545
if conversation_id:
546-
conversation_filter = self._create_conversation_filter(conversation_id)
547-
combined_filter = self._combine_filters(
548-
self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups),
549-
conversation_filter
546+
# For chat conversations: get BOTH regular course documents AND conversation-specific documents
547+
548+
# Get regular course documents (course_name + no conversation_id)
549+
regular_filter = self.vdb._create_search_filter(
550+
course_name, doc_groups, disabled_doc_groups, public_doc_groups
551+
)
552+
553+
# Get conversation-specific documents (this conversation_id)
554+
chat_filter = self.vdb._create_conversation_search_filter(conversation_id)
555+
556+
# Combine both filters with OR logic to get both types of documents
557+
combined_filter = models.Filter(
558+
should=[regular_filter, chat_filter]
550559
)
551560

552561
search_results = self.vdb.vector_search_with_filter(
@@ -821,41 +830,7 @@ def _create_conversation_filter(self, conversation_id: str):
821830
]
822831
)
823832

824-
def _combine_filters(self, filter1, filter2):
825-
"""Combine two Qdrant filters with AND logic."""
826-
from qdrant_client.http import models
827-
combined_conditions = []
828-
829-
# Add conditions from first filter
830-
if filter1.must:
831-
combined_conditions.extend(filter1.must)
832-
833-
# Add conditions from second filter
834-
if filter2.must:
835-
combined_conditions.extend(filter2.must)
836-
837-
return models.Filter(must=combined_conditions)
838-
839-
def _create_search_filter(self, course_name, doc_groups, disabled_doc_groups, public_doc_groups):
840-
"""
841-
Create a Qdrant filter for course, doc groups, and public/disabled doc groups.
842-
"""
843-
from qdrant_client.http import models
844-
845-
must_conditions = []
846-
if course_name:
847-
must_conditions.append(models.FieldCondition(
848-
key="course_name",
849-
match=models.MatchValue(value=course_name)
850-
))
851-
if doc_groups and 'All Documents' not in doc_groups:
852-
must_conditions.append(models.FieldCondition(
853-
key="doc_groups",
854-
match=models.MatchAny(any=doc_groups) # Fixed: use 'any' parameter instead of 'value'
855-
))
856-
# Optionally, you can add filters for disabled/public doc groups if needed
857-
# (depends on your schema and use case)
858-
return models.Filter(must=must_conditions)
833+
# Removed duplicate methods - now using consolidated methods from VectorDatabase
859834

860835
# Add all these methods at the end of the RetrievalService class
861836

@@ -1116,6 +1091,7 @@ def _store_conversation_content(self, text_content: str, conversation_id: str,
11161091
documents.append(doc)
11171092

11181093
except Exception as e:
1094+
print("Error in _store_conversation_content: ", e)
11191095
pass
11201096
continue
11211097

0 commit comments

Comments
 (0)