Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 123 additions & 55 deletions src/xian/services/bds/bds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

from copy import deepcopy
from loguru import logger
from datetime import datetime
from xian.services.bds import sql
Expand All @@ -9,28 +10,60 @@
from xian.services.bds.database import DB, result_to_json
from xian_py.wallet import key_is_valid
from timeit import default_timer as timer
from decimal import Decimal


# Custom JSON encoder for our own objects
def strip_trailing_zeros(s: str) -> str:
if '.' in s:
s = s.rstrip('0').rstrip('.')
return s


def set_nested_dict_value(d, keys, value):
"""Set a value in a nested dictionary using a list of keys."""
for key in keys[:-1]:
d = d.setdefault(key, {})
if keys:
d[keys[-1]] = value
else:
# No keys, set value at the current level
d.update(value)


def merge_dicts(a, b):
"""Recursively merge dictionary b into dictionary a."""
for key in b:
if key in a and isinstance(a[key], dict) and isinstance(b[key], dict):
merge_dicts(a[key], b[key])
else:
a[key] = b[key]


# Encodes everything to string - except for unknown objects
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, ContractingDecimal):
v = float(str(obj))
return int(v) if v.is_integer() else v
if isinstance(obj, Datetime):
# Convert to ISO 8601 format string with microseconds
return strip_trailing_zeros(str(obj))
elif isinstance(obj, Decimal):
return strip_trailing_zeros(str(obj))
elif isinstance(obj, Datetime):
return obj._datetime.isoformat(timespec='microseconds')
if isinstance(obj, Timedelta):
# Convert to total seconds with microseconds
return obj._timedelta.total_seconds()
return super().default(obj)

elif isinstance(obj, Timedelta):
total_seconds = str(obj._timedelta.total_seconds())
return strip_trailing_zeros(total_seconds)
elif isinstance(obj, int):
return str(obj)
else:
return super().default(obj)

# To recursively process and handle custom types within nested structures
def encode(self, obj):
def process(o):
if isinstance(o, dict):
if len(o) == 1:
if '__fixed__' in o:
return float(o['__fixed__'])
return strip_trailing_zeros(str(o['__fixed__']))
elif '__time__' in o:
# Convert __time__ list to ISO 8601 string
time_list = o['__time__']
Expand All @@ -39,18 +72,28 @@ def process(o):
dt_obj = datetime(*time_list)
# Convert to ISO 8601 string with microseconds
return dt_obj.isoformat(timespec='microseconds')
else:
return {k: process(v) for k, v in o.items()}
else:
return {k: process(v) for k, v in o.items()}
# Process nested dictionaries and convert keys to strings
return {str(k): process(v) for k, v in o.items()}
elif isinstance(o, list):
# Process each item in the list
return [process(v) for v in o]
elif isinstance(o, ContractingDecimal):
return strip_trailing_zeros(str(o))
elif isinstance(o, Decimal):
return strip_trailing_zeros(str(o))
elif isinstance(o, Datetime):
# Serialize datetime as ISO formatted string
return o._datetime.isoformat(timespec='microseconds')
elif isinstance(o, Timedelta):
return o._timedelta.total_seconds()
# Serialize total seconds as a string
total_seconds = str(o._timedelta.total_seconds())
return strip_trailing_zeros(total_seconds)
elif isinstance(o, int):
return str(o)
else:
# Return the object as-is if it doesn't match any custom types
return o
# Encode the processed object
return super().encode(process(obj))


Expand All @@ -72,13 +115,14 @@ async def init(self, cometbft_genesis: dict):

async def process_genesis_block(self, cometbft_genesis: dict):
start_time = timer()

genesis_state = cometbft_genesis["abci_genesis"]["genesis"]

# insert genesis txn
await self.insert_genesis_txn(genesis_state)

# process each item in the genesis block
for index, state in enumerate(genesis_state):
for index, state in enumerate(genesis_state):
logger.debug(f"processing item {index} from genesis_state")
parts = state["key"].split(".")

Expand All @@ -89,7 +133,7 @@ async def process_genesis_block(self, cometbft_genesis: dict):
await self.insert_genesis_state_change(state["key"], state["value"])
await self.insert_genesis_state(state["key"], state["value"])

logger.debug(f'Processed genesis block in {timer() - start_time:.3f} seconds')
logger.debug(f'Saved genesis block to BDS in {timer() - start_time:.3f} seconds')

async def __init_tables(self):
try:
Expand All @@ -105,42 +149,20 @@ async def __init_tables(self):
logger.exception(e)


async def insert_full_data(self, tx: dict, block_time: datetime):
total_time = timer()

# Tx
start_time = timer()
async def add_to_batch(self, tx: dict, block_time: datetime):
await self._insert_tx(tx, block_time)
logger.debug(f'Saved tx in {timer() - start_time:.3f} seconds')

# State
start_time = timer()
await self._insert_state(tx, block_time)
logger.debug(f'Saved contracts in {timer() - start_time:.3f} seconds')

# State changes
start_time = timer()
await self._insert_state_changes(tx, block_time)
logger.debug(f'Saved state changes in {timer() - start_time:.3f} seconds')

# Rewards
start_time = timer()
await self._insert_rewards(tx, block_time)
logger.debug(f'Saved rewards in {timer() - start_time:.3f} seconds')

# Addresses
start_time = timer()
await self._insert_addresses(tx, block_time)
logger.debug(f'Saved addresses in {timer() - start_time:.3f} seconds')
await self._insert_contracts(tx, block_time)

# Contracts
# Only save contracts if tx was successful
if tx["tx_result"]["status"] == 0:
start_time = timer()
await self._insert_contracts(tx, block_time)
logger.debug(f'Saved contracts in {timer() - start_time:.3f} seconds')
async def commit_batch(self):
if len(self.db.batch) == 0: return

logger.debug(f'Processed tx {tx["tx_result"]["hash"]} in {timer() - total_time:.3f} seconds')
start_time = timer()
await self.db.commit_batch_to_disk()
logger.debug(f'Saved block to BDS in {timer() - start_time:.3f} seconds')

async def _insert_tx(self, tx: dict, block_time: datetime):
status = True if tx['tx_result']['status'] == 0 else False
Expand Down Expand Up @@ -180,16 +202,59 @@ async def _insert_state_changes(self, tx: dict, block_time: datetime):
logger.exception(e)

async def _insert_state(self, tx: dict, block_time: datetime):
# Collect state changes by contract
contract_states = {}

for state_change in tx['tx_result']['state']:
key = state_change['key']
value = state_change['value']

# Parse the key to get contract name and variable path
parts = key.split('.', 1)
if len(parts) == 2:
contract_name, rest_of_key = parts
else:
# Handle keys without a dot (unlikely but possible)
contract_name = parts[0]
rest_of_key = ''

key_path = rest_of_key.split(':') if rest_of_key else []

# Initialize the contract state dictionary
contract_state = contract_states.setdefault(contract_name, {})

# Set the nested value in the contract's state dictionary
set_nested_dict_value(contract_state, key_path, value)

# For each contract, merge state and update the 'state' table
for contract_name, state_dict in contract_states.items():
try:
# Fetch existing state from the database
existing_state_row = await self.db.fetch_one(sql.select_state_by_key(), [contract_name])
if existing_state_row and existing_state_row['value'] not in [None, 'null']:
existing_state_json = existing_state_row['value']
# Ensure existing_state is a dictionary
if isinstance(existing_state_json, str):
existing_state = json.loads(existing_state_json)
else:
existing_state = existing_state_json
else:
existing_state = {}

# Deep copy to avoid mutating the original
merged_state = deepcopy(existing_state)

# Merge the new state changes into the existing state
merge_dicts(merged_state, state_dict)

# Update the 'state' table with the merged state
self.db.add_query_to_batch(sql.insert_or_update_state(), [
state_change['key'],
json.dumps(state_change['value'], cls=CustomEncoder),
contract_name,
json.dumps(merged_state, cls=CustomEncoder),
block_time
])

except Exception as e:
logger.exception(e)
logger.exception(f"Error updating state for contract '{contract_name}': {e}")

async def _insert_rewards(self, tx: dict, block_time: datetime):
async def insert(type, key, value):
Expand All @@ -198,7 +263,7 @@ async def insert(type, key, value):
tx['tx_result']['hash'],
type,
key,
json.dumps(value, cls=CustomEncoder),
strip_trailing_zeros(str(value)),
block_time
])

Expand All @@ -208,21 +273,21 @@ async def insert(type, key, value):
# Developer reward
for address, reward in rewards['developer_reward'].items():
try:
await insert('developer', address, float(reward))
await insert('developer', address, reward)
except Exception as e:
logger.exception(e)

# Masternode reward
for address, reward in rewards['masternode_reward'].items():
try:
await insert('masternode', address, float(reward))
await insert('masternode', address, reward)
except Exception as e:
logger.exception(e)

# Foundation reward
for address, reward in rewards['foundation_reward'].items():
try:
await insert('foundation', address, float(reward))
await insert('foundation', address, reward)
except Exception as e:
logger.exception(e)

Expand All @@ -241,6 +306,9 @@ async def _insert_addresses(self, tx: dict, block_time: datetime):
logger.exception(e)

async def _insert_contracts(self, tx: dict, block_time: datetime):
# Only save contracts if tx was successful
if tx["tx_result"]["status"] != 0: return

if tx['payload']['contract'] == 'submission' and tx['payload']['function'] == 'submit_contract':
try:
self.db.add_query_to_batch(sql.insert_contracts(), [
Expand Down Expand Up @@ -363,7 +431,7 @@ async def insert_genesis_state_contract(self, contract_name, code, submission_ti
submission_time
])
except Exception as e:
logger.exception(e)
logger.exception(e)

async def insert_genesis_state_change(self, key, value):
try:
Expand Down
4 changes: 4 additions & 0 deletions src/xian/services/bds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ async def fetch(self, query: str, params: list = []):
logger.exception(f'Error while executing SQL: {e}')
raise e

async def fetch_one(self, query: str, params: list):
async with self.pool.acquire() as connection:
return await connection.fetchrow(query, *params)

async def has_entries(self, table_name: str) -> bool:
try:
result = await self.fetch(f"SELECT COUNT(*) as count FROM {table_name}")
Expand Down
6 changes: 6 additions & 0 deletions src/xian/services/bds/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def select_state():
"""


def select_state_by_key():
return """
SELECT value FROM state WHERE key = $1;
"""


def select_state_history():
return """
SELECT
Expand Down