Skip to content

session: Prepare on one shard per node only #1320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 82 additions & 23 deletions scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
use crate::policies::speculative_execution;
use crate::policies::timestamp_generator::TimestampGenerator;
use crate::response::query_result::{MaybeFirstRowError, QueryResult, RowsError};
use crate::response::{
NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse, RawPreparedStatement,
};
use crate::response::{NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse};
use crate::routing::partitioner::PartitionerName;
use crate::routing::{Shard, ShardAwarePortRange};
use crate::statement::batch::batch_values;
Expand Down Expand Up @@ -1181,7 +1179,7 @@ impl Session {
// Making QueryPager::new_for_query work with values is too hard (if even possible)
// so instead of sending one prepare to a specific connection on each iterator query,
// we fully prepare a statement beforehand.
let prepared = self.prepare(statement).await?;
let prepared = self.prepare_nongeneric(&statement).await?;
let values = prepared.serialize_values(&values)?;
QueryPager::new_for_prepared_statement(PreparedPagerConfig {
prepared,
Expand All @@ -1199,6 +1197,9 @@ impl Session {
/// Prepares a statement on the server side and returns a prepared statement,
/// which can later be used to perform more efficient requests.
///
/// The statement is prepared on all nodes. This function finishes once all nodes respond
/// with either success or an error.
///
/// Prepared statements are much faster than unprepared statements:
/// * Database doesn't need to parse the statement string upon each execution (only once)
/// * They are properly load balanced using token aware routing
Expand Down Expand Up @@ -1237,38 +1238,97 @@ impl Session {
statement: impl Into<Statement>,
) -> Result<PreparedStatement, PrepareError> {
let statement = statement.into();
let statement_ref = &statement;
self.prepare_nongeneric(&statement).await
}

// Introduced to avoid monomorphisation of this large function.
async fn prepare_nongeneric(
&self,
statement: &Statement,
) -> Result<PreparedStatement, PrepareError> {
let cluster_state = self.get_cluster_state();
let connections_iter = cluster_state.iter_working_connections()?;

// Prepare statements on all connections concurrently
let handles = connections_iter.map(|c| async move { c.prepare_raw(statement_ref).await });
let mut results = join_all(handles).await.into_iter();
// Start by attempting preparation on a single (random) connection to every node.
{
let mut connections_to_nodes = cluster_state.iter_working_connections_to_nodes()?;
let on_all_nodes_result =
Self::prepare_on_all(statement, &cluster_state, &mut connections_to_nodes).await;
if let Ok(prepared) = on_all_nodes_result {
// We succeeded in preparing the statement on at least one node. We're done.
// Other nodes could have failed to prepare the statement, but this will be handled
// as `DbError::Unprepared` upon execution, followed by a repreparation attempt.
return Ok(prepared);
}
}

// We could have been just unlucky: we could have possibly chosen random connections all of which were defunct
// (one possibility is that we targeted overloaded shards).
// Let's try again, this time on connections to every shard. This is a "last call" fallback.
{
let mut connections_to_shards = cluster_state.iter_working_connections_to_shards()?;

// If at least one prepare was successful, `prepare()` returns Ok.
Self::prepare_on_all(statement, &cluster_state, &mut connections_to_shards).await
}
}

/// Prepares the statement on all given connections.
/// These are intended to be connections to either all nodes or all shards.
///
/// ASSUMPTION: the `working_connections` Iterator is nonempty.
///
/// Returns:
/// - `Ok(PreparedStatement)`, if preparation succeeded on at least one connection;
/// - `Err(PrepareError)`, if no connection is working or preparation failed on all attempted connections.
// TODO: There are no timeouts here. So, just one stuck node freezes the driver here, potentially indefinitely long.
// Also, what the driver requires to get from the cluster is the prepared statement metadata.
// It suffices that it gets only one copy of it, just from one success response. Therefore, it's a possible
// optimisation that the function only waits for one preparation to finish successfully, and then it returns.
// For it to be done, other preparations must continue in the background, on a separate tokio task.
// Describing issue: #1332.
async fn prepare_on_all(
statement: &Statement,
cluster_state: &ClusterState,
working_connections: &mut (dyn Iterator<Item = Arc<Connection>> + Send),
) -> Result<PreparedStatement, PrepareError> {
// Find the first result that is Ok, or Err if all failed.
let preparations =
working_connections.map(|c| async move { c.prepare_raw(statement).await });
let raw_prepared_statements_results = join_all(preparations).await;

// Safety: there is at least one node in the cluster, and `Cluster::iter_working_connections()`
// returns either an error or an iterator with at least one connection, so there will be at least one result.
let first_ok: Result<RawPreparedStatement, RequestAttemptError> =
results.by_ref().find_or_first(Result::is_ok).unwrap();
let mut prepared: PreparedStatement = first_ok
let mut raw_prepared_statements_results_iter = raw_prepared_statements_results.into_iter();

// Safety: We pass a nonempty iterator, so there will be at least one result.
let first_ok_or_error = raw_prepared_statements_results_iter
.by_ref()
.find_or_first(|res| res.is_ok())
.unwrap(); // Safety: there is at least one connection.

let mut prepared: PreparedStatement = first_ok_or_error
.map_err(|first_attempt| PrepareError::AllAttemptsFailed { first_attempt })?
.into_prepared_statement();

// Validate prepared ids equality
for statement in results.flatten() {
if prepared.get_id() != &statement.prepared_response.id {
// Validate prepared ids equality.
for another_raw_prepared in raw_prepared_statements_results_iter.flatten() {
if prepared.get_id() != &another_raw_prepared.prepared_response.id {
tracing::error!(
"Got differing ids upon statement preparation: statement \"{}\", id1: {:?}, id2: {:?}",
prepared.get_statement(),
prepared.get_id(),
another_raw_prepared.prepared_response.id
);
return Err(PrepareError::PreparedStatementIdsMismatch);
}

// Collect all tracing ids from prepare() queries in the final result
prepared.prepare_tracing_ids.extend(statement.tracing_id);
prepared
.prepare_tracing_ids
.extend(another_raw_prepared.tracing_id);
}

// This is the first preparation that succeeded.
// Let's return the PreparedStatement.
prepared.set_partitioner_name(
self.extract_partitioner_name(&prepared, &self.cluster.get_state())
Self::extract_partitioner_name(&prepared, cluster_state)
.and_then(PartitionerName::from_str)
.unwrap_or_default(),
);
Expand All @@ -1277,7 +1337,6 @@ impl Session {
}

fn extract_partitioner_name<'a>(
&self,
prepared: &PreparedStatement,
cluster_state: &'a ClusterState,
) -> Option<&'a str> {
Expand Down Expand Up @@ -1583,7 +1642,7 @@ impl Session {
.iter_mut()
.map(|statement| async move {
if let BatchStatement::Query(query) = statement {
let prepared = self.prepare(query.clone()).await?;
let prepared = self.prepare_nongeneric(query).await?;
*statement = BatchStatement::PreparedStatement(prepared);
}
Ok::<(), PrepareError>(())
Expand Down Expand Up @@ -2103,7 +2162,7 @@ impl Session {

pub async fn check_schema_agreement(&self) -> Result<Option<Uuid>, SchemaAgreementError> {
let cluster_state = self.get_cluster_state();
let connections_iter = cluster_state.iter_working_connections()?;
let connections_iter = cluster_state.iter_working_connections_to_shards()?;

let handles = connections_iter.map(|c| async move { c.fetch_schema_version().await });
let versions = try_join_all(handles).await?;
Expand Down
4 changes: 4 additions & 0 deletions scylla/src/cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ impl Node {
self.get_pool()?.get_working_connections()
}

pub(crate) fn get_random_connection(&self) -> Result<Arc<Connection>, ConnectionPoolError> {
self.get_pool()?.random_connection()
}

pub(crate) async fn wait_until_pool_initialized(&self) {
if let Some(pool) = &self.pool {
pool.wait_until_initialized().await;
Expand Down
74 changes: 60 additions & 14 deletions scylla/src/cluster/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,31 +319,77 @@ impl ClusterState {
}

/// Returns nonempty iterator of working connections to all shards.
pub(crate) fn iter_working_connections(
pub(crate) fn iter_working_connections_to_shards(
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, ConnectionPoolError> {
// The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`.
assert!(!self.known_peers.is_empty());
let mut peers_iter = self.known_peers.values();
let nodes_iter = self.known_peers.values();
let mut connection_pool_per_node_iter =
nodes_iter.map(|node| node.get_working_connections());

// First we try to find the first working pool of connections.
// If none is found, return error.
let first_working_pool = peers_iter
.by_ref()
.map(|node| node.get_working_connections())
.find_or_first(Result::is_ok)
.expect("impossible: known_peers was asserted to be nonempty")?;

let remaining_pools_iter = peers_iter
.map(|node| node.get_working_connections())
.flatten_ok()
.flatten();

Ok(first_working_pool.into_iter().chain(remaining_pools_iter))
let first_working_pool_or_error: Result<Vec<Arc<Connection>>, ConnectionPoolError> =
connection_pool_per_node_iter
.by_ref()
.find_or_first(Result::is_ok)
.expect("impossible: known_peers was asserted to be nonempty");

// We have:
// 1. either consumed the whole iterator without success and got the first error,
// in which case we propagate it;
// 2. or found the first working pool of connections.
let first_working_pool: Vec<Arc<Connection>> = first_working_pool_or_error?;

// We retrieve connection pools for remaining nodes (those that are left in the iterator
// once the first working pool has been found).
let remaining_pools_iter = connection_pool_per_node_iter;
// Pools are flattened, so now we have `impl Iterator<Item = Result<Arc<Connection>, ConnectionPoolError>>`.
let remaining_connections_iter = remaining_pools_iter.flatten_ok();
// Errors (non-working pools) are filtered out.
let remaining_working_connections_iter = remaining_connections_iter.filter_map(Result::ok);

Ok(first_working_pool
.into_iter()
.chain(remaining_working_connections_iter))
// By an invariant `self.known_peers` is nonempty, so the returned iterator
// is nonempty, too.
}

/// Returns nonempty iterator of working connections to all nodes.
pub(crate) fn iter_working_connections_to_nodes(
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, ConnectionPoolError> {
// The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`.
assert!(!self.known_peers.is_empty());
let nodes_iter = self.known_peers.values();
let mut single_connection_per_node_iter =
nodes_iter.map(|node| node.get_random_connection());

// First we try to find the first working connection.
// If none is found, return error.
let first_working_connection_or_error: Result<Arc<Connection>, ConnectionPoolError> =
single_connection_per_node_iter
.by_ref()
.find_or_first(Result::is_ok)
.expect("impossible: known_peers was asserted to be nonempty");

// We have:
// 1. either consumed the whole iterator without success and got the first error,
// in which case we propagate it;
// 2. or found the first working connection.
let first_working_connection: Arc<Connection> = first_working_connection_or_error?;

// We retrieve single random connections for remaining nodes (those that are left in the iterator
// once the first working connection has been found). Errors (non-working connections) are filtered out.
let remaining_connection_iter = single_connection_per_node_iter.filter_map(Result::ok);

// Connections to the remaining nodes are chained to the first working connection.
Ok(std::iter::once(first_working_connection).chain(remaining_connection_iter))
// The returned iterator is nonempty, because it returns at least `first_working_pool`.
}

pub(super) fn update_tablets(&mut self, raw_tablets: Vec<(TableSpec<'static>, RawTablet)>) {
let replica_translator = |uuid: Uuid| self.known_peers.get(&uuid).cloned();

Expand Down