diff --git a/notebooks/import_databricks_demo.ipynb b/notebooks/import_databricks_demo.ipynb new file mode 100644 index 00000000..ac8c56fd --- /dev/null +++ b/notebooks/import_databricks_demo.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Importing Databricks Data into Neptune Analytics via Athena Federation\n", + "This notebook demonstrates how to connect PaySim data stored in Databricks and, using Athena Federated Query, create a graph view of the data in Neptune Analytics. \n", + "\n", + "\n", + "### Prerequisite\n", + "\n", + "To enable querying Databricks connector from Amazon Athena, the Athena Databricks Connector must first be deployed in your AWS account.\n", + "Deployment and setup instructions are available in the following resources:\n", + "\n", + "**Installation guide**\n", + "\n", + "../connectors/athena-databricks-connector/README.md\n", + "\n", + "\n", + "### What this notebook covers:\n", + "1. Download the [Kaggle PaySim1 dataset](https://www.kaggle.com/datasets/ealaxi/paysim1), a synthetic financial dataset simulating mobile money transactions, and upload it to Databricks.\n", + "\n", + "2. Use Amazon Athena to federate queries against the Databricks table and generate vertex/edge projections\n", + "3. Import the projections into Amazon Neptune Analytics\n", + "4. Run community detection (Louvain) to identify transaction clusters\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Import the necessary libraries and set up logging." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dotenv\n", + "import kagglehub\n", + "from pathlib import Path\n", + "from databricks.sdk import WorkspaceClient\n", + "from databricks import sql\n", + "\n", + "from nx_neptune import empty_s3_bucket, instance_management, NeptuneGraph, set_config_graph_id\n", + "from nx_neptune.instance_management import execute_athena_query, _clean_s3_path, get_athena_query_results\n", + "from nx_neptune.utils.utils import get_stdout_logger, validate_and_get_env\n", + "\n", + "\n", + "# Configure logging to see detailed information about the instance creation process\n", + "logger = get_stdout_logger(__name__, [\n", + " 'nx_neptune.instance_management',\n", + " 'nx_neptune.utils.task_future',\n", + " 'nx_neptune.interface',\n", + " __name__\n", + "])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Check for environment variables necessary for the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Required environment variables for Neptune Analytics and Athena \n", + "dotenv.load_dotenv() \n", + "env_vars = validate_and_get_env([ \n", + " 'NETWORKX_S3_DATA_LAKE_BUCKET_PATH', \n", + " 'NETWORKX_S3_NA_IMPORT_BUCKET_PATH', \n", + " 'NETWORKX_S3_LOG_BUCKET_PATH', \n", + " 'NETWORKX_S3_TABLES_DATABASE', \n", + " 'NETWORKX_S3_TABLES_TABLENAME', \n", + " 'NETWORKX_GRAPH_ID', \n", + "]) \n", + " \n", + "(s3_location_data_lake, s3_location_na_import, s3_location_log, \n", + "s3_tables_database, s3_tables_tablename, graph_id) = env_vars.values() \n", + " \n", + "# Optional — only needed to upload test data to Databricks (skip if table already exists) \n", + "db_env = validate_and_get_env([ \n", + " 'DATABRICKS_HOST', \n", + " 'DATABRICKS_TOKEN', \n", + " 'DATABRICKS_HTTP_PATH', \n", + "]) \n", + " \n", + "db_host, db_token, db_http_path = db_env.values() \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Data Setup\n", + "\n", + "The transaction dataset used in this demo is sourced from [Kaggle PaySim1 dataset](https://www.kaggle.com/datasets/ealaxi/paysim1)\n", + ", a synthetic financial dataset simulating mobile money\n", + "transactions.\n", + "\n", + "The setup cell below automates the full data ingestion pipeline:\n", + "\n", + "1. Download — The dataset is fetched programmatically using the kagglehub package, which \n", + "handles caching to avoid redundant downloads on subsequent runs.\n", + "\n", + "2. Upload to Databricks Volume — The CSV file is uploaded to a Unity Catalog Volume via the \n", + "Databricks SDK, staging it in cloud storage accessible by the SQL Warehouse.\n", + "\n", + "3. Create Delta Table — A `CREATE TABLE AS SELECT` statement reads the CSV from the Volume and \n", + "materializes it as a managed Delta table. Schema is inferred automatically from the CSV \n", + "headers.\n", + "\n", + "If the table already exists, all three steps are skipped entirely." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "volume_path = \"/Volumes/workspace/default/test_paysim_vol\"\n", + "table_name = \"workspace.default.paysim_transactions\"\n", + "\n", + "with sql.connect(server_hostname=db_host, http_path=db_http_path, access_token=db_token) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute(f\"SHOW TABLES IN workspace.default LIKE 'paysim_transactions'\")\n", + " if cursor.fetchone():\n", + " print(f\"{table_name} already exists, skipping\")\n", + " else:\n", + " print (\"Not exist\")\n", + " # Download from Kaggle\n", + " paysim_path = Path(kagglehub.dataset_download(\"ealaxi/paysim1\"))\n", + " csv_file = next(paysim_path.glob(\"*.csv\"))\n", + "\n", + " # # Upload to Volume\n", + " w = WorkspaceClient(host=f\"https://{db_host}\", token=db_token)\n", + " with open(csv_file, \"rb\") as f:\n", + " w.files.upload(f\"{volume_path}/{csv_file.name}\", f, overwrite=True)\n", + "\n", + " # Create table\n", + " cursor.execute(f\"\"\"\n", + " CREATE TABLE {table_name} AS\n", + " SELECT * FROM read_files(\n", + " '{volume_path}/{csv_file.name}',\n", + " format => 'csv', header => true, inferSchema => true)\"\"\")\n", + "\n", + " print(f\"Created {table_name}\")\n", + " cursor.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Verification\n", + "\n", + "Quick sanity check to confirm the Databricks table is accessible via the Athena federated connector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = 'SELECT * FROM \"lambda:databricks\".\"default\".\"paysim_transactions\" LIMIT 1'\n", + "\n", + "result = await execute_athena_query(query, s3_location_na_import, database=s3_tables_database)\n", + "query_id = result[0].task_id\n", + "\n", + "rows = get_athena_query_results(query_id)\n", + "assert len(rows) == 2, f\"Expected 2 rows (1 header + 1 data), got {len(rows)}\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Transformation and Graph Import\n", + "\n", + "In this step, Amazon Athena federates queries against the Databricks Unity Catalog table via \n", + "the Athena-Databricks connector. Two projections are generated:\n", + "\n", + "1. Vertex CSV — Extracts distinct customer IDs (both source and destination) from the \n", + "transaction dataset to create graph nodes.\n", + "2. Edge CSV — Maps each transaction as a directed edge between customers, carrying transaction \n", + "attributes (type, amount, balances, fraud flags) as edge properties.\n", + "\n", + "Both projections are written to S3 in Neptune Analytics' CSV import format, cleaned of Athena \n", + "metadata files, and then bulk-imported into the graph.\n", + "\n", + "After completion, the graph contains customer nodes connected by transaction edges — ready for\n", + "graph analytics (e.g., fraud ring detection, centrality analysis).\n", + "\n", + "│ **Troubleshooting:** If the Athena federated connector is not configured properly (e.g., the Lambda \n", + "function does not exist or the connector name is incorrect), you will receive a \n", + "`GENERIC_USER_ERROR` with a `ResourceNotFoundException`. Ensure the connector Lambda function is \n", + "deployed and the catalog name in your query (e.g., lambda:databricks) matches the registered \n", + "connector name.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Clear import directory\n", + "empty_s3_bucket(s3_location_na_import)\n", + "\n", + "# Generate vertex and edge projections from Databricks table\n", + "databricks_table_ref=f'\"lambda:databricks\".\"default\".\"paysim_transactions\"'\n", + "\n", + "SOURCE_AND_DESTINATION_CUSTOMERS = f\"\"\"\n", + "SELECT DISTINCT \"~id\", 'customer' AS \"~label\"\n", + "FROM (\n", + " SELECT NAMEORIG as \"~id\" FROM {databricks_table_ref} WHERE NAMEORIG IS NOT NULL\n", + " UNION ALL\n", + " SELECT NAMEDEST as \"~id\" FROM {databricks_table_ref} WHERE NAMEDEST IS NOT NULL\n", + ")\n", + "\"\"\"\n", + "\n", + "BANK_TRANSACTIONS = f\"\"\"\n", + "SELECT\n", + " NAMEORIG as \"~from\",\n", + " NAMEDEST as \"~to\",\n", + " TYPE AS \"~label\",\n", + " STEP AS \"step:Int\",\n", + " AMOUNT AS \"amount:Float\",\n", + " OLDBALANCEORG AS \"oldbalanceOrg:Float\",\n", + " NEWBALANCEORIG AS \"newbalanceOrig:Float\",\n", + " OLDBALANCEDEST AS \"oldbalanceDest:Float\",\n", + " NEWBALANCEDEST AS \"newbalanceDest:Float\",\n", + " ISFRAUD AS \"isFraud:Int\",\n", + " ISFLAGGEDFRAUD AS \"isFlaggedFraud:Int\"\n", + "FROM {databricks_table_ref} WHERE NAMEORIG IS NOT NULL AND NAMEDEST IS NOT NULL\n", + "\"\"\"\n", + "\n", + "await execute_athena_query(SOURCE_AND_DESTINATION_CUSTOMERS, s3_location_na_import, database=s3_tables_database, polling_interval=15)\n", + "await execute_athena_query(BANK_TRANSACTIONS, s3_location_na_import, database=s3_tables_database, polling_interval=15)\n", + "\n", + "# # Remove unnecessary .csv.metadata file generated by Athena. \n", + "empty_s3_bucket(s3_location_na_import, file_extension=\".csv.metadata\")\n", + "\n", + "task_id = await instance_management.import_csv_from_s3(\n", + " NeptuneGraph.from_config(set_config_graph_id(graph_id)),\n", + " s3_location_na_import)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Graph Analytics\n", + "\n", + "With the transaction graph loaded, we run community detection to identify clusters of \n", + "customers with dense transaction patterns — a common technique for fraud ring detection.\n", + "\n", + "1. Verify Import — Confirm nodes (customers) and edges (transactions) were loaded correctly.\n", + "2. Community Detection — Run the Louvain algorithm on Neptune Analytics to partition the graph \n", + "into communities, writing the result as a community property on each node.\n", + "3. Analyze Communities — Query the top 10 communities by size to understand the graph's \n", + "structure and identify unusually large or tightly connected groups for further investigation.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = set_config_graph_id(graph_id)\n", + "na_graph = NeptuneGraph.from_config(config)\n", + "\n", + "# Verify nodes\n", + "all_nodes = na_graph.execute_call(\"MATCH (n) RETURN n LIMIT 3\")\n", + "print(\"Sample Nodes:\")\n", + "for n in all_nodes:\n", + " print(f\" {n['n']['~id']} ({n['n']['~labels'][0]})\")\n", + "\n", + "# Verify edges\n", + "all_edges = na_graph.execute_call(\"MATCH ()-[r]-() RETURN r LIMIT 5\")\n", + "print(\"\\nSample Edges:\")\n", + "for e in all_edges:\n", + " r = e[\"r\"]\n", + " print(f\" {r['~start']} --[{r['~type']}, amount: {r['~properties']['amount']}]--> {r['~end']}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run Louvain algorithm and mutate graph with community property\n", + "louvain_result = na_graph.execute_call(\n", + " 'CALL neptune.algo.louvain.mutate({iterationTolerance:1e-07, writeProperty:\"community\"}) '\n", + " 'YIELD success AS success RETURN success'\n", + ")\n", + "print(f\"Louvain result: {louvain_result}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Find the top 10 communities by size\n", + "top_communities = na_graph.execute_call(\"\"\"\n", + "MATCH (n)\n", + "WHERE n.community IS NOT NULL\n", + "RETURN n.community AS community, count(*) AS community_size\n", + "ORDER BY community_size DESC\n", + "LIMIT 10\n", + "\"\"\")\n", + "\n", + "print(\"Top 10 Communities:\")\n", + "print(f\" {'Community ID':>14} {'Size':>6}\")\n", + "print(f\" {'─' * 14} {'─' * 6}\")\n", + "for c in top_communities:\n", + " print(f\" {c['community']:>14} {c['community_size']:>6}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This notebook demonstrated an end-to-end workflow for federated graph analytics: \n", + "Sourcing transaction data from Databricks via the Athena-Databricks connector, transforming it into a graph-compatible format with Athena, importing it into Neptune Analytics, and running community detection to surface transaction clusters.\n", + "\n", + "This pattern enables teams to leverage existing data in Databricks without data duplication — Athena federates the query at runtime, and Neptune Analytics provides the graph compute layer for analytics that relational engines aren't designed for." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/import_s3_vector_embedding_demo.ipynb b/notebooks/import_s3_vector_embedding_demo.ipynb index c1cbc781..19cc2775 100644 --- a/notebooks/import_s3_vector_embedding_demo.ipynb +++ b/notebooks/import_s3_vector_embedding_demo.ipynb @@ -388,7 +388,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.13" + "version": "3.13.12" } }, "nbformat": 4, diff --git a/nx_neptune/__init__.py b/nx_neptune/__init__.py index 03a9161c..9dc83765 100644 --- a/nx_neptune/__init__.py +++ b/nx_neptune/__init__.py @@ -38,6 +38,7 @@ empty_s3_bucket, export_athena_table_to_s3, export_csv_to_s3, + get_athena_query_results, import_csv_from_s3, start_na_instance, stop_na_instance, @@ -100,6 +101,7 @@ "create_csv_table_from_s3", "create_iceberg_table_from_table", "drop_athena_table", + "get_athena_query_results", # session management "SessionManager", "CleanupTask", diff --git a/nx_neptune/instance_management.py b/nx_neptune/instance_management.py index 026353b1..f2d973d6 100644 --- a/nx_neptune/instance_management.py +++ b/nx_neptune/instance_management.py @@ -1920,3 +1920,29 @@ def execute_athena_query( client, polling_interval=polling_interval, ) + + +def get_athena_query_results( + query_execution_id: str, + client: Optional[BaseClient] = None, +): + """Fetch results for a completed Athena query. + + Args: + query_execution_id (str): The Athena query execution ID. + client (Optional[BaseClient]): Pre-configured Athena client. + If None, creates a new client instance. + + Returns: + list[list[str]]: Raw rows including header as first row. + """ + if client is None: + client = boto3.client("athena") + + rows = [] + paginator = client.get_paginator("get_query_results") + for page in paginator.paginate(QueryExecutionId=query_execution_id): + for row in page["ResultSet"]["Rows"]: + rows.append([col.get("VarCharValue") for col in row["Data"]]) + + return rows diff --git a/pyproject.toml b/pyproject.toml index e83d3009..168302f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,10 @@ jupyter = [ "ipykernel>=6.0.0", "matplotlib>=3.0.0", "nbstripout>=0.3.10", - "kagglehub" + "kagglehub", + "pandas", + "databricks-sdk", + "databricks-sql-connector" ] [tool.setuptools] diff --git a/tests/clients/test_instance_management.py b/tests/clients/test_instance_management.py index 7c91c22d..7f2399e8 100644 --- a/tests/clients/test_instance_management.py +++ b/tests/clients/test_instance_management.py @@ -35,6 +35,7 @@ ProjectionType, empty_s3_bucket, drop_athena_table, + get_athena_query_results, ) NX_CREATE_SUCCESS_FIXTURE = """{ @@ -1609,3 +1610,90 @@ async def test_drop_athena_table_with_polling_params( ) assert result == "test-query-execution-id" + + +@patch("nx_neptune.instance_management.boto3.client") +def test_get_athena_query_results_single_page(mock_boto3_client): + """Test fetching Athena query results with a single page.""" + mock_athena_client = MagicMock() + mock_paginator = MagicMock() + mock_athena_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "ResultSet": { + "Rows": [ + {"Data": [{"VarCharValue": "id"}, {"VarCharValue": "name"}]}, + {"Data": [{"VarCharValue": "1"}, {"VarCharValue": "Alice"}]}, + {"Data": [{"VarCharValue": "2"}, {"VarCharValue": "Bob"}]}, + ] + } + } + ] + + rows = get_athena_query_results("test-query-id", client=mock_athena_client) + + assert len(rows) == 3 + assert rows[0] == ["id", "name"] + assert rows[1] == ["1", "Alice"] + assert rows[2] == ["2", "Bob"] + mock_athena_client.get_paginator.assert_called_once_with("get_query_results") + mock_paginator.paginate.assert_called_once_with(QueryExecutionId="test-query-id") + + +@patch("nx_neptune.instance_management.boto3.client") +def test_get_athena_query_results_multiple_pages(mock_boto3_client): + """Test fetching Athena query results across multiple pages.""" + mock_athena_client = MagicMock() + mock_paginator = MagicMock() + mock_athena_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "ResultSet": { + "Rows": [ + {"Data": [{"VarCharValue": "id"}]}, + {"Data": [{"VarCharValue": "1"}]}, + ] + } + }, + { + "ResultSet": { + "Rows": [ + {"Data": [{"VarCharValue": "2"}]}, + ] + } + }, + ] + + rows = get_athena_query_results("test-query-id", client=mock_athena_client) + + assert len(rows) == 3 + assert rows[0] == ["id"] + assert rows[1] == ["1"] + assert rows[2] == ["2"] + + +@patch("nx_neptune.instance_management.boto3.client") +def test_get_athena_query_results_empty(mock_boto3_client): + """Test fetching Athena query results with no rows.""" + mock_athena_client = MagicMock() + mock_paginator = MagicMock() + mock_athena_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [{"ResultSet": {"Rows": []}}] + + rows = get_athena_query_results("test-query-id", client=mock_athena_client) + + assert rows == [] + + +@patch("nx_neptune.instance_management.boto3.client") +def test_get_athena_query_results_creates_default_client(mock_boto3_client): + """Test that a default Athena client is created when none is provided.""" + mock_athena_client = MagicMock() + mock_boto3_client.return_value = mock_athena_client + mock_paginator = MagicMock() + mock_athena_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [{"ResultSet": {"Rows": []}}] + + get_athena_query_results("test-query-id") + + mock_boto3_client.assert_called_once_with("athena")