Skip to content

Commit 8dbc313

Browse files
committed
wip use snowgraph degree centrality
1 parent e182a3c commit 8dbc313

File tree

4 files changed

+83
-33
lines changed

4 files changed

+83
-33
lines changed

mcp_server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ classifiers = [
1616
dependencies = [
1717
"graphdatascience>=1.16",
1818
"mcp[cli]>=1.11.0",
19-
"snowflake-connector-python==3.17.3",
19+
"snowflake-snowpark-python==1.38.0",
2020
]
2121

2222
[project.urls]

mcp_server/src/mcp_server_neo4j_gds/centrality_algorithm_handlers.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,32 +187,82 @@ def execute(self, arguments: Dict[str, Any]) -> Any:
187187

188188
class DegreeCentralityHandler(AlgorithmHandler):
189189
def degree_centrality(self, **kwargs):
190-
with projected_graph(self.gds) as G:
191-
params = {
192-
k: v
193-
for k, v in kwargs.items()
194-
if v is not None and k not in ["nodes", "nodeIdentifierProperty"]
190+
# with projected_graph(self.gds) as G:
191+
# params = {
192+
# k: v
193+
# for k, v in kwargs.items()
194+
# if v is not None and k not in ["nodes", "nodeIdentifierProperty"]
195+
# }
196+
# logger.info(f"Degree centrality parameters: {params}")
197+
# centrality = self.gds.degree.stream(G, **params)
198+
199+
# # Add node names to the results if nodeIdentifierProperty is provided
200+
# node_identifier_property = kwargs.get("nodeIdentifierProperty")
201+
# translate_ids_to_identifiers(self.gds, node_identifier_property, centrality)
202+
203+
# # Filter results by node names if provided
204+
# node_names = kwargs.get("nodes", None)
205+
# centrality = filter_identifiers(
206+
# self.gds, node_identifier_property, node_names, centrality
207+
# )
208+
orientation = kwargs.get("orientation")
209+
if orientation is None:
210+
relationships_dict = """
211+
{
212+
'RELATIONSHIPS': {
213+
'sourceTable': 'NODES',
214+
'targetTable': 'NODES',
215+
}
195216
}
196-
logger.info(f"Degree centrality parameters: {params}")
197-
centrality = self.gds.degree.stream(G, **params)
198-
199-
# Add node names to the results if nodeIdentifierProperty is provided
200-
node_identifier_property = kwargs.get("nodeIdentifierProperty")
201-
translate_ids_to_identifiers(self.gds, node_identifier_property, centrality)
202-
203-
# Filter results by node names if provided
204-
node_names = kwargs.get("nodes", None)
205-
centrality = filter_identifiers(
206-
self.gds, node_identifier_property, node_names, centrality
207-
)
208-
209-
return centrality
217+
"""
218+
else:
219+
relationships_dict = f"""
220+
{{
221+
'RELATIONSHIPS': {{
222+
'sourceTable': 'NODES',
223+
'targetTable': 'NODES',
224+
'orientation': '{orientation}'
225+
}}
226+
}}
227+
"""
228+
229+
relationshipWeightProperty = kwargs.get("relationshipWeightProperty")
230+
if relationshipWeightProperty is None:
231+
compute_dict = """
232+
{ }
233+
"""
234+
else:
235+
compute_dict = f"""
236+
{{
237+
'relationshipWeightProperty': '{relationshipWeightProperty}'
238+
}}
239+
"""
240+
241+
res = self.gds.sql(
242+
f"""
243+
CALL Neo4j_Graph_Analytics.graph.degree('CPU_X64_XS', {{
244+
'defaultTablePrefix': 'EXAMPLE_DB.PUBLIC',
245+
'project': {{
246+
'nodeTables': [ 'NODES' ],
247+
'relationshipTables': {relationships_dict}
248+
}},
249+
'compute': {compute_dict},
250+
'write': [{{
251+
'nodeLabel': 'NODES',
252+
'outputTable': 'NODES_DEGREE_CENTRALITY'
253+
}}]
254+
}});
255+
"""
256+
).collect()
257+
258+
logger.info(f"Degree centrality execution: {res}")
259+
output_table = self.gds.table("EXAMPLE_DB.PUBLIC.NODES_DEGREE_CENTRALITY")
260+
return output_table.to_pandas()
210261

211262
def execute(self, arguments: Dict[str, Any]) -> Any:
212263
return self.degree_centrality(
213-
nodes=arguments.get("nodes"),
214-
nodeIdentifierProperty=arguments.get("nodeIdentifierProperty"),
215264
orientation=arguments.get("orientation"),
265+
relationshipWeightProperty=arguments.get("relationshipWeightProperty"),
216266
)
217267

218268

mcp_server/src/mcp_server_neo4j_gds/centrality_algorithm_specs.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,14 @@
177177
inputSchema={
178178
"type": "object",
179179
"properties": {
180-
"nodes": {
181-
"type": "array",
182-
"items": {"type": "string"},
183-
"description": "List of node names to filter degree centrality results for.",
184-
},
185-
"nodeIdentifierProperty": {
186-
"type": "string",
187-
"description": "Property name to use for identifying nodes (e.g., 'name', 'Name', 'title'). Use get_node_properties_keys to find available properties.",
188-
},
189180
"orientation": {
190181
"type": "string",
191182
"description": "The orientation used to compute node degrees. Supported orientations are NATURAL (for out-degree), REVERSE (for in-degree) and UNDIRECTED (for both in-degree and out-degree) ",
192183
},
184+
"relationshipWeightProperty": {
185+
"type": "string",
186+
"description": "Property of the relationship to use for weighting. If not specified, all relationships are treated equally.",
187+
},
193188
},
194189
"required": [],
195190
},

mcp_server/src/mcp_server_neo4j_gds/server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import json
1010
from graphdatascience import GraphDataScience
11-
11+
from snowflake.snowpark import Session
1212
from .similarity_algorithm_specs import similarity_tool_definitions
1313
from .centrality_algorithm_specs import centrality_tool_definitions
1414
from .community_algorithm_specs import community_tool_definitions
@@ -56,7 +56,12 @@ async def main(db_url: str, username: str, password: str, database: str = None):
5656
db_url, auth=(username, password), aura_ds=False, database=database
5757
)
5858
else:
59-
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
59+
if db_url:
60+
gds = GraphDataScience(db_url, auth=(username, password), aura_ds=False)
61+
else:
62+
gds = Session.builder.config("connection_name", "snowflake-gds-mcp").create()
63+
print(gds.sql("SELECT 1;").collect())
64+
logger.info("Successfully connected to Snowflake database")
6065
logger.info("Successfully connected to Neo4j database")
6166
except Exception as e:
6267
logger.error(f"Failed to connect to Neo4j database: {e}")

0 commit comments

Comments
 (0)