Skip to content

Commit a9729ae

Browse files
Merge pull request #84 from tjmlabs/filter
add filter endpoint
2 parents 93a9025 + a5da505 commit a9729ae

File tree

3 files changed

+352
-7
lines changed

3 files changed

+352
-7
lines changed

web/api/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from django_stubs_ext.db.models import TypedModelMeta
2121
from pdf2image import convert_from_bytes
2222
from pgvector.django import HalfVectorField
23-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
23+
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
24+
wait_fixed)
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -111,6 +112,9 @@ class Document(models.Model):
111112
def __str__(self) -> str:
112113
return self.name
113114

115+
async def page_count(self) -> int:
116+
return await self.pages.acount()
117+
114118
class Meta(TypedModelMeta):
115119
constraints = [
116120
models.UniqueConstraint(

web/api/tests/tests.py

Lines changed: 211 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from accounts.models import CustomUser
77
from api.middleware import add_slash
88
from api.models import Collection, Document, Page, PageEmbedding
9-
from api.views import Bearer, QueryFilter, QueryIn, filter_query, router
9+
from api.views import (Bearer, QueryFilter, QueryIn, filter_collections,
10+
filter_documents, filter_query, router)
1011
from django.core.exceptions import ValidationError as DjangoValidationError
1112
from django.core.files.uploadedfile import SimpleUploadedFile
1213
from ninja.testing import TestAsyncClient
@@ -71,6 +72,7 @@ async def document(user, collection):
7172
name="Test Document Fixture",
7273
collection=collection,
7374
url="https://www.example.com",
75+
metadata={"important": True},
7476
)
7577
# create a page for the document
7678
page = await Page.objects.acreate(
@@ -670,7 +672,7 @@ async def test_get_document_by_name(async_client, user, collection, document):
670672
assert response.json() == {
671673
"id": 1,
672674
"name": "Test Document Fixture",
673-
"metadata": {},
675+
"metadata": {"important": True},
674676
"url": "https://www.example.com",
675677
"num_pages": 1,
676678
"collection_name": "Test Collection Fixture",
@@ -718,7 +720,7 @@ async def test_get_documents(async_client, user, collection, document):
718720
{
719721
"id": 1,
720722
"name": "Test Document Fixture",
721-
"metadata": {},
723+
"metadata": {"important": True},
722724
"url": "https://www.example.com",
723725
"num_pages": 1,
724726
"collection_name": "Test Collection Fixture",
@@ -748,7 +750,7 @@ async def test_patch_document_no_embed(async_client, user, collection, document)
748750
assert response.json() == {
749751
"id": 1,
750752
"name": "Test Document Update",
751-
"metadata": {},
753+
"metadata": {"important": True},
752754
"url": "https://www.example.com",
753755
"num_pages": 1,
754756
"collection_name": "Test Collection Fixture",
@@ -765,7 +767,7 @@ async def test_patch_document_no_embed(async_client, user, collection, document)
765767
"id": 1,
766768
"name": "Test Document Update",
767769
"url": "https://www.example.com",
768-
"metadata": {},
770+
"metadata": {"important": True},
769771
"num_pages": 1,
770772
"collection_name": "Test Collection Fixture",
771773
"pages": None,
@@ -849,7 +851,7 @@ async def test_patch_document_url(async_client, user, collection, document):
849851
response_data = response.json()
850852
assert response_data["id"] == 1
851853
assert response_data["name"] == "Test Document Update"
852-
assert response_data["metadata"] == {}
854+
assert response_data["metadata"] == {'important': True}
853855
assert response_data["url"] == "https://www.w3schools.com/w3css/img_lights.jpg"
854856
assert response_data["num_pages"] == 1
855857
assert response_data["collection_name"] == "Test Collection Fixture"
@@ -941,6 +943,55 @@ async def test_search_documents(async_client, user, collection, document):
941943
assert response.json() != []
942944

943945

946+
async def test_filter_collections(async_client, user, collection, document):
947+
response = await async_client.post(
948+
"/filter/",
949+
json={"on": "collection", "key": "key", "value": "value"},
950+
headers={"Authorization": f"Bearer {user.token}"},
951+
)
952+
953+
assert response.status_code == 200
954+
assert response.json() != []
955+
956+
957+
async def test_filter_documents(async_client, user, collection, document):
958+
response = await async_client.post(
959+
"/filter/",
960+
json={"on": "document", "key": "important", "value": True},
961+
headers={"Authorization": f"Bearer {user.token}"},
962+
)
963+
964+
assert response.status_code == 200
965+
assert response.json() != []
966+
967+
968+
async def test_filter_documents_expand(async_client, user, collection, document):
969+
response = await async_client.post(
970+
"/filter/?expand=pages",
971+
json={"on": "document", "key": "important", "value": True},
972+
headers={"Authorization": f"Bearer {user.token}"},
973+
)
974+
975+
assert response.status_code == 200
976+
assert response.json() == [
977+
{
978+
"id": 1,
979+
"name": "Test Document Fixture",
980+
"metadata": {"important": True},
981+
"url": "https://www.example.com",
982+
"num_pages": 1,
983+
"collection_name": "Test Collection Fixture",
984+
"pages": [
985+
{
986+
"document_name": "Test Document Fixture",
987+
"img_base64": "base64_string",
988+
"page_number": 1,
989+
}
990+
],
991+
}
992+
]
993+
994+
944995
""" Search Filtering tests """
945996

946997

@@ -1026,6 +1077,52 @@ async def test_search_filter_key_equals(async_client, user, search_filter_fixtur
10261077
assert page.document == document_1
10271078

10281079

1080+
async def test_filter_documents_key_equals(async_client, user, search_filter_fixture):
1081+
collection, document_1, document_2 = search_filter_fixture
1082+
1083+
# Create a QueryFilter object
1084+
query_filter = QueryFilter(
1085+
on="document", key="important", value=True, lookup="key_lookup"
1086+
)
1087+
1088+
# Call the query_filter function
1089+
result = await filter_documents(query_filter, user)
1090+
1091+
# Check if the result is a QuerySet
1092+
assert isinstance(result, Document.objects.all().__class__)
1093+
1094+
# Check if only one document is returned (the one with important=True)
1095+
count = await result.acount()
1096+
assert count == 1
1097+
# get the document from the queryset
1098+
document = await result.afirst()
1099+
# check if the document is the correct document
1100+
assert document == document_1
1101+
1102+
1103+
async def test_filter_collections_key_equals(async_client, user, search_filter_fixture):
1104+
collection, document_1, document_2 = search_filter_fixture
1105+
1106+
# Create a QueryFilter object
1107+
query_filter = QueryFilter(
1108+
on="collection", key="type", value="AI papers", lookup="key_lookup"
1109+
)
1110+
1111+
# Call the query_filter function
1112+
result = await filter_collections(query_filter, user)
1113+
1114+
# Check if the result is a QuerySet
1115+
assert isinstance(result, Collection.objects.all().__class__)
1116+
1117+
# Check if only one collection is returned (the one with type=AI papers)
1118+
count = await result.acount()
1119+
assert count == 1
1120+
# get the collection from the queryset
1121+
col = await result.afirst()
1122+
# check if the collection is the correct collection
1123+
assert col == collection
1124+
1125+
10291126
async def test_filter_query_document_contains(search_filter_fixture, user):
10301127
collection, document_1, _ = search_filter_fixture
10311128
query_in = QueryIn(
@@ -1042,6 +1139,32 @@ async def test_filter_query_document_contains(search_filter_fixture, user):
10421139
assert page.document == document_1
10431140

10441141

1142+
async def test_filter_documents_contains(search_filter_fixture, user):
1143+
collection, document_1, _ = search_filter_fixture
1144+
query_filter = QueryFilter(
1145+
on="document", key="important", value=True, lookup="contains"
1146+
)
1147+
1148+
result = await filter_documents(query_filter, user)
1149+
count = await result.acount()
1150+
assert count == 1
1151+
document = await result.afirst()
1152+
assert document == document_1
1153+
1154+
1155+
async def test_filter_collections_contains(async_client, user, search_filter_fixture):
1156+
collection, document_1, document_2 = search_filter_fixture
1157+
query_filter = QueryFilter(
1158+
on="collection", key="type", value="AI papers", lookup="contains"
1159+
)
1160+
1161+
result = await filter_collections(query_filter, user)
1162+
count = await result.acount()
1163+
assert count == 1
1164+
col = await result.afirst()
1165+
assert col == collection
1166+
1167+
10451168
async def test_filter_query_collection_metadata(search_filter_fixture, user):
10461169
collection, _, _ = search_filter_fixture
10471170
query_in = QueryIn(
@@ -1089,6 +1212,34 @@ async def test_filter_query_has_key(search_filter_fixture, user):
10891212
assert count == 0
10901213

10911214

1215+
async def test_filter_documents_has_key(search_filter_fixture, user):
1216+
collection, _, _ = search_filter_fixture
1217+
query_filter = QueryFilter(on="document", key="important", lookup="has_key")
1218+
result = await filter_documents(query_filter, user)
1219+
count = await result.acount()
1220+
assert count == 2
1221+
1222+
# test if key is not there
1223+
query_filter = QueryFilter(on="document", key="not_there", lookup="has_key")
1224+
result = await filter_documents(query_filter, user)
1225+
count = await result.acount()
1226+
assert count == 0
1227+
1228+
1229+
async def test_filter_collections_has_key(async_client, user, search_filter_fixture):
1230+
collection, document_1, document_2 = search_filter_fixture
1231+
query_filter = QueryFilter(on="collection", key="type", lookup="has_key")
1232+
result = await filter_collections(query_filter, user)
1233+
count = await result.acount()
1234+
assert count == 1
1235+
1236+
# test if key is not there
1237+
query_filter = QueryFilter(on="collection", key="not_there", lookup="has_key")
1238+
result = await filter_collections(query_filter, user)
1239+
count = await result.acount()
1240+
assert count == 0
1241+
1242+
10921243
async def test_filter_query_has_keys(search_filter_fixture, user):
10931244
collection, _, _ = search_filter_fixture
10941245
query_in = QueryIn(
@@ -1111,6 +1262,34 @@ async def test_filter_query_has_keys(search_filter_fixture, user):
11111262
assert count == 0
11121263

11131264

1265+
async def test_filter_documents_has_keys(search_filter_fixture, user):
1266+
collection, _, _ = search_filter_fixture
1267+
query_filter = QueryFilter(on="document", key=["important"], lookup="has_keys")
1268+
result = await filter_documents(query_filter, user)
1269+
count = await result.acount()
1270+
assert count == 2
1271+
1272+
# test if key is not there
1273+
query_filter = QueryFilter(on="document", key=["not_there"], lookup="has_keys")
1274+
result = await filter_documents(query_filter, user)
1275+
count = await result.acount()
1276+
assert count == 0
1277+
1278+
1279+
async def test_filter_collections_has_keys(async_client, user, search_filter_fixture):
1280+
collection, document_1, document_2 = search_filter_fixture
1281+
query_filter = QueryFilter(on="collection", key=["type"], lookup="has_keys")
1282+
result = await filter_collections(query_filter, user)
1283+
count = await result.acount()
1284+
assert count == 1
1285+
1286+
# test if key is not there
1287+
query_filter = QueryFilter(on="collection", key=["not_there"], lookup="has_keys")
1288+
result = await filter_collections(query_filter, user)
1289+
count = await result.acount()
1290+
assert count == 0
1291+
1292+
11141293
async def test_filter_query_document_contained_by(search_filter_fixture, user):
11151294
collection, document_1, _ = search_filter_fixture
11161295
query_in = QueryIn(
@@ -1127,6 +1306,32 @@ async def test_filter_query_document_contained_by(search_filter_fixture, user):
11271306
assert page.document == document_1
11281307

11291308

1309+
async def test_filter_documents_contained_by(search_filter_fixture, user):
1310+
collection, document_1, _ = search_filter_fixture
1311+
query_filter = QueryFilter(
1312+
on="document", key="important", value=True, lookup="contained_by"
1313+
)
1314+
result = await filter_documents(query_filter, user)
1315+
count = await result.acount()
1316+
assert count == 1
1317+
document = await result.afirst()
1318+
assert document == document_1
1319+
1320+
1321+
async def test_filter_collections_contained_by(
1322+
async_client, user, search_filter_fixture
1323+
):
1324+
collection, document_1, document_2 = search_filter_fixture
1325+
query_filter = QueryFilter(
1326+
on="collection", key="type", value="AI papers", lookup="contained_by"
1327+
)
1328+
result = await filter_collections(query_filter, user)
1329+
count = await result.acount()
1330+
assert count == 1
1331+
col = await result.afirst()
1332+
assert col == collection
1333+
1334+
11301335
@pytest.mark.parametrize(
11311336
"on, key, value, lookup, should_raise",
11321337
[

0 commit comments

Comments
 (0)