Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Changelog - [0.3.0] - yyyy-mm-dd

### New Features

1. Add a new get_relationship_properties_keys tool.

### Bug Fixes
1. Return node names in several path algorithms that only returned node ids.

* Fix a bug with loading node properties correctly.

Expand Down
28 changes: 22 additions & 6 deletions mcp_server/src/mcp_server_neo4j_gds/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def projected_graph(gds, undirected=False):

# Use separate data and additional configuration parameters
if additional_config:
G, _ = gds.graph.cypher.project(
f"""
project_query = f"""
MATCH (n)-[r]->(m)
WITH n, r, m
RETURN gds.graph.project(
Expand All @@ -169,12 +168,14 @@ def projected_graph(gds, undirected=False):
{{{data_config}}},
{{{additional_config}}}
)
""",
"""
logger.info(f"Project query: '{project_query}'")
G, _ = gds.graph.cypher.project(
project_query,
graph_name=graph_name,
)
else:
G, _ = gds.graph.cypher.project(
f"""
projection_query = f"""
MATCH (n)-[r]->(m)
WITH n, r, m
RETURN gds.graph.project(
Expand All @@ -183,7 +184,10 @@ def projected_graph(gds, undirected=False):
m,
{{{data_config}}}
)
""",
"""
logger.info(f"Projection query: '{projection_query}'")
G, _ = gds.graph.cypher.project(
projection_query,
graph_name=graph_name,
)
yield G
Expand All @@ -206,3 +210,15 @@ def get_node_properties_keys(gds: GraphDataScience):
if df.empty:
return []
return df["properties_keys"].iloc[0]


def get_relationship_properties_keys(gds: GraphDataScience):
with projected_graph(gds):
query = """
MATCH (n)-[r]->(m)
RETURN DISTINCT keys(properties(r)) AS properties_keys
"""
df = gds.run_cypher(query)
if df.empty:
return []
return df["properties_keys"].iloc[0]
20 changes: 19 additions & 1 deletion mcp_server/src/mcp_server_neo4j_gds/path_algorithm_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,16 @@ def minimum_directed_steiner_tree(

total_weight += weight

# Get node names using GDS utility function
node_name = self.gds.util.asNode(node_id)
parent_name = self.gds.util.asNode(parent_id)

edges.append(
{
"nodeId": node_id,
"parentId": parent_id,
"nodeName": node_name,
"parentName": parent_name,
"weight": weight,
}
)
Expand All @@ -590,7 +596,7 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
class PrizeCollectingSteinerTreeHandler(AlgorithmHandler):
def prize_collecting_steiner_tree(self, **kwargs):
with projected_graph(self.gds, undirected=True) as G:
# If any optional parameter is not None, use that parameter
# Prepare parameters for the algorithm
params = {k: v for k, v in kwargs.items() if v is not None}
logger.info(f"Prize-Collecting Steiner Tree parameters: {params}")

Expand Down Expand Up @@ -618,10 +624,16 @@ def prize_collecting_steiner_tree(self, **kwargs):

total_weight += weight

# Get node names using GDS utility function if available
node_name = self.gds.util.asNode(node_id)
parent_name = self.gds.util.asNode(parent_id)

edges.append(
{
"nodeId": node_id,
"parentId": parent_id,
"nodeName": node_name,
"parentName": parent_name,
"weight": weight,
}
)
Expand Down Expand Up @@ -660,10 +672,16 @@ def all_pairs_shortest_paths(self, **kwargs):
target_id = int(row["targetNodeId"])
distance = float(row["distance"])

# Get node names using GDS utility function
source_name = self.gds.util.asNode(source_id)
target_name = self.gds.util.asNode(target_id)

paths.append(
{
"sourceNodeId": source_id,
"targetNodeId": target_id,
"sourceNodeName": source_name,
"targetNodeName": target_name,
"distance": distance,
}
)
Expand Down
13 changes: 12 additions & 1 deletion mcp_server/src/mcp_server_neo4j_gds/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .community_algorithm_specs import community_tool_definitions
from .path_algorithm_specs import path_tool_definitions
from .registry import AlgorithmRegistry
from .gds import count_nodes, get_node_properties_keys
from .gds import count_nodes, get_node_properties_keys, get_relationship_properties_keys

logger = logging.getLogger("mcp_server_neo4j_gds")

Expand Down Expand Up @@ -80,6 +80,13 @@ async def handle_list_tools() -> list[types.Tool]:
"type": "object",
},
),
types.Tool(
name="get_relationship_properties_keys",
description="""Get all relationship properties keys in the database""",
inputSchema={
"type": "object",
},
),
]
+ centrality_tool_definitions
+ community_tool_definitions
Expand All @@ -105,6 +112,10 @@ async def handle_call_tool(
result = get_node_properties_keys(gds)
return [types.TextContent(type="text", text=serialize_result(result))]

elif name == "get_relationship_properties_keys":
result = get_relationship_properties_keys(gds)
return [types.TextContent(type="text", text=serialize_result(result))]

else:
handler = AlgorithmRegistry.get_handler(name, gds)
result = handler.execute(arguments or {})
Expand Down
12 changes: 12 additions & 0 deletions mcp_server/tests/test_basic_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def test_list_tools(mcp_client):
# Basic tools
"count_nodes",
"get_node_properties_keys",
"get_relationship_properties_keys",
# Centrality algorithms
"article_rank",
"articulation_points",
Expand Down Expand Up @@ -97,3 +98,14 @@ async def test_get_node_properties_keys(mcp_client):
"display_name",
"longitude",
]


@pytest.mark.asyncio
async def test_get_relationship_properties_keys(mcp_client):
result = await mcp_client.call_tool("get_relationship_properties_keys")

assert len(result) == 1
result_text = result[0]["text"]
properties_keys = json.loads(result_text)

assert properties_keys == ["distance", "line", "time"]
6 changes: 6 additions & 0 deletions mcp_server/tests/test_path_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ async def test_minimum_directed_steiner_tree(mcp_client):
first_edge = edges[0]
assert "nodeId" in first_edge
assert "parentId" in first_edge
assert "nodeName" in first_edge
assert "parentName" in first_edge
assert "weight" in first_edge
assert first_edge["weight"] > 0

Expand Down Expand Up @@ -360,6 +362,8 @@ async def test_prize_collecting_steiner_tree(mcp_client):
first_edge = edges[0]
assert "nodeId" in first_edge
assert "parentId" in first_edge
assert "nodeName" in first_edge
assert "parentName" in first_edge
assert "weight" in first_edge
assert first_edge["weight"] > 0

Expand Down Expand Up @@ -388,6 +392,8 @@ async def test_all_pairs_shortest_paths(mcp_client):
first_path = paths[0]
assert "sourceNodeId" in first_path
assert "targetNodeId" in first_path
assert "sourceNodeName" in first_path
assert "targetNodeName" in first_path
assert "distance" in first_path

distance = first_path["distance"]
Expand Down