Skip to content

Commit 871334d

Browse files
Merge pull request #931 from neo4j-contrib/feature/new-exists-operator-in-filter
New exists operator for filtering operations.
2 parents 7139fc2 + b7c1fb4 commit 871334d

File tree

4 files changed

+304
-136
lines changed

4 files changed

+304
-136
lines changed

neomodel/async_/match.py

Lines changed: 122 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from neomodel.properties import AliasProperty, ArrayProperty, Property
1515
from neomodel.semantic_filters import FulltextFilter, VectorFilter
1616
from neomodel.typing import Subquery, Transformation
17-
from neomodel.util import RelationshipDirection
17+
from neomodel.util import RelationshipDirection, deprecated
1818

1919
CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
2020

@@ -160,6 +160,7 @@ def _rel_merge_helper(
160160
_SPECIAL_OPERATOR_ISNULL = "IS NULL"
161161
_SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL"
162162
_SPECIAL_OPERATOR_REGEX = "=~"
163+
_SPECIAL_OPERATOR_EXISTS = "EXISTS"
163164

164165
_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL, _SPECIAL_OPERATOR_ISNOTNULL)
165166

@@ -196,6 +197,7 @@ def _rel_merge_helper(
196197
"isnull": _SPECIAL_OPERATOR_ISNULL,
197198
"regex": _SPECIAL_OPERATOR_REGEX,
198199
"exact": "=",
200+
"exists": "EXISTS",
199201
}
200202
# add all regex operators
201203
OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE)
@@ -241,6 +243,13 @@ def _handle_special_operators(
241243
raise ValueError(f"Value must be a bool for isnull operation on {key}")
242244
operator = "IS NULL" if value else "IS NOT NULL"
243245
deflated_value = None
246+
elif operator == _SPECIAL_OPERATOR_EXISTS:
247+
if not isinstance(value, bool):
248+
raise ValueError(f"Value must be a bool for exists operation on {key}")
249+
operator = (
250+
f"NOT {_SPECIAL_OPERATOR_EXISTS}" if not value else _SPECIAL_OPERATOR_EXISTS
251+
)
252+
deflated_value = value
244253
elif operator in _REGEX_OPERATOR_TABLE.values():
245254
deflated_value = property_obj.deflate(value)
246255
if not isinstance(deflated_value, str):
@@ -307,6 +316,7 @@ def _process_filter_key(
307316
prop,
308317
) = _initialize_filter_args_variables(cls, key)
309318

319+
hop_name = None
310320
for part in re.split(path_split_regex, key):
311321
defined_props = current_class.defined_properties(rels=True)
312322
# update defined props dictionary with relationship properties if
@@ -320,6 +330,7 @@ def _process_filter_key(
320330
defined_props[part].lookup_node_class()
321331
current_class = defined_props[part].definition["node_class"]
322332
current_rel_model = defined_props[part].definition["model"]
333+
hop_name = part
323334
elif part in OPERATOR_TABLE:
324335
operator = OPERATOR_TABLE[part]
325336
prop, _ = prop.rsplit("__", 1)
@@ -332,7 +343,11 @@ def _process_filter_key(
332343

333344
if leaf_prop is None:
334345
raise ValueError(f"Badly formed filter, no property found in {key}")
335-
if is_rel_property and current_rel_model:
346+
347+
if hop_name == leaf_prop:
348+
# Path ended on a hop, not a property
349+
property_obj = None
350+
elif is_rel_property and current_rel_model:
336351
property_obj = getattr(current_rel_model, leaf_prop)
337352
else:
338353
property_obj = getattr(current_class, leaf_prop)
@@ -389,6 +404,71 @@ def process_has_args(
389404
return match, dont_match
390405

391406

407+
def generate_traversal_from_path(
408+
relation: "Path",
409+
source_class: Any,
410+
create_ids: bool = False,
411+
node_id_generator=None,
412+
rel_id_generator=None,
413+
namespace: str | None = None,
414+
):
415+
"""
416+
Generator function to construct a cypher traversal from the given path.
417+
"""
418+
path: str = relation.value
419+
stmt: str = ""
420+
source_class_iterator = source_class
421+
parts = re.split(path_split_regex, path)
422+
rel_iterator: str = ""
423+
for index, part in enumerate(parts):
424+
relationship = getattr(source_class_iterator, part)
425+
if rel_iterator:
426+
rel_iterator += "__"
427+
rel_iterator += part
428+
# build source
429+
if "node_class" not in relationship.definition:
430+
relationship.lookup_node_class()
431+
lhs_name = None
432+
if not stmt:
433+
lhs_label = source_class_iterator.__label__
434+
lhs_name = lhs_label.lower()
435+
if create_ids and not namespace:
436+
lhs_ident = f"{lhs_name}:{lhs_label}"
437+
else:
438+
lhs_ident = lhs_name
439+
else:
440+
lhs_ident = stmt
441+
442+
rel_ident = None
443+
rhs_name = None
444+
rhs_label = relationship.definition["node_class"].__label__
445+
if create_ids:
446+
rel_ident = rel_id_generator()
447+
if relation.relation_filtering:
448+
rhs_name = rel_ident
449+
rhs_ident = f":{rhs_label}"
450+
else:
451+
if index + 1 == len(parts) and relation.alias:
452+
# If an alias is defined, use it to store the last hop in the path
453+
rhs_name = relation.alias
454+
else:
455+
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
456+
rhs_name = node_id_generator(rhs_name, rel_iterator)
457+
rhs_ident = f"{rhs_name}:{rhs_label}"
458+
else:
459+
rhs_ident = f":{rhs_label}"
460+
461+
stmt = _rel_helper(
462+
lhs=lhs_ident,
463+
rhs=rhs_ident,
464+
ident=rel_ident,
465+
direction=relationship.definition["direction"],
466+
relation_type=relationship.definition["relation_type"],
467+
)
468+
yield stmt, lhs_name, rhs_name, rel_ident, part, source_class_iterator
469+
source_class_iterator = relationship.definition["node_class"]
470+
471+
392472
class QueryAST:
393473
match: list[str]
394474
optional_match: list[str]
@@ -657,54 +737,27 @@ def _additional_return(self, name: str) -> None:
657737
def build_traversal_from_path(
658738
self, relation: "Path", source_class: Any
659739
) -> tuple[str, Any]:
660-
path: str = relation.value
661-
stmt: str = ""
662-
source_class_iterator = source_class
663-
parts = re.split(path_split_regex, path)
664740
subgraph = self._ast.subgraph
665-
rel_iterator: str = ""
666-
already_present = False
667-
existing_rhs_name = ""
668-
for index, part in enumerate(parts):
741+
generator = generate_traversal_from_path(
742+
relation,
743+
source_class,
744+
True,
745+
self.create_node_identifier,
746+
self.create_relation_identifier,
747+
self._subquery_namespace,
748+
)
749+
for index, items in enumerate(generator):
750+
stmt, lhs_name, rhs_name, rel_ident, part, source_class_iterator = items
669751
relationship = getattr(source_class_iterator, part)
670-
if rel_iterator:
671-
rel_iterator += "__"
672-
rel_iterator += part
673-
# build source
674-
if "node_class" not in relationship.definition:
675-
relationship.lookup_node_class()
676-
if not stmt:
677-
lhs_label = source_class_iterator.__label__
678-
lhs_name = lhs_label.lower()
679-
lhs_ident = f"{lhs_name}:{lhs_label}"
680-
if not index:
681-
# This is the first one, we make sure that 'return'
682-
# contains the primary node so _contains() works
683-
# as usual
684-
self._ast.return_clause = lhs_name
685-
if self._subquery_namespace:
686-
# Don't include label in identifier if we are in a subquery
687-
lhs_ident = lhs_name
688-
elif relation.include_nodes_in_return:
689-
self._additional_return(lhs_name)
690-
else:
691-
lhs_ident = stmt
692752

693-
already_present = part in subgraph
694-
rel_ident = self.create_relation_identifier()
695-
rhs_label = relationship.definition["node_class"].__label__
696-
if relation.relation_filtering:
697-
rhs_name = rel_ident
698-
rhs_ident = f":{rhs_label}"
699-
else:
700-
if index + 1 == len(parts) and relation.alias:
701-
# If an alias is defined, use it to store the last hop in the path
702-
rhs_name = relation.alias
703-
else:
704-
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
705-
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
706-
rhs_ident = f"{rhs_name}:{rhs_label}"
753+
if not index:
754+
# This is the first one, we make sure that 'return'
755+
# contains the primary node so _contains() works
756+
# as usual
757+
self._ast.return_clause = lhs_name
758+
self._additional_return(lhs_name)
707759

760+
already_present = part in subgraph
708761
if relation.include_nodes_in_return and not already_present:
709762
self._additional_return(rhs_name)
710763

@@ -725,14 +778,6 @@ def build_traversal_from_path(
725778
]
726779
if relation.include_rels_in_return and not already_present:
727780
self._additional_return(rel_ident)
728-
stmt = _rel_helper(
729-
lhs=lhs_ident,
730-
rhs=rhs_ident,
731-
ident=rel_ident,
732-
direction=relationship.definition["direction"],
733-
relation_type=relationship.definition["relation_type"],
734-
)
735-
source_class_iterator = relationship.definition["node_class"]
736781
subgraph = subgraph[part]["children"]
737782

738783
if not already_present:
@@ -840,6 +885,11 @@ def _finalize_filter_statement(
840885
if operator in _UNARY_OPERATORS:
841886
# unary operators do not have a parameter
842887
statement = f"{ident}.{prop} {operator}"
888+
elif _SPECIAL_OPERATOR_EXISTS in operator:
889+
statement = list(
890+
generate_traversal_from_path(Path(prop), self.node_set.source)
891+
)[-1][0]
892+
statement = f"{'NOT ' if not val else ''}EXISTS {{ {statement} }}"
843893
else:
844894
place_holder = self._register_place_holder(ident + "_" + prop)
845895
if operator == _SPECIAL_OPERATOR_ARRAY_IN:
@@ -862,21 +912,22 @@ def _build_filter_statements(
862912
source_class: type[AsyncStructuredNode],
863913
) -> None:
864914
for prop, op_and_val in filters.items():
865-
is_rel_filter = "|" in prop
866-
target_class = source_class
867-
is_optional_relation = False
868-
if "__" in prop or is_rel_filter:
869-
(
870-
ident,
871-
prop,
872-
target_class,
873-
is_optional_relation,
874-
) = self._parse_path(source_class, prop)
875915
operator, val = op_and_val
876-
if not is_rel_filter:
877-
prop = target_class.defined_properties(rels=False)[
878-
prop
879-
].get_db_property_name(prop)
916+
is_optional_relation = False
917+
if _SPECIAL_OPERATOR_EXISTS not in operator:
918+
is_rel_filter = "|" in prop
919+
target_class = source_class
920+
if "__" in prop or is_rel_filter:
921+
(
922+
ident,
923+
prop,
924+
target_class,
925+
is_optional_relation,
926+
) = self._parse_path(source_class, prop)
927+
if not is_rel_filter:
928+
prop = target_class.defined_properties(rels=False)[
929+
prop
930+
].get_db_property_name(prop)
880931
statement = self._finalize_filter_statement(operator, ident, prop, val)
881932
target.append((statement, is_optional_relation))
882933

@@ -1714,6 +1765,9 @@ def exclude(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
17141765
self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs))
17151766
return self
17161767

1768+
@deprecated(
1769+
"This method is deprecated and set to be removed in a future release. Please use .filter(has_rel__exists=True) instead."
1770+
)
17171771
def has(self, **kwargs: Any) -> "AsyncBaseSet":
17181772
must_match, dont_match = process_has_args(self.source_class, kwargs)
17191773
self.must_match.update(must_match)

0 commit comments

Comments
 (0)