Skip to content

Commit 1499ac6

Browse files
committed
fix: support renamed cypher reentry alias
1 parent d8f34fd commit 1499ac6

3 files changed

Lines changed: 80 additions & 17 deletions

File tree

graphistry/compute/gfql/cypher/lowering.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5953,7 +5953,7 @@ def _bounded_reentry_carry_columns(
59535953
*,
59545954
query: CypherQuery,
59555955
prefix_stage: ProjectionStage,
5956-
) -> Tuple[str, ...]:
5956+
) -> Tuple[str, Tuple[str, ...]]:
59575957
whole_row_columns = [column for column in prefix_projection.columns if column.kind == "whole_row"]
59585958
if len(whole_row_columns) != 1:
59595959
raise _unsupported(
@@ -5964,7 +5964,7 @@ def _bounded_reentry_carry_columns(
59645964
column=prefix_stage.span.column,
59655965
)
59665966
if len(prefix_projection.columns) == 1:
5967-
return ()
5967+
return whole_row_columns[0].output_name, ()
59685968
seed_alias = _single_node_seed_alias(query.matches[0]) if len(query.matches) == 1 else None
59695969
if seed_alias is None or seed_alias != prefix_projection.alias:
59705970
raise _unsupported(
@@ -5998,7 +5998,7 @@ def _bounded_reentry_carry_columns(
59985998
)
59995999
hidden_names.add(hidden_column)
60006000
carried_columns.append(column.output_name)
6001-
return tuple(carried_columns)
6001+
return whole_row_columns[0].output_name, tuple(carried_columns)
60026002

60036003

60046004
def _literal_limit_value(limit_clause: Optional[LimitClause]) -> Optional[int]:
@@ -6101,7 +6101,7 @@ def _compile_bounded_reentry_query(
61016101
line=prefix_stage.span.line,
61026102
column=prefix_stage.span.column,
61036103
)
6104-
carry_columns = _bounded_reentry_carry_columns(
6104+
reentry_alias, carry_columns = _bounded_reentry_carry_columns(
61056105
prefix_projection,
61066106
query=query,
61076107
prefix_stage=prefix_stage,
@@ -6137,7 +6137,7 @@ def _compile_bounded_reentry_query(
61376137

61386138
reentry_match = query.reentry_matches[0]
61396139
first_alias = _first_pattern_node_alias(reentry_match)
6140-
if first_alias is None or first_alias != prefix_projection.alias:
6140+
if first_alias is None or first_alias != reentry_alias:
61416141
raise _unsupported(
61426142
"Cypher MATCH after WITH currently requires the trailing MATCH to start from the same carried node alias",
61436143
field="match",
@@ -6152,7 +6152,7 @@ def _compile_bounded_reentry_query(
61526152
reentry_where,
61536153
expr=_rewrite_reentry_expr_to_hidden_properties(
61546154
reentry_where.expr,
6155-
carried_alias=prefix_projection.alias,
6155+
carried_alias=reentry_alias,
61566156
carried_columns=carry_columns,
61576157
field="where",
61586158
),
@@ -6171,7 +6171,7 @@ def _compile_bounded_reentry_query(
61716171
for rewritten_expr in (
61726172
_rewrite_reentry_expr_to_hidden_properties(
61736173
item.expression,
6174-
carried_alias=prefix_projection.alias,
6174+
carried_alias=reentry_alias,
61756175
carried_columns=carry_columns,
61766176
field=query.return_.kind,
61776177
),
@@ -6187,7 +6187,7 @@ def _compile_bounded_reentry_query(
61876187
item,
61886188
expression=_rewrite_reentry_expr_to_hidden_properties(
61896189
item.expression,
6190-
carried_alias=prefix_projection.alias,
6190+
carried_alias=reentry_alias,
61916191
carried_columns=carry_columns,
61926192
field="order_by",
61936193
),
@@ -6219,7 +6219,7 @@ def _compile_bounded_reentry_query(
62196219
column=reentry_match.span.column,
62206220
)
62216221
result_projection = suffix_compiled.result_projection
6222-
if result_projection is not None and result_projection.alias == prefix_projection.alias and carry_columns:
6222+
if result_projection is not None and result_projection.alias == reentry_alias and carry_columns:
62236223
result_projection = replace(
62246224
result_projection,
62256225
exclude_columns=tuple(

graphistry/tests/compute/gfql/cypher/test_lowering.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5046,14 +5046,15 @@ def test_string_cypher_executes_with_match_reentry_multihop_shape() -> None:
50465046

50475047

50485048
@pytest.mark.parametrize(
5049-
("query", "expected_columns"),
5049+
("query", "expected_whole_row_output", "expected_columns"),
50505050
[
50515051
(
50525052
"MATCH (a:A) "
50535053
"WITH a, a.num AS property "
50545054
"MATCH (a)-->(b) "
50555055
"RETURN property "
50565056
"ORDER BY property DESC",
5057+
"a",
50575058
("property",),
50585059
),
50595060
(
@@ -5062,21 +5063,43 @@ def test_string_cypher_executes_with_match_reentry_multihop_shape() -> None:
50625063
"MATCH (a)-->(b) "
50635064
"RETURN property, property2 "
50645065
"ORDER BY property DESC",
5066+
"a",
50655067
("property", "property2"),
50665068
),
5069+
(
5070+
"MATCH (a:A) "
5071+
"WITH a AS x, a.num AS property "
5072+
"MATCH (x)-->(b) "
5073+
"RETURN property "
5074+
"ORDER BY property DESC",
5075+
"x",
5076+
("property",),
5077+
),
50675078
],
50685079
)
5069-
def test_compile_cypher_tracks_reentry_carried_scalar_columns(query: str, expected_columns: Tuple[str, ...]) -> None:
5080+
def test_compile_cypher_tracks_reentry_carried_scalar_columns(
5081+
query: str,
5082+
expected_whole_row_output: str,
5083+
expected_columns: Tuple[str, ...],
5084+
) -> None:
50705085
compiled = _compile_query(query)
50715086
whole_row_output, carried_columns = _compiled_reentry_projection_outputs(compiled)
50725087

5073-
assert whole_row_output == "a"
5088+
assert whole_row_output == expected_whole_row_output
50745089
assert carried_columns == expected_columns
50755090

50765091

50775092
@pytest.mark.parametrize(
50785093
("query", "expected"),
50795094
[
5095+
(
5096+
"MATCH (a:A) "
5097+
"WITH a AS x "
5098+
"MATCH (x)-->(b) "
5099+
"RETURN b.id AS bid "
5100+
"ORDER BY bid",
5101+
[{"bid": "b1"}, {"bid": "b2"}],
5102+
),
50805103
(
50815104
"MATCH (a:A) "
50825105
"WITH a, a.num AS property "
@@ -5101,6 +5124,14 @@ def test_compile_cypher_tracks_reentry_carried_scalar_columns(query: str, expect
51015124
"ORDER BY property DESC",
51025125
[{"property": 2, "property2": 12}, {"property": 1, "property2": 11}],
51035126
),
5127+
(
5128+
"MATCH (a:A) "
5129+
"WITH a AS x, a.num AS property "
5130+
"MATCH (x)-->(b) "
5131+
"RETURN x, property "
5132+
"ORDER BY property DESC",
5133+
[{"x": "(:A {num: 2})", "property": 2}, {"x": "(:A {num: 1})", "property": 1}],
5134+
),
51045135
],
51055136
)
51065137
def test_string_cypher_executes_with_match_reentry_carried_scalar_shapes(query: str, expected: List[Dict[str, Any]]) -> None:

graphistry/tests/compute/test_gfql.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -488,22 +488,32 @@ def test_gfql_chain_dict_envelope(self):
488488
class TestGFQLCypherReentryCarrier:
489489

490490
@staticmethod
491-
def _compile_reentry_query(with_clause: str = "a, a.num AS property"):
491+
def _compile_reentry_query(
492+
with_clause: str = "a, a.num AS property",
493+
*,
494+
match_alias: str = "a",
495+
):
492496
return compile_cypher(
493497
"MATCH (a:A) "
494498
f"WITH {with_clause} "
495-
"MATCH (a)-->(b) "
499+
f"MATCH ({match_alias})-->(b) "
496500
"RETURN b.id AS bid"
497501
)
498502

499503
@staticmethod
500-
def _bind_reentry_prefix_result(g, rows: Dict[str, List[Any]], ids: List[Any]):
504+
def _bind_reentry_prefix_result(
505+
g,
506+
rows: Dict[str, List[Any]],
507+
ids: List[Any],
508+
*,
509+
output_name: str = "a",
510+
):
501511
prefix_result = g.bind()
502512
prefix_result._nodes = pd.DataFrame(rows)
503513
prefix_result._cypher_entity_projection_meta = {
504-
"a": {
514+
output_name: {
505515
"table": "nodes",
506-
"alias": "a",
516+
"alias": output_name,
507517
"id_column": "id",
508518
"ids": pd.Series(ids, name="id"),
509519
}
@@ -682,3 +692,25 @@ def test_reentry_state_overrides_internal_hidden_column_collisions(self):
682692
)
683693

684694
assert g._nodes["__cypher_reentry_property__"].tolist() == ["orig1", "orig2", None, None]
695+
696+
def test_reentry_state_uses_projected_whole_row_alias_for_contract(self):
697+
g = _mk_reentry_scalar_graph()
698+
self._assert_reentry_state(
699+
g=g,
700+
compiled=self._compile_reentry_query("a AS x, a.num AS property", match_alias="x"),
701+
prefix_result=self._bind_reentry_prefix_result(
702+
g,
703+
rows={"property": [2, 1]},
704+
ids=["a2", "a1"],
705+
output_name="x",
706+
),
707+
expected_rows=[
708+
{"id": "a2", "label__A": True, "num": 2, "__cypher_reentry_property__": 2},
709+
{"id": "a1", "label__A": True, "num": 1, "__cypher_reentry_property__": 1},
710+
],
711+
expected_carry={
712+
"a1": {"__cypher_reentry_property__": 1},
713+
"a2": {"__cypher_reentry_property__": 2},
714+
},
715+
expect_same_graph=False,
716+
)

0 commit comments

Comments
 (0)