diff --git a/Cargo.lock b/Cargo.lock index 151d3ab34..5768fa2d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1905,10 +1905,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.18" +version = "1.2.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" +checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -3416,6 +3417,12 @@ dependencies = [ "scale-info", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" + [[package]] name = "finito" version = "0.1.0" @@ -7673,12 +7680,11 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.8.0" +version = "2.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "198db74531d58c70a361c42201efde7e2591e976d518caf7662a47dc5720e7b6" +checksum = "989e7521a040efde50c3ab6bbadafbe15ab6dc042686926be59ac35d74607df4" dependencies = [ "memchr", - "thiserror 2.0.18", "ucd-trie", ] @@ -8258,8 +8264,10 @@ dependencies = [ "scale-value", "serde", "serde_json", + "ss58-registry", "subxt", "subxt-rpcs", + "tempfile", "thiserror 2.0.18", "tokio", "tracing-subscriber 0.3.22", @@ -11757,9 +11765,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.14" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" dependencies = [ "bytes", "futures-core", diff --git a/Cargo.toml b/Cargo.toml index fd9b6dd2c..741e8f13e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ pin-project-lite = "0.2" scale-value = "0.18" subxt = { version = "0.44.2", features = ["reconnecting-rpc-client"] } subxt-rpcs = "0.44.0" - +ss58-registry = "1.9" # polkadot-sdk polkadot-sdk = { git = "https://github.com/paritytech/polkadot-sdk", features = [ @@ -61,10 +61,11 @@ once_cell = "1.21" anyhow = "1" assert_cmd = "2.1" regex = "1" +tempfile = "3.8" [features] integration-tests = [] [package.metadata.docs.rs] default-features = true -rustdoc-args = ["--cfg", "docsrs"] +rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file diff --git a/Dockerfile.README.md b/Dockerfile.README.md index 0b8f2c9fb..6c05b334b 100644 --- a/Dockerfile.README.md +++ b/Dockerfile.README.md @@ -2,4 +2,4 @@ [GitHub](https://github.com/paritytech/polkadot-staking-miner) -Formerly known as `staking-miner-v2` historical images versions are available in the [hub.docker.com](https://hub.docker.com/r/paritytech/staking-miner-v2) \ No newline at end of file +Formerly known as `staking-miner-v2` historical images versions are available in the [hub.docker.com](https://hub.docker.com/r/paritytech/staking-miner-v2) diff --git a/README.md b/README.md index b4fec5376..257b67655 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,7 @@ Here are some notable options you can use with the command: | `--min-signed-phase-blocks ` | Minimum number of blocks required in the signed phase before submitting a solution. | 10 | | `--balancing-iterations ` | Number of balancing iterations for the sequential phragmen algorithm. Higher values may produce better balanced solutions at the | 10 | | | cost of more computation time. | | +| `--algorithm ` | Election algorithm to use for mining solutions. Supported: `seq-phragmen`, `phragmms`. | `seq-phragmen` | Refer to `--help` for the full list of options. @@ -239,6 +240,176 @@ The miner uses the on-chain nonce for a given user to submit solutions, which ca collisions if multiple miners are running with the same account. This can cause transaction failures and potentially result in lost rewards or other issues. +## Predict + +The `predict` command allows you to predict validator election outcomes for Substrate-based chains +without running a full node. It fetches the necessary staking data from the chain and runs the same +Phragmén algorithm that the chain uses to determine validator sets. + +### Basic Usage + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --do-reduce +``` + +### Command Options + +| Option | Description | Default Value | +| :------------------------------ | :---------------------------------------------------------------------------------------------------------------------------- | :----------------- | +| `--desired-validators ` | Desired number of validators for the prediction | Fetched from chain | +| `--overrides ` | Path to election overrides JSON file (see format below) | None | +| `--output-dir ` | Output directory for prediction results | `results` | +| `--balancing-iterations `| Number of balancing iterations for the sequential phragmen algorithm. Higher values may produce better balanced solutions at the cost of more computation time. | 10 | +| `--do-reduce` | Reduce the solution to prevent further trimming. | `false` | +| `--algorithm ` | Election algorithm to use. Supported: `seq-phragmen`, `phragmms`. | `seq-phragmen` | +| `--block-number ` | Block number at which to run the prediction. If not specified, uses the latest block. | Latest block | + +### Examples + +#### Basic Prediction + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --do-reduce +``` + +#### With Desired Validators + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --desired-validators 50 --do-reduce +``` + +#### Using Election Overrides File + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --overrides overrides.json +``` + +#### Prediction at a Specific Block + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --block-number 13196110 --do-reduce +``` + +#### Run Prediction with reduction + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --do-reduce +``` + +#### Run Prediction with PhragMMS algorithm + +```bash +cargo run -- --uri wss://westend-asset-hub-rpc.polkadot.io predict --algorithm phragmms +``` + +### Output Files + +The tool generates the following JSON files in the specified output directory: + +1. **`validators_prediction.json`**: Contains elected validators with their stake information +2. **`nominators_prediction.json`**: Contains nominator allocations and validator support + +#### Validators Prediction Format + +```json +{ + "metadata": { + "timestamp": "1765799538", + "desired_validators": 600, + "round": 40, + "block_number": 10803423, + "solution_score": { + "minimal_stake": 11797523289886283, + "sum_stake": 8372189060111758480, + "sum_stake_squared": 117584540059969491964159919300216042 + }, + "data_source": "snapshot" + }, + "results": [ + { + "account": "15roBmbe5NmRXb4imfmhKxSjH8k9J5xtHSrvYJKpmmCLoPqD", + "total_stake": "2372626.3933261476 DOT", + "self_stake": "0 DOT", + "nominator_count": 2, + "nominators": [ + { + "address": "121GCLDNk9ErAkCovjjuF3npDB3veo3i3myY6a5v2yNEgrZw", + "allocated_stake": "769476 DOT" + }, + { + "address": "14mtWxmkUHsWqJLxMiRR8qrHTHyck712E5yjWpnxPBEh8Acb", + "allocated_stake": "135680 DOT" + }, + ] + } + ] +} +``` + +#### Nominators Prediction Format + +```json +{ + "nominators": [ + { + "address": "15VArSaLFf3r9MzyQjcNTexjPoRDJuVVkqUmqtuUuBcPCYrX", + "stake": "447.2323363908 DOT", + "active_validators": [ + { + "validator": "15ZvLonEseaWZNy8LDkXXj3Y8bmAjxCjwvpy4pXWSL4nGSBs", + "allocated_stake": "447.2323363908 DOT" + } + ], + "inactive_validators": [ + "1627VVB5gtHiseCV8ZdffF7P3bWrLMkU92Q6u3LsG8tGuB63" + ], + "waiting_validators": [ + "13K6QTYBPMUFTbhZzqToKcfCiWbt4wDPHr3rUPyUessiPR61", + "15rb4HVycC1KLHsdaSdV1x2TJAmUkD7PhubmhL3PnGv7RiGY" + ] + } + ] +} +``` + +### Election Overrides File Format + +When using `--overrides`, the file should have the following JSON structure: + +```json +{ + "candidates_include": ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG"], + "candidates_exclude": [], + "voters_include": [ + ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG", 1000000, ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG"]] + ], + "voters_exclude": [] +} +``` + +**Note:** Override file paths can be nested (e.g., `data/elections/overrides.json`). The tool will +automatically resolve relative paths from the current working directory. + +### How Predict Command Works + +1. **Data Source**: The tool first tries to fetch data from the chain's snapshot (if available), + then falls back to the staking pallet. + +2. **Overrides Application**: If `--overrides` is provided, the tool applies the specified modifications to the fetched candidates and voters: + - (1) Add candidates that may not exist on-chain. + - (2) Remove specific candidates from the election. + - (3) Add or override voters with custom stake amounts. + - (4) Remove specific voters from the election. + +3. **Election Algorithm**: Runs the same Phragmén algorithm (`seq_phragmen`) used by Substrate chains + to determine: + - Which validators would be elected + - Stake distribution among validators + - Nominator allocations to validators + +4. **Output Generation**: Creates detailed JSON files with predictions, including validator and + nominator perspectives. + ## Update metadata The binary itself embeds [multi-block static metadata](./artifacts/multi_block.scale) to generate a diff --git a/src/client.rs b/src/client.rs index 855f29c1c..f7be13200 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,6 +12,7 @@ use std::{ }; use subxt::backend::{ chain_head::{ChainHeadBackend, ChainHeadBackendBuilder}, + legacy::LegacyBackend, rpc::reconnecting_rpc_client::{ExponentialBackoff, RpcClient as ReconnectingRpcClient}, }; use tokio::sync::RwLock; @@ -52,7 +53,7 @@ pub struct Client { } impl Client { - /// Create a new client from a comma-separated list of RPC endpoints. + /// Create a new client from a comma-separated list of RPC endpoints using ChainHeadBackend. /// /// The client will try each endpoint in sequence until one connects successfully. /// Multiple endpoints can be specified for failover: @@ -72,7 +73,38 @@ impl Client { log::info!(target: LOG_TARGET, "RPC endpoint pool: {} endpoint(s)", endpoints.len()); } - let (chain_api, connected_index) = Self::connect_with_failover(&endpoints, 0).await?; + let (chain_api, connected_index) = + Self::connect_with_failover(&endpoints, 0, false).await?; + + Ok(Self { + chain_api: Arc::new(RwLock::new(chain_api)), + endpoints: Arc::new(endpoints), + current_endpoint_index: Arc::new(AtomicUsize::new(connected_index)), + reconnect_generation: Arc::new(AtomicUsize::new(0)), + }) + } + + /// Create a new client from a comma-separated list of RPC endpoints using LegacyBackend. + /// + /// The client will try each endpoint in sequence until one connects successfully. + /// Multiple endpoints can be specified for failover: + /// "wss://rpc1.example.com,wss://rpc2.example.com" + pub async fn new_with_legacy_backend(uris: &str) -> Result { + let endpoints: Vec = uris + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if endpoints.is_empty() { + return Err(Error::Other("No RPC endpoints provided".into())); + } + + if endpoints.len() > 1 { + log::info!(target: LOG_TARGET, "RPC endpoint pool: {} endpoint(s)", endpoints.len()); + } + + let (chain_api, connected_index) = Self::connect_with_failover(&endpoints, 0, true).await?; Ok(Self { chain_api: Arc::new(RwLock::new(chain_api)), @@ -95,6 +127,7 @@ impl Client { async fn connect_with_failover( endpoints: &[String], start_index: usize, + use_legacy: bool, ) -> Result<(ChainClient, usize), Error> { let mut last_error = None; let total = endpoints.len(); @@ -112,7 +145,7 @@ impl Client { "attempting to connect to {uri:?} (endpoint {endpoint_num}/{total}, attempt {attempt}/{max_attempts})" ); - match Self::try_connect(uri).await { + match Self::try_connect(uri, use_legacy).await { Ok(client) => { if total > 1 { log::info!( @@ -150,8 +183,8 @@ impl Client { } /// Try to connect to a single endpoint with timeout. - async fn try_connect(uri: &str) -> Result { - let connect_future = async { + async fn try_connect(uri: &str, use_legacy: bool) -> Result { + let connect_future = async move { let reconnecting_rpc = ReconnectingRpcClient::builder() .retry_policy( ExponentialBackoff::from_millis(500).max_delay(Duration::from_secs(10)).take(3), @@ -160,11 +193,18 @@ impl Client { .await .map_err(|e| Error::Other(format!("Failed to connect: {e:?}")))?; - let backend: ChainHeadBackend = - ChainHeadBackendBuilder::default().build_with_background_driver(reconnecting_rpc); - let chain_api = ChainClient::from_backend(Arc::new(backend)).await?; - - log::info!(target: LOG_TARGET, "Connected to {uri} with ChainHead backend"); + let chain_api = if use_legacy { + let backend = LegacyBackend::builder().build(reconnecting_rpc.clone()); + let client = ChainClient::from_backend(Arc::new(backend)).await?; + log::info!(target: LOG_TARGET, "Connected to {uri} with Legacy backend"); + client + } else { + let backend: ChainHeadBackend = ChainHeadBackendBuilder::default() + .build_with_background_driver(reconnecting_rpc); + let client = ChainClient::from_backend(Arc::new(backend)).await?; + log::info!(target: LOG_TARGET, "Connected to {uri} with ChainHead backend"); + client + }; Ok::(chain_api) }; @@ -216,7 +256,7 @@ impl Client { // Establish new connection before acquiring write lock let (new_client, connected_idx) = - Self::connect_with_failover(&self.endpoints, start_idx).await?; + Self::connect_with_failover(&self.endpoints, start_idx, false).await?; // Acquire write lock and check if another task already reconnected let mut guard = self.chain_api.write().await; @@ -255,4 +295,9 @@ impl Client { pub async fn chain_api(&self) -> tokio::sync::RwLockReadGuard<'_, ChainClient> { self.chain_api.read().await } + + /// Get the currently connected endpoint + pub async fn current_endpoint(&self) -> String { + self.endpoints[self.current_endpoint_index.load(Ordering::Relaxed)].clone() + } } diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 0f6fc2f08..62ec1df22 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,4 +1,5 @@ //! Supported commands for the polkadot-staking-miner and related types. pub mod multi_block; +pub mod predict; pub mod types; diff --git a/src/commands/multi_block/monitor.rs b/src/commands/multi_block/monitor.rs index 787278658..5b2503803 100644 --- a/src/commands/multi_block/monitor.rs +++ b/src/commands/multi_block/monitor.rs @@ -147,17 +147,17 @@ where let last_round = state.prev_round; state.prev_round = Some(current_round); - if let Some(last_round) = last_round { - if current_round > last_round { - on_round_increment( - last_round, - current_round, - &phase, - &channels.miner_tx, - &channels.clear_old_rounds_tx, - ) - .await?; - } + if let Some(last_round) = last_round && + current_round > last_round + { + on_round_increment( + last_round, + current_round, + &phase, + &channels.miner_tx, + &channels.clear_old_rounds_tx, + ) + .await?; } let last_phase = state.prev_phase.clone(); @@ -200,17 +200,16 @@ where return Ok(ListenerAction::Continue); }, Phase::Signed(_) | Phase::Snapshot(_) => { - if let Some(last_phase) = last_phase { - if matches!(&last_phase, Phase::Off) { - log::debug!(target: LOG_TARGET, "Phase transition: Off → {phase:?} - stopping era pruning"); - // Use send() to ensure ExitOffPhase is delivered and to stop era pruning before - // mining starts - if let Err(e) = - channels.era_pruning_tx.send(EraPruningMessage::ExitOffPhase).await - { - log::error!(target: LOG_TARGET, "Era pruning channel closed while sending ExitOffPhase: {e:?}"); - return Err(ChannelFailureError::EraPruning.into()); - } + if let Some(last_phase) = last_phase && + matches!(&last_phase, Phase::Off) + { + log::debug!(target: LOG_TARGET, "Phase transition: Off → {phase:?} - stopping era pruning"); + // Use send() to ensure ExitOffPhase is delivered and to stop era pruning before + // mining starts + if let Err(e) = channels.era_pruning_tx.send(EraPruningMessage::ExitOffPhase).await + { + log::error!(target: LOG_TARGET, "Era pruning channel closed while sending ExitOffPhase: {e:?}"); + return Err(ChannelFailureError::EraPruning.into()); } } // Continue with mining logic for Signed/Snapshot phases @@ -587,6 +586,7 @@ where T::MaxVotesPerVoter: Send + Sync + 'static, { crate::dynamic::set_balancing_iterations(config.balancing_iterations); + crate::dynamic::set_algorithm(config.algorithm); let signer = Signer::new(&config.seed_or_path)?; diff --git a/src/commands/predict.rs b/src/commands/predict.rs new file mode 100644 index 000000000..ff932179a --- /dev/null +++ b/src/commands/predict.rs @@ -0,0 +1,207 @@ +//! Predict command implementation for election prediction +use polkadot_sdk::pallet_election_provider_multi_block::unsigned::miner::MinerConfig; + +use crate::{ + client::Client, + commands::types::{ElectionDataSource, ElectionOverrides, PredictConfig}, + dynamic::{ + election_data::{ + PredictionContext, apply_overrides, build_predictions_from_solution, + convert_election_data_to_snapshots, get_election_data, + }, + multi_block::mine_solution, + update_metadata_constants, + }, + error::Error, + prelude::{AccountId, LOG_TARGET}, + runtime::multi_block::{self as runtime}, + static_types::multi_block::Pages, + utils::{ + TimedFuture, get_block_hash, get_chain_properties, read_data_from_json_file, + write_data_to_json_file, + }, +}; + +/// Run the election prediction with the given configuration +pub async fn predict_cmd(client: Client, config: PredictConfig) -> Result<(), Error> +where + T: MinerConfig + Send + Sync + 'static, + T::Solution: Send, + T::Pages: Send, + T::TargetSnapshotPerBlock: Send, + T::VoterSnapshotPerBlock: Send, + T::MaxVotesPerVoter: Send, +{ + // Update metadata constants + update_metadata_constants(&*client.chain_api().await)?; + crate::dynamic::set_balancing_iterations(config.balancing_iterations); + crate::dynamic::set_algorithm(config.algorithm); + + let n_pages = Pages::get(); + + // Determine block number: use provided or latest + let block_number = if let Some(block_num) = config.block_number { + block_num + } else { + client + .chain_api() + .await + .blocks() + .at_latest() + .await + .map_err(|e| Error::Other(format!("Failed to fetch latest block number: {e}")))? + .number() + }; + + log::info!(target: LOG_TARGET, "Using block number: {block_number}"); + + // Get storage at the specified block number + let storage = if let Some(block_num) = config.block_number { + // Get block hash from block number + let block_hash = get_block_hash(&client, block_num).await?; + + crate::utils::storage_at(Some(block_hash), &*client.chain_api().await).await? + } else { + client.chain_api().await.storage().at_latest().await? + }; + + let current_round = storage + .fetch_or_default(&runtime::storage().multi_block_election().round()) + .await?; + + let desired_targets = match config.desired_validators { + Some(targets) => targets, + None => { + // Fetch from chain + storage + .fetch(&runtime::storage().staking().validator_count()) + .await + .map_err(|e| { + Error::Other(format!("Failed to fetch Desired Targets from chain: {e}")) + })? + .ok_or_else(|| { + Error::Other("Desired validators not found in chain storage".to_string()) + })? + }, + }; + + // Fetch election data + let (candidates, nominators, data_source) = + get_election_data::(n_pages, current_round, storage).await?; + + // Apply overrides if provided + let (candidates, nominators) = if let Some(overrides_path) = &config.overrides { + log::info!(target: LOG_TARGET, "Applying overrides from {overrides_path}"); + let overrides: ElectionOverrides = read_data_from_json_file(overrides_path).await?; + apply_overrides(candidates, nominators, overrides)? + } else { + (candidates, nominators) + }; + + // Convert raw data to snapshots + let (target_snapshot, mut voter_snapshot) = + convert_election_data_to_snapshots::(candidates, nominators)?; + + // When fetching from staking data, voters come from BagsList in descending order (highest + // stake first). The SDK expects page 0 (lsp) to contain lowest stake voters and page n-1 + // (msp) to contain highest stake voters. Reversing ensures correct page assignment during + // pagination. + if matches!(data_source, ElectionDataSource::Staking) { + voter_snapshot.reverse(); + } + + log::debug!( + target: LOG_TARGET, + "Mining solution with desired_targets={}, candidates={}, voter pages={}", + desired_targets, + target_snapshot.len(), + voter_snapshot.len() + ); + + // Use actual voter page count, not the chain's max pages + // Staking/Overridden data may have added some pages + let n_pages = n_pages.max(voter_snapshot.len() as u32); + + // Mine the solution with timeout to prevent indefinite hanging + const MINING_TIMEOUT_SECS: u64 = 600; // 10 minutes + log::debug!(target: LOG_TARGET, "Mining solution for block #{block_number} round {current_round}"); + + let paged_raw_solution = match tokio::time::timeout( + std::time::Duration::from_secs(MINING_TIMEOUT_SECS), + mine_solution::( + target_snapshot.clone(), + voter_snapshot.clone(), + n_pages, + current_round, + desired_targets, + block_number, + config.do_reduce, + ) + .timed(), + ) + .await + { + Ok((Ok(sol), dur)) => { + log::info!(target: LOG_TARGET, "Mining solution took {}ms for block #{}", dur.as_millis(), block_number); + sol + }, + Ok((Err(e), dur)) => { + log::error!(target: LOG_TARGET, "Mining failed after {}ms: {:?}", dur.as_millis(), e); + return Err(e); + }, + Err(_) => { + log::error!(target: LOG_TARGET, "Mining solution timed out after {MINING_TIMEOUT_SECS} seconds for block #{block_number}"); + return Err(Error::Timeout(crate::error::TimeoutError::Mining { + timeout_secs: MINING_TIMEOUT_SECS, + })); + }, + }; + + let (ss58_prefix, token_decimals, token_symbol) = get_chain_properties(client.clone()).await?; + + let prediction_ctx = PredictionContext { + round: current_round, + desired_targets, + block_number, + ss58_prefix, + token_decimals, + token_symbol: &token_symbol, + data_source: data_source.clone(), + }; + + let (validators_prediction, nominators_prediction) = build_predictions_from_solution::( + &paged_raw_solution, + &target_snapshot, + &voter_snapshot, + &prediction_ctx, + )?; + + // Determine output file paths + // Create output directory if it doesn't exist + let output_dir = std::path::Path::new(&config.output_dir); + std::fs::create_dir_all(output_dir).map_err(|e| { + Error::Other(format!("Failed to create output directory {}: {}", output_dir.display(), e)) + })?; + + let validators_output = output_dir.join("validators_prediction.json"); + let nominators_output = output_dir.join("nominators_prediction.json"); + // Save validators prediction + write_data_to_json_file(&validators_prediction, &validators_output).await?; + + log::info!( + target: LOG_TARGET, + "Validators prediction saved to {}", + validators_output.display() + ); + + // Save nominators prediction + write_data_to_json_file(&nominators_prediction, &nominators_output).await?; + + log::info!( + target: LOG_TARGET, + "Nominators prediction saved to {}", + nominators_output.display() + ); + + Ok(()) +} diff --git a/src/commands/types.rs b/src/commands/types.rs index 2caaa92f6..d3d029683 100644 --- a/src/commands/types.rs +++ b/src/commands/types.rs @@ -1,4 +1,7 @@ -use polkadot_sdk::sp_runtime::Perbill; +use polkadot_sdk::{sp_npos_elections::ElectionScore, sp_runtime::Perbill}; +use serde::{Deserialize, Serialize}; + +use crate::prelude::AccountId; /// Submission strategy to use. #[derive(Debug, Copy, Clone)] @@ -47,6 +50,20 @@ impl std::str::FromStr for SubmissionStrategy { } } +/// Election algorithm to use for mining solutions. +#[derive( + Debug, Copy, Clone, PartialEq, serde::Serialize, serde::Deserialize, clap::ValueEnum, Default, +)] +pub enum ElectionAlgorithm { + /// Sequential Phragmen algorithm. + #[default] + #[clap(name = "seq-phragmen")] + SeqPhragmen, + /// PhragMMS algorithm. + #[clap(name = "phragmms")] + Phragmms, +} + /// TODO: make `solver algorithm` configurable https://github.com/paritytech/polkadot-staking-miner/issues/989 #[derive(Debug, Clone, clap::Parser)] #[cfg_attr(test, derive(PartialEq))] @@ -83,4 +100,130 @@ pub struct MultiBlockMonitorConfig { /// Higher values may produce better balanced solutions at the cost of more computation time. #[clap(long, default_value_t = 10)] pub balancing_iterations: usize, + + /// Election algorithm to use. + #[clap(long, value_enum, default_value_t = ElectionAlgorithm::SeqPhragmen)] + pub algorithm: ElectionAlgorithm, +} + +/// CLI configuration for election prediction +#[derive(Debug, Clone, clap::Parser)] +#[cfg_attr(test, derive(PartialEq))] +pub struct PredictConfig { + /// Desired number of validators for the prediction + /// [If omitted, the value is fetched from the chain] + #[clap(long)] + pub desired_validators: Option, + + /// Output directory for prediction results + #[clap(long, default_value = "results")] + pub output_dir: String, + + /// Number of balancing iterations for the sequential phragmen algorithm. + /// Higher values may produce better balanced solutions at the cost of more computation time. + #[clap(long, default_value_t = 10)] + pub balancing_iterations: usize, + + /// Reduce the solution to prevent further trimming. + /// [default: false] + #[clap(long, default_value_t = false)] + pub do_reduce: bool, + + /// Block number at which to run the prediction. + /// [If omitted, uses the latest block] + #[clap(long)] + pub block_number: Option, + + /// Path to election overrides JSON file + #[clap(long)] + pub overrides: Option, + + /// Election algorithm to use. + #[clap(long, value_enum, default_value_t = ElectionAlgorithm::SeqPhragmen)] + pub algorithm: ElectionAlgorithm, +} + +/// Validator prediction output +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ValidatorsPrediction { + pub(crate) metadata: PredictionMetadata, + pub(crate) results: Vec, +} + +/// Prediction metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct PredictionMetadata { + pub(crate) timestamp: String, + pub(crate) desired_validators: u32, + pub(crate) round: u32, + pub(crate) block_number: u32, + pub(crate) solution_score: Option, + pub(crate) data_source: String, +} + +/// Validator information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ValidatorInfo { + pub(crate) account: String, + pub(crate) total_stake: String, // Token amount as string + pub(crate) self_stake: String, // Token amount as string + pub(crate) nominator_count: usize, + pub(crate) nominators: Vec, +} + +/// Nominator allocation details for a validator +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct NominatorAllocation { + pub(crate) address: String, + pub(crate) allocated_stake: String, // Token amount as string +} + +/// Nominator prediction output +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct NominatorsPrediction { + pub(crate) nominators: Vec, +} + +/// Nominator prediction +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct NominatorPrediction { + pub(crate) address: String, + pub(crate) stake: String, // Token amount as string + pub(crate) active_validators: Vec, + pub(crate) inactive_validators: Vec, + pub(crate) waiting_validators: Vec, +} + +/// Validator stake allocation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ValidatorStakeAllocation { + pub(crate) validator: String, + pub(crate) allocated_stake: String, // Token amount as string +} + +pub(crate) type NominatorData = (AccountId, u64, Vec); +pub(crate) type ValidatorData = (AccountId, u128); + +// ============================================================================ +// Custom Data File Format Types +// ============================================================================ + +/// JSON format for election overrides +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ElectionOverrides { + #[serde(default)] + pub candidates_include: Vec, + #[serde(default)] + pub candidates_exclude: Vec, + #[serde(default)] + pub voters_include: Vec<(String, u64, Vec)>, + #[serde(default)] + pub voters_exclude: Vec, +} + +/// Data source for election data +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ElectionDataSource { + Snapshot, + Staking, } diff --git a/src/dynamic/election_data.rs b/src/dynamic/election_data.rs new file mode 100644 index 000000000..c29c9c78a --- /dev/null +++ b/src/dynamic/election_data.rs @@ -0,0 +1,491 @@ +//! Helpers for fetching and shaping election data shared by CLI commands + +use polkadot_sdk::{ + frame_election_provider_support::{BoundedSupports, Get}, + frame_support::BoundedVec, + pallet_election_provider_multi_block::{ + PagedRawSolution, + unsigned::miner::{BaseMiner, MinerConfig}, + }, + sp_npos_elections::Support, +}; + +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + time::{SystemTime, UNIX_EPOCH}, +}; + +use crate::{ + commands::{ + multi_block::types::{TargetSnapshotPageOf, Voter, VoterSnapshotPageOf}, + types::{ + ElectionDataSource, ElectionOverrides, NominatorAllocation, NominatorData, + NominatorPrediction, NominatorsPrediction, PredictionMetadata, ValidatorData, + ValidatorInfo, ValidatorStakeAllocation, ValidatorsPrediction, + }, + }, + dynamic::staking::{fetch_candidates, fetch_voters}, + error::Error, + prelude::{AccountId, LOG_TARGET, Storage}, + static_types::multi_block::VoterSnapshotPerBlock, + utils::{encode_account_id, planck_to_token, planck_to_token_u64}, +}; + +use crate::dynamic::multi_block::try_fetch_snapshot; + +/// Context for building predictions, grouping chain metadata and election parameters. +pub struct PredictionContext<'a> { + pub round: u32, + pub desired_targets: u32, + pub block_number: u32, + pub ss58_prefix: u16, + pub token_decimals: u8, + pub token_symbol: &'a str, + pub data_source: ElectionDataSource, +} + +/// Convert election data into the snapshot format expected by the miner. +/// +/// Returns a single-page target snapshot and a Vec of voter pages +pub(crate) fn convert_election_data_to_snapshots( + candidates: Vec, + voters: Vec, +) -> Result<(TargetSnapshotPageOf, Vec>), Error> +where + T: MinerConfig, +{ + log::debug!( + target: LOG_TARGET, + "Converting election data to snapshots (candidates={}, voters={})", + candidates.len(), + voters.len() + ); + + // Extract only accounts from candidates + let target_accounts: Vec = + candidates.into_iter().map(|(account, _)| account).collect(); + log::trace!( + target: LOG_TARGET, + "Fetched {} target accounts from candidates", + target_accounts.len() + ); + + let total_targets = target_accounts.len(); + let target_snapshot: TargetSnapshotPageOf = BoundedVec::truncate_from(target_accounts); + if target_snapshot.len() < total_targets { + log::warn!( + target: LOG_TARGET, + "Target snapshot truncated: kept {} of {} candidates ({} dropped)", + target_snapshot.len(), + total_targets, + total_targets - target_snapshot.len() + ); + } + + let per_voter_page = VoterSnapshotPerBlock::get(); + let total_voters = voters.len(); + log::trace!( + target: LOG_TARGET, + "Preparing {total_voters} voters for conversion" + ); + + let mut voter_pages_vec: Vec> = Vec::new(); + for (stash, stake, votes) in voters { + let votes: BoundedVec::MaxVotesPerVoter> = + BoundedVec::truncate_from(votes); + + // voters → Voter conversion + let voter: Voter = (stash, stake, votes); + + // Start a new page if we have no pages yet or the last page is full + if voter_pages_vec.last().is_none_or(|last| last.len() >= per_voter_page as usize) { + voter_pages_vec.push(BoundedVec::truncate_from(vec![voter])); + } else { + // Try to push to the last page; if it fails (unexpectedly full), start a new page + match voter_pages_vec.last_mut().unwrap().try_push(voter.clone()) { + Ok(_) => {}, + Err(_) => { + let last_idx = voter_pages_vec.len().saturating_sub(1); + let last_len = voter_pages_vec.last().map(|p| p.len()).unwrap_or(0); + log::warn!( + target: LOG_TARGET, + "Voter page {last_idx} unexpectedly full at size {last_len}; starting new page" + ); + voter_pages_vec.push(BoundedVec::truncate_from(vec![voter])); + }, + } + } + } + + let n_pages = voter_pages_vec.len(); + + log::debug!( + target: LOG_TARGET, + "Converted election data: {} targets, {} voters across {} pages", + target_snapshot.len(), + total_voters, + n_pages + ); + + Ok((target_snapshot, voter_pages_vec)) +} + +/// Apply election overrides to candidates and voters. +pub(crate) fn apply_overrides( + mut candidates: Vec, + mut voters: Vec, + overrides: ElectionOverrides, +) -> Result<(Vec, Vec), Error> { + // (1) Remove specific candidates from the election + let candidates_exclude: HashSet = overrides + .candidates_exclude + .iter() + .map(|c| { + c.parse::() + .map_err(|e| Error::Other(format!("Invalid candidate exclude {c}: {e}"))) + }) + .collect::>()?; + + candidates.retain(|(account, _)| !candidates_exclude.contains(account)); + + // (2) Add candidates that may not exist on-chain + let current_candidates: HashSet = + candidates.iter().map(|(a, _)| a.clone()).collect(); + for c_str in overrides.candidates_include { + let account = c_str + .parse::() + .map_err(|e| Error::Other(format!("Invalid candidate include {c_str}: {e}")))?; + if !current_candidates.contains(&account) { + candidates.push((account, 0)); + } + } + + // (3) Remove specific voters from the election + let voters_exclude: HashSet = overrides + .voters_exclude + .iter() + .map(|v| { + v.parse::() + .map_err(|e| Error::Other(format!("Invalid voter exclude {v}: {e}"))) + }) + .collect::>()?; + + voters.retain(|(account, _, _)| !voters_exclude.contains(account)); + + // (4) Add or override voters with custom stake amounts + let voter_map: HashMap = + voters.iter().enumerate().map(|(i, (a, _, _))| (a.clone(), i)).collect(); + + for (v_str, stake, t_strs) in overrides.voters_include { + let account = v_str + .parse::() + .map_err(|e| Error::Other(format!("Invalid voter include {v_str}: {e}")))?; + let targets: Vec = t_strs + .iter() + .map(|t| { + t.parse::() + .map_err(|e| Error::Other(format!("Invalid voter target {t}: {e}"))) + }) + .collect::>()?; + + if let Some(&index) = voter_map.get(&account) { + voters[index] = (account, stake, targets); + } else { + voters.push((account, stake, targets)); + } + } + + Ok((candidates, voters)) +} + +/// Build structured predictions from the mined solution and snapshots. +pub(crate) fn build_predictions_from_solution( + solution: &PagedRawSolution, + target_snapshot: &TargetSnapshotPageOf, + voter_snapshot: &[VoterSnapshotPageOf], + ctx: &PredictionContext<'_>, +) -> Result<(ValidatorsPrediction, NominatorsPrediction), Error> +where + T: MinerConfig, +{ + // Convert slice to BoundedVec for feasibility check (truncates to T::Pages if needed) + let voter_pages_bounded: BoundedVec, T::Pages> = + BoundedVec::truncate_from(voter_snapshot.to_vec()); + + // Reuse the on-chain feasibility logic to reconstruct supports from the paged solution. + let page_supports = BaseMiner::::check_feasibility( + solution, + &voter_pages_bounded, + target_snapshot, + ctx.desired_targets, + ) + .map_err(|err| Error::Other(format!("Failed to evaluate solution supports: {err:?}")))?; + + let mut winner_support_map: BTreeMap> = BTreeMap::new(); + + for page_support in page_supports { + let BoundedSupports(inner) = page_support; + for (winner, bounded_support) in inner.into_iter() { + let support: Support = bounded_support.into(); + let entry = winner_support_map + .entry(winner) + .or_insert_with(|| Support { total: 0, voters: Vec::new() }); + entry.total = entry.total.saturating_add(support.total); + entry.voters.extend(support.voters); + } + } + + // Build allocation map per nominator for quick lookup. + let mut allocation_map: HashMap> = HashMap::new(); + for (validator, support) in winner_support_map.iter() { + for (voter, stake) in support.voters.iter() { + allocation_map + .entry(voter.clone()) + .or_default() + .entry(validator.clone()) + .and_modify(|existing| *existing = existing.saturating_add(*stake)) + .or_insert(*stake); + } + } + + // Sort winners by backing and enforce desired_targets limit. + let mut winners_sorted: Vec<(AccountId, Support)> = + winner_support_map.into_iter().collect(); + winners_sorted.sort_by(|a, b| b.1.total.cmp(&a.1.total)); + if winners_sorted.len() > ctx.desired_targets as usize { + winners_sorted.truncate(ctx.desired_targets as usize); + } + + let active_set: HashSet = + winners_sorted.iter().map(|(validator, _)| validator.clone()).collect(); + + // Flatten voters from paged snapshot for nominator perspective. + let all_voters: Vec> = + voter_snapshot.iter().flat_map(|page| page.iter().cloned()).collect(); + + // Identify validators who only have self-votes + let validators_with_only_self_vote: HashSet = all_voters + .iter() + .filter(|(nominator, _, targets)| { + // validator has only self-vote if either: + // 1. They are a validator (in active_set) + // 2. Their only target is themselves + + active_set.contains(nominator) || (targets.len() == 1 && targets[0] == *nominator) + }) + .map(|(nominator, _, _)| nominator.clone()) + .collect(); + + let mut validator_infos: Vec = Vec::new(); + for (validator, support) in winners_sorted.iter() { + let self_stake = support + .voters + .iter() + .find(|(who, _)| who == validator) + .map(|(_, stake)| *stake) + .unwrap_or(0); + + // Collect nominators backing this validator (excluding self-votes) + let mut validator_nominators: Vec<(AccountId, u128)> = support + .voters + .iter() + .filter(|(who, _)| who != validator) + .map(|(who, stake)| (who.clone(), *stake)) + .collect(); + // Sort by stake descending for consistent ordering + validator_nominators.sort_by(|a, b| b.1.cmp(&a.1)); + + let nominator_allocations = validator_nominators + .iter() + .map(|(nominator, stake)| NominatorAllocation { + address: encode_account_id(nominator, ctx.ss58_prefix), + allocated_stake: planck_to_token(*stake, ctx.token_decimals, ctx.token_symbol), + }) + .collect(); + + validator_infos.push(ValidatorInfo { + account: encode_account_id(validator, ctx.ss58_prefix), + total_stake: planck_to_token(support.total, ctx.token_decimals, ctx.token_symbol), + self_stake: planck_to_token(self_stake, ctx.token_decimals, ctx.token_symbol), + nominator_count: validator_nominators.len(), + nominators: nominator_allocations, + }); + } + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs().to_string()) + .unwrap_or_else(|_| "0".to_string()); + + let data_source_str = match &ctx.data_source { + ElectionDataSource::Snapshot => "snapshot", + ElectionDataSource::Staking => "staking", + } + .to_string(); + + let metadata = PredictionMetadata { + timestamp, + desired_validators: ctx.desired_targets, + round: ctx.round, + block_number: ctx.block_number, + solution_score: Some(solution.score), + data_source: data_source_str, + }; + + let validators_prediction = ValidatorsPrediction { metadata, results: validator_infos }; + + // Build nominator predictions, excluding validators who only have self-votes + let mut nominator_predictions: Vec = Vec::new(); + + for (nominator, stake, nominated_targets) in all_voters { + // Skip validators who only have self-votes + if validators_with_only_self_vote.contains(&nominator) { + continue; + } + + let nominator_encoded = encode_account_id(&nominator, ctx.ss58_prefix); + let allocations = allocation_map.get(&nominator); + + let mut active_supported = Vec::new(); + let mut inactive = Vec::new(); + let mut waiting = Vec::new(); + + for target in nominated_targets.iter() { + let encoded = encode_account_id(target, ctx.ss58_prefix); + let is_winner = active_set.contains(target); + let allocated = allocations.and_then(|m| m.get(target)).copied().unwrap_or(0); + + if is_winner && allocated > 0 { + active_supported.push(ValidatorStakeAllocation { + validator: encoded, + allocated_stake: planck_to_token( + allocated, + ctx.token_decimals, + ctx.token_symbol, + ), + }); + } else if is_winner { + inactive.push(encoded); + } else { + waiting.push(encoded); + } + } + + nominator_predictions.push(NominatorPrediction { + address: nominator_encoded, + stake: planck_to_token_u64(stake, ctx.token_decimals, ctx.token_symbol), + active_validators: active_supported, + inactive_validators: inactive, + waiting_validators: waiting, + }); + } + + let nominators_prediction = NominatorsPrediction { nominators: nominator_predictions }; + + Ok((validators_prediction, nominators_prediction)) +} + +/// Fetch snapshots from chain or synthesize them from staking storage when snapshot is unavailable. +pub(crate) async fn get_election_data( + n_pages: u32, + round: u32, + storage: Storage, +) -> Result<(Vec, Vec, ElectionDataSource), Error> +where + T: MinerConfig + Send + Sync + 'static, + T::Solution: Send, + T::Pages: Send, + T::TargetSnapshotPerBlock: Send, + T::VoterSnapshotPerBlock: Send, + T::MaxVotesPerVoter: Send, +{ + // try to fetch election data from the snapshot + // if snapshot is not available fetch from staking + log::info!(target: LOG_TARGET, "Trying to fetch data from snapshot"); + + match try_fetch_snapshot::(n_pages, round, &storage).await { + Ok((target_snapshot, voter_pages)) => { + log::info!(target: LOG_TARGET, "Snapshot found"); + + let candidates: Vec = + target_snapshot.into_iter().map(|a| (a, 0)).collect(); + + let voters: Vec = voter_pages + .into_iter() + .flat_map(|page| { + page.into_iter().map(|(stash, stake, votes)| { + (stash, stake, votes.into_iter().collect::>()) + }) + }) + .collect(); + + Ok((candidates, voters, ElectionDataSource::Snapshot)) + }, + Err(err) => { + log::warn!(target: LOG_TARGET, "Fetching from Snapshot failed: {err}. Falling back to staking pallet"); + + let candidates = fetch_candidates(&storage) + .await + .map_err(|e| Error::Other(format!("Failed to fetch candidates: {e}")))?; + + let voter_limit = (T::Pages::get() * T::VoterSnapshotPerBlock::get()) as usize; + + let voters = fetch_voters(voter_limit, &storage) + .await + .map_err(|e| Error::Other(format!("Failed to fetch voters: {e}")))?; + + Ok((candidates, voters, ElectionDataSource::Staking)) + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::commands::types::ElectionOverrides; + use polkadot_sdk::sp_core::crypto::Ss58Codec; + + #[test] + fn test_apply_overrides_logic() { + // Create some test accounts + let acc1 = AccountId::from([1u8; 32]); + let acc2 = AccountId::from([2u8; 32]); + let acc3 = AccountId::from([3u8; 32]); + let acc4 = AccountId::from([4u8; 32]); + + let s1 = acc1.to_ss58check(); + let s2 = acc2.to_ss58check(); + let s3 = acc3.to_ss58check(); + let s4 = acc4.to_ss58check(); + + let candidates = vec![(acc1.clone(), 1000), (acc2.clone(), 2000)]; + + let voters = vec![(acc3.clone(), 500, vec![acc1.clone()])]; + + // Override: + // - Remove acc1 candidate + // - Add acc4 candidate + // - Remove acc3 voter + // - Add acc4 voter with targets [acc2, acc4] + let overrides = ElectionOverrides { + candidates_include: vec![s4.clone()], + candidates_exclude: vec![s1.clone()], + voters_include: vec![(s4.clone(), 1500, vec![s2.clone(), s4.clone()])], + voters_exclude: vec![s3.clone()], + }; + + let (new_candidates, new_voters) = apply_overrides(candidates, voters, overrides).unwrap(); + + // Check candidates + assert_eq!(new_candidates.len(), 2); + assert!(new_candidates.iter().any(|(a, _)| a == &acc2)); + assert!(new_candidates.iter().any(|(a, _)| a == &acc4)); + assert!(!new_candidates.iter().any(|(a, _)| a == &acc1)); + + // Check voters + assert_eq!(new_voters.len(), 1); + assert_eq!(new_voters[0].0, acc4); + assert_eq!(new_voters[0].1, 1500); + assert_eq!(new_voters[0].2, vec![acc2, acc4]); + } +} diff --git a/src/dynamic/mod.rs b/src/dynamic/mod.rs index c7ce050f7..d305bd1a1 100644 --- a/src/dynamic/mod.rs +++ b/src/dynamic/mod.rs @@ -9,12 +9,14 @@ use crate::{error::Error, prelude::ChainClient, static_types}; +pub mod election_data; pub mod multi_block; pub mod pallet_api; +pub mod staking; pub mod utils; use static_types::multi_block::{ - BalancingIterations, MaxBackersPerWinner, MaxLength, MaxWinnersPerPage, Pages, + Algorithm, BalancingIterations, MaxBackersPerWinner, MaxLength, MaxWinnersPerPage, Pages, TargetSnapshotPerBlock, VoterSnapshotPerBlock, }; @@ -49,3 +51,8 @@ pub fn update_metadata_constants(api: &ChainClient) -> Result<(), Error> { pub fn set_balancing_iterations(iterations: usize) { BalancingIterations::set(iterations); } + +/// Set the election algorithm from CLI config. +pub fn set_algorithm(algo: crate::commands::types::ElectionAlgorithm) { + Algorithm::set(algo); +} diff --git a/src/dynamic/multi_block.rs b/src/dynamic/multi_block.rs index d0509f4b6..584195708 100644 --- a/src/dynamic/multi_block.rs +++ b/src/dynamic/multi_block.rs @@ -10,12 +10,15 @@ use crate::{ utils::{decode_error, storage_addr, to_scale_value, tx}, }, error::Error, - prelude::{AccountId, ChainClient, Config, ExtrinsicParamsBuilder, Hash, LOG_TARGET, Storage}, + prelude::{ + AccountId, ChainClient, Config, ExtrinsicParamsBuilder, Hash, + MULTI_BLOCK_LOG_TARGET as LOG_TARGET, Storage, + }, runtime::multi_block::{ self as runtime, runtime_types::pallet_election_provider_multi_block::types::Phase, }, signer::Signer, - utils, + static_types, utils, }; use codec::Decode; use futures::{StreamExt, stream::FuturesUnordered}; @@ -231,6 +234,59 @@ where .map_err(|e| Error::Other(format!("{e:?}")))? } +/// Try to fetch the election snapshot from chain storage. +pub(crate) async fn try_fetch_snapshot( + n_pages: u32, + round: u32, + storage: &Storage, +) -> Result<(TargetSnapshotPageOf, BoundedVec, T::Pages>), Error> +where + T: MinerConfig + Send + Sync + 'static, + T::Solution: Send, + T::Pages: Send, + T::TargetSnapshotPerBlock: Send, + T::VoterSnapshotPerBlock: Send, + T::MaxVotesPerVoter: Send, +{ + // Validate n_pages + let chain_pages = static_types::multi_block::Pages::get(); + if n_pages != chain_pages { + return Err(Error::Other(format!("n_pages must be equal to {chain_pages}"))); + } + + // Fetch the (single) target snapshot. Use the last page index + let target_snapshot: TargetSnapshotPageOf = + target_snapshot::(n_pages - 1, round, storage).await?; + + log::trace!(target: LOG_TARGET, "Fetched {} targets from snapshot", target_snapshot.len()); + + // Fetch all voter snapshot pages + let mut voter_snapshot_paged: Vec> = + Vec::with_capacity(n_pages as usize); + for page in 0..n_pages { + let voter_page = paged_voter_snapshot::(page, round, storage).await?; + log::trace!(target: LOG_TARGET, "Fetched {page}/{n_pages} pages of voter snapshot"); + voter_snapshot_paged.push(voter_page); + } + + log::trace!( + target: LOG_TARGET, + "Mine_and_submit: election target snap size: {:?}, voter snap size: {:?}", + target_snapshot.len(), + voter_snapshot_paged.len() + ); + + let voter_pages: BoundedVec, T::Pages> = + BoundedVec::truncate_from(voter_snapshot_paged); + + log::trace!( + target: LOG_TARGET, + "Fetched: pages={n_pages}, target_snapshot_len={}, voters_pages_len={}, round={round}", + target_snapshot.len(), voter_pages.len() + ); + Ok((target_snapshot, voter_pages)) +} + /// Fetches the target snapshot and all voter snapshots which are missing /// but some snapshots may not exist yet which is just ignored. pub(crate) async fn fetch_missing_snapshots_lossy( diff --git a/src/dynamic/pallet_api.rs b/src/dynamic/pallet_api.rs index 3a0f79f32..78148c00a 100644 --- a/src/dynamic/pallet_api.rs +++ b/src/dynamic/pallet_api.rs @@ -147,5 +147,28 @@ pub mod system { use super::{super::*, *}; pub const BLOCK_LENGTH: PalletConstant = PalletConstant::new(NAME, "BlockLength"); + pub const SS58_PREFIX: PalletConstant = PalletConstant::new(NAME, "SS58Prefix"); + } +} + +pub mod staking { + pub const NAME: &str = "Staking"; + + pub mod storage { + use super::{super::*, *}; + pub const VALIDATORS: PalletItem = PalletItem::new(NAME, "Validators"); + pub const LEDGER: PalletItem = PalletItem::new(NAME, "Ledger"); + pub const NOMINATORS: PalletItem = PalletItem::new(NAME, "Nominators"); + pub const BONDED: PalletItem = PalletItem::new(NAME, "Bonded"); + } +} + +pub mod voter_list { + pub const NAME: &str = "VoterList"; + + pub mod storage { + use super::{super::*, *}; + pub const LIST_NODES: PalletItem = PalletItem::new(NAME, "ListNodes"); + pub const LIST_BAGS: PalletItem = PalletItem::new(NAME, "ListBags"); } } diff --git a/src/dynamic/staking.rs b/src/dynamic/staking.rs new file mode 100644 index 000000000..70b7ad887 --- /dev/null +++ b/src/dynamic/staking.rs @@ -0,0 +1,642 @@ +//! Shared utilities for fetching staking data + +use crate::{ + commands::types::{NominatorData, ValidatorData}, + dynamic::{pallet_api, utils::storage_addr}, + error::Error, + prelude::{AccountId, STAKING_LOG_TARGET as LOG_TARGET, Storage}, +}; +use codec::{Decode, Encode}; +use scale_value::At; +use std::{ + collections::{HashMap, HashSet}, + time::Duration, +}; +use subxt::dynamic::Value; + +/// Fetch all candidate validators (stash AccountId) with their active stake +pub(crate) async fn fetch_candidates(storage: &Storage) -> Result, Error> { + log::debug!(target: LOG_TARGET, "Fetching candidate validators (Staking::Validators keys)"); + + let validators_addr = storage_addr(pallet_api::staking::storage::VALIDATORS, vec![]); + let mut iter = storage.iter(validators_addr).await?; + + let mut candidate_accounts: Vec = Vec::new(); + let mut count: usize = 0; + + while let Some(next) = iter.next().await { + let kv = match next { + Ok(kv) => kv, + Err(e) => return Err(Error::Other(format!("storage iteration error: {e}"))), + }; + + let key_bytes = kv.key_bytes; + if key_bytes.len() < 32 { + return Err(Error::Other(format!( + "unexpected key length {} (< 32); cannot decode AccountId", + key_bytes.len() + ))); + } + let tail = &key_bytes[key_bytes.len() - 32..]; + let arr: [u8; 32] = tail + .try_into() + .map_err(|_| Error::Other("failed to slice AccountId32 bytes".into()))?; + let account = AccountId::from(arr); + + candidate_accounts.push(account); + count += 1; + + if count % 500 == 0 { + log::debug!(target: LOG_TARGET, "Fetched {count} candidate accounts..."); + } + } + + log::info!( + target: LOG_TARGET, + "Total candidate accounts fetched: {}", + candidate_accounts.len() + ); + + // Fetch self-stakes concurrently + log::debug!(target: LOG_TARGET, "Fetching stakes for {} validators...", candidate_accounts.len()); + + let mut stake_futures = Vec::with_capacity(candidate_accounts.len()); + + for account in &candidate_accounts { + let storage = storage.clone(); + let account = account.clone(); + + stake_futures.push(async move { + let bytes_vec: Vec = account.encode(); + let params: Vec = vec![Value::from_bytes(bytes_vec)]; + let bonded_addr = storage_addr(pallet_api::staking::storage::BONDED, params.clone()); + + // Deterministic Controller lookup: first check Bonded, then fallback to account (Stash) + let ledger_key_addr = if let Some(bonded) = storage.fetch(&bonded_addr).await? { + let controller_bytes = bonded.encoded(); + if controller_bytes.len() < 32 { + return Err(Error::Other("Unexpected Bonded key length".into())); + } + let tail = &controller_bytes[controller_bytes.len() - 32..]; + let arr: [u8; 32] = tail + .try_into() + .map_err(|_| Error::Other("Failed to slice Controller ID".into()))?; + let controller = AccountId::from(arr); + storage_addr( + pallet_api::staking::storage::LEDGER, + vec![Value::from_bytes(controller.encode())], + ) + } else { + storage_addr(pallet_api::staking::storage::LEDGER, params) + }; + + let mut stake = 0u128; + if let Some(ledger) = storage.fetch(&ledger_key_addr).await? && + let Ok(value) = ledger.to_value() + { + // Try to get 'active' first (self-stake), fallback to 'total' if 'active' not + // available + if let Some(active) = value.at("active").and_then(|v| v.as_u128()) { + stake = active; + } else if let Some(total) = value.at("total").and_then(|v| v.as_u128()) { + stake = total; + } + } + + Result::::Ok(stake) + }); + } + + let stakes_results = futures::future::join_all(stake_futures).await; + let mut stakes = Vec::with_capacity(stakes_results.len()); + + for result in stakes_results { + stakes.push(result?); + } + + log::debug!(target: LOG_TARGET, "Fetched stakes for {} validators", stakes.len()); + + let candidates: Vec = + candidate_accounts.into_iter().zip(stakes.into_iter()).collect(); + + log::info!(target: LOG_TARGET, "Total registered candidates: {}", candidates.len()); + + Ok(candidates) +} + +/// Helper to fetch just the validator keys (Set) for O(1) existence checks +pub(crate) async fn fetch_validator_keys(storage: &Storage) -> Result, Error> { + log::debug!(target: LOG_TARGET, "Fetching validator keys for existence checks"); + let validators_addr = storage_addr(pallet_api::staking::storage::VALIDATORS, vec![]); + let mut iter = storage.iter(validators_addr).await?; + let mut keys = HashSet::new(); + + while let Some(next) = iter.next().await { + let kv = next.map_err(|e| Error::Other(format!("storage iteration error: {e}")))?; + let key_bytes = kv.key_bytes; + if key_bytes.len() >= 32 { + let tail = &key_bytes[key_bytes.len() - 32..]; + if let Ok(arr) = <[u8; 32]>::try_from(tail) { + keys.insert(AccountId::from(arr)); + } + } + } + log::debug!(target: LOG_TARGET, "Fetched {} validator keys", keys.len()); + Ok(keys) +} + +/// Node data structure for VoterList (decoded from storage) +/// The actual on-chain structure from pallet-bags-list +/// Structure based on substrate bags-list pallet Node definition +#[derive(Debug, Clone, Decode)] +struct ListNode { + id: AccountId, + #[allow(dead_code)] + prev: Option, + next: Option, + bag_upper: u64, + score: u64, +} + +/// Bag data structure for VoterList +#[derive(Debug, Clone, Decode)] +struct ListBag { + head: Option, + #[allow(dead_code)] + tail: Option, +} + +/// Node data structure for VoterList processing +#[derive(Debug, Clone)] +struct VoterNode { + id: AccountId, + score: u64, + next: Option, + #[allow(dead_code)] + bag_upper: u64, +} + +/// Fetch and sort voters from the VoterList (BagsList) +/// Returns the top voters limited by voter_limit, sorted by score (stake) in descending order +pub(crate) async fn fetch_voter_list( + voter_limit: usize, + storage: &Storage, +) -> Result, Error> { + // Increase voter limit to have a buffer for filtering ineligible voters later + let extended_voter_limit = voter_limit.saturating_add(100); + + log::info!(target: LOG_TARGET, "Fetching From VoterList"); + + // Fetch all bags (ListBags) - store as HashMap with bag_upper as key + log::trace!(target: LOG_TARGET, "Fetching ListBags..."); + let list_bags_addr = storage_addr(pallet_api::voter_list::storage::LIST_BAGS, vec![]); + let mut bags_iter = storage.iter(list_bags_addr).await?; + + let mut bags: HashMap = HashMap::new(); + + while let Some(next) = bags_iter.next().await { + let kv = match next { + Ok(kv) => kv, + Err(e) => return Err(Error::Other(format!("bags iteration error: {e}"))), + }; + + // Extract bag score (bag_upper) from key (u64) + let key_bytes = kv.key_bytes; + let bag_upper = if key_bytes.len() >= 8 { + let start = key_bytes.len().saturating_sub(8); + u64::from_le_bytes(key_bytes[start..].try_into().unwrap_or([0u8; 8])) + } else { + 0u64 + }; + + // Decode the bag structure + let bag: ListBag = Decode::decode(&mut &kv.value.encoded()[..])?; + + // Store bag with its upper bound as key + bags.insert(bag_upper, bag); + } + + log::trace!(target: LOG_TARGET, "Found {} bags", bags.len()); + + // Fetch all nodes (ListNodes) - store as HashMap with AccountId as key + log::trace!(target: LOG_TARGET, "Fetching ListNodes..."); + + let mut nodes: HashMap = HashMap::new(); + let list_nodes_addr = storage_addr(pallet_api::voter_list::storage::LIST_NODES, vec![]); + let mut nodes_iter = storage.iter(list_nodes_addr).await?; + + let mut nodes_count = 0; + + while let Some(next) = nodes_iter.next().await { + let kv = match next { + Ok(kv) => kv, + Err(e) => return Err(Error::Other(format!("node iteration error: {e}"))), + }; + + // Extract AccountId from key + let key_bytes = &kv.key_bytes; + if key_bytes.len() < 32 { + continue; + } + let tail = &key_bytes[key_bytes.len() - 32..]; + let arr: [u8; 32] = tail + .try_into() + .map_err(|_| Error::Other("failed to slice AccountId32 bytes".into()))?; + let account_id = AccountId::from(arr); + + // Decode node data using the struct + let list_node = match ListNode::decode(&mut &kv.value.encoded()[..]) { + Ok(node) => node, + Err(e) => { + log::warn!( + target: LOG_TARGET, + "Failed to decode ListNode for {account_id:?}: {e}" + ); + continue; + }, + }; + + // Store node in HashMap for O(1) lookup + nodes.insert( + list_node.id.clone(), + VoterNode { + id: list_node.id, + score: list_node.score, + next: list_node.next, + bag_upper: list_node.bag_upper, + }, + ); + + nodes_count += 1; + + if nodes_count % 5000 == 0 { + log::trace!(target: LOG_TARGET, "Fetched {nodes_count} nodes..."); + } + } + + log::trace!(target: LOG_TARGET, "Found {} nodes total", nodes.len()); + + // Sort bags from highest to lowest score (descending order) + log::trace!(target: LOG_TARGET, "Sorting bags by score (descending)..."); + let mut sorted_bag_keys: Vec = bags.keys().copied().collect(); + sorted_bag_keys.sort_by(|a, b| b.cmp(a)); // Descending order - highest stake first + + log::info!( + target: LOG_TARGET, + "Bags sorted. Highest bag: {:?}, Lowest bag: {:?}", + sorted_bag_keys.first(), + sorted_bag_keys.last() + ); + + // Iterate through bags and follow linked lists to build voter snapshot + log::info!( + target: LOG_TARGET, + "Building voter snapshot (limit: {voter_limit})..." + ); + + let mut voters: Vec<(AccountId, u64)> = Vec::new(); + let mut processed: HashMap = HashMap::new(); + let mut total_nodes_processed = 0; + + // Iterate through each bag from highest to lowest + for bag_upper in sorted_bag_keys { + // Check if we've reached the extended voter limit + if voters.len() >= extended_voter_limit { + log::trace!(target: LOG_TARGET, "Reached extended voter limit of {extended_voter_limit}"); + break; + } + + // Get the bag + let bag = match bags.get(&bag_upper) { + Some(b) => b, + None => continue, + }; + + // Skip empty bags (no head) + if bag.head.is_none() { + continue; + } + + // Start from the head of this bag's linked list + let mut current_node_id = bag.head.clone(); + let mut nodes_in_bag = 0; + + // Walk through the linked list for this bag + while let Some(node_id) = current_node_id { + // Check extended voter limit + if voters.len() >= extended_voter_limit { + break; + } + + // Skip if already processed (detect cycles) + if processed.contains_key(&node_id) { + log::warn!( + target: LOG_TARGET, + "Cycle detected: node {node_id:?} already processed in bag {bag_upper}" + ); + break; + } + + // Mark as processed + processed.insert(node_id.clone(), true); + + // Get the node from our HashMap + let node = match nodes.get(&node_id) { + Some(n) => n, + None => { + log::warn!( + target: LOG_TARGET, + "Broken chain: Node {node_id:?} not found in bag {bag_upper}" + ); + break; + }, + }; + + // Add this voter to our snapshot + voters.push((node.id.clone(), node.score)); + nodes_in_bag += 1; + total_nodes_processed += 1; + + // Move to the next node in the linked list + current_node_id = node.next.clone(); + } + + if nodes_in_bag > 0 { + log::debug!( + target: LOG_TARGET, + "Bag {bag_upper} (upper: {bag_upper}): processed {nodes_in_bag} nodes" + ); + } + } + + log::info!( + target: LOG_TARGET, + "VoterList Fetch Completed" + ); + log::info!( + target: LOG_TARGET, + "Total voters fetched for selection pool: {} (voter_limit: {}, extended_voter_limit: {})", + voters.len(), + voter_limit, + extended_voter_limit + ); + log::info!( + target: LOG_TARGET, + "Total nodes processed: {total_nodes_processed}" + ); + + Ok(voters) +} + +/// Fetch complete voter data including nomination targets +/// This function first fetches voters from VoterList, then queries staking.nominators in batches +/// Uses concurrent batch processing +pub(crate) async fn fetch_voters( + voter_limit: usize, + storage: &Storage, +) -> Result, Error> { + // Fetch voters from VoterList (BagsList) with their stakes (already sorted) + log::debug!(target: LOG_TARGET, "Fetching voters from VoterList..."); + let voters = fetch_voter_list(voter_limit, storage).await?; + + log::info!(target: LOG_TARGET, "Fetched {} voters from VoterList", voters.len()); + + // Fetch Validators (keys only) to perform `Validators::::contains_key(&voter)` check + // This is critical for implicit self-votes where no Nominator entry exists. + let validator_keys = fetch_validator_keys(storage).await?; + + // Prepare for batch fetching of nomination targets + const BATCH_SIZE: usize = 20; + const MAX_CONCURRENT_BATCHES: usize = 10; + const MAX_RETRIES: u32 = 8; + + let total = voters.len(); + log::info!( + target: LOG_TARGET, + "Fetching targets for {total} voters" + ); + + let mut complete_voter_data: Vec = Vec::with_capacity(total); + + // Note: Using the return type of the batch task to Option> to handle missing entries + type BatchResult = Vec<(AccountId, u64, Option>)>; + type BatchHandle = tokio::task::JoinHandle>; + let mut batch_handles: Vec = Vec::new(); + let mut processed_count: usize = 0; + + // Process in batches with concurrency + for chunk in voters.chunks(BATCH_SIZE) { + let chunk = chunk.to_vec(); + let storage_clone = storage.clone(); + + let handle = tokio::spawn(async move { + // Retry logic for individual batch + let mut last_error: Option = None; + for attempt in 0..=MAX_RETRIES { + match fetch_voter_batch(&chunk, &storage_clone).await { + Ok(results) => return Ok(results), + Err(e) => { + let error_msg = format!("{e}"); + if (error_msg.contains("limit reached") || + error_msg.contains("RPC error") || + error_msg.contains("timeout")) && + attempt < MAX_RETRIES + { + last_error = Some(e); + let delay = Duration::from_secs(2) * (1 << attempt); // Exponential backoff starting at 2s + tokio::time::sleep(delay.min(Duration::from_secs(60))).await; + continue; + } else { + return Err(e); + } + }, + } + } + Err(last_error.unwrap_or_else(|| Error::Other("Failed after retries".into()))) + }); + + batch_handles.push(handle); + + // Manage concurrency and process results + if batch_handles.len() >= MAX_CONCURRENT_BATCHES { + let handle = batch_handles.remove(0); + match handle.await { + Ok(Ok(results)) => { + // Logic: get_npos_voters implementation + for (account, stake, maybe_targets) in results { + match maybe_targets { + // Case 1: Nominator entry exists (include only if targets exist) + Some(targets) => + if !targets.is_empty() { + complete_voter_data.push((account, stake, targets)); + } else { + log::debug!(target: LOG_TARGET, "Skipping nominator {account:?} - no targets"); + }, + // Case 2: No Nominator entry -> Check if Validator + None => { + if validator_keys.contains(&account) { + // Implicit self-vote + complete_voter_data.push(( + account.clone(), + stake, + vec![account], + )); + } + // Else: Defensive error / skip + }, + } + } + processed_count += BATCH_SIZE.min(total - processed_count); // Approximate + if processed_count % 5000 == 0 { + log::trace!(target: LOG_TARGET, "Processed {processed_count} voters..."); + } + }, + Ok(Err(e)) => return Err(e), + Err(e) => return Err(Error::Other(format!("Task join error: {e}"))), + } + } + } + + // Await remaining tasks + if !batch_handles.is_empty() { + let batch_results = futures::future::join_all(batch_handles).await; + for result in batch_results { + match result { + Ok(Ok(results)) => + for (account, stake, maybe_targets) in results { + match maybe_targets { + // Include only nominators with non-empty targets + Some(targets) => + if !targets.is_empty() { + complete_voter_data.push((account, stake, targets)); + }, + None => + if validator_keys.contains(&account) { + complete_voter_data.push(( + account.clone(), + stake, + vec![account], + )); + }, + } + }, + Ok(Err(e)) => return Err(e), + Err(e) => return Err(Error::Other(format!("Task join error: {e}"))), + } + } + } + + // Truncate to exact voter_limit to match on-chain behavior + if complete_voter_data.len() > voter_limit { + log::info!( + target: LOG_TARGET, + "Truncating voters from {} to {}", + complete_voter_data.len(), + voter_limit + ); + complete_voter_data.truncate(voter_limit); + } + + log::info!( + target: LOG_TARGET, + "Completed fetching voter data with targets for {} voters", + complete_voter_data.len() + ); + + Ok(complete_voter_data) +} + +/// Nominations data structure from the blockchain +#[derive(Debug, Clone, Decode)] +struct Nominations { + /// List of validator accounts being nominated + pub targets: Vec, + /// Era when nominations were submitted + #[allow(dead_code)] + pub submitted_in: u32, + /// Whether the nominator is suppressed + pub suppressed: bool, +} + +/// Helper to fetch a single batch of nominators +async fn fetch_voter_batch( + voters: &[(AccountId, u64)], + storage: &Storage, +) -> Result>)>, Error> { + let mut batch_results = Vec::with_capacity(voters.len()); + + // Prepare futures for fetching individual nominator data + let mut futs = Vec::with_capacity(voters.len()); + + for (account_id, stake) in voters { + let storage = &storage; + let account_id = account_id.clone(); + let score_stake = *stake; + + futs.push(async move { + let bytes_vec: Vec = account_id.encode(); + let params: Vec = vec![Value::from_bytes(bytes_vec)]; + let nominators_addr = + storage_addr(pallet_api::staking::storage::NOMINATORS, params.clone()); + let bonded_addr = storage_addr(pallet_api::staking::storage::BONDED, params.clone()); + + // Fetch targets + let nominator_data = storage.fetch(&nominators_addr).await?; + + // Deterministic Controller lookup for stake accuracy + let ledger_key_addr = if let Some(bonded) = storage.fetch(&bonded_addr).await? { + let controller_bytes = bonded.encoded(); + if controller_bytes.len() < 32 { + return Err(Error::Other("Unexpected Bonded key length".into())); + } + let tail = &controller_bytes[controller_bytes.len() - 32..]; + let arr: [u8; 32] = tail + .try_into() + .map_err(|_| Error::Other("Failed to slice Controller ID".into()))?; + let controller = AccountId::from(arr); + storage_addr( + pallet_api::staking::storage::LEDGER, + vec![Value::from_bytes(controller.encode())], + ) + } else { + storage_addr(pallet_api::staking::storage::LEDGER, params) + }; + + // Default to VoterList score, but override with actual Ledger active stake + let mut actual_stake = score_stake; + + if let Some(ledger) = storage.fetch(&ledger_key_addr).await? && + let Ok(value) = ledger.to_value() + { + if let Some(active) = value.at("active").and_then(|v| v.as_u128()) { + actual_stake = active as u64; + } + } + + Result::<_, Error>::Ok((account_id, actual_stake, nominator_data)) + }); + } + + let results = futures::future::join_all(futs).await; + + for result in results { + let (account_id, stake, nominator_data) = result?; + + let targets = if let Some(data) = nominator_data { + // Decode the nominations + if let Ok(nominations) = Nominations::decode(&mut &data.encoded()[..]) { + // Only include active nominators (not suppressed) + if !nominations.suppressed { Some(nominations.targets) } else { Some(Vec::new()) } + } else { + // Failed to decode, treat as empty targets + Some(Vec::new()) + } + } else { + // Nominator data NOT FOUND -> This is where we will check if it is a validator later + None + }; + batch_results.push((account_id, stake, targets)); + } + + Ok(batch_results) +} diff --git a/src/error.rs b/src/error.rs index 56e089891..bd99727d0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -124,6 +124,8 @@ pub enum Error { ChannelFailure(#[from] ChannelFailureError), #[error("Task failure: {0}")] TaskFailure(#[from] TaskFailureError), + #[error("JSON serialization error: `{0}`")] + SerdeJson(#[from] serde_json::Error), } impl From for Error { diff --git a/src/macros.rs b/src/macros.rs index 9161f84b2..0128ba7a9 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -104,7 +104,38 @@ macro_rules! for_multi_block_runtime { }; } +/// Macro to mimic a polkadot-sdk runtime parameter type for ElectionAlgorithm +macro_rules! impl_algorithm_parameter_type { + ($mod:ident, $name:ident) => { + mod $mod { + use crate::commands::types::ElectionAlgorithm; + use std::sync::atomic::{AtomicU8, Ordering}; + static VAL: AtomicU8 = AtomicU8::new(0); // 0 = SeqPhragmen, 1 = Phragmms + pub struct $name; + + impl $name { + pub fn get() -> ElectionAlgorithm { + match VAL.load(Ordering::SeqCst) { + 0 => ElectionAlgorithm::SeqPhragmen, + 1 => ElectionAlgorithm::Phragmms, + _ => unreachable!(), + } + } + pub fn set(val: ElectionAlgorithm) { + let v = match val { + ElectionAlgorithm::SeqPhragmen => 0, + ElectionAlgorithm::Phragmms => 1, + }; + VAL.store(v, Ordering::SeqCst); + } + } + } + pub use $mod::$name; + }; +} + #[allow(unused)] pub(crate) use { - for_multi_block_runtime, impl_balancing_config_parameter_type, impl_u32_parameter_type, + for_multi_block_runtime, impl_algorithm_parameter_type, impl_balancing_config_parameter_type, + impl_u32_parameter_type, }; diff --git a/src/main.rs b/src/main.rs index 0a01415bf..0730e664b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -86,6 +86,8 @@ pub enum Command { Monitor(commands::types::MultiBlockMonitorConfig), /// Check if the staking-miner metadata is compatible to a remote node. Info, + /// Run election prediction + Predict(commands::types::PredictConfig), } #[tokio::main] @@ -101,7 +103,11 @@ async fn main() -> Result<(), Error> { // Initialize the timestamp so that if connection hangs, the stall detection alert can fire. prometheus::set_last_block_processing_time(); - let client = Client::new(&uri).await?; + // Create client with appropriate backend based on command type + let client = match command { + Command::Predict(_) => Client::new_with_legacy_backend(&uri).await?, + _ => Client::new(&uri).await?, + }; let version_bytes = client .chain_api() @@ -153,6 +159,11 @@ async fn main() -> Result<(), Error> { commands::multi_block::monitor_cmd::(client, cfg).boxed() }) }, + Command::Predict(cfg) => { + macros::for_multi_block_runtime!(chain, { + commands::predict::predict_cmd::(client, cfg).boxed() + }) + }, }; let res = run_command(fut, rx_upgrade).await; @@ -390,7 +401,7 @@ async fn runtime_upgrade_task(client: Client, tx: oneshot::Sender) { #[cfg(test)] mod tests { use super::*; - use crate::commands::types::{MultiBlockMonitorConfig, SubmissionStrategy}; + use crate::commands::types::{ElectionAlgorithm, MultiBlockMonitorConfig, SubmissionStrategy}; #[test] fn cli_monitor_works() { @@ -421,6 +432,7 @@ mod tests { min_signed_phase_blocks: 10, // Default shady: false, // Default balancing_iterations: 10, // Default + algorithm: ElectionAlgorithm::SeqPhragmen, }), } ); @@ -449,6 +461,7 @@ mod tests { min_signed_phase_blocks: 10, // Default shady: false, // Default balancing_iterations: 10, // Default + algorithm: ElectionAlgorithm::SeqPhragmen, }) ); } @@ -477,6 +490,7 @@ mod tests { min_signed_phase_blocks: 10, // Default shady: false, // Default balancing_iterations: 10, // Default + algorithm: ElectionAlgorithm::SeqPhragmen, }) ); } @@ -505,6 +519,7 @@ mod tests { min_signed_phase_blocks: 5, // Explicitly set shady: false, // Default balancing_iterations: 10, // Default + algorithm: ElectionAlgorithm::SeqPhragmen, }) ); } @@ -532,6 +547,7 @@ mod tests { min_signed_phase_blocks: 10, // Default shady: true, // Explicitly set balancing_iterations: 10, // Default + algorithm: ElectionAlgorithm::SeqPhragmen, }) ); } diff --git a/src/prelude.rs b/src/prelude.rs index 0862f3fc9..f7fc2bd50 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -11,6 +11,10 @@ pub const DEFAULT_URI: &str = "ws://127.0.0.1:9944"; pub const DEFAULT_PROMETHEUS_PORT: u16 = 9999; /// The logging target. pub const LOG_TARGET: &str = "polkadot-staking-miner"; +/// The multi-block logging target. +pub const MULTI_BLOCK_LOG_TARGET: &str = "polkadot-staking-miner::multi-block"; +/// The staking logging target. +pub const STAKING_LOG_TARGET: &str = "polkadot-staking-miner::staking"; /// Subxt client used by the staking miner on all chains. pub type ChainClient = subxt::OnlineClient; diff --git a/src/static_types/multi_block.rs b/src/static_types/multi_block.rs index 6bec5965a..8b690c187 100644 --- a/src/static_types/multi_block.rs +++ b/src/static_types/multi_block.rs @@ -1,9 +1,12 @@ use crate::{ - macros::{impl_balancing_config_parameter_type, impl_u32_parameter_type}, + macros::{ + impl_algorithm_parameter_type, impl_balancing_config_parameter_type, + impl_u32_parameter_type, + }, prelude::{AccountId, Accuracy, Hash}, }; use polkadot_sdk::{ - frame_election_provider_support::{self, SequentialPhragmen}, + frame_election_provider_support::{self, PerThing128, PhragMMS, SequentialPhragmen}, frame_support, pallet_election_provider_multi_block as multi_block, sp_runtime::{PerU16, Percent, traits::ConstU32}, }; @@ -15,6 +18,64 @@ impl_u32_parameter_type!(max_winners_per_page, MaxWinnersPerPage); impl_u32_parameter_type!(max_backers_per_winner, MaxBackersPerWinner); impl_u32_parameter_type!(max_length, MaxLength); impl_balancing_config_parameter_type!(balancing, BalancingIterations); +impl_algorithm_parameter_type!(algorithm, Algorithm); + +pub struct DynamicSolver( + std::marker::PhantomData<(AccountId, Accuracy, Balancing)>, +); + +impl frame_election_provider_support::NposSolver + for DynamicSolver +where + AccountId: frame_election_provider_support::IdentifierT, + Accuracy: PerThing128, + Balancing: frame_support::traits::Get>, +{ + type AccountId = AccountId; + type Accuracy = Accuracy; + type Error = polkadot_sdk::sp_npos_elections::Error; + + fn solve( + winners: usize, + targets: Vec, + voters: Vec<( + AccountId, + polkadot_sdk::sp_npos_elections::VoteWeight, + impl Clone + IntoIterator, + )>, + ) -> Result, Self::Error> + { + match Algorithm::get() { + crate::commands::types::ElectionAlgorithm::SeqPhragmen => + SequentialPhragmen::::solve( + winners, targets, voters, + ), + crate::commands::types::ElectionAlgorithm::Phragmms => + PhragMMS::::solve(winners, targets, voters), + } + } + + fn weight( + voters: u32, + targets: u32, + vote_degree: u32, + ) -> polkadot_sdk::frame_election_provider_support::Weight { + match Algorithm::get() { + crate::commands::types::ElectionAlgorithm::SeqPhragmen => + SequentialPhragmen::::weight::( + voters, + targets, + vote_degree, + ), + crate::commands::types::ElectionAlgorithm::Phragmms => + PhragMMS::::weight::( + voters, + targets, + vote_degree, + ), + } + } +} pub mod node { use super::*; @@ -37,7 +98,7 @@ pub mod node { impl multi_block::unsigned::miner::MinerConfig for MinerConfig { type AccountId = AccountId; type Solution = NposSolution16; - type Solver = SequentialPhragmen; + type Solver = DynamicSolver; type Pages = Pages; type MaxVotesPerVoter = ConstU32<16>; type MaxWinnersPerPage = MaxWinnersPerPage; @@ -52,7 +113,6 @@ pub mod node { pub mod polkadot { use super::*; - use frame_election_provider_support::SequentialPhragmen; // TODO: validate config https://github.com/paritytech/polkadot-staking-miner/issues/994 frame_election_provider_support::generate_solution_type!( @@ -72,7 +132,7 @@ pub mod polkadot { impl multi_block::unsigned::miner::MinerConfig for MinerConfig { type AccountId = AccountId; type Solution = NposSolution16; - type Solver = SequentialPhragmen; + type Solver = DynamicSolver; type Pages = Pages; type MaxVotesPerVoter = ConstU32<16>; type MaxWinnersPerPage = MaxWinnersPerPage; @@ -87,7 +147,6 @@ pub mod polkadot { pub mod kusama { use super::*; - use frame_election_provider_support::SequentialPhragmen; // TODO: validate config https://github.com/paritytech/polkadot-staking-miner/issues/994 frame_election_provider_support::generate_solution_type!( @@ -107,7 +166,7 @@ pub mod kusama { impl multi_block::unsigned::miner::MinerConfig for MinerConfig { type AccountId = AccountId; type Solution = NposSolution24; - type Solver = SequentialPhragmen; + type Solver = DynamicSolver; type Pages = Pages; type MaxVotesPerVoter = ConstU32<24>; type MaxWinnersPerPage = MaxWinnersPerPage; @@ -122,7 +181,6 @@ pub mod kusama { pub mod westend { use super::*; - use frame_election_provider_support::SequentialPhragmen; // TODO: validate config https://github.com/paritytech/polkadot-staking-miner/issues/994 frame_election_provider_support::generate_solution_type!( @@ -142,7 +200,7 @@ pub mod westend { impl multi_block::unsigned::miner::MinerConfig for MinerConfig { type AccountId = AccountId; type Solution = NposSolution16; - type Solver = SequentialPhragmen; + type Solver = DynamicSolver; type Pages = Pages; type MaxVotesPerVoter = ConstU32<16>; type MaxWinnersPerPage = MaxWinnersPerPage; @@ -158,7 +216,6 @@ pub mod westend { /// This is used to test against staking-async runtimes from the SDK. pub mod staking_async { use super::*; - use frame_election_provider_support::SequentialPhragmen; // TODO: validate config https://github.com/paritytech/polkadot-staking-miner/issues/994 frame_election_provider_support::generate_solution_type!( @@ -178,7 +235,7 @@ pub mod staking_async { impl multi_block::unsigned::miner::MinerConfig for MinerConfig { type AccountId = AccountId; type Solution = NposSolution16; - type Solver = SequentialPhragmen; + type Solver = DynamicSolver; type Pages = Pages; type MaxVotesPerVoter = ConstU32<16>; type MaxWinnersPerPage = MaxWinnersPerPage; diff --git a/src/utils.rs b/src/utils.rs index f68ef2e7b..b0c3c5dd7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -18,14 +18,23 @@ use crate::{ client::Client, commands::types::SubmissionStrategy, error::Error, - prelude::{ChainClient, Config, Hash, Storage}, + prelude::{AccountId, ChainClient, Config, Hash, LOG_TARGET, Storage}, }; use pin_project_lite::pin_project; -use polkadot_sdk::{sp_npos_elections, sp_runtime::Perbill}; +use polkadot_sdk::{ + sp_core::crypto::{Ss58AddressFormat, Ss58Codec}, + sp_npos_elections, + sp_runtime::Perbill, +}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use ss58_registry::Ss58AddressFormat as RegistryFormat; use std::{ + fs::{self, File}, future::Future, + io::{BufWriter, Read, Write}, + path::Path, pin::Pin, - task::{Context, Poll}, + task::{Context as TaskContext, Poll}, time::{Duration, Instant}, }; use subxt::tx::{TxInBlock, TxProgress}; @@ -47,7 +56,7 @@ where { type Output = (Fut::Output, Duration); - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll { let this = self.project(); let start = this.start.get_or_insert_with(Instant::now); @@ -74,7 +83,7 @@ pub async fn storage_at(block: Option, api: &ChainClient) -> Result(data: &T, file_path: &P) -> Result<(), Error> +where + T: Serialize, + P: AsRef, +{ + let path = file_path.as_ref(); + if let Some(parent) = path.parent() && + !parent.as_os_str().is_empty() + { + fs::create_dir_all(parent)?; + } + + let file = File::create(path)?; + let mut writer = BufWriter::with_capacity(1024 * 1024, file); + + let json = serde_json::to_string_pretty(data)?; + writer.write_all(json.as_bytes())?; + writer.flush()?; + + Ok(()) +} + +/// Read data from a JSON file +pub async fn read_data_from_json_file(file_path: P) -> Result +where + T: DeserializeOwned, + P: AsRef, +{ + let path = file_path.as_ref(); + + let mut file = File::open(path)?; + let mut content = String::new(); + file.read_to_string(&mut content)?; + + Ok(serde_json::from_str(&content)?) +} + +/// Chain properties from system_properties RPC call +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ChainProperties { + token_symbol: Option, + token_decimals: Option, +} + +/// Get the SS58 prefix from the chain +pub async fn get_ss58_prefix(client: &Client) -> Result { + match crate::dynamic::pallet_api::system::constants::SS58_PREFIX + .fetch(&*client.chain_api().await) + { + Ok(ss58_prefix) => Ok(ss58_prefix), + Err(e) => { + log::warn!(target: LOG_TARGET, "Failed to fetch SS58 prefix: {e}"); + log::warn!(target: LOG_TARGET, "Using default SS58 prefix: 0"); + Ok(0) + }, + } +} + +/// Get block hash from block number using RPC +pub async fn get_block_hash(client: &Client, block_number: u32) -> Result { + use subxt_rpcs::{RpcClient, client::RpcParams}; + + let mut params = RpcParams::new(); + params + .push(block_number) + .map_err(|e| Error::Other(format!("Failed to serialize block number: {e}")))?; + + let endpoint = client.current_endpoint().await; + let rpc_client = RpcClient::from_url(endpoint).await?; + + let block_hash: Option = + rpc_client.request("chain_getBlockHash", params).await.map_err(|e| { + Error::Other(format!("Failed to get block hash for block {block_number}: {e}")) + })?; + + block_hash.ok_or_else(|| { + Error::Other(format!("Block {block_number} not found (may be pruned or invalid)")) + }) +} + +/// Get chain properties (ss58 prefix, token decimals and symbol) from system_properties RPC +pub async fn get_chain_properties(client: Client) -> Result<(u16, u8, String), Error> { + use subxt_rpcs::{RpcClient, client::RpcParams}; + + let endpoint = client.current_endpoint().await; + let rpc_client = RpcClient::from_url(endpoint).await?; + + let response: ChainProperties = rpc_client + .request("system_properties", RpcParams::new()) + .await + .map_err(|e| Error::Other(format!("Failed to call system_properties RPC: {e}")))?; + + // Extract token decimals + let decimals = response.token_decimals.unwrap_or(10); // Default to 10 for most Substrate chains + + // Extract token symbol + let symbol = response.token_symbol.unwrap_or("UNIT".to_string()); // Default symbol + + // fetch the ss58 prefix of the chain + let ss58_prefix = get_ss58_prefix(&client).await?; + + log::info!( + target: LOG_TARGET, + "Fetched chain properties: ss_58 prefix={ss58_prefix} token_symbol={symbol}, token_decimals={decimals}" + ); + + Ok((ss58_prefix, decimals, symbol)) +} + +/// Encode an AccountId to SS58 string with chain-specific prefix +/// Uses ss58-registry to validate the prefix against known networks +pub fn encode_account_id(account: &AccountId, ss58_prefix: u16) -> String { + // Use ss58-registry to validate and get network information + let is_known = RegistryFormat::all().iter().any(|entry| { + let entry_format: RegistryFormat = (*entry).into(); + entry_format.prefix() == ss58_prefix + }); + + if is_known { + log::trace!( + target: LOG_TARGET, + "Encoding with SS58 prefix {ss58_prefix} (validated in registry)" + ); + } else { + log::trace!( + target: LOG_TARGET, + "Encoding with SS58 prefix {ss58_prefix} (custom format, not in registry)" + ); + } + + // Encode using the standard SS58 encoding with the provided prefix + // The registry validation above ensures we're aware if it's a known network + account + .clone() + .to_ss58check_with_version(Ss58AddressFormat::custom(ss58_prefix)) +} + +/// Convert Plancks to tokens (divide by 10^decimals) and format with token symbol +pub fn planck_to_token(planck: u128, decimals: u8, symbol: &str) -> String { + let divisor = 10_u128.pow(decimals as u32); + let whole = planck / divisor; + let remainder = planck % divisor; + + let amount_str = if remainder == 0 { + whole.to_string() + } else { + // Format with proper decimal places + let remainder_str = format!("{:0>width$}", remainder, width = decimals as usize); + // Remove trailing zeros + let remainder_trimmed = remainder_str.trim_end_matches('0'); + if remainder_trimmed.is_empty() { + whole.to_string() + } else { + format!("{whole}.{remainder_trimmed}") + } + }; + + format!("{amount_str} {symbol}") +} + +/// Convert Plancks (u64) to tokens with symbol +pub fn planck_to_token_u64(planck: u64, decimals: u8, symbol: &str) -> String { + planck_to_token(planck as u128, decimals, symbol) +} + #[cfg(test)] mod tests { use super::*; @@ -166,4 +342,53 @@ mod tests { Ok(SubmissionStrategy::ClaimBetterThan(Accuracy::from_percent(99))) ); } + + #[tokio::test] + async fn test_read_write_json_file() { + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("test.json"); + + let data = vec![1, 2, 3]; + write_data_to_json_file(&data, &file_path).await.unwrap(); + + let read_data: Vec = read_data_from_json_file(&file_path).await.unwrap(); + assert_eq!(data, read_data); + } + + #[test] + fn test_planck_to_token() { + assert_eq!(planck_to_token(100, 2, "DOT"), "1 DOT"); + assert_eq!(planck_to_token(100, 0, "DOT"), "100 DOT"); + assert_eq!(planck_to_token(1234, 3, "DOT"), "1.234 DOT"); + assert_eq!(planck_to_token(1230, 3, "DOT"), "1.23 DOT"); + assert_eq!(planck_to_token(5, 2, "DOT"), "0.05 DOT"); + assert_eq!(planck_to_token(0, 5, "DOT"), "0 DOT"); + } + + #[test] + fn test_planck_to_token_u64() { + assert_eq!(planck_to_token_u64(100, 2, "KSM"), "1 KSM"); + assert_eq!(planck_to_token_u64(123456789, 6, "KSM"), "123.456789 KSM"); + } + + #[test] + fn test_encode_account_id() { + // Alice's public key + let alice_pub_key = [ + 212, 53, 147, 199, 21, 253, 211, 28, 97, 20, 26, 189, 4, 169, 159, 214, 130, 44, 133, + 88, 133, 76, 205, 227, 154, 86, 132, 231, 165, 109, 162, 125, + ]; + let account = AccountId::new(alice_pub_key); + + // Polkadot prefix (0) + assert_eq!( + encode_account_id(&account, 0), + "15oF4uVJwmo4TdGW7VfQxNLavjCXviqxT9S1MgbjMNHr6Sp5" + ); + // Generic Substrate (42) + assert_eq!( + encode_account_id(&account, 42), + "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" + ); + } } diff --git a/tests/monitor.rs b/tests/monitor.rs index e2e3c1667..60e65bac1 100644 --- a/tests/monitor.rs +++ b/tests/monitor.rs @@ -1,7 +1,6 @@ #![cfg(feature = "integration-tests")] //! Integration tests for the multi-block monitor (pallet-election-multi-block). //! See nightly.yml for instructions on how to run it compared vs a zombienet setup. -use assert_cmd::cargo::cargo_bin; use polkadot_staking_miner::{ prelude::ChainClient, runtime::multi_block::{ @@ -47,7 +46,7 @@ fn run_miner(port: u16, seed: &str, shady: bool) -> KillChildOnDrop { } let mut miner = KillChildOnDrop( - std::process::Command::new(cargo_bin(env!("CARGO_PKG_NAME"))) + std::process::Command::new(assert_cmd::cargo::cargo_bin!(env!("CARGO_PKG_NAME"))) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .args(args) diff --git a/tests/predict.rs b/tests/predict.rs new file mode 100644 index 000000000..fdde1c037 --- /dev/null +++ b/tests/predict.rs @@ -0,0 +1,188 @@ +//! Tests for the predict command + +use assert_cmd::Command; +use polkadot_staking_miner::commands::types::PredictConfig; +use std::fs; +use tempfile::TempDir; + +/// Test that the predict command help works +#[test] +fn predict_help_works() { + let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!(env!("CARGO_PKG_NAME"))); + cmd.args(["predict", "--help"]); + cmd.assert().success(); +} + +/// Test that predict command accepts basic arguments +#[test] +fn predict_cli_args_parsing() { + // Test with desired validators + let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!(env!("CARGO_PKG_NAME"))); + cmd.args(["predict", "--desired-validators", "19"]); + // This will fail because we need a valid URI, but we're just testing argument parsing + // In a real scenario, you'd need a running node or mock +} + +/// Test PredictConfig parsing +#[test] +fn predict_config_parsing() { + use clap::Parser; + + // Test with output directory + let config = PredictConfig::try_parse_from(["predict", "--output-dir", "outputs"]).unwrap(); + assert_eq!(config.output_dir, "outputs"); + + // Test with overrides + let config = + PredictConfig::try_parse_from(["predict", "--overrides", "overrides.json"]).unwrap(); + assert_eq!(config.overrides, Some("overrides.json".to_string())); + + // Test with algorithm + let config = PredictConfig::try_parse_from(["predict", "--algorithm", "phragmms"]).unwrap(); + assert_eq!( + config.algorithm, + polkadot_staking_miner::commands::types::ElectionAlgorithm::Phragmms + ); + + // Test with all options + let config = PredictConfig::try_parse_from([ + "predict", + "--desired-validators", + "50", + "--overrides", + "data/overrides.json", + "--output-dir", + "test_outputs", + "--algorithm", + "phragmms", + ]) + .unwrap(); + assert_eq!(config.desired_validators, Some(50)); + assert_eq!(config.overrides, Some("data/overrides.json".to_string())); + assert_eq!(config.output_dir, "test_outputs"); + assert_eq!( + config.algorithm, + polkadot_staking_miner::commands::types::ElectionAlgorithm::Phragmms + ); +} + +/// Test output directory creation +#[test] +fn test_output_directory_creation() { + let temp_dir = TempDir::new().unwrap(); + let output_dir = temp_dir.path().join("test_outputs"); + + // Directory should not exist initially + assert!(!output_dir.exists()); + + // Create directory + fs::create_dir_all(&output_dir).unwrap(); + + // Directory should now exist + assert!(output_dir.exists()); + assert!(output_dir.is_dir()); +} + +/// Test that output files are created in the correct directory +#[test] +fn test_output_files_creation() { + let temp_dir = TempDir::new().unwrap(); + let output_dir = temp_dir.path().join("test_outputs"); + fs::create_dir_all(&output_dir).unwrap(); + + // Create expected output files + let validators_output = output_dir.join("validators_prediction.json"); + let nominators_output = output_dir.join("nominators_prediction.json"); + + // Create test data + let validators_data = serde_json::json!({ + "metadata": { + "timestamp": "1234567890", + "desired_validators": 19, + "round": 0, + "block_number": 1000, + "solution_score": null, + "data_source": "test" + }, + "results": [] + }); + + let nominators_data = serde_json::json!({ + "metadata": { + "timestamp": "1234567890", + "desired_validators": 19, + "round": 0, + "block_number": 1000, + "solution_score": null, + "data_source": "test" + }, + "nominators": [] + }); + + // Write files + fs::write(&validators_output, serde_json::to_string_pretty(&validators_data).unwrap()).unwrap(); + fs::write(&nominators_output, serde_json::to_string_pretty(&nominators_data).unwrap()).unwrap(); + + // Verify files exist + assert!(validators_output.exists()); + assert!(nominators_output.exists()); + + // Verify file contents + let validators_content: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&validators_output).unwrap()).unwrap(); + assert_eq!(validators_content["metadata"]["desired_validators"], 19); + + let nominators_content: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&nominators_output).unwrap()).unwrap(); + assert_eq!(nominators_content["metadata"]["desired_validators"], 19); +} + +/// Create a temporary overrides file for testing +fn create_test_overrides_file(temp_dir: &TempDir) -> std::path::PathBuf { + let overrides_data = serde_json::json!({ + "candidates_include": ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG"], + "candidates_exclude": [], + "voters_include": [ + ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG", 1000000, ["15S7YtETM31QxYYqubAwRJKRSM4v4Ua6WGFYnx1VuFBnWqdG"]] + ], + "voters_exclude": [] + }); + + let file_path = temp_dir.path().join("test_overrides.json"); + fs::write(&file_path, serde_json::to_string_pretty(&overrides_data).unwrap()).unwrap(); + file_path +} + +/// Test that overrides file format is correctly validated +#[test] +fn test_overrides_file_format_validation() { + let temp_dir = TempDir::new().unwrap(); + let overrides_path = create_test_overrides_file(&temp_dir); + + // Read and parse the file + let content = fs::read_to_string(&overrides_path).unwrap(); + let parsed: polkadot_staking_miner::commands::types::ElectionOverrides = + serde_json::from_str(&content).unwrap(); + + // Validate structure + assert_eq!(parsed.candidates_include.len(), 1); + assert_eq!(parsed.voters_include.len(), 1); + assert_eq!(parsed.voters_include[0].1, 1000000); +} + +/// Test invalid overrides file format +#[test] +fn test_invalid_overrides_file_format() { + let temp_dir = TempDir::new().unwrap(); + let invalid_file = temp_dir.path().join("invalid_overrides.json"); + + // Write invalid JSON (missing some fields is okay due to #[serde(default)], but totally invalid + // JSON should fail) + fs::write(&invalid_file, "{ this is not json }").unwrap(); + + // Try to parse - should fail + let content = fs::read_to_string(&invalid_file).unwrap(); + let result: Result = + serde_json::from_str(&content); + assert!(result.is_err()); +}