Skip to content

Commit d4204a1

Browse files
committed
fix(storage): address PR #359 review on tenancy, scans, and migrations
Restore per-embeddings-table cascade predicates, pin route collection in sparse fallback scans, forward user scope through async FTS, batch unbounded file-status document lookups, and always schedule migration compatibility checks while keeping backfill gated by LANCEDB_AUTO_MIGRATE.
1 parent 51cbcad commit d4204a1

10 files changed

Lines changed: 215 additions & 32 deletions

File tree

src/xagent/core/tools/core/RAG_tools/retrieval/search_sparse.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,18 @@ def search_sparse(
201201
)
202202

203203

204+
def _build_substring_scan_filters(
205+
*,
206+
collection: str,
207+
filters: Optional[Dict[str, Any]],
208+
) -> Dict[str, Any]:
209+
"""Build iter_batches filters with a non-overridable route collection."""
210+
scan_filters: Dict[str, Any] = dict(filters) if filters else {}
211+
scan_filters.pop("collection", None)
212+
scan_filters["collection"] = collection
213+
return scan_filters
214+
215+
204216
def _substring_fallback(
205217
*,
206218
model_tag: str,
@@ -218,9 +230,9 @@ def _substring_fallback(
218230
vector_store = get_vector_index_store()
219231
results: List[SearchResult] = []
220232

221-
query_filters: Dict[str, Any] = {"collection": collection}
222-
if filters:
223-
query_filters.update(filters)
233+
query_filters = _build_substring_scan_filters(
234+
collection=collection, filters=filters
235+
)
224236

225237
_table = None
226238
try:
@@ -413,6 +425,8 @@ async def search_sparse_async(
413425
top_k=top_k,
414426
filters=filter_expr,
415427
text_column_name="text",
428+
user_id=user_id,
429+
is_admin=is_admin,
416430
)
417431

418432
if not raw_results:
@@ -512,10 +526,9 @@ async def _substring_fallback_async(
512526
vector_store = get_vector_index_store()
513527
results: List[SearchResult] = []
514528

515-
# Build query filters
516-
query_filters: Dict[str, Any] = {"collection": collection}
517-
if filters:
518-
query_filters.update(filters)
529+
query_filters = _build_substring_scan_filters(
530+
collection=collection, filters=filters
531+
)
519532

520533
_table = None
521534
try:

src/xagent/core/tools/core/RAG_tools/storage/contracts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,8 @@ async def search_fts_async(
10431043
top_k: int,
10441044
filters: Optional[FilterExpression] = None,
10451045
text_column_name: str = "text",
1046+
user_id: Optional[int] = None,
1047+
is_admin: bool = False,
10461048
) -> List[Dict[str, Any]]:
10471049
"""Execute full-text search (async).
10481050
@@ -1052,6 +1054,8 @@ async def search_fts_async(
10521054
top_k: Number of top results to return.
10531055
filters: Optional abstract filter expression.
10541056
text_column_name: Name of text column with FTS index (default "text").
1057+
user_id: Optional user ID for multi-tenancy filtering.
1058+
is_admin: Whether the user has admin privileges.
10551059
10561060
Returns:
10571061
List of search result dictionaries with keys:
@@ -1073,6 +1077,8 @@ async def search_fts_by_model_async(
10731077
top_k: int,
10741078
filters: Optional[FilterExpression] = None,
10751079
text_column_name: str = "text",
1080+
user_id: Optional[int] = None,
1081+
is_admin: bool = False,
10761082
) -> List[Dict[str, Any]]:
10771083
"""Convenience method: search FTS by model_tag with automatic table resolution.
10781084
@@ -1085,6 +1091,8 @@ async def search_fts_by_model_async(
10851091
top_k: Number of top results to return.
10861092
filters: Optional abstract filter expression.
10871093
text_column_name: Name of text column with FTS index (default "text").
1094+
user_id: Optional user ID for multi-tenancy filtering.
1095+
is_admin: Whether the user has admin privileges.
10881096
10891097
Returns:
10901098
List of search result dictionaries with keys:
@@ -1105,6 +1113,8 @@ async def search_fts_by_model_async(
11051113
top_k=top_k,
11061114
filters=filters,
11071115
text_column_name=text_column_name,
1116+
user_id=user_id,
1117+
is_admin=is_admin,
11081118
)
11091119
finally:
11101120
_release_embeddings_table_probe(_table)

src/xagent/core/tools/core/RAG_tools/storage/lancedb_stores.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,8 @@ async def search_fts_async(
18671867
top_k: int,
18681868
filters: Optional[FilterExpression] = None,
18691869
text_column_name: str = "text",
1870+
user_id: Optional[int] = None,
1871+
is_admin: bool = False,
18701872
) -> List[Dict[str, Any]]:
18711873
"""Execute full-text search using async LanceDB FTS API.
18721874
@@ -1881,7 +1883,7 @@ async def search_fts_async(
18811883

18821884
# Build filter expression
18831885
backend_filter = self.build_filter_expression(
1884-
filters, user_id=None, is_admin=False
1886+
filters, user_id=user_id, is_admin=is_admin
18851887
)
18861888

18871889
# Build FTS search query

src/xagent/core/tools/core/RAG_tools/version_management/cascade_cleaner.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,30 +269,25 @@ def cascade_delete(
269269
if t.startswith("embeddings_")
270270
and (model_tag is None or t == f"embeddings_{model_tag}")
271271
]
272-
if target_embeddings_tables:
273-
sample_embeddings_table = target_embeddings_tables[0]
272+
for t in target_embeddings_tables:
274273
if target == "collection":
275-
embed_filter = _build_collection_filter(
274+
predicates[t] = _build_collection_filter(
276275
conn=conn,
277-
table_name=sample_embeddings_table,
276+
table_name=t,
278277
collection=collection,
279278
user_id=user_id,
280279
is_admin=is_admin,
281280
)
282281
else:
283-
embed_filter = _build_document_filter(
282+
predicates[t] = _build_document_filter(
284283
conn=conn,
285-
table_name=sample_embeddings_table,
284+
table_name=t,
286285
collection=collection,
287286
doc_id=str(doc_id),
288287
user_id=user_id,
289288
is_admin=is_admin,
290289
)
291290

292-
# Reuse one computed embeddings filter across all target embeddings_* tables.
293-
for t in target_embeddings_tables:
294-
predicates[t] = embed_filter
295-
296291
if preview_only and not confirm:
297292
return _plan_by_predicates(conn, predicates, model_tag=model_tag)
298293

src/xagent/web/services/kb_file_service.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import OrderedDict
99
from datetime import datetime, timedelta, timezone
1010
from pathlib import Path
11-
from typing import Any, Dict, List, Optional, Union
11+
from typing import Any, Dict, List, Optional, Sequence, Union
1212

1313
from sqlalchemy.orm import Session
1414

@@ -81,6 +81,33 @@ def clear(self) -> None:
8181
_file_status_cache = _FileStatusCache(ttl_seconds=5)
8282

8383

84+
def _list_document_records_for_file_ids(
85+
store: VectorIndexStore,
86+
*,
87+
file_ids: Sequence[str],
88+
user_id: int,
89+
is_admin: bool,
90+
) -> List[DocumentRecord]:
91+
"""Load document rows for many file IDs without scan-limit truncation."""
92+
unique_ids = sorted({file_id for file_id in file_ids if file_id})
93+
if not unique_ids:
94+
return []
95+
96+
records: List[DocumentRecord] = []
97+
for offset in range(0, len(unique_ids), _FILE_STATUS_BATCH_SIZE):
98+
batch = unique_ids[offset : offset + _FILE_STATUS_BATCH_SIZE]
99+
records.extend(
100+
store.list_document_records(
101+
collection_name=None,
102+
user_id=user_id,
103+
is_admin=is_admin,
104+
file_ids=batch,
105+
max_results=-1,
106+
)
107+
)
108+
return records
109+
110+
84111
def upsert_uploaded_file_record(
85112
db: Session,
86113
*,
@@ -376,11 +403,11 @@ def aggregate_uploaded_file_statuses(
376403

377404
# Cache miss - compute from database via abstraction layer
378405
store = get_vector_index_store()
379-
records = store.list_document_records(
380-
collection_name=None,
406+
records = _list_document_records_for_file_ids(
407+
store,
408+
file_ids=normalized_file_ids,
381409
user_id=user_id,
382410
is_admin=is_admin,
383-
file_ids=normalized_file_ids,
384411
)
385412

386413
doc_refs_by_file_id: Dict[str, List[tuple[str, str]]] = {

src/xagent/web/services/rag_storage_migration_service.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,10 @@
1616
class RAGStorageMigrationService:
1717
"""Run storage compatibility checks and background migrations."""
1818

19-
async def start_background_migrations(self) -> Optional[asyncio.Task[None]]:
20-
"""Create a non-blocking background task when auto-migrate is enabled."""
21-
if not self._should_start_background_task():
22-
logger.info("Skipping background LanceDB migration task (not needed)")
23-
return None
19+
async def start_background_migrations(self) -> asyncio.Task[None]:
20+
"""Schedule compatibility checks; backfill execution is env-gated."""
2421
return asyncio.create_task(self._run_migrations())
2522

26-
def _should_start_background_task(self) -> bool:
27-
"""Return whether startup should create a migration background task."""
28-
return os.getenv("LANCEDB_AUTO_MIGRATE", "true").lower() == "true"
29-
3023
async def _run_migrations(self) -> None:
3124
"""Run migration checks and backfills in background."""
3225
auto_migrate = os.getenv("LANCEDB_AUTO_MIGRATE", "true").lower() == "true"

tests/core/tools/core/RAG_tools/retrieval/test_search_sparse.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,47 @@ def test_search_sparse_fts_fallback_warning_content(self) -> None:
401401
assert "Check FTS tokenizer configuration" in warning.message
402402
assert "update LanceDB to ensure proper tokenisation" in warning.message
403403

404+
def test_substring_fallback_pins_route_collection_over_caller_override(
405+
self,
406+
) -> None:
407+
"""Route collection must not be replaced by caller filters with the same key."""
408+
from xagent.core.tools.core.RAG_tools.retrieval.search_sparse import (
409+
_substring_fallback,
410+
)
411+
412+
mock_batch = Mock()
413+
mock_batch.to_pandas.return_value = pd.DataFrame(
414+
{
415+
"doc_id": ["doc1"],
416+
"chunk_id": ["chunk1"],
417+
"text": ["needle"],
418+
"parse_hash": ["hash1"],
419+
"created_at": [pd.Timestamp.now()],
420+
"metadata": [None],
421+
}
422+
)
423+
424+
mock_vector_store = Mock()
425+
mock_vector_store.open_embeddings_table.return_value = (Mock(), "embeddings_x")
426+
mock_vector_store.iter_batches.return_value = [mock_batch]
427+
428+
with patch(
429+
"xagent.core.tools.core.RAG_tools.retrieval.search_sparse.get_vector_index_store",
430+
return_value=mock_vector_store,
431+
):
432+
_substring_fallback(
433+
model_tag="test_model",
434+
collection="route_collection",
435+
query_text="needle",
436+
top_k=5,
437+
filters={"collection": "evil_override", "doc_id": "d1"},
438+
current_warnings=[],
439+
)
440+
441+
iter_kwargs = mock_vector_store.iter_batches.call_args.kwargs
442+
assert iter_kwargs["filters"]["collection"] == "route_collection"
443+
assert iter_kwargs["filters"]["doc_id"] == "d1"
444+
404445

405446
@pytest.mark.asyncio
406447
class TestSearchSparseAsync:
@@ -450,6 +491,8 @@ async def test_search_sparse_async_success_forwards_user_scope(self) -> None:
450491
"filters"
451492
],
452493
text_column_name="text",
494+
user_id=9,
495+
is_admin=False,
453496
)
454497

455498
async def test_search_sparse_async_triggers_fallback(self) -> None:

tests/core/tools/core/RAG_tools/version_management/test_cascade_cleaner.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,3 +1177,49 @@ def _delete(_: str) -> None:
11771177
"ingestion_runs",
11781178
"documents",
11791179
]
1180+
1181+
1182+
@patch(
1183+
"xagent.core.tools.core.RAG_tools.version_management.cascade_cleaner._plan_by_predicates"
1184+
)
1185+
@patch(
1186+
"xagent.core.tools.core.RAG_tools.version_management.cascade_cleaner._build_collection_filter"
1187+
)
1188+
@patch(
1189+
"xagent.core.tools.core.RAG_tools.version_management.cascade_cleaner.get_vector_store_raw_connection"
1190+
)
1191+
def test_cascade_delete_builds_distinct_embeddings_predicates_per_table(
1192+
mock_get_conn: MagicMock,
1193+
mock_build_collection_filter: MagicMock,
1194+
mock_plan: MagicMock,
1195+
) -> None:
1196+
"""Each embeddings_* table must get its own predicate (schema may differ during migration)."""
1197+
mock_build_collection_filter.side_effect = (
1198+
lambda *, table_name, **_: f"expr:{table_name}"
1199+
)
1200+
mock_plan.return_value = {}
1201+
1202+
conn = MagicMock(spec=["table_names"])
1203+
conn.table_names.return_value = ["embeddings_m1", "embeddings_m2"]
1204+
mock_get_conn.return_value = conn
1205+
1206+
cascade_delete(
1207+
target="collection",
1208+
collection="kb",
1209+
user_id=1,
1210+
is_admin=False,
1211+
preview_only=True,
1212+
confirm=False,
1213+
)
1214+
1215+
predicates = mock_plan.call_args[0][1]
1216+
assert predicates["embeddings_m1"] == "expr:embeddings_m1"
1217+
assert predicates["embeddings_m2"] == "expr:embeddings_m2"
1218+
assert predicates["embeddings_m1"] != predicates["embeddings_m2"]
1219+
1220+
embeddings_calls = [
1221+
call.kwargs["table_name"]
1222+
for call in mock_build_collection_filter.call_args_list
1223+
if str(call.kwargs["table_name"]).startswith("embeddings_")
1224+
]
1225+
assert sorted(embeddings_calls) == ["embeddings_m1", "embeddings_m2"]

tests/web/test_kb_file_service.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,38 @@ def test_list_documents_for_user_delegates_to_vector_store_without_raw_lancedb()
4242
"source_path": "/tmp/demo.pdf",
4343
}
4444
]
45+
46+
47+
def test_aggregate_uploaded_file_statuses_chunks_file_ids_with_unbounded_limit() -> (
48+
None
49+
):
50+
"""Large file-id lists must be queried in batches with max_results=-1."""
51+
file_ids = [f"file-{index}" for index in range(250)]
52+
list_calls: list[dict[str, object]] = []
53+
54+
def _fake_list_document_records(**kwargs: object) -> list[object]:
55+
list_calls.append(dict(kwargs))
56+
return []
57+
58+
fake_store = SimpleNamespace(list_document_records=_fake_list_document_records)
59+
60+
with (
61+
patch.object(
62+
kb_file_service, "get_vector_index_store", return_value=fake_store
63+
),
64+
patch.object(kb_file_service, "load_ingestion_status", return_value=[]),
65+
patch.object(kb_file_service, "_load_indexed_doc_refs", return_value=set()),
66+
):
67+
status_map = kb_file_service.aggregate_uploaded_file_statuses(
68+
file_ids=file_ids,
69+
user_id=7,
70+
is_admin=False,
71+
use_cache=False,
72+
)
73+
74+
assert len(list_calls) == 2
75+
assert len(list_calls[0]["file_ids"]) == 200
76+
assert len(list_calls[1]["file_ids"]) == 50
77+
assert all(call["max_results"] == -1 for call in list_calls)
78+
assert all(call["user_id"] == 7 for call in list_calls)
79+
assert status_map == {file_id: "UNKNOWN" for file_id in file_ids}

0 commit comments

Comments
 (0)