Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

session: Prepare on one shard per node only #1320

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
192 changes: 156 additions & 36 deletions scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::cluster::node::CloudEndpoint;
use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
use crate::errors::{
BadQuery, ExecutionError, MetadataError, NewSessionError, PagerExecutionError, PrepareError,
RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError,
BadQuery, ConnectionPoolError, ExecutionError, MetadataError, NewSessionError,
PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError,
TracingError, UseKeyspaceError,
};
use crate::frame::response::result;
use crate::network::tls::TlsProvider;
Expand Down Expand Up @@ -41,7 +42,6 @@ use crate::statement::{Consistency, PageSize, StatementConfig};
use arc_swap::ArcSwapOption;
use futures::future::join_all;
use futures::future::try_join_all;
use itertools::Itertools;
use scylla_cql::frame::response::NonErrorResponse;
use scylla_cql::serialize::batch::BatchValues;
use scylla_cql::serialize::row::{SerializeRow, SerializedValues};
Expand Down Expand Up @@ -1198,6 +1198,10 @@ impl Session {
/// Prepares a statement on the server side and returns a prepared statement,
/// which can later be used to perform more efficient requests.
///
/// The statement is prepared on all nodes. This function finishes once any node reports preparation success
/// or when preparation on all nodes fails.
// TODO: Consider introducing timeouts here.
///
/// Prepared statements are much faster than unprepared statements:
/// * Database doesn't need to parse the statement string upon each execution (only once)
/// * They are properly load balanced using token aware routing
Expand Down Expand Up @@ -1236,44 +1240,160 @@ impl Session {
statement: impl Into<Statement>,
) -> Result<PreparedStatement, PrepareError> {
let statement = statement.into();
let statement_ref = &statement;
self.prepare_nongeneric(statement).await
}

// Introduced to avoid monomorphisation of this large function.
async fn prepare_nongeneric(
&self,
statement: Statement,
) -> Result<PreparedStatement, PrepareError> {
type PreparationResult = Result<PreparedStatement, RequestAttemptError>;

let cluster_state = self.get_cluster_state();
let connections_iter = cluster_state.iter_working_connections()?;

// Prepare statements on all connections concurrently
let handles = connections_iter.map(|c| async move { c.prepare(statement_ref).await });
let mut results = join_all(handles).await.into_iter();

// If at least one prepare was successful, `prepare()` returns Ok.
// Find the first result that is Ok, or Err if all failed.

// Safety: there is at least one node in the cluster, and `Cluster::iter_working_connections()`
// returns either an error or an iterator with at least one connection, so there will be at least one result.
let first_ok: Result<PreparedStatement, RequestAttemptError> =
results.by_ref().find_or_first(Result::is_ok).unwrap();
let mut prepared: PreparedStatement =
first_ok.map_err(|first_attempt| PrepareError::AllAttemptsFailed { first_attempt })?;

// Validate prepared ids equality
for statement in results.flatten() {
if prepared.get_id() != statement.get_id() {
return Err(PrepareError::PreparedStatementIdsMismatch);
}

// Collect all tracing ids from prepare() queries in the final result
prepared
.prepare_tracing_ids
.extend(statement.prepare_tracing_ids);
/// Prepares statement on all nodes/shards concurrently.
///
/// Sends result of each preparation attempt through a channel, whose receiving end is first sent
/// though the oneshot channel accepted as an argument.
///
/// If no connection is working, sends a `ConnectionPoolError` instead of the channel's RX.
async fn preparation_worker(
cluster_state: Arc<ClusterState>,
statement: Statement,
oneshot_tx: tokio::sync::oneshot::Sender<
Result<tokio::sync::mpsc::Receiver<PreparationResult>, ConnectionPoolError>,
>,
prepare_on_all_shards: bool,
) {
// `iter_working_connection_to_nodes()` returns no more than one connection per node, so the number of all nodes
// is a reasonable capacity for the channel.
let (tx, rx) =
tokio::sync::mpsc::channel::<PreparationResult>(cluster_state.all_nodes.len());

let working_connections = if prepare_on_all_shards {
cluster_state
.iter_working_connections_to_shards()
.map(itertools::Either::Left)
} else {
cluster_state
.iter_working_connections_to_nodes()
.map(itertools::Either::Right)
};
let connections_iter = match working_connections {
Ok(iter) => {
// We have at least one working connection to some node.
// Let's provide our listener with the receiving end of the preparation results channel.
let _ = oneshot_tx.send(Ok(rx));
iter
}
Err(pool_error) => {
// We have no working connection to any node.
// Notify our listener and finish.
let _ = oneshot_tx.send(Err(pool_error));
return;
}
};

let tx_ref = &tx;
let statement_ref = &statement;
let preparations = connections_iter.map(|c| async move {
let res = c.prepare(statement_ref).await;
let _ = tx_ref.send(res).await;
});
join_all(preparations).await;
}

prepared.set_partitioner_name(
self.extract_partitioner_name(&prepared, &self.cluster.get_state())
.and_then(PartitionerName::from_str)
.unwrap_or_default(),
);
/// Prepares the statement on either all nodes or all shards.
///
/// Sets up the worker task that attempts preparation on (all nodes) or (all shards), depending on the flag value.
/// Finishes once any preparation succeeds or when all attempts fail. If this functions return happily,
/// the worker task keeps preparing on other connections in the background.
///
/// Returns:
/// - `Err(ConnectionPoolError)`, if no connection is working;
/// - `Ok(Ok(PreparedStatement))`, if preparation succeeded on at least one connection,
/// - `Ok(Err(RequestAttemptError))`, if preparation failed on all attempted connections.
async fn prepare_on_all(
session: &Session,
statement: Statement,
cluster_state: Arc<ClusterState>,
on_all_shards: bool,
) -> Result<PreparationResult, ConnectionPoolError> {
// This is required for the following reason:
// 1. The iterator returned from `ClusterState::iter_working_connections_to_{nodes,shards}()` borrows `ClusterState`.
// 2. Only after we call `ClusterState::iter_working_connections_to_{nodes,shards}()` do we know if there is at least
// one working connection. It would be thus perfect to call it here (not in the worker task) and return
// the error early if no connection is working. However, we cannot send the resulting iterator to the task,
// because the iterator is not 'static.
// 3. Thus, it must be the worker task that calls `ClusterState::iter_working_connections_to_{nodes,shards}()`. If it fails,
// it signals `ConnectionPoolError` to the listening task (us). Else, it provides the listening task (us)
// with an mpsc channel that will be used to send subsequent results of preparation attempts on connections.
let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel::<
Result<tokio::sync::mpsc::Receiver<PreparationResult>, ConnectionPoolError>,
>();

tokio::task::spawn(preparation_worker(
cluster_state.clone(),
statement,
oneshot_tx,
on_all_shards,
));

// If at least one prepare was successful, `prepare()` returns Ok.
// Find the first result that is Ok, or Err if all failed.
let mut rx = oneshot_rx
.await
.expect("statement preparation tokio task terminated prematurely")?;

let mut first_error = None;
while let Some(prepare_result) = rx.recv().await {
match prepare_result {
Ok(mut prepared) => {
// This is the first preparation that succeeded.
// Let's return the PreparedStatement.
// Preparation on other nodes will continue in the background tokio task.
prepared.set_partitioner_name(
session
.extract_partitioner_name(&prepared, &cluster_state)
.and_then(PartitionerName::from_str)
.unwrap_or_default(),
);
return Ok(Ok(prepared));
}
Err(attempt_error) => {
if first_error.is_none() {
first_error = Some(attempt_error);
}
}
}
}
// Safety: there is at least one node in the cluster, and `ClusterState::iter_working_connections_to_{nodes,shards}()`
// returns either an error or an iterator with at least one connection, so there will be at least one result.
Ok(Err(first_error.expect(
"ClusterState::iter_working_connections_to_{nodes,shards}() returns at least one connection or errors out",
)))
}

Ok(prepared)
// Start by attempting preparation on a single (random) connection to every node.
{
let on_all_nodes_result =
prepare_on_all(self, statement.clone(), cluster_state.clone(), false).await?;
if let Ok(prepared) = on_all_nodes_result {
// We succeeded in preparing the statement on at least one node. We're done; at the same time,
// the background tokio task attempts preparation on remaining nodes.
return Ok(prepared);
}
}

// We could have been just unlucky: we could have possibly chosen random connections all of which were defunct
// (one possibility is that we targeted overloaded shards).
// Let's try again, this time on connections to every shard. This is a "last call" fallback.
{
let on_all_shards_result = prepare_on_all(self, statement, cluster_state, true).await?;
on_all_shards_result
.map_err(|err| PrepareError::AllAttemptsFailed { first_attempt: err })
}
}

fn extract_partitioner_name<'a>(
Expand Down Expand Up @@ -2098,7 +2218,7 @@ impl Session {

pub async fn check_schema_agreement(&self) -> Result<Option<Uuid>, SchemaAgreementError> {
let cluster_state = self.get_cluster_state();
let connections_iter = cluster_state.iter_working_connections()?;
let connections_iter = cluster_state.iter_working_connections_to_shards()?;

let handles = connections_iter.map(|c| async move { c.fetch_schema_version().await });
let versions = try_join_all(handles).await?;
Expand Down
4 changes: 4 additions & 0 deletions scylla/src/cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ impl Node {
self.get_pool()?.get_working_connections()
}

pub(crate) fn get_random_connection(&self) -> Result<Arc<Connection>, ConnectionPoolError> {
self.get_pool()?.random_connection()
}

pub(crate) async fn wait_until_pool_initialized(&self) {
if let Some(pool) = &self.pool {
pool.wait_until_initialized().await;
Expand Down
24 changes: 23 additions & 1 deletion scylla/src/cluster/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl ClusterState {
}

/// Returns nonempty iterator of working connections to all shards.
pub(crate) fn iter_working_connections(
pub(crate) fn iter_working_connections_to_shards(
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, ConnectionPoolError> {
// The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`.
Expand All @@ -344,6 +344,28 @@ impl ClusterState {
// is nonempty, too.
}

/// Returns nonempty iterator of working connections to all nodes.
pub(crate) fn iter_working_connections_to_nodes(
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, ConnectionPoolError> {
// The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`.
assert!(!self.known_peers.is_empty());
let mut peers_iter = self.known_peers.values();

// First we try to find the first working pool of connections.
// If none is found, return error.
let first_working_pool = peers_iter
.by_ref()
.map(|node| node.get_random_connection())
.find_or_first(Result::is_ok)
.expect("impossible: known_peers was asserted to be nonempty")?;

let remaining_pools_iter = peers_iter.flat_map(|node| node.get_random_connection());

Ok(std::iter::once(first_working_pool).chain(remaining_pools_iter))
// The returned iterator is nonempty, because it returns at least `first_working_pool`.
}

pub(super) fn update_tablets(&mut self, raw_tablets: Vec<(TableSpec<'static>, RawTablet)>) {
let replica_translator = |uuid: Uuid| self.known_peers.get(&uuid).cloned();

Expand Down