diff --git a/Changelog b/Changelog index 705f048f..349f712c 100644 --- a/Changelog +++ b/Changelog @@ -4,6 +4,7 @@ Version 6.0.0 2025-xx * Introduces merge_by parameter for batch operations to customize merge behaviour (label and property keys) * Enforce strict cardinality check by default * Refactor internal code: core.py file is now split into smaller files for database, node, transaction +* Fix object resolution for maps and lists Cypher objects, even when nested. This changes the way you can access lists in your Cypher results, see documentation for more info * Make AsyncDatabase / Database a true singleton for clarity * Remove deprecated methods (including fetch_relations & traverse_relations, replaced with traverse ; database operations like clear_neo4j_database or change_neo4j_password have been moved to db/adb singleton internal methods) * Housekeeping and bug fixes diff --git a/README.md b/README.md index c67d39ac..bc62c0ae 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,11 @@ See the [documentation](https://neomodel.readthedocs.io/en/latest/configuration. [Semantic Indexes](https://neomodel.readthedocs.io/en/latest/semantic_indexes.html#) (Vector and Full-text) are now natively supported so you do not have to use a custom Cypher query. Special thanks to @greengori11a for this. +### Breaking changes + +* List object resolution from Cypher was creating "2-depth" lists for no apparent reason. This release fixes this so that, for example "RETURN collect(node)" will return the nodes directly as a list in the result. In other words, you can extract this list at `results[0][0]` instead of `results[0][0][0]` +* See more breaking changes in the [documentation](http://neomodel.readthedocs.org) + # Installation Install from pypi (recommended): diff --git a/doc/source/index.rst b/doc/source/index.rst index 48837415..3dfcac70 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -56,6 +56,7 @@ To install from github:: **Breaking changes in 6.0** - The soft cardinality check is now available for all cardinalities, and strict check is enabled by default. + - List object resolution from Cypher was creating "2-depth" lists for no apparent reason. This release fixes this so that, for example "RETURN collect(node)" will return the nodes directly as a list in the result. In other words, you can extract this list at `results[0][0]` instead of `results[0][0][0]` - AsyncDatabase / Database are now true singletons for clarity - Standalone methods moved into the Database() class have been removed outside of the Database() class : - change_neo4j_password diff --git a/neomodel/async_/database.py b/neomodel/async_/database.py index bb0672d9..964f3cb3 100644 --- a/neomodel/async_/database.py +++ b/neomodel/async_/database.py @@ -596,7 +596,13 @@ def _object_resolution(self, object_to_resolve: Any) -> Any: return AsyncNeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) + return [self._object_resolution(item) for item in object_to_resolve] + + if isinstance(object_to_resolve, dict): + return { + key: self._object_resolution(value) + for key, value in object_to_resolve.items() + } return object_to_resolve diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index d5a40d38..2ac7d37f 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1787,9 +1787,6 @@ async def resolve_subgraph(self) -> list: if node.__class__ is self.source and "_" not in name: root_node = node continue - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - continue other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) diff --git a/neomodel/sync_/database.py b/neomodel/sync_/database.py index aac8631e..48ba76cf 100644 --- a/neomodel/sync_/database.py +++ b/neomodel/sync_/database.py @@ -592,7 +592,13 @@ def _object_resolution(self, object_to_resolve: Any) -> Any: return NeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) + return [self._object_resolution(item) for item in object_to_resolve] + + if isinstance(object_to_resolve, dict): + return { + key: self._object_resolution(value) + for key, value in object_to_resolve.items() + } return object_to_resolve diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 51c3c1fe..4aace76d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1781,9 +1781,6 @@ def resolve_subgraph(self) -> list: if node.__class__ is self.source and "_" not in name: root_node = node continue - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - continue other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index ddbd6808..3ee29f18 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -5,10 +5,11 @@ More information about the same issue at: https://github.com/aanastasiou/neomodelInheritanceTest -The following example uses a recursive relationship for economy, but the -idea remains the same: "Instantiate the correct type of node at the end of +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ + import random from test._async_compat import mark_async_test @@ -123,56 +124,6 @@ async def test_automatic_result_resolution(): assert type((await A.friends_with)[0]) is TechnicalPerson -@mark_async_test -async def test_recursive_automatic_result_resolution(): - """ - Node objects are instantiated to native Python objects, both at the top - level of returned results and in the case where they are returned within - lists. - """ - - # Create a few entities - A = ( - await TechnicalPerson.get_or_create( - {"name": "Grumpier", "expertise": "Grumpiness"} - ) - )[0] - B = ( - await TechnicalPerson.get_or_create( - {"name": "Happier", "expertise": "Grumpiness"} - ) - )[0] - C = ( - await TechnicalPerson.get_or_create( - {"name": "Sleepier", "expertise": "Pillows"} - ) - )[0] - D = ( - await TechnicalPerson.get_or_create( - {"name": "Sneezier", "expertise": "Pillows"} - ) - )[0] - - # Retrieve mixed results, both at the top level and nested - L, _ = await adb.cypher_query( - "MATCH (a:TechnicalPerson) " - "WHERE a.expertise='Grumpiness' " - "WITH collect(a) as Alpha " - "MATCH (b:TechnicalPerson) " - "WHERE b.expertise='Pillows' " - "WITH Alpha, collect(b) as Beta " - "RETURN [Alpha, [Beta, [Beta, ['Banana', " - "Alpha]]]]", - resolve_objects=True, - ) - - # Assert that a Node returned deep in a nested list structure is of the - # correct type - assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson - # Assert that primitive data types remain primitive data types - assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - - @mark_async_test async def test_validation_with_inheritance_from_db(): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 9ec63a9a..cc27e194 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -793,7 +793,7 @@ async def test_annotate_and_collect(): .all() ) assert len(result) == 1 - assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + assert len(result[0][1]) == 3 # 3 species must be there (with 2 duplicates) result = ( await Supplier.nodes.traverse( @@ -806,7 +806,7 @@ async def test_annotate_and_collect(): .annotate(Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( await Supplier.nodes.traverse( @@ -832,7 +832,7 @@ async def test_annotate_and_collect(): .annotate(all_species=Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( await Supplier.nodes.traverse( @@ -850,8 +850,8 @@ async def test_annotate_and_collect(): ) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there - assert len(result[0][2][0]) == 3 # 3 species relations must be there + assert len(result[0][1]) == 2 # 2 species must be there + assert len(result[0][2]) == 3 # 3 species relations must be there @mark_async_test diff --git a/test/async_/test_object_resolution.py b/test/async_/test_object_resolution.py new file mode 100644 index 00000000..2eff1daf --- /dev/null +++ b/test/async_/test_object_resolution.py @@ -0,0 +1,554 @@ +""" +Test cases for object resolution with resolve_objects=True in raw Cypher queries. + +This test file covers various scenarios for automatic class resolution, +including the issues identified in GitHub issues #905 and #906: +- Issue #905: Nested lists in results of raw Cypher queries with collect keyword +- Issue #906: Automatic class resolution for raw queries with nodes nested in maps + +Additional scenarios tested: +- Basic object resolution +- Nested structures (lists, maps, mixed) +- Path resolution +- Relationship resolution +- Complex nested scenarios with collect() and other Cypher functions +""" + +from test._async_compat import mark_async_test + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + adb, +) + + +class ResolutionRelationship(AsyncStructuredRel): + """Test relationship with properties.""" + + weight = IntegerProperty(default=1) + description = StringProperty(default="test") + + +class ResolutionNode(AsyncStructuredNode): + """Base test node class.""" + + name = StringProperty(required=True) + value = IntegerProperty(default=0) + related = AsyncRelationshipTo( + "ResolutionNode", "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionSpecialNode(AsyncStructuredNode): + """Specialized test node class.""" + + name = StringProperty(required=True) + special_value = IntegerProperty(default=42) + related = AsyncRelationshipTo( + ResolutionNode, "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionContainerNode(AsyncStructuredNode): + """Container node for testing nested structures.""" + + name = StringProperty(required=True) + items = AsyncRelationshipTo( + ResolutionNode, "CONTAINS", model=ResolutionRelationship + ) + + +@mark_async_test +async def test_basic_object_resolution(): + """Test basic object resolution for nodes and relationships.""" + # Create test data + await ResolutionNode(name="Node1", value=10).save() + await ResolutionNode(name="Node2", value=20).save() + + # Test basic node resolution + results, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "Node1"}, + resolve_objects=True, + ) + + assert len(results) == 1 + assert len(results[0]) == 1 + resolved_node = results[0][0] + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "Node1" + assert resolved_node.value == 10 + + +@mark_async_test +async def test_relationship_resolution(): + """Test relationship resolution in queries.""" + # Create test data with relationships + node1 = await ResolutionNode(name="Source", value=100).save() + node2 = await ResolutionNode(name="Target", value=200).save() + + # Create relationship + await node1.related.connect(node2, {"weight": 5, "description": "test_rel"}) + + # Test relationship resolution + results, _ = await adb.cypher_query( + "MATCH (a:ResolutionNode)-[r:RELATED_TO]->(b:ResolutionNode) RETURN a, r, b", + resolve_objects=True, + ) + + assert len(results) == 1 + source, rel, target = results[0] + + assert isinstance(source, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert isinstance(target, ResolutionNode) + + assert source.name == "Source" + assert target.name == "Target" + assert rel.weight == 5 + assert rel.description == "test_rel" + + +@mark_async_test +async def test_path_resolution(): + """Test path resolution in queries.""" + # Create test data + node1 = await ResolutionNode(name="Start", value=1).save() + node2 = await ResolutionNode(name="Middle", value=2).save() + node3 = await ResolutionNode(name="End", value=3).save() + + # Create path + await node1.related.connect(node2, {"weight": 1}) + await node2.related.connect(node3, {"weight": 2}) + + # Test path resolution + results, _ = await adb.cypher_query( + "MATCH p=(a:ResolutionNode)-[:RELATED_TO*2]->(c:ResolutionNode) RETURN p", + resolve_objects=True, + ) + + assert len(results) == 1 + path = results[0][0] + + # Path should be resolved to AsyncNeomodelPath + from neomodel.async_.path import AsyncNeomodelPath + + assert isinstance(path, AsyncNeomodelPath) + assert len(path._nodes) == 3 # pylint: disable=protected-access + assert len(path._relationships) == 2 # pylint: disable=protected-access + + +@mark_async_test +async def test_nested_lists_basic(): + """Test basic nested list resolution (Issue #905 - basic case).""" + # Create test data + nodes = [] + for i in range(3): + node = await ResolutionNode(name=f"Node{i}", value=i * 10).save() + nodes.append(node) + + # Test nested list resolution + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + RETURN collect(n) as nodes + """, + resolve_objects=True, + ) + + assert len(results) == 1 + collected_nodes = results[0][0] + + assert isinstance(collected_nodes, list) + assert len(collected_nodes) == 3 + + for i, node in enumerate(collected_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Node{i}" + assert node.value == i * 10 + + +@mark_async_test +async def test_nested_lists_complex(): + """Test complex nested list resolution with collect() (Issue #905 - complex case).""" + # Create test data with relationships + container = await ResolutionContainerNode(name="Container").save() + items = [] + for i in range(2): + item = await ResolutionNode(name=f"Item{i}", value=i * 5).save() + items.append(item) + await container.items.connect(item, {"weight": i + 1}) + + # Test complex nested list with collect + results, _ = await adb.cypher_query( + """ + MATCH (c:ResolutionContainerNode)-[r:CONTAINS]->(i:ResolutionNode) + WITH c, r, i ORDER BY i.name + WITH c, collect({item: i, rel: r}) as items + RETURN c, items + """, + resolve_objects=True, + ) + + assert len(results) == 1 + container_result, items_result = results[0] + + assert isinstance(container_result, ResolutionContainerNode) + assert container_result.name == "Container" + + assert isinstance(items_result, list) + assert len(items_result) == 2 + + for i, item_data in enumerate(items_result): + assert isinstance(item_data, dict) + assert "item" in item_data + assert "rel" in item_data + + item = item_data["item"] + rel = item_data["rel"] + + assert isinstance(item, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert item.name == f"Item{i}" + assert rel.weight == i + 1 + + +@mark_async_test +async def test_nodes_nested_in_maps(): + """Test nodes nested in maps (Issue #906).""" + # Create test data + await ResolutionNode(name="Node1", value=100).save() + await ResolutionNode(name="Node2", value=200).save() + + # Test nodes nested in maps + results, _ = await adb.cypher_query( + """ + MATCH (n1:ResolutionNode), (n2:ResolutionNode) + WHERE n1.name = 'Node1' AND n2.name = 'Node2' + RETURN { + first: n1, + second: n2, + metadata: { + count: 2, + description: 'test map' + } + } as result_map + """, + resolve_objects=True, + ) + + assert len(results) == 1 + result_map = results[0][0] + + assert isinstance(result_map, dict) + assert "first" in result_map + assert "second" in result_map + assert "metadata" in result_map + + # Check that nodes are properly resolved + first_node = result_map["first"] + second_node = result_map["second"] + + assert isinstance(first_node, ResolutionNode) + assert isinstance(second_node, ResolutionNode) + assert first_node.name == "Node1" + assert second_node.name == "Node2" + + # Check metadata (should remain as primitive types) + metadata = result_map["metadata"] + assert isinstance(metadata, dict) + assert metadata["count"] == 2 + assert metadata["description"] == "test map" + + +@mark_async_test +async def test_mixed_nested_structures(): + """Test mixed nested structures with lists, maps, and nodes.""" + # Create test data + special = await ResolutionSpecialNode(name="Special", special_value=999).save() + test_nodes = [] + for i in range(2): + node = await ResolutionNode(name=f"Test{i}", value=i * 100).save() + test_nodes.append(node) + await special.related.connect(node, {"weight": i + 10}) + + # Test complex mixed structure + results, _ = await adb.cypher_query( + """ + MATCH (s:ResolutionSpecialNode)-[r:RELATED_TO]->(t:ResolutionNode) + WITH s, r, t ORDER BY t.name + WITH s, collect({node: t, rel: r}) as related_items + RETURN { + special_node: s, + related: related_items, + summary: { + total_relations: size(related_items), + node_names: [item in related_items | item.node.name] + } + } as complex_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + complex_result = results[0][0] + + assert isinstance(complex_result, dict) + assert "special_node" in complex_result + assert "related" in complex_result + assert "summary" in complex_result + + # Check special node resolution + special_node = complex_result["special_node"] + assert isinstance(special_node, ResolutionSpecialNode) + assert special_node.name == "Special" + assert special_node.special_value == 999 + + # Check related items (list of dicts with nodes and relationships) + related = complex_result["related"] + assert isinstance(related, list) + assert len(related) == 2 + + for i, item in enumerate(related): + assert isinstance(item, dict) + assert "node" in item + assert "rel" in item + + node = item["node"] + rel = item["rel"] + + assert isinstance(node, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert node.name == f"Test{i}" + assert rel.weight == i + 10 + + # Check summary (should remain as primitive types) + summary = complex_result["summary"] + assert isinstance(summary, dict) + assert summary["total_relations"] == 2 + assert isinstance(summary["node_names"], list) + assert summary["node_names"] == ["Test0", "Test1"] + + +@mark_async_test +async def test_deeply_nested_structures(): + """Test deeply nested structures to ensure recursive resolution works.""" + # Create test data + nodes = [] + for i in range(3): + node = await ResolutionNode(name=f"Deep{i}", value=i * 50).save() + nodes.append(node) + + # Test deeply nested structure + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + WITH collect(n) as level1 + RETURN { + level1: level1, + level2: { + nodes: level1, + metadata: { + level3: { + count: size(level1), + items: level1 + } + } + } + } as deep_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + deep_result = results[0][0] + + assert isinstance(deep_result, dict) + assert "level1" in deep_result + assert "level2" in deep_result + + # Check level1 (direct list of nodes) + level1 = deep_result["level1"] + assert isinstance(level1, list) + assert len(level1) == 3 + for i, node in enumerate(level1): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check level2 (nested structure) + level2 = deep_result["level2"] + assert isinstance(level2, dict) + assert "nodes" in level2 + assert "metadata" in level2 + + # Check nodes in level2 + level2_nodes = level2["nodes"] + assert isinstance(level2_nodes, list) + assert len(level2_nodes) == 3 + for i, node in enumerate(level2_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check metadata in level2 + metadata = level2["metadata"] + assert isinstance(metadata, dict) + assert "level3" in metadata + + level3 = metadata["level3"] + assert isinstance(level3, dict) + assert "count" in level3 + assert "items" in level3 + + assert level3["count"] == 3 + level3_items = level3["items"] + assert isinstance(level3_items, list) + assert len(level3_items) == 3 + for i, node in enumerate(level3_items): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + +@mark_async_test +async def test_collect_with_aggregation(): + """Test collect() with aggregation functions.""" + # Create test data + for i in range(5): + node = await ResolutionNode(name=f"AggNode{i}", value=i * 10).save() + + # Test collect with aggregation + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WHERE n.name STARTS WITH 'Agg' + WITH n ORDER BY n.name + WITH collect(n) as all_nodes + RETURN { + nodes: all_nodes, + count: size(all_nodes), + total_value: reduce(total = 0, n in all_nodes | total + n.value), + names: [n in all_nodes | n.name] + } as aggregated_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + aggregated_result = results[0][0] + + assert isinstance(aggregated_result, dict) + assert "nodes" in aggregated_result + assert "count" in aggregated_result + assert "total_value" in aggregated_result + assert "names" in aggregated_result + + # Check nodes are resolved + nodes = aggregated_result["nodes"] + assert isinstance(nodes, list) + assert len(nodes) == 5 + for i, node in enumerate(nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"AggNode{i}" + assert node.value == i * 10 + + # Check aggregated values + assert aggregated_result["count"] == 5 + assert aggregated_result["total_value"] == 100 # 0+10+20+30+40 + assert aggregated_result["names"] == [ + "AggNode0", + "AggNode1", + "AggNode2", + "AggNode3", + "AggNode4", + ] + + +@mark_async_test +async def test_resolve_objects_false_comparison(): + """Test that resolve_objects=False returns raw Neo4j objects.""" + # Create test data + await ResolutionNode(name="RawNode", value=123).save() + + # Test with resolve_objects=False + results_false, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=False, + ) + + # Test with resolve_objects=True + results_true, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=True, + ) + + # Compare results + raw_node = results_false[0][0] + resolved_node = results_true[0][0] + + # Raw node should be a Neo4j Node object + from neo4j.graph import Node + + assert isinstance(raw_node, Node) + assert raw_node["name"] == "RawNode" + assert raw_node["value"] == 123 + + # Resolved node should be a ResolutionNode instance + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "RawNode" + assert resolved_node.value == 123 + + +@mark_async_test +async def test_empty_results(): + """Test object resolution with empty results.""" + # Test empty results + results, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = 'NonExistent' RETURN n", + resolve_objects=True, + ) + + assert len(results) == 0 + + +@mark_async_test +async def test_primitive_types_preserved(): + """Test that primitive types are preserved during object resolution.""" + # Create test data + await ResolutionNode(name="PrimitiveTest", value=456).save() + + # Test with mixed primitive and node types + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) WHERE n.name = $name + RETURN n, n.value as int_val, n.name as str_val, true as bool_val, 3.14 as float_val + """, + {"name": "PrimitiveTest"}, + resolve_objects=True, + ) + + assert len(results) == 1 + node_result, int_val, str_val, bool_val, float_val = results[0] + + # Node should be resolved + assert isinstance(node_result, ResolutionNode) + assert node_result.name == "PrimitiveTest" + + # Primitives should remain primitive + assert isinstance(int_val, int) + assert int_val == 456 + + assert isinstance(str_val, str) + assert str_val == "PrimitiveTest" + + assert isinstance(bool_val, bool) + assert bool_val is True + + assert isinstance(float_val, float) + assert float_val == 3.14 diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index fab4f0d7..3b30c14f 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -5,8 +5,8 @@ More information about the same issue at: https://github.com/aanastasiou/neomodelInheritanceTest -The following example uses a recursive relationship for economy, but the -idea remains the same: "Instantiate the correct type of node at the end of +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ @@ -118,44 +118,6 @@ def test_automatic_result_resolution(): assert type((A.friends_with)[0]) is TechnicalPerson -@mark_sync_test -def test_recursive_automatic_result_resolution(): - """ - Node objects are instantiated to native Python objects, both at the top - level of returned results and in the case where they are returned within - lists. - """ - - # Create a few entities - A = ( - TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"}) - )[0] - B = (TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"}))[ - 0 - ] - C = (TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"}))[0] - D = (TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"}))[0] - - # Retrieve mixed results, both at the top level and nested - L, _ = db.cypher_query( - "MATCH (a:TechnicalPerson) " - "WHERE a.expertise='Grumpiness' " - "WITH collect(a) as Alpha " - "MATCH (b:TechnicalPerson) " - "WHERE b.expertise='Pillows' " - "WITH Alpha, collect(b) as Beta " - "RETURN [Alpha, [Beta, [Beta, ['Banana', " - "Alpha]]]]", - resolve_objects=True, - ) - - # Assert that a Node returned deep in a nested list structure is of the - # correct type - assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson - # Assert that primitive data types remain primitive data types - assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - - @mark_sync_test def test_validation_with_inheritance_from_db(): """ diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index c27d299c..ba2d20a3 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -783,7 +783,7 @@ def test_annotate_and_collect(): .all() ) assert len(result) == 1 - assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + assert len(result[0][1]) == 3 # 3 species must be there (with 2 duplicates) result = ( Supplier.nodes.traverse( @@ -796,7 +796,7 @@ def test_annotate_and_collect(): .annotate(Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( Supplier.nodes.traverse( @@ -822,7 +822,7 @@ def test_annotate_and_collect(): .annotate(all_species=Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( Supplier.nodes.traverse( @@ -840,8 +840,8 @@ def test_annotate_and_collect(): ) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there - assert len(result[0][2][0]) == 3 # 3 species relations must be there + assert len(result[0][1]) == 2 # 2 species must be there + assert len(result[0][2]) == 3 # 3 species relations must be there @mark_sync_test diff --git a/test/sync_/test_object_resolution.py b/test/sync_/test_object_resolution.py new file mode 100644 index 00000000..6ea63201 --- /dev/null +++ b/test/sync_/test_object_resolution.py @@ -0,0 +1,550 @@ +""" +Test cases for object resolution with resolve_objects=True in raw Cypher queries. + +This test file covers various scenarios for automatic class resolution, +including the issues identified in GitHub issues #905 and #906: +- Issue #905: Nested lists in results of raw Cypher queries with collect keyword +- Issue #906: Automatic class resolution for raw queries with nodes nested in maps + +Additional scenarios tested: +- Basic object resolution +- Nested structures (lists, maps, mixed) +- Path resolution +- Relationship resolution +- Complex nested scenarios with collect() and other Cypher functions +""" + +from test._async_compat import mark_sync_test + +from neomodel import ( + IntegerProperty, + RelationshipTo, + StringProperty, + StructuredNode, + StructuredRel, + db, +) + + +class ResolutionRelationship(StructuredRel): + """Test relationship with properties.""" + + weight = IntegerProperty(default=1) + description = StringProperty(default="test") + + +class ResolutionNode(StructuredNode): + """Base test node class.""" + + name = StringProperty(required=True) + value = IntegerProperty(default=0) + related = RelationshipTo( + "ResolutionNode", "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionSpecialNode(StructuredNode): + """Specialized test node class.""" + + name = StringProperty(required=True) + special_value = IntegerProperty(default=42) + related = RelationshipTo(ResolutionNode, "RELATED_TO", model=ResolutionRelationship) + + +class ResolutionContainerNode(StructuredNode): + """Container node for testing nested structures.""" + + name = StringProperty(required=True) + items = RelationshipTo(ResolutionNode, "CONTAINS", model=ResolutionRelationship) + + +@mark_sync_test +def test_basic_object_resolution(): + """Test basic object resolution for nodes and relationships.""" + # Create test data + ResolutionNode(name="Node1", value=10).save() + ResolutionNode(name="Node2", value=20).save() + + # Test basic node resolution + results, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "Node1"}, + resolve_objects=True, + ) + + assert len(results) == 1 + assert len(results[0]) == 1 + resolved_node = results[0][0] + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "Node1" + assert resolved_node.value == 10 + + +@mark_sync_test +def test_relationship_resolution(): + """Test relationship resolution in queries.""" + # Create test data with relationships + node1 = ResolutionNode(name="Source", value=100).save() + node2 = ResolutionNode(name="Target", value=200).save() + + # Create relationship + node1.related.connect(node2, {"weight": 5, "description": "test_rel"}) + + # Test relationship resolution + results, _ = db.cypher_query( + "MATCH (a:ResolutionNode)-[r:RELATED_TO]->(b:ResolutionNode) RETURN a, r, b", + resolve_objects=True, + ) + + assert len(results) == 1 + source, rel, target = results[0] + + assert isinstance(source, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert isinstance(target, ResolutionNode) + + assert source.name == "Source" + assert target.name == "Target" + assert rel.weight == 5 + assert rel.description == "test_rel" + + +@mark_sync_test +def test_path_resolution(): + """Test path resolution in queries.""" + # Create test data + node1 = ResolutionNode(name="Start", value=1).save() + node2 = ResolutionNode(name="Middle", value=2).save() + node3 = ResolutionNode(name="End", value=3).save() + + # Create path + node1.related.connect(node2, {"weight": 1}) + node2.related.connect(node3, {"weight": 2}) + + # Test path resolution + results, _ = db.cypher_query( + "MATCH p=(a:ResolutionNode)-[:RELATED_TO*2]->(c:ResolutionNode) RETURN p", + resolve_objects=True, + ) + + assert len(results) == 1 + path = results[0][0] + + # Path should be resolved to AsyncNeomodelPath + from neomodel.sync_.path import NeomodelPath + + assert isinstance(path, NeomodelPath) + assert len(path._nodes) == 3 # pylint: disable=protected-access + assert len(path._relationships) == 2 # pylint: disable=protected-access + + +@mark_sync_test +def test_nested_lists_basic(): + """Test basic nested list resolution (Issue #905 - basic case).""" + # Create test data + nodes = [] + for i in range(3): + node = ResolutionNode(name=f"Node{i}", value=i * 10).save() + nodes.append(node) + + # Test nested list resolution + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + RETURN collect(n) as nodes + """, + resolve_objects=True, + ) + + assert len(results) == 1 + collected_nodes = results[0][0] + + assert isinstance(collected_nodes, list) + assert len(collected_nodes) == 3 + + for i, node in enumerate(collected_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Node{i}" + assert node.value == i * 10 + + +@mark_sync_test +def test_nested_lists_complex(): + """Test complex nested list resolution with collect() (Issue #905 - complex case).""" + # Create test data with relationships + container = ResolutionContainerNode(name="Container").save() + items = [] + for i in range(2): + item = ResolutionNode(name=f"Item{i}", value=i * 5).save() + items.append(item) + container.items.connect(item, {"weight": i + 1}) + + # Test complex nested list with collect + results, _ = db.cypher_query( + """ + MATCH (c:ResolutionContainerNode)-[r:CONTAINS]->(i:ResolutionNode) + WITH c, r, i ORDER BY i.name + WITH c, collect({item: i, rel: r}) as items + RETURN c, items + """, + resolve_objects=True, + ) + + assert len(results) == 1 + container_result, items_result = results[0] + + assert isinstance(container_result, ResolutionContainerNode) + assert container_result.name == "Container" + + assert isinstance(items_result, list) + assert len(items_result) == 2 + + for i, item_data in enumerate(items_result): + assert isinstance(item_data, dict) + assert "item" in item_data + assert "rel" in item_data + + item = item_data["item"] + rel = item_data["rel"] + + assert isinstance(item, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert item.name == f"Item{i}" + assert rel.weight == i + 1 + + +@mark_sync_test +def test_nodes_nested_in_maps(): + """Test nodes nested in maps (Issue #906).""" + # Create test data + ResolutionNode(name="Node1", value=100).save() + ResolutionNode(name="Node2", value=200).save() + + # Test nodes nested in maps + results, _ = db.cypher_query( + """ + MATCH (n1:ResolutionNode), (n2:ResolutionNode) + WHERE n1.name = 'Node1' AND n2.name = 'Node2' + RETURN { + first: n1, + second: n2, + metadata: { + count: 2, + description: 'test map' + } + } as result_map + """, + resolve_objects=True, + ) + + assert len(results) == 1 + result_map = results[0][0] + + assert isinstance(result_map, dict) + assert "first" in result_map + assert "second" in result_map + assert "metadata" in result_map + + # Check that nodes are properly resolved + first_node = result_map["first"] + second_node = result_map["second"] + + assert isinstance(first_node, ResolutionNode) + assert isinstance(second_node, ResolutionNode) + assert first_node.name == "Node1" + assert second_node.name == "Node2" + + # Check metadata (should remain as primitive types) + metadata = result_map["metadata"] + assert isinstance(metadata, dict) + assert metadata["count"] == 2 + assert metadata["description"] == "test map" + + +@mark_sync_test +def test_mixed_nested_structures(): + """Test mixed nested structures with lists, maps, and nodes.""" + # Create test data + special = ResolutionSpecialNode(name="Special", special_value=999).save() + test_nodes = [] + for i in range(2): + node = ResolutionNode(name=f"Test{i}", value=i * 100).save() + test_nodes.append(node) + special.related.connect(node, {"weight": i + 10}) + + # Test complex mixed structure + results, _ = db.cypher_query( + """ + MATCH (s:ResolutionSpecialNode)-[r:RELATED_TO]->(t:ResolutionNode) + WITH s, r, t ORDER BY t.name + WITH s, collect({node: t, rel: r}) as related_items + RETURN { + special_node: s, + related: related_items, + summary: { + total_relations: size(related_items), + node_names: [item in related_items | item.node.name] + } + } as complex_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + complex_result = results[0][0] + + assert isinstance(complex_result, dict) + assert "special_node" in complex_result + assert "related" in complex_result + assert "summary" in complex_result + + # Check special node resolution + special_node = complex_result["special_node"] + assert isinstance(special_node, ResolutionSpecialNode) + assert special_node.name == "Special" + assert special_node.special_value == 999 + + # Check related items (list of dicts with nodes and relationships) + related = complex_result["related"] + assert isinstance(related, list) + assert len(related) == 2 + + for i, item in enumerate(related): + assert isinstance(item, dict) + assert "node" in item + assert "rel" in item + + node = item["node"] + rel = item["rel"] + + assert isinstance(node, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert node.name == f"Test{i}" + assert rel.weight == i + 10 + + # Check summary (should remain as primitive types) + summary = complex_result["summary"] + assert isinstance(summary, dict) + assert summary["total_relations"] == 2 + assert isinstance(summary["node_names"], list) + assert summary["node_names"] == ["Test0", "Test1"] + + +@mark_sync_test +def test_deeply_nested_structures(): + """Test deeply nested structures to ensure recursive resolution works.""" + # Create test data + nodes = [] + for i in range(3): + node = ResolutionNode(name=f"Deep{i}", value=i * 50).save() + nodes.append(node) + + # Test deeply nested structure + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + WITH collect(n) as level1 + RETURN { + level1: level1, + level2: { + nodes: level1, + metadata: { + level3: { + count: size(level1), + items: level1 + } + } + } + } as deep_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + deep_result = results[0][0] + + assert isinstance(deep_result, dict) + assert "level1" in deep_result + assert "level2" in deep_result + + # Check level1 (direct list of nodes) + level1 = deep_result["level1"] + assert isinstance(level1, list) + assert len(level1) == 3 + for i, node in enumerate(level1): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check level2 (nested structure) + level2 = deep_result["level2"] + assert isinstance(level2, dict) + assert "nodes" in level2 + assert "metadata" in level2 + + # Check nodes in level2 + level2_nodes = level2["nodes"] + assert isinstance(level2_nodes, list) + assert len(level2_nodes) == 3 + for i, node in enumerate(level2_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check metadata in level2 + metadata = level2["metadata"] + assert isinstance(metadata, dict) + assert "level3" in metadata + + level3 = metadata["level3"] + assert isinstance(level3, dict) + assert "count" in level3 + assert "items" in level3 + + assert level3["count"] == 3 + level3_items = level3["items"] + assert isinstance(level3_items, list) + assert len(level3_items) == 3 + for i, node in enumerate(level3_items): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + +@mark_sync_test +def test_collect_with_aggregation(): + """Test collect() with aggregation functions.""" + # Create test data + for i in range(5): + node = ResolutionNode(name=f"AggNode{i}", value=i * 10).save() + + # Test collect with aggregation + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WHERE n.name STARTS WITH 'Agg' + WITH n ORDER BY n.name + WITH collect(n) as all_nodes + RETURN { + nodes: all_nodes, + count: size(all_nodes), + total_value: reduce(total = 0, n in all_nodes | total + n.value), + names: [n in all_nodes | n.name] + } as aggregated_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + aggregated_result = results[0][0] + + assert isinstance(aggregated_result, dict) + assert "nodes" in aggregated_result + assert "count" in aggregated_result + assert "total_value" in aggregated_result + assert "names" in aggregated_result + + # Check nodes are resolved + nodes = aggregated_result["nodes"] + assert isinstance(nodes, list) + assert len(nodes) == 5 + for i, node in enumerate(nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"AggNode{i}" + assert node.value == i * 10 + + # Check aggregated values + assert aggregated_result["count"] == 5 + assert aggregated_result["total_value"] == 100 # 0+10+20+30+40 + assert aggregated_result["names"] == [ + "AggNode0", + "AggNode1", + "AggNode2", + "AggNode3", + "AggNode4", + ] + + +@mark_sync_test +def test_resolve_objects_false_comparison(): + """Test that resolve_objects=False returns raw Neo4j objects.""" + # Create test data + ResolutionNode(name="RawNode", value=123).save() + + # Test with resolve_objects=False + results_false, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=False, + ) + + # Test with resolve_objects=True + results_true, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=True, + ) + + # Compare results + raw_node = results_false[0][0] + resolved_node = results_true[0][0] + + # Raw node should be a Neo4j Node object + from neo4j.graph import Node + + assert isinstance(raw_node, Node) + assert raw_node["name"] == "RawNode" + assert raw_node["value"] == 123 + + # Resolved node should be a ResolutionNode instance + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "RawNode" + assert resolved_node.value == 123 + + +@mark_sync_test +def test_empty_results(): + """Test object resolution with empty results.""" + # Test empty results + results, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = 'NonExistent' RETURN n", + resolve_objects=True, + ) + + assert len(results) == 0 + + +@mark_sync_test +def test_primitive_types_preserved(): + """Test that primitive types are preserved during object resolution.""" + # Create test data + ResolutionNode(name="PrimitiveTest", value=456).save() + + # Test with mixed primitive and node types + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) WHERE n.name = $name + RETURN n, n.value as int_val, n.name as str_val, true as bool_val, 3.14 as float_val + """, + {"name": "PrimitiveTest"}, + resolve_objects=True, + ) + + assert len(results) == 1 + node_result, int_val, str_val, bool_val, float_val = results[0] + + # Node should be resolved + assert isinstance(node_result, ResolutionNode) + assert node_result.name == "PrimitiveTest" + + # Primitives should remain primitive + assert isinstance(int_val, int) + assert int_val == 456 + + assert isinstance(str_val, str) + assert str_val == "PrimitiveTest" + + assert isinstance(bool_val, bool) + assert bool_val is True + + assert isinstance(float_val, float) + assert float_val == 3.14