Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/badges/coverage.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"schemaVersion":1,"label":"coverage","message":"80.91%","color":"green"}
{"schemaVersion":1,"label":"coverage","message":"80.98%","color":"green"}
10 changes: 7 additions & 3 deletions api/endpoints/collections.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional

from fastapi import APIRouter, Body, Depends, Path, Query, Request, Response, Security
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession

from api.helpers._accesscontroller import AccessController
from api.schemas.collections import Collection, CollectionRequest, Collections, CollectionUpdateRequest
from api.schemas.collections import Collection, CollectionRequest, Collections, CollectionUpdateRequest, CollectionVisibility
from api.sql.session import get_db_session
from api.utils.context import global_context, request_context
from api.utils.exceptions import CollectionNotFoundException
Expand Down Expand Up @@ -52,7 +54,6 @@ async def get_collection(
session=session,
collection_id=collection,
user_id=request_context.get().user_info.id,
include_public=True,
)

return JSONResponse(status_code=200, content=collections[0].model_dump())
Expand All @@ -61,6 +62,8 @@ async def get_collection(
@router.get(path=ENDPOINT__COLLECTIONS, dependencies=[Security(dependency=AccessController())], status_code=200, response_model=Collections)
async def get_collections(
request: Request,
name: str = Query(default=None, description="Filter by collection name."),
visibility: Optional[CollectionVisibility] = Query(default=None, description="Filter by collection visibility."),
offset: int = Query(default=0, ge=0, description="The offset of the collections to get."),
limit: int = Query(default=10, ge=1, le=100, description="The limit of the collections to get."),
session: AsyncSession = Depends(get_db_session),
Expand All @@ -74,7 +77,8 @@ async def get_collections(
data = await global_context.document_manager.get_collections(
session=session,
user_id=request_context.get().user_info.id,
include_public=True,
collection_name=name,
visibility=visibility,
offset=offset,
limit=limit,
)
Expand Down
2 changes: 2 additions & 0 deletions api/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ async def get_document(
@router.get(path=ENDPOINT__DOCUMENTS, dependencies=[Security(dependency=AccessController())], status_code=200)
async def get_documents(
request: Request,
name: Optional[str] = Query(default=None, description="Filter documents by name."),
collection: Optional[int] = Query(default=None, description="Filter documents by collection ID"),
limit: Optional[int] = Query(default=10, ge=1, le=100, description="The number of documents to return"),
offset: Union[int, UUID] = Query(default=0, description="The offset of the first document to return"),
Expand All @@ -152,6 +153,7 @@ async def get_documents(
data = await global_context.document_manager.get_documents(
session=session,
collection_id=collection,
document_name=name,
limit=limit,
offset=offset,
user_id=request_context.get().user_info.id,
Expand Down
21 changes: 17 additions & 4 deletions api/helpers/_documentmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,16 @@ async def update_collection(self, session: AsyncSession, user_id: int, collectio
await session.commit()

@check_dependencies(dependencies=["vector_store"])
async def get_collections(self, session: AsyncSession, user_id: int, collection_id: Optional[int] = None, include_public: bool = True, offset: int = 0, limit: int = 10) -> List[Collection]: # fmt: off
async def get_collections(
self,
session: AsyncSession,
user_id: int,
collection_id: Optional[int] = None,
collection_name: Optional[str] = None,
visibility: Optional[CollectionVisibility] = None,
offset: int = 0,
limit: int = 10,
) -> List[Collection]:
# Query basic collection data
statement = (
select(
Expand All @@ -159,10 +168,12 @@ async def get_collections(self, session: AsyncSession, user_id: int, collection_

if collection_id:
statement = statement.where(CollectionTable.id == collection_id)
if include_public:
if collection_name:
statement = statement.where(CollectionTable.name == collection_name)
if visibility is None:
statement = statement.where(or_(CollectionTable.user_id == user_id, CollectionTable.visibility == CollectionVisibility.PUBLIC))
else:
statement = statement.where(CollectionTable.user_id == user_id)
statement = statement.where(CollectionTable.user_id == user_id, CollectionTable.visibility == visibility)

result = await session.execute(statement=statement)
collections = [Collection(**row._asdict()) for row in result.all()]
Expand Down Expand Up @@ -241,7 +252,7 @@ async def create_document(
return document_id

@check_dependencies(dependencies=["vector_store"])
async def get_documents(self, session: AsyncSession, user_id: int, collection_id: Optional[int] = None, document_id: Optional[int] = None, offset: int = 0, limit: int = 10) -> List[Document]: # fmt: off
async def get_documents(self, session: AsyncSession, user_id: int, collection_id: Optional[int] = None, document_id: Optional[int] = None, document_name: Optional[str] = None, offset: int = 0, limit: int = 10) -> List[Document]: # fmt: off
statement = (
select(
DocumentTable.id,
Expand All @@ -256,6 +267,8 @@ async def get_documents(self, session: AsyncSession, user_id: int, collection_id
)
if collection_id:
statement = statement.where(DocumentTable.collection_id == collection_id)
if document_name:
statement = statement.where(DocumentTable.name == document_name)
if document_id:
statement = statement.where(DocumentTable.id == document_id)

Expand Down
2 changes: 1 addition & 1 deletion api/sql/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from http import HTTPMethod

from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String, Boolean, UniqueConstraint, func
from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String, UniqueConstraint, func
from sqlalchemy.orm import backref, declarative_base, relationship

from api.schemas.admin.roles import LimitType, PermissionType
Expand Down
166 changes: 166 additions & 0 deletions api/tests/unit/test_helpers/test_documentmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from api.helpers._documentmanager import DocumentManager
from api.schemas.collections import CollectionVisibility
from api.schemas.documents import Chunker
from api.schemas.parse import ParsedDocument, ParsedDocumentMetadata, ParsedDocumentPage
from api.utils.exceptions import CollectionNotFoundException
Expand Down Expand Up @@ -183,3 +184,168 @@ def side_effect(*args, **kwargs):
separators=["\n\n", "\n", " "],
chunk_min_size=50,
)


@pytest.mark.asyncio
async def test_get_collections_filter_by_visibility():
"""Test that get_collections correctly filters by visibility (private/public)."""

# Mock dependencies
mock_vector_store = AsyncMock()
mock_vector_store_model = AsyncMock()
mock_parser = AsyncMock()
mock_session = AsyncMock(spec=AsyncSession)

# Create DocumentManager instance
document_manager = DocumentManager(vector_store=mock_vector_store, vector_store_model=mock_vector_store_model, parser_manager=mock_parser)

# Mock database result for PRIVATE visibility filter
mock_private_result = MagicMock()
mock_private_row = MagicMock()
mock_private_row._asdict.return_value = {
"id": 1,
"name": "Private Collection",
"owner": "test_user",
"visibility": CollectionVisibility.PRIVATE,
"description": "A private collection",
"documents": 5,
"created_at": 1697000000,
"updated_at": 1697000000,
}
mock_private_result.all.return_value = [mock_private_row]

# Mock database result for PUBLIC visibility filter
mock_public_result = MagicMock()
mock_public_row = MagicMock()
mock_public_row._asdict.return_value = {
"id": 2,
"name": "Public Collection",
"owner": "test_user",
"visibility": CollectionVisibility.PUBLIC,
"description": "A public collection",
"documents": 10,
"created_at": 1697000000,
"updated_at": 1697000000,
}
mock_public_result.all.return_value = [mock_public_row]

# Test filtering by PRIVATE visibility
mock_session.execute.return_value = mock_private_result

collections = await document_manager.get_collections(session=mock_session, user_id=1, visibility=CollectionVisibility.PRIVATE, offset=0, limit=10)

assert len(collections) == 1, "Should return exactly one private collection"
assert collections[0].visibility == CollectionVisibility.PRIVATE, "Collection should be private"
assert collections[0].name == "Private Collection"
assert collections[0].id == 1

# Verify the SQL query was called with the correct filter
assert mock_session.execute.called
call_args = mock_session.execute.call_args
statement_str = str(call_args[1]["statement"])
# The visibility filter should be applied in the WHERE clause
assert "visibility" in statement_str.lower()

# Test filtering by PUBLIC visibility
mock_session.execute.return_value = mock_public_result

collections = await document_manager.get_collections(session=mock_session, user_id=1, visibility=CollectionVisibility.PUBLIC, offset=0, limit=10)

assert len(collections) == 1, "Should return exactly one public collection"
assert collections[0].visibility == CollectionVisibility.PUBLIC, "Collection should be public"
assert collections[0].name == "Public Collection"
assert collections[0].id == 2


@pytest.mark.asyncio
async def test_get_collections_filter_by_collection_name():
"""Test that get_collections correctly filters by collection name."""

# Mock dependencies
mock_vector_store = AsyncMock()
mock_vector_store_model = AsyncMock()
mock_parser = AsyncMock()
mock_session = AsyncMock(spec=AsyncSession)

# Create DocumentManager instance
document_manager = DocumentManager(vector_store=mock_vector_store, vector_store_model=mock_vector_store_model, parser_manager=mock_parser)

# Mock database result for name filter matching
mock_result_with_matches = MagicMock()
mock_row1 = MagicMock()
mock_row1._asdict.return_value = {
"id": 1,
"name": "test_collection_alpha",
"owner": "test_user",
"visibility": CollectionVisibility.PRIVATE,
"description": "First test collection",
"documents": 3,
"created_at": 1697000000,
"updated_at": 1697000000,
}
mock_row2 = MagicMock()
mock_row2._asdict.return_value = {
"id": 2,
"name": "test_collection_beta",
"owner": "test_user",
"visibility": CollectionVisibility.PRIVATE,
"description": "Second test collection",
"documents": 7,
"created_at": 1697000000,
"updated_at": 1697000000,
}
mock_result_with_matches.all.return_value = [mock_row1, mock_row2]

# Test filtering by partial name match
mock_session.execute.return_value = mock_result_with_matches

collections = await document_manager.get_collections(session=mock_session, user_id=1, collection_name="test_collection", offset=0, limit=10)

assert len(collections) == 2, "Should return two matching collections"
assert all("test_collection" in col.name for col in collections), "All collections should contain 'test_collection' in their name"
assert collections[0].name == "test_collection_alpha"
assert collections[1].name == "test_collection_beta"

# Verify the SQL query was called with the name filter
assert mock_session.execute.called
call_args = mock_session.execute.call_args
statement_str = str(call_args[1]["statement"])
# The name filter should be applied (typically with LIKE/ILIKE)
assert "name" in statement_str.lower()

# Mock database result for no matches
mock_result_empty = MagicMock()
mock_result_empty.all.return_value = []
mock_session.execute.return_value = mock_result_empty

# Test filtering by non-existent name
collections = await document_manager.get_collections(
session=mock_session, user_id=1, collection_name="nonexistent_collection_xyz", offset=0, limit=10
)

assert len(collections) == 0, "Should return empty list for non-existent collection name"

# Mock database result for exact match
mock_result_exact = MagicMock()
mock_exact_row = MagicMock()
mock_exact_row._asdict.return_value = {
"id": 5,
"name": "exact_match_collection",
"owner": "test_user",
"visibility": CollectionVisibility.PUBLIC,
"description": "Exact match collection",
"documents": 1,
"created_at": 1697000000,
"updated_at": 1697000000,
}
mock_result_exact.all.return_value = [mock_exact_row]
mock_session.execute.return_value = mock_result_exact

# Test filtering by exact name
collections = await document_manager.get_collections(
session=mock_session, user_id=1, collection_name="exact_match_collection", offset=0, limit=10
)

assert len(collections) == 1, "Should return exactly one collection for exact match"
assert collections[0].name == "exact_match_collection"
assert collections[0].id == 5