Skip to content

Commit

Permalink
Merge pull request #84 from tjmlabs/filter
Browse files Browse the repository at this point in the history
add filter endpoint
  • Loading branch information
Jonathan-Adly authored Nov 14, 2024
2 parents 93a9025 + a5da505 commit a9729ae
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 7 deletions.
6 changes: 5 additions & 1 deletion web/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from django_stubs_ext.db.models import TypedModelMeta
from pdf2image import convert_from_bytes
from pgvector.django import HalfVectorField
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
wait_fixed)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,6 +112,9 @@ class Document(models.Model):
def __str__(self) -> str:
return self.name

async def page_count(self) -> int:
return await self.pages.acount()

class Meta(TypedModelMeta):
constraints = [
models.UniqueConstraint(
Expand Down
217 changes: 211 additions & 6 deletions web/api/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from accounts.models import CustomUser
from api.middleware import add_slash
from api.models import Collection, Document, Page, PageEmbedding
from api.views import Bearer, QueryFilter, QueryIn, filter_query, router
from api.views import (Bearer, QueryFilter, QueryIn, filter_collections,
filter_documents, filter_query, router)
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.files.uploadedfile import SimpleUploadedFile
from ninja.testing import TestAsyncClient
Expand Down Expand Up @@ -71,6 +72,7 @@ async def document(user, collection):
name="Test Document Fixture",
collection=collection,
url="https://www.example.com",
metadata={"important": True},
)
# create a page for the document
page = await Page.objects.acreate(
Expand Down Expand Up @@ -670,7 +672,7 @@ async def test_get_document_by_name(async_client, user, collection, document):
assert response.json() == {
"id": 1,
"name": "Test Document Fixture",
"metadata": {},
"metadata": {"important": True},
"url": "https://www.example.com",
"num_pages": 1,
"collection_name": "Test Collection Fixture",
Expand Down Expand Up @@ -718,7 +720,7 @@ async def test_get_documents(async_client, user, collection, document):
{
"id": 1,
"name": "Test Document Fixture",
"metadata": {},
"metadata": {"important": True},
"url": "https://www.example.com",
"num_pages": 1,
"collection_name": "Test Collection Fixture",
Expand Down Expand Up @@ -748,7 +750,7 @@ async def test_patch_document_no_embed(async_client, user, collection, document)
assert response.json() == {
"id": 1,
"name": "Test Document Update",
"metadata": {},
"metadata": {"important": True},
"url": "https://www.example.com",
"num_pages": 1,
"collection_name": "Test Collection Fixture",
Expand All @@ -765,7 +767,7 @@ async def test_patch_document_no_embed(async_client, user, collection, document)
"id": 1,
"name": "Test Document Update",
"url": "https://www.example.com",
"metadata": {},
"metadata": {"important": True},
"num_pages": 1,
"collection_name": "Test Collection Fixture",
"pages": None,
Expand Down Expand Up @@ -849,7 +851,7 @@ async def test_patch_document_url(async_client, user, collection, document):
response_data = response.json()
assert response_data["id"] == 1
assert response_data["name"] == "Test Document Update"
assert response_data["metadata"] == {}
assert response_data["metadata"] == {'important': True}
assert response_data["url"] == "https://www.w3schools.com/w3css/img_lights.jpg"
assert response_data["num_pages"] == 1
assert response_data["collection_name"] == "Test Collection Fixture"
Expand Down Expand Up @@ -941,6 +943,55 @@ async def test_search_documents(async_client, user, collection, document):
assert response.json() != []


async def test_filter_collections(async_client, user, collection, document):
response = await async_client.post(
"/filter/",
json={"on": "collection", "key": "key", "value": "value"},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 200
assert response.json() != []


async def test_filter_documents(async_client, user, collection, document):
response = await async_client.post(
"/filter/",
json={"on": "document", "key": "important", "value": True},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 200
assert response.json() != []


async def test_filter_documents_expand(async_client, user, collection, document):
response = await async_client.post(
"/filter/?expand=pages",
json={"on": "document", "key": "important", "value": True},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 200
assert response.json() == [
{
"id": 1,
"name": "Test Document Fixture",
"metadata": {"important": True},
"url": "https://www.example.com",
"num_pages": 1,
"collection_name": "Test Collection Fixture",
"pages": [
{
"document_name": "Test Document Fixture",
"img_base64": "base64_string",
"page_number": 1,
}
],
}
]


""" Search Filtering tests """


Expand Down Expand Up @@ -1026,6 +1077,52 @@ async def test_search_filter_key_equals(async_client, user, search_filter_fixtur
assert page.document == document_1


async def test_filter_documents_key_equals(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture

# Create a QueryFilter object
query_filter = QueryFilter(
on="document", key="important", value=True, lookup="key_lookup"
)

# Call the query_filter function
result = await filter_documents(query_filter, user)

# Check if the result is a QuerySet
assert isinstance(result, Document.objects.all().__class__)

# Check if only one document is returned (the one with important=True)
count = await result.acount()
assert count == 1
# get the document from the queryset
document = await result.afirst()
# check if the document is the correct document
assert document == document_1


async def test_filter_collections_key_equals(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture

# Create a QueryFilter object
query_filter = QueryFilter(
on="collection", key="type", value="AI papers", lookup="key_lookup"
)

# Call the query_filter function
result = await filter_collections(query_filter, user)

# Check if the result is a QuerySet
assert isinstance(result, Collection.objects.all().__class__)

# Check if only one collection is returned (the one with type=AI papers)
count = await result.acount()
assert count == 1
# get the collection from the queryset
col = await result.afirst()
# check if the collection is the correct collection
assert col == collection


async def test_filter_query_document_contains(search_filter_fixture, user):
collection, document_1, _ = search_filter_fixture
query_in = QueryIn(
Expand All @@ -1042,6 +1139,32 @@ async def test_filter_query_document_contains(search_filter_fixture, user):
assert page.document == document_1


async def test_filter_documents_contains(search_filter_fixture, user):
collection, document_1, _ = search_filter_fixture
query_filter = QueryFilter(
on="document", key="important", value=True, lookup="contains"
)

result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 1
document = await result.afirst()
assert document == document_1


async def test_filter_collections_contains(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture
query_filter = QueryFilter(
on="collection", key="type", value="AI papers", lookup="contains"
)

result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 1
col = await result.afirst()
assert col == collection


async def test_filter_query_collection_metadata(search_filter_fixture, user):
collection, _, _ = search_filter_fixture
query_in = QueryIn(
Expand Down Expand Up @@ -1089,6 +1212,34 @@ async def test_filter_query_has_key(search_filter_fixture, user):
assert count == 0


async def test_filter_documents_has_key(search_filter_fixture, user):
collection, _, _ = search_filter_fixture
query_filter = QueryFilter(on="document", key="important", lookup="has_key")
result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 2

# test if key is not there
query_filter = QueryFilter(on="document", key="not_there", lookup="has_key")
result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 0


async def test_filter_collections_has_key(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture
query_filter = QueryFilter(on="collection", key="type", lookup="has_key")
result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 1

# test if key is not there
query_filter = QueryFilter(on="collection", key="not_there", lookup="has_key")
result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 0


async def test_filter_query_has_keys(search_filter_fixture, user):
collection, _, _ = search_filter_fixture
query_in = QueryIn(
Expand All @@ -1111,6 +1262,34 @@ async def test_filter_query_has_keys(search_filter_fixture, user):
assert count == 0


async def test_filter_documents_has_keys(search_filter_fixture, user):
collection, _, _ = search_filter_fixture
query_filter = QueryFilter(on="document", key=["important"], lookup="has_keys")
result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 2

# test if key is not there
query_filter = QueryFilter(on="document", key=["not_there"], lookup="has_keys")
result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 0


async def test_filter_collections_has_keys(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture
query_filter = QueryFilter(on="collection", key=["type"], lookup="has_keys")
result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 1

# test if key is not there
query_filter = QueryFilter(on="collection", key=["not_there"], lookup="has_keys")
result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 0


async def test_filter_query_document_contained_by(search_filter_fixture, user):
collection, document_1, _ = search_filter_fixture
query_in = QueryIn(
Expand All @@ -1127,6 +1306,32 @@ async def test_filter_query_document_contained_by(search_filter_fixture, user):
assert page.document == document_1


async def test_filter_documents_contained_by(search_filter_fixture, user):
collection, document_1, _ = search_filter_fixture
query_filter = QueryFilter(
on="document", key="important", value=True, lookup="contained_by"
)
result = await filter_documents(query_filter, user)
count = await result.acount()
assert count == 1
document = await result.afirst()
assert document == document_1


async def test_filter_collections_contained_by(
async_client, user, search_filter_fixture
):
collection, document_1, document_2 = search_filter_fixture
query_filter = QueryFilter(
on="collection", key="type", value="AI papers", lookup="contained_by"
)
result = await filter_collections(query_filter, user)
count = await result.acount()
assert count == 1
col = await result.afirst()
assert col == collection


@pytest.mark.parametrize(
"on, key, value, lookup, should_raise",
[
Expand Down
Loading

0 comments on commit a9729ae

Please sign in to comment.