diff --git a/src/xian/services/bds/bds.py b/src/xian/services/bds/bds.py index cca1af4..52b98d7 100644 --- a/src/xian/services/bds/bds.py +++ b/src/xian/services/bds/bds.py @@ -1,5 +1,6 @@ import json +from copy import deepcopy from loguru import logger from datetime import datetime from xian.services.bds import sql @@ -9,28 +10,60 @@ from xian.services.bds.database import DB, result_to_json from xian_py.wallet import key_is_valid from timeit import default_timer as timer +from decimal import Decimal # Custom JSON encoder for our own objects +def strip_trailing_zeros(s: str) -> str: + if '.' in s: + s = s.rstrip('0').rstrip('.') + return s + + +def set_nested_dict_value(d, keys, value): + """Set a value in a nested dictionary using a list of keys.""" + for key in keys[:-1]: + d = d.setdefault(key, {}) + if keys: + d[keys[-1]] = value + else: + # No keys, set value at the current level + d.update(value) + + +def merge_dicts(a, b): + """Recursively merge dictionary b into dictionary a.""" + for key in b: + if key in a and isinstance(a[key], dict) and isinstance(b[key], dict): + merge_dicts(a[key], b[key]) + else: + a[key] = b[key] + + +# Encodes everything to string - except for unknown objects class CustomEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, ContractingDecimal): - v = float(str(obj)) - return int(v) if v.is_integer() else v - if isinstance(obj, Datetime): - # Convert to ISO 8601 format string with microseconds + return strip_trailing_zeros(str(obj)) + elif isinstance(obj, Decimal): + return strip_trailing_zeros(str(obj)) + elif isinstance(obj, Datetime): return obj._datetime.isoformat(timespec='microseconds') - if isinstance(obj, Timedelta): - # Convert to total seconds with microseconds - return obj._timedelta.total_seconds() - return super().default(obj) - + elif isinstance(obj, Timedelta): + total_seconds = str(obj._timedelta.total_seconds()) + return strip_trailing_zeros(total_seconds) + elif isinstance(obj, int): + return str(obj) + else: + return super().default(obj) + + # To recursively process and handle custom types within nested structures def encode(self, obj): def process(o): if isinstance(o, dict): if len(o) == 1: if '__fixed__' in o: - return float(o['__fixed__']) + return strip_trailing_zeros(str(o['__fixed__'])) elif '__time__' in o: # Convert __time__ list to ISO 8601 string time_list = o['__time__'] @@ -39,18 +72,28 @@ def process(o): dt_obj = datetime(*time_list) # Convert to ISO 8601 string with microseconds return dt_obj.isoformat(timespec='microseconds') - else: - return {k: process(v) for k, v in o.items()} - else: - return {k: process(v) for k, v in o.items()} + # Process nested dictionaries and convert keys to strings + return {str(k): process(v) for k, v in o.items()} elif isinstance(o, list): + # Process each item in the list return [process(v) for v in o] + elif isinstance(o, ContractingDecimal): + return strip_trailing_zeros(str(o)) + elif isinstance(o, Decimal): + return strip_trailing_zeros(str(o)) elif isinstance(o, Datetime): + # Serialize datetime as ISO formatted string return o._datetime.isoformat(timespec='microseconds') elif isinstance(o, Timedelta): - return o._timedelta.total_seconds() + # Serialize total seconds as a string + total_seconds = str(o._timedelta.total_seconds()) + return strip_trailing_zeros(total_seconds) + elif isinstance(o, int): + return str(o) else: + # Return the object as-is if it doesn't match any custom types return o + # Encode the processed object return super().encode(process(obj)) @@ -72,13 +115,14 @@ async def init(self, cometbft_genesis: dict): async def process_genesis_block(self, cometbft_genesis: dict): start_time = timer() + genesis_state = cometbft_genesis["abci_genesis"]["genesis"] # insert genesis txn await self.insert_genesis_txn(genesis_state) # process each item in the genesis block - for index, state in enumerate(genesis_state): + for index, state in enumerate(genesis_state): logger.debug(f"processing item {index} from genesis_state") parts = state["key"].split(".") @@ -89,7 +133,7 @@ async def process_genesis_block(self, cometbft_genesis: dict): await self.insert_genesis_state_change(state["key"], state["value"]) await self.insert_genesis_state(state["key"], state["value"]) - logger.debug(f'Processed genesis block in {timer() - start_time:.3f} seconds') + logger.debug(f'Saved genesis block to BDS in {timer() - start_time:.3f} seconds') async def __init_tables(self): try: @@ -105,42 +149,20 @@ async def __init_tables(self): logger.exception(e) - async def insert_full_data(self, tx: dict, block_time: datetime): - total_time = timer() - - # Tx - start_time = timer() + async def add_to_batch(self, tx: dict, block_time: datetime): await self._insert_tx(tx, block_time) - logger.debug(f'Saved tx in {timer() - start_time:.3f} seconds') - - # State - start_time = timer() await self._insert_state(tx, block_time) - logger.debug(f'Saved contracts in {timer() - start_time:.3f} seconds') - - # State changes - start_time = timer() await self._insert_state_changes(tx, block_time) - logger.debug(f'Saved state changes in {timer() - start_time:.3f} seconds') - - # Rewards - start_time = timer() await self._insert_rewards(tx, block_time) - logger.debug(f'Saved rewards in {timer() - start_time:.3f} seconds') - - # Addresses - start_time = timer() await self._insert_addresses(tx, block_time) - logger.debug(f'Saved addresses in {timer() - start_time:.3f} seconds') + await self._insert_contracts(tx, block_time) - # Contracts - # Only save contracts if tx was successful - if tx["tx_result"]["status"] == 0: - start_time = timer() - await self._insert_contracts(tx, block_time) - logger.debug(f'Saved contracts in {timer() - start_time:.3f} seconds') + async def commit_batch(self): + if len(self.db.batch) == 0: return - logger.debug(f'Processed tx {tx["tx_result"]["hash"]} in {timer() - total_time:.3f} seconds') + start_time = timer() + await self.db.commit_batch_to_disk() + logger.debug(f'Saved block to BDS in {timer() - start_time:.3f} seconds') async def _insert_tx(self, tx: dict, block_time: datetime): status = True if tx['tx_result']['status'] == 0 else False @@ -180,16 +202,59 @@ async def _insert_state_changes(self, tx: dict, block_time: datetime): logger.exception(e) async def _insert_state(self, tx: dict, block_time: datetime): + # Collect state changes by contract + contract_states = {} + for state_change in tx['tx_result']['state']: + key = state_change['key'] + value = state_change['value'] + + # Parse the key to get contract name and variable path + parts = key.split('.', 1) + if len(parts) == 2: + contract_name, rest_of_key = parts + else: + # Handle keys without a dot (unlikely but possible) + contract_name = parts[0] + rest_of_key = '' + + key_path = rest_of_key.split(':') if rest_of_key else [] + + # Initialize the contract state dictionary + contract_state = contract_states.setdefault(contract_name, {}) + + # Set the nested value in the contract's state dictionary + set_nested_dict_value(contract_state, key_path, value) + + # For each contract, merge state and update the 'state' table + for contract_name, state_dict in contract_states.items(): try: + # Fetch existing state from the database + existing_state_row = await self.db.fetch_one(sql.select_state_by_key(), [contract_name]) + if existing_state_row and existing_state_row['value'] not in [None, 'null']: + existing_state_json = existing_state_row['value'] + # Ensure existing_state is a dictionary + if isinstance(existing_state_json, str): + existing_state = json.loads(existing_state_json) + else: + existing_state = existing_state_json + else: + existing_state = {} + + # Deep copy to avoid mutating the original + merged_state = deepcopy(existing_state) + + # Merge the new state changes into the existing state + merge_dicts(merged_state, state_dict) + + # Update the 'state' table with the merged state self.db.add_query_to_batch(sql.insert_or_update_state(), [ - state_change['key'], - json.dumps(state_change['value'], cls=CustomEncoder), + contract_name, + json.dumps(merged_state, cls=CustomEncoder), block_time ]) - except Exception as e: - logger.exception(e) + logger.exception(f"Error updating state for contract '{contract_name}': {e}") async def _insert_rewards(self, tx: dict, block_time: datetime): async def insert(type, key, value): @@ -198,7 +263,7 @@ async def insert(type, key, value): tx['tx_result']['hash'], type, key, - json.dumps(value, cls=CustomEncoder), + strip_trailing_zeros(str(value)), block_time ]) @@ -208,21 +273,21 @@ async def insert(type, key, value): # Developer reward for address, reward in rewards['developer_reward'].items(): try: - await insert('developer', address, float(reward)) + await insert('developer', address, reward) except Exception as e: logger.exception(e) # Masternode reward for address, reward in rewards['masternode_reward'].items(): try: - await insert('masternode', address, float(reward)) + await insert('masternode', address, reward) except Exception as e: logger.exception(e) # Foundation reward for address, reward in rewards['foundation_reward'].items(): try: - await insert('foundation', address, float(reward)) + await insert('foundation', address, reward) except Exception as e: logger.exception(e) @@ -241,6 +306,9 @@ async def _insert_addresses(self, tx: dict, block_time: datetime): logger.exception(e) async def _insert_contracts(self, tx: dict, block_time: datetime): + # Only save contracts if tx was successful + if tx["tx_result"]["status"] != 0: return + if tx['payload']['contract'] == 'submission' and tx['payload']['function'] == 'submit_contract': try: self.db.add_query_to_batch(sql.insert_contracts(), [ @@ -363,7 +431,7 @@ async def insert_genesis_state_contract(self, contract_name, code, submission_ti submission_time ]) except Exception as e: - logger.exception(e) + logger.exception(e) async def insert_genesis_state_change(self, key, value): try: diff --git a/src/xian/services/bds/database.py b/src/xian/services/bds/database.py index b0efdbb..667965b 100644 --- a/src/xian/services/bds/database.py +++ b/src/xian/services/bds/database.py @@ -94,6 +94,10 @@ async def fetch(self, query: str, params: list = []): logger.exception(f'Error while executing SQL: {e}') raise e + async def fetch_one(self, query: str, params: list): + async with self.pool.acquire() as connection: + return await connection.fetchrow(query, *params) + async def has_entries(self, table_name: str) -> bool: try: result = await self.fetch(f"SELECT COUNT(*) as count FROM {table_name}") diff --git a/src/xian/services/bds/sql.py b/src/xian/services/bds/sql.py index eb5726e..7d7e686 100644 --- a/src/xian/services/bds/sql.py +++ b/src/xian/services/bds/sql.py @@ -219,6 +219,12 @@ def select_state(): """ +def select_state_by_key(): + return """ + SELECT value FROM state WHERE key = $1; + """ + + def select_state_history(): return """ SELECT