diff --git a/scylla/src/client/session.rs b/scylla/src/client/session.rs index 44de7862fb..ec8b67ce23 100644 --- a/scylla/src/client/session.rs +++ b/scylla/src/client/session.rs @@ -12,8 +12,9 @@ use crate::cluster::node::CloudEndpoint; use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef}; use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState}; use crate::errors::{ - BadQuery, ExecutionError, MetadataError, NewSessionError, PagerExecutionError, PrepareError, - RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError, + BadQuery, ConnectionPoolError, ExecutionError, MetadataError, NewSessionError, + PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError, + TracingError, UseKeyspaceError, }; use crate::frame::response::result; use crate::network::tls::TlsProvider; @@ -41,7 +42,6 @@ use crate::statement::{Consistency, PageSize, StatementConfig}; use arc_swap::ArcSwapOption; use futures::future::join_all; use futures::future::try_join_all; -use itertools::Itertools; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::serialize::batch::BatchValues; use scylla_cql::serialize::row::{SerializeRow, SerializedValues}; @@ -1198,6 +1198,10 @@ impl Session { /// Prepares a statement on the server side and returns a prepared statement, /// which can later be used to perform more efficient requests. /// + /// The statement is prepared on all nodes. This function finishes once any node reports preparation success + /// or when preparation on all nodes fails. + // TODO: Consider introducing timeouts here. + /// /// Prepared statements are much faster than unprepared statements: /// * Database doesn't need to parse the statement string upon each execution (only once) /// * They are properly load balanced using token aware routing @@ -1236,44 +1240,160 @@ impl Session { statement: impl Into, ) -> Result { let statement = statement.into(); - let statement_ref = &statement; + self.prepare_nongeneric(statement).await + } + + // Introduced to avoid monomorphisation of this large function. + async fn prepare_nongeneric( + &self, + statement: Statement, + ) -> Result { + type PreparationResult = Result; let cluster_state = self.get_cluster_state(); - let connections_iter = cluster_state.iter_working_connections()?; - - // Prepare statements on all connections concurrently - let handles = connections_iter.map(|c| async move { c.prepare(statement_ref).await }); - let mut results = join_all(handles).await.into_iter(); - - // If at least one prepare was successful, `prepare()` returns Ok. - // Find the first result that is Ok, or Err if all failed. - - // Safety: there is at least one node in the cluster, and `Cluster::iter_working_connections()` - // returns either an error or an iterator with at least one connection, so there will be at least one result. - let first_ok: Result = - results.by_ref().find_or_first(Result::is_ok).unwrap(); - let mut prepared: PreparedStatement = - first_ok.map_err(|first_attempt| PrepareError::AllAttemptsFailed { first_attempt })?; - - // Validate prepared ids equality - for statement in results.flatten() { - if prepared.get_id() != statement.get_id() { - return Err(PrepareError::PreparedStatementIdsMismatch); - } - // Collect all tracing ids from prepare() queries in the final result - prepared - .prepare_tracing_ids - .extend(statement.prepare_tracing_ids); + /// Prepares statement on all nodes/shards concurrently. + /// + /// Sends result of each preparation attempt through a channel, whose receiving end is first sent + /// though the oneshot channel accepted as an argument. + /// + /// If no connection is working, sends a `ConnectionPoolError` instead of the channel's RX. + async fn preparation_worker( + cluster_state: Arc, + statement: Statement, + oneshot_tx: tokio::sync::oneshot::Sender< + Result, ConnectionPoolError>, + >, + prepare_on_all_shards: bool, + ) { + // `iter_working_connection_to_nodes()` returns no more than one connection per node, so the number of all nodes + // is a reasonable capacity for the channel. + let (tx, rx) = + tokio::sync::mpsc::channel::(cluster_state.all_nodes.len()); + + let working_connections = if prepare_on_all_shards { + cluster_state + .iter_working_connections_to_shards() + .map(itertools::Either::Left) + } else { + cluster_state + .iter_working_connections_to_nodes() + .map(itertools::Either::Right) + }; + let connections_iter = match working_connections { + Ok(iter) => { + // We have at least one working connection to some node. + // Let's provide our listener with the receiving end of the preparation results channel. + let _ = oneshot_tx.send(Ok(rx)); + iter + } + Err(pool_error) => { + // We have no working connection to any node. + // Notify our listener and finish. + let _ = oneshot_tx.send(Err(pool_error)); + return; + } + }; + + let tx_ref = &tx; + let statement_ref = &statement; + let preparations = connections_iter.map(|c| async move { + let res = c.prepare(statement_ref).await; + let _ = tx_ref.send(res).await; + }); + join_all(preparations).await; } - prepared.set_partitioner_name( - self.extract_partitioner_name(&prepared, &self.cluster.get_state()) - .and_then(PartitionerName::from_str) - .unwrap_or_default(), - ); + /// Prepares the statement on either all nodes or all shards. + /// + /// Sets up the worker task that attempts preparation on (all nodes) or (all shards), depending on the flag value. + /// Finishes once any preparation succeeds or when all attempts fail. If this functions return happily, + /// the worker task keeps preparing on other connections in the background. + /// + /// Returns: + /// - `Err(ConnectionPoolError)`, if no connection is working; + /// - `Ok(Ok(PreparedStatement))`, if preparation succeeded on at least one connection, + /// - `Ok(Err(RequestAttemptError))`, if preparation failed on all attempted connections. + async fn prepare_on_all( + session: &Session, + statement: Statement, + cluster_state: Arc, + on_all_shards: bool, + ) -> Result { + // This is required for the following reason: + // 1. The iterator returned from `ClusterState::iter_working_connections_to_{nodes,shards}()` borrows `ClusterState`. + // 2. Only after we call `ClusterState::iter_working_connections_to_{nodes,shards}()` do we know if there is at least + // one working connection. It would be thus perfect to call it here (not in the worker task) and return + // the error early if no connection is working. However, we cannot send the resulting iterator to the task, + // because the iterator is not 'static. + // 3. Thus, it must be the worker task that calls `ClusterState::iter_working_connections_to_{nodes,shards}()`. If it fails, + // it signals `ConnectionPoolError` to the listening task (us). Else, it provides the listening task (us) + // with an mpsc channel that will be used to send subsequent results of preparation attempts on connections. + let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel::< + Result, ConnectionPoolError>, + >(); + + tokio::task::spawn(preparation_worker( + cluster_state.clone(), + statement, + oneshot_tx, + on_all_shards, + )); + + // If at least one prepare was successful, `prepare()` returns Ok. + // Find the first result that is Ok, or Err if all failed. + let mut rx = oneshot_rx + .await + .expect("statement preparation tokio task terminated prematurely")?; + + let mut first_error = None; + while let Some(prepare_result) = rx.recv().await { + match prepare_result { + Ok(mut prepared) => { + // This is the first preparation that succeeded. + // Let's return the PreparedStatement. + // Preparation on other nodes will continue in the background tokio task. + prepared.set_partitioner_name( + session + .extract_partitioner_name(&prepared, &cluster_state) + .and_then(PartitionerName::from_str) + .unwrap_or_default(), + ); + return Ok(Ok(prepared)); + } + Err(attempt_error) => { + if first_error.is_none() { + first_error = Some(attempt_error); + } + } + } + } + // Safety: there is at least one node in the cluster, and `ClusterState::iter_working_connections_to_{nodes,shards}()` + // returns either an error or an iterator with at least one connection, so there will be at least one result. + Ok(Err(first_error.expect( + "ClusterState::iter_working_connections_to_{nodes,shards}() returns at least one connection or errors out", + ))) + } - Ok(prepared) + // Start by attempting preparation on a single (random) connection to every node. + { + let on_all_nodes_result = + prepare_on_all(self, statement.clone(), cluster_state.clone(), false).await?; + if let Ok(prepared) = on_all_nodes_result { + // We succeeded in preparing the statement on at least one node. We're done; at the same time, + // the background tokio task attempts preparation on remaining nodes. + return Ok(prepared); + } + } + + // We could have been just unlucky: we could have possibly chosen random connections all of which were defunct + // (one possibility is that we targeted overloaded shards). + // Let's try again, this time on connections to every shard. This is a "last call" fallback. + { + let on_all_shards_result = prepare_on_all(self, statement, cluster_state, true).await?; + on_all_shards_result + .map_err(|err| PrepareError::AllAttemptsFailed { first_attempt: err }) + } } fn extract_partitioner_name<'a>( @@ -2098,7 +2218,7 @@ impl Session { pub async fn check_schema_agreement(&self) -> Result, SchemaAgreementError> { let cluster_state = self.get_cluster_state(); - let connections_iter = cluster_state.iter_working_connections()?; + let connections_iter = cluster_state.iter_working_connections_to_shards()?; let handles = connections_iter.map(|c| async move { c.fetch_schema_version().await }); let versions = try_join_all(handles).await?; diff --git a/scylla/src/cluster/node.rs b/scylla/src/cluster/node.rs index e4e9332ae3..4c38e75706 100644 --- a/scylla/src/cluster/node.rs +++ b/scylla/src/cluster/node.rs @@ -207,6 +207,10 @@ impl Node { self.get_pool()?.get_working_connections() } + pub(crate) fn get_random_connection(&self) -> Result, ConnectionPoolError> { + self.get_pool()?.random_connection() + } + pub(crate) async fn wait_until_pool_initialized(&self) { if let Some(pool) = &self.pool { pool.wait_until_initialized().await; diff --git a/scylla/src/cluster/state.rs b/scylla/src/cluster/state.rs index 9796d4b33b..570da2dc59 100644 --- a/scylla/src/cluster/state.rs +++ b/scylla/src/cluster/state.rs @@ -319,7 +319,7 @@ impl ClusterState { } /// Returns nonempty iterator of working connections to all shards. - pub(crate) fn iter_working_connections( + pub(crate) fn iter_working_connections_to_shards( &self, ) -> Result> + '_, ConnectionPoolError> { // The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`. @@ -344,6 +344,28 @@ impl ClusterState { // is nonempty, too. } + /// Returns nonempty iterator of working connections to all nodes. + pub(crate) fn iter_working_connections_to_nodes( + &self, + ) -> Result> + '_, ConnectionPoolError> { + // The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`. + assert!(!self.known_peers.is_empty()); + let mut peers_iter = self.known_peers.values(); + + // First we try to find the first working pool of connections. + // If none is found, return error. + let first_working_pool = peers_iter + .by_ref() + .map(|node| node.get_random_connection()) + .find_or_first(Result::is_ok) + .expect("impossible: known_peers was asserted to be nonempty")?; + + let remaining_pools_iter = peers_iter.flat_map(|node| node.get_random_connection()); + + Ok(std::iter::once(first_working_pool).chain(remaining_pools_iter)) + // The returned iterator is nonempty, because it returns at least `first_working_pool`. + } + pub(super) fn update_tablets(&mut self, raw_tablets: Vec<(TableSpec<'static>, RawTablet)>) { let replica_translator = |uuid: Uuid| self.known_peers.get(&uuid).cloned();