Skip to content

Commit 22b50d4

Browse files
committed
Added method to generate unique variable names for specific pathes.
1 parent 0d38b1a commit 22b50d4

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

neomodel/async_/match.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,11 @@ def create_relation_identifier(self) -> str:
503503
self._relation_identifier_count += 1
504504
return f"r{self._relation_identifier_count}"
505505

506-
def create_node_identifier(self, prefix: str) -> str:
507-
self._node_identifier_count += 1
508-
return f"{prefix}{self._node_identifier_count}"
506+
def create_node_identifier(self, prefix: str, path: str) -> str:
507+
if path not in self.node_set._unique_variables:
508+
self._node_identifier_count += 1
509+
return f"{prefix}{self._node_identifier_count}"
510+
return prefix
509511

510512
def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
511513
if "?" in source.order_by_elements:
@@ -619,7 +621,7 @@ def build_traversal_from_path(
619621
rhs_name = relation["alias"]
620622
else:
621623
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
622-
rhs_name = self.create_node_identifier(rhs_name)
624+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
623625
rhs_ident = f"{rhs_name}:{rhs_label}"
624626
if relation["include_in_return"] and not already_present:
625627
self._additional_return(rhs_name)
@@ -1381,6 +1383,7 @@ def __init__(self, source: Any) -> None:
13811383
self._extra_results: list = []
13821384
self._subqueries: list[Subquery] = []
13831385
self._intermediate_transforms: list = []
1386+
self._unique_variables: list[str] = []
13841387

13851388
def __await__(self) -> Any:
13861389
return self.all().__await__() # type: ignore[attr-defined]
@@ -1552,6 +1555,11 @@ def _register_relation_to_fetch(
15521555
item["alias"] = alias
15531556
return item
15541557

1558+
def unique_variables(self, *pathes: tuple[str, ...]) -> "AsyncNodeSet":
1559+
"""Generate unique variable names for the given pathes."""
1560+
self._unique_variables = pathes
1561+
return self
1562+
15551563
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
15561564
"""Specify a set of relations to traverse and return."""
15571565
relations = []

neomodel/sync_/match.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,11 @@ def create_relation_identifier(self) -> str:
501501
self._relation_identifier_count += 1
502502
return f"r{self._relation_identifier_count}"
503503

504-
def create_node_identifier(self, prefix: str) -> str:
505-
self._node_identifier_count += 1
506-
return f"{prefix}{self._node_identifier_count}"
504+
def create_node_identifier(self, prefix: str, path: str) -> str:
505+
if path not in self.node_set._unique_variables:
506+
self._node_identifier_count += 1
507+
return f"{prefix}{self._node_identifier_count}"
508+
return prefix
507509

508510
def build_order_by(self, ident: str, source: "NodeSet") -> None:
509511
if "?" in source.order_by_elements:
@@ -617,7 +619,7 @@ def build_traversal_from_path(
617619
rhs_name = relation["alias"]
618620
else:
619621
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
620-
rhs_name = self.create_node_identifier(rhs_name)
622+
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
621623
rhs_ident = f"{rhs_name}:{rhs_label}"
622624
if relation["include_in_return"] and not already_present:
623625
self._additional_return(rhs_name)
@@ -1377,6 +1379,7 @@ def __init__(self, source: Any) -> None:
13771379
self._extra_results: list = []
13781380
self._subqueries: list[Subquery] = []
13791381
self._intermediate_transforms: list = []
1382+
self._unique_variables: list[str] = []
13801383

13811384
def __await__(self) -> Any:
13821385
return self.all().__await__() # type: ignore[attr-defined]
@@ -1548,6 +1551,11 @@ def _register_relation_to_fetch(
15481551
item["alias"] = alias
15491552
return item
15501553

1554+
def unique_variables(self, *pathes: tuple[str, ...]) -> "NodeSet":
1555+
"""Generate unique variable names for the given pathes."""
1556+
self._unique_variables = pathes
1557+
return self
1558+
15511559
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet":
15521560
"""Specify a set of relations to traverse and return."""
15531561
relations = []

test/async_/test_match_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,37 @@ async def test_in_filter_with_array_property():
11551155
), "Species found by tags with not match tags given"
11561156

11571157

1158+
@mark_async_test
1159+
async def test_unique_variables():
1160+
arabica = await Species(name="Arabica").save()
1161+
nescafe = await Coffee(name="Nescafe", price=99).save()
1162+
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
1163+
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
1164+
1165+
await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1166+
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1167+
await nescafe.species.connect(arabica)
1168+
1169+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1170+
coffees__name="Nescafe"
1171+
)
1172+
ast = await nodeset.query_cls(nodeset).build_ast()
1173+
query = ast.build_query()
1174+
assert "coffee_coffees1" in query
1175+
assert "coffee_coffees2" in query
1176+
1177+
nodeset = (
1178+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1179+
.filter(coffees__name="Nescafe")
1180+
.unique_variables("coffees")
1181+
)
1182+
ast = await nodeset.query_cls(nodeset).build_ast()
1183+
query = ast.build_query()
1184+
assert "coffee_coffees" in query
1185+
assert "coffee_coffees1" not in query
1186+
assert "coffee_coffees2" not in query
1187+
1188+
11581189
@mark_async_test
11591190
async def test_async_iterator():
11601191
n = 10

test/sync_/test_match_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,37 @@ def test_in_filter_with_array_property():
11391139
), "Species found by tags with not match tags given"
11401140

11411141

1142+
@mark_sync_test
1143+
def test_unique_variables():
1144+
arabica = Species(name="Arabica").save()
1145+
nescafe = Coffee(name="Nescafe", price=99).save()
1146+
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
1147+
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
1148+
1149+
nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
1150+
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
1151+
nescafe.species.connect(arabica)
1152+
1153+
nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
1154+
coffees__name="Nescafe"
1155+
)
1156+
ast = nodeset.query_cls(nodeset).build_ast()
1157+
query = ast.build_query()
1158+
assert "coffee_coffees1" in query
1159+
assert "coffee_coffees2" in query
1160+
1161+
nodeset = (
1162+
Supplier.nodes.fetch_relations("coffees", "coffees__species")
1163+
.filter(coffees__name="Nescafe")
1164+
.unique_variables("coffees")
1165+
)
1166+
ast = nodeset.query_cls(nodeset).build_ast()
1167+
query = ast.build_query()
1168+
assert "coffee_coffees" in query
1169+
assert "coffee_coffees1" not in query
1170+
assert "coffee_coffees2" not in query
1171+
1172+
11421173
@mark_sync_test
11431174
def test_async_iterator():
11441175
n = 10

0 commit comments

Comments
 (0)