Skip to content

Integration test for CachingSession #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: branch-hackathon
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions scylla/tests/common/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,31 @@ impl PerformDDL for CachingSession {
self.execute_unpaged(query, &[]).await.map(|_| ())
}
}

#[derive(Debug)]
#[allow(dead_code)]
Copy link
Author

@Bouncheck Bouncheck Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why I had to add this line here even though I use this struct in my test. Shouldn't it be not a dead_code because of that?

pub(crate) struct SingleTargetLBP {
pub(crate) target: (Arc<scylla::cluster::Node>, Option<u32>),
}

impl LoadBalancingPolicy for SingleTargetLBP {
fn pick<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> Option<(NodeRef<'a>, Option<u32>)> {
Some((&self.target.0, self.target.1))
}

fn fallback<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> FallbackPlan<'a> {
Box::new(std::iter::empty())
}

fn name(&self) -> String {
"SingleTargetLBP".to_owned()
}
}
193 changes: 193 additions & 0 deletions scylla/tests/integration/caching_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use std::sync::Arc;

use crate::utils::test_with_3_node_cluster;
use crate::utils::{setup_tracing, unique_keyspace_name, PerformDDL};
use scylla::batch::Batch;
use scylla::batch::BatchType;
use scylla::client::caching_session::CachingSession;
use scylla_proxy::RequestOpcode;
use scylla_proxy::RequestReaction;
use scylla_proxy::RequestRule;
use scylla_proxy::ShardAwareness;
use scylla_proxy::{Condition, ProxyError, Reaction, RequestFrame, TargetShard, WorkerError};
use tokio::sync::mpsc;

fn consume_current_feedbacks(
rx: &mut mpsc::UnboundedReceiver<(RequestFrame, Option<TargetShard>)>,
) -> usize {
std::iter::from_fn(|| rx.try_recv().ok()).count()
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn ensure_cache_is_used() {
use scylla::client::execution_profile::ExecutionProfile;

use crate::utils::SingleTargetLBP;

setup_tracing();
let res = test_with_3_node_cluster(
ShardAwareness::QueryNode,
|proxy_uris, translation_map, mut running_proxy| async move {
let session = scylla::client::session_builder::SessionBuilder::new()
.known_node(proxy_uris[0].as_str())
.address_translator(Arc::new(translation_map))
.build()
.await
.unwrap();

let cluster_size: usize = 3;
let (feedback_txs, mut feedback_rxs): (Vec<_>, Vec<_>) = (0..cluster_size)
.map(|_| mpsc::unbounded_channel::<(RequestFrame, Option<TargetShard>)>())
.unzip();
for (i, tx) in feedback_txs.iter().cloned().enumerate() {
running_proxy.running_nodes[i].change_request_rules(Some(vec![RequestRule(
Condition::and(
Condition::RequestOpcode(RequestOpcode::Prepare),
Condition::not(Condition::ConnectionRegisteredAnyEvent),
),
RequestReaction::noop().with_feedback_when_performed(tx),
)]));
}

let ks = unique_keyspace_name();
let rs = "{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}";
session
.ddl(format!(
"CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {}",
ks, rs
))
.await
.unwrap();
session.use_keyspace(ks, false).await.unwrap();
session
.ddl("CREATE TABLE IF NOT EXISTS tab (a int, b int, c int, primary key (a, b, c))")
.await
.unwrap();
// Assumption: all nodes have the same number of shards
let nr_shards = session
.get_cluster_state()
.get_nodes_info()
.first()
.expect("No nodes information available")
.sharder()
.map(|sharder| sharder.nr_shards.get() as usize)
.unwrap_or_else(|| 1); // If there is no sharder, assume 1 shard.

// Consume all feedbacks so far to ensure we will not count something unrelated.
let _feedbacks = feedback_rxs
.iter_mut()
.map(consume_current_feedbacks)
.sum::<usize>();

let caching_session: CachingSession = CachingSession::from(session, 100);

let batch_size: usize = 4;
let mut batch = Batch::new(BatchType::Logged);
for i in 1..=batch_size {
let insert_b_c = format!("INSERT INTO tab (a, b, c) VALUES ({}, ?, ?)", i);
batch.append_statement(insert_b_c.as_str());
}
let batch_values: Vec<(i32, i32)> = (1..=batch_size as i32).map(|i| (i, i)).collect();

// First batch that should generate prepares for each shard.
caching_session
.batch(&batch, batch_values.clone())
.await
.unwrap();
let feedbacks: usize = feedback_rxs.iter_mut().map(consume_current_feedbacks).sum();
assert_eq!(feedbacks, batch_size * nr_shards * cluster_size);

// Few extra runs. Those batches should not result in any prepares being sent.
for _ in 0..4 {
caching_session
.batch(&batch, batch_values.clone())
.await
.unwrap();
let feedbacks: usize = feedback_rxs.iter_mut().map(consume_current_feedbacks).sum();
assert_eq!(feedbacks, 0);
}

let prepared_batch_res_rows: Vec<(i32, i32, i32)> = caching_session
.execute_unpaged("SELECT * FROM tab", &[])
.await
.unwrap()
.into_rows_result()
.unwrap()
.rows()
.unwrap()
.collect::<Result<_, _>>()
.unwrap();

// Select should have been prepared on all shards
let feedbacks: usize = feedback_rxs.iter_mut().map(consume_current_feedbacks).sum();
assert_eq!(feedbacks, nr_shards * cluster_size);

// Verify the data from inserts
let mut prepared_batch_res_rows = prepared_batch_res_rows;
prepared_batch_res_rows.sort();
let expected_rows: Vec<(i32, i32, i32)> =
(1..=batch_size as i32).map(|i| (i, i, i)).collect();
assert_eq!(prepared_batch_res_rows, expected_rows);

// Run some alters to invalidate the server side cache, similarly to scylla/src/session_test.rs
caching_session
.ddl("ALTER TABLE tab RENAME c to tmp")
.await
.unwrap();
caching_session
.ddl("ALTER TABLE tab RENAME b to c")
.await
.unwrap();
caching_session
.ddl("ALTER TABLE tab RENAME tmp to b")
.await
.unwrap();

// execute_unpageds caused by alters likely resulted in some prepares being sent.
// Consume those frames.
feedback_rxs
.iter_mut()
.map(consume_current_feedbacks)
.sum::<usize>();

// Run batch for each shard. The server cache should be updated on the first mismatch,
// therefore only first contacted shard will request reprepare due to mismatch.
for node_info in caching_session
.get_session()
.get_cluster_state()
.get_nodes_info()
.iter()
{
for shard_id in 0..nr_shards {
let policy = SingleTargetLBP {
target: (node_info.clone(), Some(shard_id as u32)),
};
let execution_profile = ExecutionProfile::builder()
.load_balancing_policy(Arc::new(policy))
.build();
batch.set_execution_profile_handle(Some(execution_profile.into_handle()));
caching_session
.batch(&batch, batch_values.clone())
.await
.unwrap();
let feedbacks: usize =
feedback_rxs.iter_mut().map(consume_current_feedbacks).sum();
let expected_feedbacks = if shard_id == 0 { batch_size } else { 0 };
assert_eq!(
feedbacks, expected_feedbacks,
"Mismatch in feedbacks on execution for node: {:?}, shard: {}",
node_info, shard_id
);
}
}
running_proxy
},
)
.await;
match res {
Ok(()) => (),
Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (),
Err(err) => panic!("{}", err),
}
}
1 change: 1 addition & 0 deletions scylla/tests/integration/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod authenticate;
mod batch;
mod caching_session;
mod consistency;
mod cql_collections;
mod cql_types;
Expand Down
32 changes: 1 addition & 31 deletions scylla/tests/integration/tablets.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use crate::utils::SingleTargetLBP;
use crate::utils::{
scylla_supports_tablets, setup_tracing, test_with_3_node_cluster, unique_keyspace_name,
PerformDDL,
Expand All @@ -12,10 +13,6 @@ use scylla::client::execution_profile::ExecutionProfile;
use scylla::client::session::Session;
use scylla::cluster::ClusterState;
use scylla::cluster::Node;
use scylla::cluster::NodeRef;
use scylla::policies::load_balancing::FallbackPlan;
use scylla::policies::load_balancing::LoadBalancingPolicy;
use scylla::policies::load_balancing::RoutingInfo;
use scylla::prepared_statement::PreparedStatement;
use scylla::query::Query;
use scylla::response::query_result::QueryResult;
Expand Down Expand Up @@ -156,33 +153,6 @@ fn calculate_key_per_tablet(tablets: &[Tablet], prepared: &PreparedStatement) ->
value_lists
}

#[derive(Debug)]
struct SingleTargetLBP {
target: (Arc<Node>, Option<u32>),
}

impl LoadBalancingPolicy for SingleTargetLBP {
fn pick<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> Option<(NodeRef<'a>, Option<u32>)> {
Some((&self.target.0, self.target.1))
}

fn fallback<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> FallbackPlan<'a> {
Box::new(std::iter::empty())
}

fn name(&self) -> String {
"SingleTargetLBP".to_owned()
}
}

async fn send_statement_everywhere(
session: &Session,
cluster: &ClusterState,
Expand Down
Loading