Skip to content

Commit 00a8d14

Browse files
authored
Add thresholding for Semantic indexes
1 parent 9b722e6 commit 00a8d14

File tree

9 files changed

+251
-14
lines changed

9 files changed

+251
-14
lines changed

doc/source/semantic_indexes.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ Can be queried using :class:`~neomodel.semantic_filters.FulltextFilter`. Such as
5050

5151
Where the result will be a list of length topk of nodes with the form (ProductNode, score).
5252

53+
If you would like to filter the nodes based on a threshold (ie. nodes similarity >= threshold), you can use the following::
54+
55+
from neomodel.semantic_filters import FulltextFilter
56+
result = Product.nodes.filter(
57+
fulltext_filter=FulltextFilter(
58+
topk=10,
59+
fulltext_attribute_name="description",
60+
query_string="product",
61+
threshold=0.08)).all()
62+
63+
Only nodes above threshold = 0.08 will be returned.
64+
5365
The :class:`~neomodel.semantic_filters.FulltextFilter` can be used in conjunction with the normal filter types.
5466

5567
.. attention::
@@ -103,10 +115,26 @@ The following node vector index property::
103115
Can be queried using :class:`~neomodel.semantic_filters.VectorFilter`. Such as::
104116

105117
from neomodel.semantic_filters import VectorFilter
106-
result = someNode.nodes.filter(vector_filter=VectorFilter(topk=3, vector_attribute_name="vector")).all()
118+
result = someNode.nodes.filter(
119+
vector_filter=VectorFilter(
120+
topk=3,
121+
vector_attribute_name="vector",
122+
candidate_vector=[0.25, 0.25])).all()
107123

108124
Where the result will be a list of length topk of tuples having the form (someNode, score).
109125

126+
If you would like to filter the nodes based on a threshold (ie. nodes similarity >= threshold), you can use the following::
127+
128+
from neomodel.semantic_filters import VectorFilter
129+
result = someNode.nodes.filter(
130+
vector_filter=VectorFilter(
131+
topk=3,
132+
vector_attribute_name="vector",
133+
candidate_vector=[0.25, 0.25],
134+
threshold=0.85)).all()
135+
136+
Only nodes above threshold = 0.85 will be returned.
137+
110138
The :class:`~neomodel.semantic_filters.VectorFilter` can be used in conjunction with the normal filter types.
111139

112140
.. attention::

neomodel/async_/match.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,17 @@ def build_vector_query(self):
577577
f"Attribute {vector_filter.vector_attribute_name} is not declared with a vector index."
578578
)
579579

580+
if type(vector_filter.threshold) not in [float, type(None)]:
581+
raise ValueError(f"Vector Filter Threshold must be a float or None.")
582+
580583
vector_filter.index_name = f"vector_index_{source_class.__label__}_{vector_filter.vector_attribute_name}"
581584
vector_filter.node_set_label = source_class.__label__.lower()
582585

583586
self._ast.vector_index_query = vector_filter
584587
self._ast.return_clause = f"{vector_filter.node_set_label}, score"
585588
self._ast.result_class = source_class.__class__
586589

590+
587591
def build_fulltext_query(self):
588592
"""
589593
Query a free text indexed property on the node.
@@ -604,6 +608,9 @@ def build_fulltext_query(self):
604608
f"Attribute {full_text_filter.fulltext_attribute_name} is not declared with a full text index."
605609
)
606610

611+
if type(full_text_filter.threshold) not in [float, type(None)]:
612+
raise ValueError(f"Full Text Filter Threshold must be a float or None.")
613+
607614
full_text_filter.index_name = f"fulltext_index_{source_class.__label__}_{full_text_filter.fulltext_attribute_name}"
608615
full_text_filter.node_set_label = source_class.__label__.lower()
609616

@@ -1005,19 +1012,33 @@ def build_query(self) -> str:
10051012
if self._ast.vector_index_query:
10061013
query += f"""CALL () {{
10071014
CALL db.index.vector.queryNodes("{self._ast.vector_index_query.index_name}", {self._ast.vector_index_query.topk}, {self._ast.vector_index_query.vector})
1008-
YIELD node AS {self._ast.vector_index_query.node_set_label}, score
1015+
YIELD node AS {self._ast.vector_index_query.node_set_label}, score """
1016+
1017+
if self._ast.vector_index_query.threshold:
1018+
query += f"""
1019+
WHERE score >= {self._ast.vector_index_query.threshold}
1020+
"""
1021+
1022+
query += f"""
10091023
RETURN {self._ast.vector_index_query.node_set_label}, score
1010-
}}"""
1024+
}}"""
10111025

10121026
# This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering
10131027
query += f""" WITH {self._ast.vector_index_query.node_set_label}, score"""
10141028

10151029
if self._ast.fulltext_index_query:
10161030
query += f"""CALL () {{
10171031
CALL db.index.fulltext.queryNodes("{self._ast.fulltext_index_query.index_name}", "{self._ast.fulltext_index_query.query_string}")
1018-
YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score
1032+
YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score"""
1033+
1034+
if self._ast.fulltext_index_query.threshold:
1035+
query += f"""
1036+
WHERE score >= {self._ast.fulltext_index_query.threshold}
1037+
"""
1038+
1039+
query += f"""
10191040
RETURN {self._ast.fulltext_index_query.node_set_label}, score LIMIT {self._ast.fulltext_index_query.topk}
1020-
}}
1041+
}}
10211042
"""
10221043
# This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering
10231044
query += f""" WITH {self._ast.fulltext_index_query.node_set_label}, score"""

neomodel/semantic_filters.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
class VectorFilter(object):
24
"""
35
Represents a CALL db.index.vector.query* neo functions call within the OGM
@@ -6,35 +8,42 @@ class VectorFilter(object):
68
:type topk: int
79
:param vector_attribute_name: The property name for vector indexed property on the searched object.
810
:type vector_attribute_name: str
11+
:param threshold: Threshold for vector similarity.
12+
:type threshold: float
913
:param candidate_vector: The vector you are finding the nearest topk neighbours for.
1014
:type candidate_vector: list[float]
1115
1216
"""
1317

1418
def __init__(
15-
self, topk: int, vector_attribute_name: str, candidate_vector: list[float]
19+
self, topk: int, vector_attribute_name: str, candidate_vector: list[float], threshold: Union[float, None] = None
1620
):
1721
self.topk = topk
1822
self.vector_attribute_name = vector_attribute_name
23+
self.threshold = threshold
1924
self.index_name = None
2025
self.node_set_label = None
2126
self.vector = candidate_vector
2227

2328
class FulltextFilter(object):
2429
"""
25-
Represents a CALL db.index.fulltext.query* neo function call within the OGM.
30+
Represents a CALL db.index.fulltext.query* neo functon call within the OGM.
2631
:param query_strng: The string you are finding the nearest
2732
:type query_string: str
2833
:param freetext_attribute_name: The property name for the free text indexed property.
2934
:type fulltext_attribute_name: str
35+
:param threshold: Threshold for vector similarity.
36+
:type threshold: float
3037
:param topk: Amount to nodes to return
3138
:type topk: int
3239
3340
"""
3441

35-
def __init__(self, query_string: str, fulltext_attribute_name: str, topk: int):
42+
def __init__(self, query_string: str, fulltext_attribute_name: str, topk: int, threshold: Union[float, None] = None
43+
):
44+
self.topk = topk
3645
self.query_string = query_string
3746
self.fulltext_attribute_name = fulltext_attribute_name
47+
self.threshold = threshold
3848
self.index_name = None
3949
self.node_set_label = None
40-
self.topk = topk

neomodel/sync_/match.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ def build_vector_query(self):
575575
f"Attribute {vector_filter.vector_attribute_name} is not declared with a vector index."
576576
)
577577

578+
if type(vector_filter.threshold) not in [float, type(None)]:
579+
raise ValueError(f"Vector Filter Threshold must be a float or None.")
580+
578581
vector_filter.index_name = f"vector_index_{source_class.__label__}_{vector_filter.vector_attribute_name}"
579582
vector_filter.node_set_label = source_class.__label__.lower()
580583

@@ -602,6 +605,9 @@ def build_fulltext_query(self):
602605
f"Attribute {full_text_filter.fulltext_attribute_name} is not declared with a full text index."
603606
)
604607

608+
if type(full_text_filter.threshold) not in [float, type(None)]:
609+
raise ValueError(f"Full Text Filter Threshold must be a float or None.")
610+
605611
full_text_filter.index_name = f"fulltext_index_{source_class.__label__}_{full_text_filter.fulltext_attribute_name}"
606612
full_text_filter.node_set_label = source_class.__label__.lower()
607613

@@ -1003,19 +1009,33 @@ def build_query(self) -> str:
10031009
if self._ast.vector_index_query:
10041010
query += f"""CALL () {{
10051011
CALL db.index.vector.queryNodes("{self._ast.vector_index_query.index_name}", {self._ast.vector_index_query.topk}, {self._ast.vector_index_query.vector})
1006-
YIELD node AS {self._ast.vector_index_query.node_set_label}, score
1012+
YIELD node AS {self._ast.vector_index_query.node_set_label}, score """
1013+
1014+
if self._ast.vector_index_query.threshold:
1015+
query += f"""
1016+
WHERE score >= {self._ast.vector_index_query.threshold}
1017+
"""
1018+
1019+
query += f"""
10071020
RETURN {self._ast.vector_index_query.node_set_label}, score
1008-
}}"""
1021+
}}"""
10091022

10101023
# This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering
10111024
query += f""" WITH {self._ast.vector_index_query.node_set_label}, score"""
10121025

10131026
if self._ast.fulltext_index_query:
10141027
query += f"""CALL () {{
10151028
CALL db.index.fulltext.queryNodes("{self._ast.fulltext_index_query.index_name}", "{self._ast.fulltext_index_query.query_string}")
1016-
YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score
1029+
YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score"""
1030+
1031+
if self._ast.fulltext_index_query.threshold:
1032+
query += f"""
1033+
WHERE score >= {self._ast.fulltext_index_query.threshold}
1034+
"""
1035+
1036+
query += f"""
10171037
RETURN {self._ast.fulltext_index_query.node_set_label}, score LIMIT {self._ast.fulltext_index_query.topk}
1018-
}}
1038+
}}
10191039
"""
10201040
# This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering
10211041
query += f""" WITH {self._ast.fulltext_index_query.node_set_label}, score"""

test/async_/test_fulltextfilter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,46 @@ class fulltextNodeBis(AsyncStructuredNode):
122122
assert all(isinstance(x[1], float) for x in result)
123123

124124

125+
@mark_async_test
126+
async def test_fulltextfilter_threshold():
127+
"""
128+
Tests that the fulltext query is run, and only nodes above threshold returns.
129+
"""
130+
131+
if not await adb.version_is_higher_than("5.16"):
132+
pytest.skip("Not supported before 5.16")
133+
134+
class fulltextNodeThresh(AsyncStructuredNode):
135+
description = StringProperty(
136+
fulltext_index=FulltextIndex(
137+
analyzer="standard-no-stop-words", eventually_consistent=False
138+
)
139+
)
140+
other = StringProperty()
141+
142+
await adb.install_labels(fulltextNodeThresh)
143+
144+
node1 = await fulltextNodeThresh(other="thing", description="Another thing").save()
145+
146+
node2 = await fulltextNodeThresh(
147+
other="other thing", description="Another other thing"
148+
).save()
149+
150+
fulltextFilterThresh= fulltextNodeThresh.nodes.filter(
151+
fulltext_filter=FulltextFilter(
152+
topk=3, fulltext_attribute_name="description", query_string="thing", threshold=0.09
153+
),
154+
other="thing",
155+
)
156+
157+
result = await fulltextFilterThresh.all()
158+
159+
print(result)
160+
assert len(result) == 1
161+
assert all(isinstance(x[0], fulltextNodeThresh) for x in result)
162+
assert result[0][0].other == "thing"
163+
assert all(x[1] >= 0.09 for x in result)
164+
125165
@mark_async_test
126166
async def test_dont_duplicate_fulltext_filter_node():
127167
"""

test/async_/test_vectorfilter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,38 @@ class someNode(AsyncStructuredNode):
4747
assert all(isinstance(x[0], someNode) for x in result)
4848
assert all(isinstance(x[1], float) for x in result)
4949

50+
@mark_async_test
51+
async def test_vectorfilter_thresholding():
52+
"""
53+
Tests that the vector query is run, and only node above threshold returns.
54+
"""
55+
# Vector Indexes only exist from 5.13 onwards
56+
if not await adb.version_is_higher_than("5.13"):
57+
pytest.skip("Vector Index not Generally Available in Neo4j.")
58+
59+
class someNodeThresh(AsyncStructuredNode):
60+
name = StringProperty()
61+
vector = ArrayProperty(
62+
base_property=FloatProperty(), vector_index=VectorIndex(2, "cosine")
63+
)
64+
65+
await adb.install_labels(someNodeThresh)
66+
67+
john = await someNodeThresh(name="John", vector=[float(0.5), float(0.5)]).save()
68+
fred = await someNodeThresh(name="Fred", vector=[float(1.0), float(0.0)]).save()
69+
70+
vectorsearchFilterThreshold = someNodeThresh.nodes.filter(
71+
vector_filter=VectorFilter(
72+
topk=3, vector_attribute_name="vector", candidate_vector=[0.25, 0], threshold=0.8
73+
),
74+
name="John",
75+
)
76+
result = await vectorsearchFilterThreshold.all()
77+
78+
assert len(result) == 1
79+
assert all(isinstance(x[0], someNodeThresh) for x in result)
80+
assert result[0][0].name == "John"
81+
assert all(x[1] >= 0.8 for x in result)
5082

5183
@mark_async_test
5284
async def test_vectorfilter_with_node_propertyfilter():

test/sync_/test_fulltextfilter.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,50 @@ class fulltextNodeBis(StructuredNode):
122122
assert all(isinstance(x[1], float) for x in result)
123123

124124

125+
@mark_sync_test
126+
def test_fulltextfilter_threshold():
127+
"""
128+
Tests that the fulltext query is run, and only nodes above threshold returns.
129+
"""
130+
131+
if not db.version_is_higher_than("5.16"):
132+
pytest.skip("Not supported before 5.16")
133+
134+
class fulltextNodeThresh(StructuredNode):
135+
description = StringProperty(
136+
fulltext_index=FulltextIndex(
137+
analyzer="standard-no-stop-words", eventually_consistent=False
138+
)
139+
)
140+
other = StringProperty()
141+
142+
db.install_labels(fulltextNodeThresh)
143+
144+
node1 = fulltextNodeThresh(other="thing", description="Another thing").save()
145+
146+
node2 = fulltextNodeThresh(
147+
other="other thing", description="Another other thing"
148+
).save()
149+
150+
fulltextFilterThresh = fulltextNodeThresh.nodes.filter(
151+
fulltext_filter=FulltextFilter(
152+
topk=3,
153+
fulltext_attribute_name="description",
154+
query_string="thing",
155+
threshold=0.09,
156+
),
157+
other="thing",
158+
)
159+
160+
result = fulltextFilterThresh.all()
161+
162+
print(result)
163+
assert len(result) == 1
164+
assert all(isinstance(x[0], fulltextNodeThresh) for x in result)
165+
assert result[0][0].other == "thing"
166+
assert all(x[1] >= 0.09 for x in result)
167+
168+
125169
@mark_sync_test
126170
def test_dont_duplicate_fulltext_filter_node():
127171
"""

test/sync_/test_indexing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from pytest import raises
55

6-
from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty, db
6+
from neomodel import (
7+
IntegerProperty,
8+
StringProperty,
9+
StructuredNode,
10+
UniqueProperty,
11+
db,
12+
)
713
from neomodel.exceptions import ConstraintValidationFailed
814

915

0 commit comments

Comments
 (0)