Skip to content

Commit c3ba626

Browse files
committed
implemented separate clients for monitor and predict | some refactoring
1 parent 8415e40 commit c3ba626

File tree

6 files changed

+79
-30
lines changed

6 files changed

+79
-30
lines changed

src/client.rs

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
error::{Error, TimeoutError},
3-
prelude::{ChainClient, LOG_TARGET},
3+
prelude::{ChainClient, Config, LOG_TARGET},
44
prometheus,
55
};
66
use std::{
@@ -11,6 +11,7 @@ use std::{
1111
time::Duration,
1212
};
1313
use subxt::backend::{
14+
chain_head::{ChainHeadBackend, ChainHeadBackendBuilder},
1415
legacy::LegacyBackend,
1516
rpc::reconnecting_rpc_client::{ExponentialBackoff, RpcClient as ReconnectingRpcClient},
1617
};
@@ -52,7 +53,7 @@ pub struct Client {
5253
}
5354

5455
impl Client {
55-
/// Create a new client from a comma-separated list of RPC endpoints.
56+
/// Create a new client from a comma-separated list of RPC endpoints using ChainHeadBackend.
5657
///
5758
/// The client will try each endpoint in sequence until one connects successfully.
5859
/// Multiple endpoints can be specified for failover:
@@ -72,7 +73,38 @@ impl Client {
7273
log::info!(target: LOG_TARGET, "RPC endpoint pool: {} endpoint(s)", endpoints.len());
7374
}
7475

75-
let (chain_api, connected_index) = Self::connect_with_failover(&endpoints, 0).await?;
76+
let (chain_api, connected_index) =
77+
Self::connect_with_failover(&endpoints, 0, false).await?;
78+
79+
Ok(Self {
80+
chain_api: Arc::new(RwLock::new(chain_api)),
81+
endpoints: Arc::new(endpoints),
82+
current_endpoint_index: Arc::new(AtomicUsize::new(connected_index)),
83+
reconnect_generation: Arc::new(AtomicUsize::new(0)),
84+
})
85+
}
86+
87+
/// Create a new client from a comma-separated list of RPC endpoints using LegacyBackend.
88+
///
89+
/// The client will try each endpoint in sequence until one connects successfully.
90+
/// Multiple endpoints can be specified for failover:
91+
/// "wss://rpc1.example.com,wss://rpc2.example.com"
92+
pub async fn new_with_legacy_backend(uris: &str) -> Result<Self, Error> {
93+
let endpoints: Vec<String> = uris
94+
.split(',')
95+
.map(|s| s.trim().to_string())
96+
.filter(|s| !s.is_empty())
97+
.collect();
98+
99+
if endpoints.is_empty() {
100+
return Err(Error::Other("No RPC endpoints provided".into()));
101+
}
102+
103+
if endpoints.len() > 1 {
104+
log::info!(target: LOG_TARGET, "RPC endpoint pool: {} endpoint(s)", endpoints.len());
105+
}
106+
107+
let (chain_api, connected_index) = Self::connect_with_failover(&endpoints, 0, true).await?;
76108

77109
Ok(Self {
78110
chain_api: Arc::new(RwLock::new(chain_api)),
@@ -95,6 +127,7 @@ impl Client {
95127
async fn connect_with_failover(
96128
endpoints: &[String],
97129
start_index: usize,
130+
use_legacy: bool,
98131
) -> Result<(ChainClient, usize), Error> {
99132
let mut last_error = None;
100133
let total = endpoints.len();
@@ -112,7 +145,7 @@ impl Client {
112145
"attempting to connect to {uri:?} (endpoint {endpoint_num}/{total}, attempt {attempt}/{max_attempts})"
113146
);
114147

115-
match Self::try_connect(uri).await {
148+
match Self::try_connect(uri, use_legacy).await {
116149
Ok(client) => {
117150
if total > 1 {
118151
log::info!(
@@ -150,8 +183,8 @@ impl Client {
150183
}
151184

152185
/// Try to connect to a single endpoint with timeout.
153-
async fn try_connect(uri: &str) -> Result<ChainClient, Error> {
154-
let connect_future = async {
186+
async fn try_connect(uri: &str, use_legacy: bool) -> Result<ChainClient, Error> {
187+
let connect_future = async move {
155188
let reconnecting_rpc = ReconnectingRpcClient::builder()
156189
.retry_policy(
157190
ExponentialBackoff::from_millis(500).max_delay(Duration::from_secs(10)).take(3),
@@ -160,11 +193,18 @@ impl Client {
160193
.await
161194
.map_err(|e| Error::Other(format!("Failed to connect: {e:?}")))?;
162195

163-
let backend = LegacyBackend::builder().build(reconnecting_rpc.clone());
164-
165-
let chain_api = ChainClient::from_backend(Arc::new(backend)).await?;
166-
167-
log::info!(target: LOG_TARGET, "Connected to {uri} with Legacy backend");
196+
let chain_api = if use_legacy {
197+
let backend = LegacyBackend::builder().build(reconnecting_rpc.clone());
198+
let client = ChainClient::from_backend(Arc::new(backend)).await?;
199+
log::info!(target: LOG_TARGET, "Connected to {uri} with Legacy backend");
200+
client
201+
} else {
202+
let backend: ChainHeadBackend<Config> = ChainHeadBackendBuilder::default()
203+
.build_with_background_driver(reconnecting_rpc);
204+
let client = ChainClient::from_backend(Arc::new(backend)).await?;
205+
log::info!(target: LOG_TARGET, "Connected to {uri} with ChainHead backend");
206+
client
207+
};
168208

169209
Ok::<ChainClient, Error>(chain_api)
170210
};
@@ -216,7 +256,7 @@ impl Client {
216256

217257
// Establish new connection before acquiring write lock
218258
let (new_client, connected_idx) =
219-
Self::connect_with_failover(&self.endpoints, start_idx).await?;
259+
Self::connect_with_failover(&self.endpoints, start_idx, false).await?;
220260

221261
// Acquire write lock and check if another task already reconnected
222262
let mut guard = self.chain_api.write().await;

src/commands/predict.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ where
7979
.map_err(|e| {
8080
Error::Other(format!("Failed to fetch Desired Targets from chain: {e}"))
8181
})?
82-
.expect("Error in fetching desired validators from chain")
82+
.ok_or_else(|| {
83+
Error::Other("Desired validators not found in chain storage".to_string())
84+
})?
8385
},
8486
};
8587

@@ -100,7 +102,10 @@ where
100102
let (target_snapshot, mut voter_snapshot) =
101103
convert_election_data_to_snapshots::<T>(candidates, nominators)?;
102104

103-
// Fix the order for staking data source
105+
// When fetching from staking data, voters come from BagsList in descending order (highest
106+
// stake first). The SDK expects page 0 (lsp) to contain lowest stake voters and page n-1
107+
// (msp) to contain highest stake voters. Reversing ensures correct page assignment during
108+
// pagination.
104109
if matches!(data_source, ElectionDataSource::Staking) {
105110
voter_snapshot.reverse();
106111
}
@@ -114,7 +119,7 @@ where
114119
);
115120

116121
// Use actual voter page count, not the chain's max pages
117-
// Staking data may have added some pages
122+
// Staking/Overridden data may have added some pages
118123
let n_pages = n_pages.max(voter_snapshot.len() as u32);
119124

120125
// Mine the solution with timeout to prevent indefinite hanging
@@ -181,7 +186,7 @@ where
181186
let validators_output = output_dir.join("validators_prediction.json");
182187
let nominators_output = output_dir.join("nominators_prediction.json");
183188
// Save validators prediction
184-
write_data_to_json_file(&validators_prediction, validators_output.to_str().unwrap()).await?;
189+
write_data_to_json_file(&validators_prediction, &validators_output).await?;
185190

186191
log::info!(
187192
target: LOG_TARGET,
@@ -190,7 +195,7 @@ where
190195
);
191196

192197
// Save nominators prediction
193-
write_data_to_json_file(&nominators_prediction, nominators_output.to_str().unwrap()).await?;
198+
write_data_to_json_file(&nominators_prediction, &nominators_output).await?;
194199

195200
log::info!(
196201
target: LOG_TARGET,

src/dynamic/election_data.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,10 @@ where
267267
let validators_with_only_self_vote: HashSet<AccountId> = all_voters
268268
.iter()
269269
.filter(|(nominator, _, targets)| {
270-
// validator has only self-vote if:
270+
// validator has only self-vote if either:
271271
// 1. They are a validator (in active_set)
272272
// 2. Their only target is themselves
273-
// NOTE: Reverted to your original logic as requested, assuming you want strictly this
274-
// behavior.
273+
275274
active_set.contains(nominator) || (targets.len() == 1 && targets[0] == *nominator)
276275
})
277276
.map(|(nominator, _, _)| nominator.clone())

src/dynamic/staking.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ pub(crate) async fn fetch_voter_list(
185185
// Increase voter limit to have a buffer for filtering ineligible voters later
186186
let extended_voter_limit = voter_limit.saturating_add(100);
187187

188-
log::info!(target: LOG_TARGET, "Fetching From Voter List");
188+
log::info!(target: LOG_TARGET, "Fetching From VoterList");
189189

190190
// Fetch all bags (ListBags) - store as HashMap with bag_upper as key
191191
log::trace!(target: LOG_TARGET, "Fetching ListBags...");
@@ -371,7 +371,7 @@ pub(crate) async fn fetch_voter_list(
371371

372372
log::info!(
373373
target: LOG_TARGET,
374-
"Voter List Fetch Completed"
374+
"VoterList Fetch Completed"
375375
);
376376
log::info!(
377377
target: LOG_TARGET,

src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ async fn main() -> Result<(), Error> {
103103
// Initialize the timestamp so that if connection hangs, the stall detection alert can fire.
104104
prometheus::set_last_block_processing_time();
105105

106-
let client = Client::new(&uri).await?;
106+
// Create client with appropriate backend based on command type
107+
let client = match command {
108+
Command::Predict(_) => Client::new_with_legacy_backend(&uri).await?,
109+
_ => Client::new(&uri).await?,
110+
};
107111

108112
let version_bytes = client
109113
.chain_api()

src/utils.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ pub async fn wait_tx_in_finalized_block(
123123
}
124124

125125
/// Write data to a JSON file
126-
pub async fn write_data_to_json_file<T>(data: &T, file_path: &str) -> Result<(), Error>
126+
pub async fn write_data_to_json_file<T, P>(data: &T, file_path: &P) -> Result<(), Error>
127127
where
128128
T: Serialize,
129+
P: AsRef<Path>,
129130
{
130-
let path = Path::new(file_path);
131+
let path = file_path.as_ref();
131132
if let Some(parent) = path.parent() &&
132133
!parent.as_os_str().is_empty()
133134
{
@@ -145,11 +146,12 @@ where
145146
}
146147

147148
/// Read data from a JSON file
148-
pub async fn read_data_from_json_file<T>(file_path: &str) -> Result<T, Error>
149+
pub async fn read_data_from_json_file<T, P>(file_path: P) -> Result<T, Error>
149150
where
150151
T: DeserializeOwned,
152+
P: AsRef<Path>,
151153
{
152-
let path = Path::new(file_path);
154+
let path = file_path.as_ref();
153155

154156
let mut file = File::open(path)?;
155157
let mut content = String::new();
@@ -345,12 +347,11 @@ mod tests {
345347
async fn test_read_write_json_file() {
346348
let dir = tempfile::tempdir().unwrap();
347349
let file_path = dir.path().join("test.json");
348-
let file_path_str = file_path.to_str().unwrap();
349350

350351
let data = vec![1, 2, 3];
351-
write_data_to_json_file(&data, file_path_str).await.unwrap();
352+
write_data_to_json_file(&data, &file_path).await.unwrap();
352353

353-
let read_data: Vec<i32> = read_data_from_json_file(file_path_str).await.unwrap();
354+
let read_data: Vec<i32> = read_data_from_json_file(&file_path).await.unwrap();
354355
assert_eq!(data, read_data);
355356
}
356357

0 commit comments

Comments
 (0)