Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,15 @@ VECTOR_DB__QDRANT_VECTOR_SIZE=768
# --- Graph Database ---

# Connection URI for the Neo4j database.
GRAPH_DB__NEO4J_URI=neo4j://localhost:7687
NEO4J_URI="bolt://localhost:7687"

# The name of the specific Neo4j database.
# GRAPH_DB__NEO4J_DATABASE=neo4j
NEO4J_DATABASE="neo4j"
# Username for the Neo4j database.
NEO4J_USERNAME="neo4j"

# Password for the Neo4j database.
NEO4J_PASSWORD="password"

# ==============================================================================
# SOCIAL MEDIA
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.12.7
59 changes: 59 additions & 0 deletions examples/graph_query_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio

from flare_ai_kit.rag.graph.engine import GraphQueryEngine
from flare_ai_kit.rag.graph.indexers.neo4j_indexer import Neo4jIngester
from flare_ai_kit.rag.graph.settings import GraphDbSettings


async def main():
settings = GraphDbSettings()
ingester = Neo4jIngester(settings)
engine = GraphQueryEngine(settings)

# Ingest some blocks
await ingester.batch_ingest(start_block=45476458, count=3)

# Query recent transactions
recent_txs = engine.get_recent_transactions(limit=5)
print(" Recent transactions:")
for tx in recent_txs:
print(tx)

# Only try to fetch a tx by hash if we actually got some
if recent_txs:
print("\n Transactions Details:")
for tx in recent_txs:
sample_hash = tx["hash"]
tx_details = engine.get_transaction_by_hash(sample_hash)
print(tx_details, end="\n")
print()

if tx_details.get("from"):
print(f"Sender: {tx_details['from']}, Receiver: {tx_details['to']}")
else:
print("No recent transactions found to show details for.")

# Account balance of sender
if tx_details.get("from_address"):
balance = engine.get_account_balance(tx_details["from_address"])
print("\n Account balance:")
print(balance)

# List some contracts
contracts = engine.get_contracts(limit=5)
print("\n Example contracts:")
for c in contracts:
print(c)

# Account profile of receiver
if tx_details.get("to_address"):
profile = engine.get_account_profile(tx_details["to_address"])
print("\n Account profile:")
print(profile)

ingester.close()
engine.close()


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ flare-ai-kit = "flare_ai_kit.main:start"
[project.optional-dependencies]
rag = [
"qdrant-client>=1.13.3",
"dulwich>=0.23.2"
"dulwich>=0.23.2",
"neo4j>=5.0.0"
]
social = [
"python-telegram-bot>=22.0",
Expand Down
138 changes: 138 additions & 0 deletions src/flare_ai_kit/rag/graph/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any

from neo4j import GraphDatabase, ManagedTransaction

from flare_ai_kit.rag.graph.settings import GraphDbSettings


class GraphQueryEngine:
def __init__(self, settings: GraphDbSettings):
if settings.neo4j_password is None:
raise ValueError("Neo4j password must be set")

self._driver = GraphDatabase.driver(
settings.neo4j_uri,
auth=("neo4j", settings.neo4j_password),
database=settings.neo4j_database,
)

def close(self):
self._driver.close()

def get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]:
query = """
MATCH (tx:Transaction {hash: $tx_hash})
OPTIONAL MATCH (from:Account)-[:FROM]->(tx)
OPTIONAL MATCH (tx)-[:TO]->(to:Account)
RETURN tx.hash AS hash,
tx.blockNumber AS blockNumber,
tx.value AS value,
tx.timestamp AS timestamp,
from.address AS from_address,
to.address AS to_address
LIMIT 1
"""
with self._driver.session() as session:
result = session.execute_read(
lambda tx: tx.run(query, tx_hash=tx_hash).data()
)
if not result:
return {}
r = result[0]
return {
"hash": "0x" + r["hash"].hex()
if isinstance(r["hash"], (bytes, bytearray))
else str(r["hash"]).lower(),
"blockNumber": r["blockNumber"],
"value": r["value"],
"timestamp": r["timestamp"].isoformat() if r["timestamp"] else None,
"from_address": r.get("from_address"),
"to_address": r.get("to_address"),
}

def get_recent_transactions(self, limit: int = 10) -> list[dict[str, Any]]:
query = """
MATCH (tx:Transaction)
RETURN tx.hash AS hash,
tx.blockNumber AS blockNumber,
tx.value AS value,
tx.timestamp AS timestamp
ORDER BY tx.timestamp DESC
LIMIT $limit
"""
with self._driver.session() as session:
results = session.execute_read(lambda tx: tx.run(query, limit=limit).data())
cleaned: list[dict[str, Any]] = []
for r in results:
cleaned.append(
{
"hash": "0x" + r["hash"].hex()
if isinstance(r["hash"], (bytes, bytearray))
else str(r["hash"]).lower(),
"blockNumber": r["blockNumber"],
"value": r["value"],
"timestamp": r["timestamp"].isoformat()
if r["timestamp"]
else None,
}
)
return cleaned

def get_account_balance(self, address: str) -> dict[str, Any]:
query = """
MATCH (account:Account {address: toLower($address)})
RETURN account.balance AS balance
"""

def fetch_tx(tx: ManagedTransaction) -> Any:
return tx.run(query, address=address).single()

with self._driver.session() as session:
result = session.execute_read(fetch_tx)

return {
"address": address,
"balance": result["balance"] if result else None,
}

def get_contracts(self, limit: int = 20) -> list[dict[str, Any]]:
query = """
MATCH (c:Contract)
RETURN c.address AS address, c.balance AS balance
LIMIT $limit
"""

def fetch_tx(tx: ManagedTransaction) -> Any:
return tx.run(query, limit=limit).data()

with self._driver.session() as session:
result = session.execute_read(fetch_tx)

return result

def get_account_profile(self, address: str) -> dict[str, Any]:
query = """
MATCH (account:Account {address: toLower($address)})
OPTIONAL MATCH (account)-[:FROM]->(tx_out:Transaction)
OPTIONAL MATCH (tx_in:Transaction)-[:TO]->(account)
RETURN account,
count(DISTINCT tx_out) AS total_sent,
count(DISTINCT tx_in) AS total_received
"""

def fetch_tx(tx: ManagedTransaction) -> Any:
return tx.run(query, address=address).single()

with self._driver.session() as session:
result = session.execute_read(fetch_tx)

if result:
account_node = result["account"]
return {
"address": account_node.get("address"),
"balance": account_node.get("balance"),
"total_sent": result["total_sent"],
"total_received": result["total_received"],
"is_contract": "Contract" in account_node.labels,
}
return {"address": address, "profile": "not found"}
148 changes: 148 additions & 0 deletions src/flare_ai_kit/rag/graph/indexers/neo4j_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Any, cast

from eth_typing import ChecksumAddress
from neo4j import GraphDatabase, ManagedTransaction
from web3 import AsyncHTTPProvider, AsyncWeb3
from web3.middleware import (
ExtraDataToPOAMiddleware, # pyright: ignore[reportUnknownVariableType]
)
from web3.types import BlockData, TxData

from flare_ai_kit.rag.graph.settings import GraphDbSettings


class Neo4jIngester:
def __init__(self, settings: GraphDbSettings):
if settings.neo4j_password is None:
raise ValueError("Neo4j password must be set")
self.driver = GraphDatabase.driver(
settings.neo4j_uri, auth=("neo4j", settings.neo4j_password)
)
self.web3 = AsyncWeb3(
AsyncHTTPProvider(str(settings.web3_provider_url)),
middleware=[ExtraDataToPOAMiddleware],
)

def close(self):
self.driver.close()

def ingest_transactions(self, transactions: list[dict[str, Any]]) -> None:
with self.driver.session(database="neo4j") as session:
session.execute_write(self._create_transaction_nodes, transactions)

@staticmethod
def _create_transaction_nodes(
tx: ManagedTransaction, transactions: list[dict[str, Any]]
) -> None:
# Normalize hashes to lowercase hex strings
for t in transactions:
if isinstance(t.get("hash"), (bytes, bytearray)):
t["hash"] = t["hash"].hex()
elif isinstance(t.get("hash"), str) and t["hash"].startswith("0x"):
t["hash"] = t["hash"][2:] # strip 0x
elif isinstance(t.get("hash"), str):
t["hash"] = t["hash"].lower()

tx.run(
"""
UNWIND $transactions AS tx_data

// Create sender node (User or Contract)
MERGE (from:Account {address: toLower(tx_data.from)})
ON CREATE SET from.created_at = timestamp()
SET from.balance = tx_data.from_balance

// Create recipient node and check if it's a contract
MERGE (to:Account {address: toLower(tx_data.to)})
ON CREATE SET to.created_at = timestamp()
SET to.balance = tx_data.to_balance

// Optional: Label contract if code is present
FOREACH (_ IN CASE WHEN tx_data.to_code IS NOT NULL AND tx_data.to_code <> '0x' THEN [1] ELSE [] END |
SET to:Contract
)
FOREACH (_ IN CASE WHEN tx_data.from_code IS NOT NULL AND tx_data.from_code <> '0x' THEN [1] ELSE [] END |
SET from:Contract
)

// Transaction node
MERGE (tx:Transaction {hash: tx_data.hash})
SET tx.blockNumber = tx_data.blockNumber,
tx.timestamp = datetime({ epochMillis: tx_data.timestamp }),
tx.value = tx_data.value

MERGE (from)-[:FROM]->(tx)
MERGE (tx)-[:TO]->(to)
""",
parameters={"transactions": transactions},
)

async def fetch_block_transactions(self, block_number: int) -> list[dict[str, Any]]:
block: BlockData | None = await self.web3.eth.get_block(
block_number, full_transactions=True
)
transactions: list[dict[str, Any]] = []

if not block:
return transactions

block_timestamp = block.get("timestamp")
if block_timestamp is None:
return transactions

tx_list = block.get("transactions", [])
for tx_raw in tx_list:
if not isinstance(tx_raw, dict):
continue

tx: TxData = tx_raw

hash_val = tx.get("hash")
block_number_val = tx.get("blockNumber")
value = tx.get("value")
from_address = cast("ChecksumAddress", tx.get("from"))
to_address = cast("ChecksumAddress", tx.get("to"))

# Normalize hash into canonical 0x-prefixed string
if isinstance(hash_val, (bytes, bytearray)):
hash_str = "0x" + hash_val.hex()
else:
hash_str = str(hash_val).lower()
if not hash_str.startswith("0x"):
hash_str = "0x" + hash_str

if not all([hash_str, block_number_val, value, from_address]):
continue

from_balance = await self.web3.eth.get_balance(from_address)
from_code = await self.web3.eth.get_code(from_address)

to_balance = await self.web3.eth.get_balance(to_address)
to_code = await self.web3.eth.get_code(to_address)

transactions.append(
{
"hash": hash_str,
"from": from_address,
"to": to_address,
"blockNumber": block_number_val,
"timestamp": block_timestamp * 1000,
"value": str(value),
"from_balance": str(from_balance),
"to_balance": str(to_balance),
"from_code": from_code.hex(),
"to_code": to_code.hex(),
}
)

return transactions

async def batch_ingest(self, start_block: int, count: int) -> None:
for i in range(start_block, start_block + count):
txs = await self.fetch_block_transactions(i)
if txs:
print(f"Ingesting block {i} with {len(txs)} transactions")
self.ingest_transactions(txs)
else:
print(f"Block {i} has no transactions, skipping")
self.close()
Loading