diff --git a/.env.example b/.env.example index d50ee456..1a7957e6 100644 --- a/.env.example +++ b/.env.example @@ -141,10 +141,15 @@ VECTOR_DB__QDRANT_VECTOR_SIZE=768 # --- Graph Database --- # Connection URI for the Neo4j database. -GRAPH_DB__NEO4J_URI=neo4j://localhost:7687 +NEO4J_URI="bolt://localhost:7687" # The name of the specific Neo4j database. -# GRAPH_DB__NEO4J_DATABASE=neo4j +NEO4J_DATABASE="neo4j" +# Username for the Neo4j database. +NEO4J_USERNAME="neo4j" + +# Password for the Neo4j database. +NEO4J_PASSWORD="password" # ============================================================================== # SOCIAL MEDIA diff --git a/.python-version b/.python-version index e4fba218..56bb6605 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.12.7 diff --git a/examples/graph_query_example.py b/examples/graph_query_example.py new file mode 100644 index 00000000..65aa1774 --- /dev/null +++ b/examples/graph_query_example.py @@ -0,0 +1,59 @@ +import asyncio + +from flare_ai_kit.rag.graph.engine import GraphQueryEngine +from flare_ai_kit.rag.graph.indexers.neo4j_indexer import Neo4jIngester +from flare_ai_kit.rag.graph.settings import GraphDbSettings + + +async def main(): + settings = GraphDbSettings() + ingester = Neo4jIngester(settings) + engine = GraphQueryEngine(settings) + + # Ingest some blocks + await ingester.batch_ingest(start_block=45476458, count=3) + + # Query recent transactions + recent_txs = engine.get_recent_transactions(limit=5) + print(" Recent transactions:") + for tx in recent_txs: + print(tx) + + # Only try to fetch a tx by hash if we actually got some + if recent_txs: + print("\n Transactions Details:") + for tx in recent_txs: + sample_hash = tx["hash"] + tx_details = engine.get_transaction_by_hash(sample_hash) + print(tx_details, end="\n") + print() + + if tx_details.get("from"): + print(f"Sender: {tx_details['from']}, Receiver: {tx_details['to']}") + else: + print("No recent transactions found to show details for.") + + # Account balance of sender + if tx_details.get("from_address"): + balance = engine.get_account_balance(tx_details["from_address"]) + print("\n Account balance:") + print(balance) + + # List some contracts + contracts = engine.get_contracts(limit=5) + print("\n Example contracts:") + for c in contracts: + print(c) + + # Account profile of receiver + if tx_details.get("to_address"): + profile = engine.get_account_profile(tx_details["to_address"]) + print("\n Account profile:") + print(profile) + + ingester.close() + engine.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 5f5537c6..2877752f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ flare-ai-kit = "flare_ai_kit.main:start" [project.optional-dependencies] rag = [ "qdrant-client>=1.13.3", - "dulwich>=0.23.2" + "dulwich>=0.23.2", + "neo4j>=5.0.0" ] social = [ "python-telegram-bot>=22.0", diff --git a/src/flare_ai_kit/rag/graph/engine.py b/src/flare_ai_kit/rag/graph/engine.py new file mode 100644 index 00000000..512f1037 --- /dev/null +++ b/src/flare_ai_kit/rag/graph/engine.py @@ -0,0 +1,138 @@ +from typing import Any + +from neo4j import GraphDatabase, ManagedTransaction + +from flare_ai_kit.rag.graph.settings import GraphDbSettings + + +class GraphQueryEngine: + def __init__(self, settings: GraphDbSettings): + if settings.neo4j_password is None: + raise ValueError("Neo4j password must be set") + + self._driver = GraphDatabase.driver( + settings.neo4j_uri, + auth=("neo4j", settings.neo4j_password), + database=settings.neo4j_database, + ) + + def close(self): + self._driver.close() + + def get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]: + query = """ + MATCH (tx:Transaction {hash: $tx_hash}) + OPTIONAL MATCH (from:Account)-[:FROM]->(tx) + OPTIONAL MATCH (tx)-[:TO]->(to:Account) + RETURN tx.hash AS hash, + tx.blockNumber AS blockNumber, + tx.value AS value, + tx.timestamp AS timestamp, + from.address AS from_address, + to.address AS to_address + LIMIT 1 + """ + with self._driver.session() as session: + result = session.execute_read( + lambda tx: tx.run(query, tx_hash=tx_hash).data() + ) + if not result: + return {} + r = result[0] + return { + "hash": "0x" + r["hash"].hex() + if isinstance(r["hash"], (bytes, bytearray)) + else str(r["hash"]).lower(), + "blockNumber": r["blockNumber"], + "value": r["value"], + "timestamp": r["timestamp"].isoformat() if r["timestamp"] else None, + "from_address": r.get("from_address"), + "to_address": r.get("to_address"), + } + + def get_recent_transactions(self, limit: int = 10) -> list[dict[str, Any]]: + query = """ + MATCH (tx:Transaction) + RETURN tx.hash AS hash, + tx.blockNumber AS blockNumber, + tx.value AS value, + tx.timestamp AS timestamp + ORDER BY tx.timestamp DESC + LIMIT $limit + """ + with self._driver.session() as session: + results = session.execute_read(lambda tx: tx.run(query, limit=limit).data()) + cleaned: list[dict[str, Any]] = [] + for r in results: + cleaned.append( + { + "hash": "0x" + r["hash"].hex() + if isinstance(r["hash"], (bytes, bytearray)) + else str(r["hash"]).lower(), + "blockNumber": r["blockNumber"], + "value": r["value"], + "timestamp": r["timestamp"].isoformat() + if r["timestamp"] + else None, + } + ) + return cleaned + + def get_account_balance(self, address: str) -> dict[str, Any]: + query = """ + MATCH (account:Account {address: toLower($address)}) + RETURN account.balance AS balance + """ + + def fetch_tx(tx: ManagedTransaction) -> Any: + return tx.run(query, address=address).single() + + with self._driver.session() as session: + result = session.execute_read(fetch_tx) + + return { + "address": address, + "balance": result["balance"] if result else None, + } + + def get_contracts(self, limit: int = 20) -> list[dict[str, Any]]: + query = """ + MATCH (c:Contract) + RETURN c.address AS address, c.balance AS balance + LIMIT $limit + """ + + def fetch_tx(tx: ManagedTransaction) -> Any: + return tx.run(query, limit=limit).data() + + with self._driver.session() as session: + result = session.execute_read(fetch_tx) + + return result + + def get_account_profile(self, address: str) -> dict[str, Any]: + query = """ + MATCH (account:Account {address: toLower($address)}) + OPTIONAL MATCH (account)-[:FROM]->(tx_out:Transaction) + OPTIONAL MATCH (tx_in:Transaction)-[:TO]->(account) + RETURN account, + count(DISTINCT tx_out) AS total_sent, + count(DISTINCT tx_in) AS total_received + """ + + def fetch_tx(tx: ManagedTransaction) -> Any: + return tx.run(query, address=address).single() + + with self._driver.session() as session: + result = session.execute_read(fetch_tx) + + if result: + account_node = result["account"] + return { + "address": account_node.get("address"), + "balance": account_node.get("balance"), + "total_sent": result["total_sent"], + "total_received": result["total_received"], + "is_contract": "Contract" in account_node.labels, + } + return {"address": address, "profile": "not found"} diff --git a/src/flare_ai_kit/rag/graph/indexers/neo4j_indexer.py b/src/flare_ai_kit/rag/graph/indexers/neo4j_indexer.py new file mode 100644 index 00000000..4591dabb --- /dev/null +++ b/src/flare_ai_kit/rag/graph/indexers/neo4j_indexer.py @@ -0,0 +1,148 @@ +from typing import Any, cast + +from eth_typing import ChecksumAddress +from neo4j import GraphDatabase, ManagedTransaction +from web3 import AsyncHTTPProvider, AsyncWeb3 +from web3.middleware import ( + ExtraDataToPOAMiddleware, # pyright: ignore[reportUnknownVariableType] +) +from web3.types import BlockData, TxData + +from flare_ai_kit.rag.graph.settings import GraphDbSettings + + +class Neo4jIngester: + def __init__(self, settings: GraphDbSettings): + if settings.neo4j_password is None: + raise ValueError("Neo4j password must be set") + self.driver = GraphDatabase.driver( + settings.neo4j_uri, auth=("neo4j", settings.neo4j_password) + ) + self.web3 = AsyncWeb3( + AsyncHTTPProvider(str(settings.web3_provider_url)), + middleware=[ExtraDataToPOAMiddleware], + ) + + def close(self): + self.driver.close() + + def ingest_transactions(self, transactions: list[dict[str, Any]]) -> None: + with self.driver.session(database="neo4j") as session: + session.execute_write(self._create_transaction_nodes, transactions) + + @staticmethod + def _create_transaction_nodes( + tx: ManagedTransaction, transactions: list[dict[str, Any]] + ) -> None: + # Normalize hashes to lowercase hex strings + for t in transactions: + if isinstance(t.get("hash"), (bytes, bytearray)): + t["hash"] = t["hash"].hex() + elif isinstance(t.get("hash"), str) and t["hash"].startswith("0x"): + t["hash"] = t["hash"][2:] # strip 0x + elif isinstance(t.get("hash"), str): + t["hash"] = t["hash"].lower() + + tx.run( + """ + UNWIND $transactions AS tx_data + + // Create sender node (User or Contract) + MERGE (from:Account {address: toLower(tx_data.from)}) + ON CREATE SET from.created_at = timestamp() + SET from.balance = tx_data.from_balance + + // Create recipient node and check if it's a contract + MERGE (to:Account {address: toLower(tx_data.to)}) + ON CREATE SET to.created_at = timestamp() + SET to.balance = tx_data.to_balance + + // Optional: Label contract if code is present + FOREACH (_ IN CASE WHEN tx_data.to_code IS NOT NULL AND tx_data.to_code <> '0x' THEN [1] ELSE [] END | + SET to:Contract + ) + FOREACH (_ IN CASE WHEN tx_data.from_code IS NOT NULL AND tx_data.from_code <> '0x' THEN [1] ELSE [] END | + SET from:Contract + ) + + // Transaction node + MERGE (tx:Transaction {hash: tx_data.hash}) + SET tx.blockNumber = tx_data.blockNumber, + tx.timestamp = datetime({ epochMillis: tx_data.timestamp }), + tx.value = tx_data.value + + MERGE (from)-[:FROM]->(tx) + MERGE (tx)-[:TO]->(to) + """, + parameters={"transactions": transactions}, + ) + + async def fetch_block_transactions(self, block_number: int) -> list[dict[str, Any]]: + block: BlockData | None = await self.web3.eth.get_block( + block_number, full_transactions=True + ) + transactions: list[dict[str, Any]] = [] + + if not block: + return transactions + + block_timestamp = block.get("timestamp") + if block_timestamp is None: + return transactions + + tx_list = block.get("transactions", []) + for tx_raw in tx_list: + if not isinstance(tx_raw, dict): + continue + + tx: TxData = tx_raw + + hash_val = tx.get("hash") + block_number_val = tx.get("blockNumber") + value = tx.get("value") + from_address = cast("ChecksumAddress", tx.get("from")) + to_address = cast("ChecksumAddress", tx.get("to")) + + # Normalize hash into canonical 0x-prefixed string + if isinstance(hash_val, (bytes, bytearray)): + hash_str = "0x" + hash_val.hex() + else: + hash_str = str(hash_val).lower() + if not hash_str.startswith("0x"): + hash_str = "0x" + hash_str + + if not all([hash_str, block_number_val, value, from_address]): + continue + + from_balance = await self.web3.eth.get_balance(from_address) + from_code = await self.web3.eth.get_code(from_address) + + to_balance = await self.web3.eth.get_balance(to_address) + to_code = await self.web3.eth.get_code(to_address) + + transactions.append( + { + "hash": hash_str, + "from": from_address, + "to": to_address, + "blockNumber": block_number_val, + "timestamp": block_timestamp * 1000, + "value": str(value), + "from_balance": str(from_balance), + "to_balance": str(to_balance), + "from_code": from_code.hex(), + "to_code": to_code.hex(), + } + ) + + return transactions + + async def batch_ingest(self, start_block: int, count: int) -> None: + for i in range(start_block, start_block + count): + txs = await self.fetch_block_transactions(i) + if txs: + print(f"Ingesting block {i} with {len(txs)} transactions") + self.ingest_transactions(txs) + else: + print(f"Block {i} has no transactions, skipping") + self.close() diff --git a/src/flare_ai_kit/rag/graph/settings.py b/src/flare_ai_kit/rag/graph/settings.py index baedf1d4..770da6c2 100644 --- a/src/flare_ai_kit/rag/graph/settings.py +++ b/src/flare_ai_kit/rag/graph/settings.py @@ -1,6 +1,6 @@ """Settings for GraphRAG.""" -from pydantic import Field +from pydantic import Field, HttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict @@ -12,11 +12,24 @@ class GraphDbSettings(BaseSettings): env_file=".env", extra="ignore", ) + web3_provider_url: HttpUrl = Field( + default=HttpUrl( + "https://stylish-light-theorem.flare-mainnet.quiknode.pro/ext/bc/C/rpc" + ), + description="Flare RPC endpoint URL.", + ) neo4j_uri: str = Field( - default="neo4j://localhost:7687", + default="bolt://localhost:7687", description="Connection URI for the Neo4j database.", ) + neo4j_username: str = Field( + default="neo4j", description="Username for the Neo4j database." + ) neo4j_database: str = Field( default="neo4j", # Default database name in Neo4j v4+ description="The name of the specific Neo4j database.", ) + neo4j_password: str | None = Field( + default=None, + description="password for the Neo4j database.", + ) diff --git a/uv.lock b/uv.lock index 7dbe07f1..bfdb24e1 100644 --- a/uv.lock +++ b/uv.lock @@ -848,6 +848,7 @@ ingestion = [ ] rag = [ { name = "dulwich" }, + { name = "neo4j" }, { name = "qdrant-client" }, ] social = [ @@ -895,6 +896,7 @@ requires-dist = [ { name = "google-genai", specifier = ">=1.8.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", marker = "extra == 'wallet'", specifier = ">=0.28.1" }, + { name = "neo4j", marker = "extra == 'rag'", specifier = ">=5.0.0" }, { name = "pillow", specifier = ">=11.3.0" }, { name = "pillow", marker = "extra == 'ingestion'", specifier = ">=11.3.0" }, { name = "pydantic", specifier = ">=2.11.1" }, @@ -1974,6 +1976,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/30/9aec301e9772b098c1f5c0ca0279237c9766d94b97802e9888010c64b0ed/multidict-6.6.3-py3-none-any.whl", hash = "sha256:8db10f29c7541fc5da4defd8cd697e1ca429db743fa716325f236079b96f775a", size = 12313, upload-time = "2025-06-30T15:53:45.437Z" }, ] +[[package]] +name = "neo4j" +version = "5.28.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/69/4862fabc082f2447131aada5c91736155349d77ebf443af7f59553b7b789/neo4j-5.28.2.tar.gz", hash = "sha256:7d38e27e4f987a45cc9052500c6ee27325cb23dae6509037fe31dd7ddaed70c7", size = 231874, upload-time = "2025-07-30T06:04:34.669Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/00/1f74089c06aec1fac9390e2300a6a6b2381e0dac281783d64ccca9d681fd/neo4j-5.28.2-py3-none-any.whl", hash = "sha256:5c53b5c3eee6dee7e920c9724391aa38d7135a651e71b766da00533b92a91a94", size = 313156, upload-time = "2025-07-30T06:04:31.438Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -2756,6 +2770,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/54/0955bd46a1e046169500e129c7883664b6675d580074d68823485e4d5de1/python_telegram_bot-22.3-py3-none-any.whl", hash = "sha256:88fab2d1652dbfd5379552e8b904d86173c524fdb9270d3a8685f599ffe0299f", size = 717115, upload-time = "2025-07-20T20:03:07.261Z" }, ] +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + [[package]] name = "pyunormalize" version = "16.0.0"