Skip to content

tests: introducing keep-alive tests #1241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: branch-hackathon
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions scylla/tests/ccm_integration/keep_alive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
use crate::common::utils::test_with_3_node_cluster;
use scylla::client::session::Session;
use scylla::client::session_builder::SessionBuilder;

use scylla_proxy::{
Condition, ProxyError, Reaction, RequestOpcode, RequestReaction, RequestRule, ShardAwareness,
WorkerError,
};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio::time::sleep;

/// Waits for a specified number of keep-alive messages to be received within a given timeout period.
///
/// # Arguments
///
/// * `request_rx` - An `UnboundedReceiver` that receives the keep-alive messages.
/// * `expected_number_of_keep_alive` - The number of keep-alive messages expected to be received.
/// * `timeout` - The maximum duration to wait for the expected number of keep-alive messages.
///
/// # Panics
///
/// This function will panic if the number of received keep-alive messages does not match the
/// expected number within the timeout period.
///
async fn wait_for_keep_alive<T>(
Comment on lines +14 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I really like that you put such extensive comment, such things are sadly missing in many of our tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mostly copilot, I just removed a few unneeded things after him

mut request_rx: UnboundedReceiver<T>,
expected_number_of_keep_alive: u32,
timeout: Duration,
) {
let mut total_keep_alives: u32 = 0;
let start = tokio::time::Instant::now();
while total_keep_alives < expected_number_of_keep_alive && start.elapsed() < timeout {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is the best idiomatic way of doing such a thing

Copy link
Collaborator

@Lorak-mmk Lorak-mmk Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not. I'd do it roughly like that (didn't test the code, so it may not compile):

tokio::time::timeout(timeout,  async move {
    for _ in 0..expected_number_of_keep_alive {
        request_rx.recv().await.unwrap();
    }
});

if let Ok(_t) = request_rx.try_recv() {
total_keep_alives += 1;
}
sleep(Duration::from_millis(1)).await;
}
println!("Total keep alive: {}", total_keep_alives);
assert_eq!(total_keep_alives, expected_number_of_keep_alive);
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn node_should_go_down_gracefully_when_connection_closed_during_heartbeat() {
let res = test_with_3_node_cluster(
ShardAwareness::QueryNode,
|proxy_uris, translation_map, mut running_proxy| async move {
let _session: Session = SessionBuilder::new()
.known_node(proxy_uris[0].as_str())
.address_translator(Arc::new(translation_map))
.keepalive_interval(Duration::from_millis(20))
.build()
.await
.unwrap();

// TODO: No way to get node status, no as in java-driver

Comment on lines +56 to +59
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM by node status? Do you need information if the driver is connected to a given node?
If so you can get it with session.cluster_state().get_nodes_info()[node_idx].is_connected()

// Stop listening for new connections (so it can't reconnect)

let request_rule = |tx| {
vec![RequestRule(
Condition::RequestOpcode(RequestOpcode::Options),
RequestReaction::drop_connection().with_feedback_when_performed(tx),
)]
};
let (request_tx, request_rx) = mpsc::unbounded_channel();
running_proxy.running_nodes[0]
.change_request_rules(Some(request_rule(request_tx.clone())));

// Wait for heartbeat and for node to subsequently close its connection.

wait_for_keep_alive(request_rx, 1, Duration::from_secs(1)).await;

//int heartbeatCount = getHeartbeatsForNode().size();
// When node receives a heartbeat, close the connection.
//simulacronNode.prime(
// when(Options.INSTANCE)
// .then(closeConnection(DisconnectAction.Scope.CONNECTION, CloseType.DISCONNECT)));

// Wait for heartbeat and for node to subsequently close its connection.
//waitForDown(node);

// Should have been a heartbeat received since that's what caused the disconnect.
// assertThat(getHeartbeatsForNode().size()).isGreaterThan(heartbeatCount);

running_proxy
},
)
.await;

match res {
Ok(()) => (),
Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (),
Err(err) => panic!("{}", err),
}
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn should_not_send_heartbeat_during_protocol_initialization() {
let res = test_with_3_node_cluster(
ShardAwareness::QueryNode,
|proxy_uris, translation_map, mut running_proxy| async move {
let request_rule = |tx| {
vec![RequestRule(
Condition::RequestOpcode(RequestOpcode::Options),
RequestReaction::drop_connection().with_feedback_when_performed(tx),
)]
};
let (request_tx, request_rx) = mpsc::unbounded_channel();
running_proxy.running_nodes[0]
.change_request_rules(Some(request_rule(request_tx.clone())));

let result = SessionBuilder::new()
.known_node(proxy_uris[0].as_str())
.address_translator(Arc::new(translation_map))
.keepalive_interval(Duration::from_millis(20))
.build()
.await;
assert!(result.is_err());

wait_for_keep_alive(request_rx, 1, Duration::from_secs(1)).await;

running_proxy
},
)
.await;

match res {
Ok(()) => (),
Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (),
Err(err) => panic!("{}", err),
}
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn should_send_heartbeat_on_control_connection() {
let res = test_with_3_node_cluster(
ShardAwareness::QueryNode,
|proxy_uris, translation_map, mut running_proxy| async move {
let request_rule = |tx| {
vec![RequestRule(
Condition::RequestOpcode(RequestOpcode::Options)
.and(Condition::ConnectionRegisteredAnyEvent),
RequestReaction::noop().with_feedback_when_performed(tx),
)]
};

let (request_tx, request_rx) = mpsc::unbounded_channel();
running_proxy.running_nodes[0]
.change_request_rules(Some(request_rule(request_tx.clone())));

let _session = SessionBuilder::new()
.known_node(proxy_uris[0].as_str())
.address_translator(Arc::new(translation_map))
.keepalive_interval(Duration::from_millis(20))
//.tcp_keepalive_interval(Duration::from_millis(20))
.build()
.await
.unwrap();

wait_for_keep_alive(request_rx, 1, Duration::from_secs(1)).await;

running_proxy
},
)
.await;

match res {
Ok(()) => (),
Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (),
Err(err) => panic!("{}", err),
}
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn should_send_heartbeat_on_regular_connection() {
/*
// Prime a simple query so we get at least some results
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't got to implementing all of the tests, left the java code as skeletons

simulacronNode.prime(when(QUERY).then(PrimeDsl.rows().row("column1", "1", "column2", "2")));

try (CqlSession session = newSession()) {
// Make a bunch of queries over two seconds. This should preempt any heartbeats.
assertThat(session.execute(QUERY)).hasSize(1);
final AtomicInteger nonControlHeartbeats = countHeartbeatsOnRegularConnection();
for (int i = 0; i < 20; i++) {
assertThat(session.execute(QUERY)).hasSize(1);
MILLISECONDS.sleep(100);
}

// No heartbeats should be sent, except those on the control connection.
assertThat(nonControlHeartbeats.get()).isZero();

// Stop querying, heartbeats should be sent again
await()
.pollInterval(500, TimeUnit.MILLISECONDS)
.atMost(60, TimeUnit.SECONDS)
.until(() -> nonControlHeartbeats.get() >= 1);
}
*/
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn should_send_heartbeat_when_requests_being_written_but_nothing_received() {
/*
// Prime a query that will never return a response.
String noResponseQueryStr = "delay";
SIMULACRON_RULE.cluster().prime(when(noResponseQueryStr).then(noResult()));

try (CqlSession session = newSession()) {
AtomicInteger heartbeats = countHeartbeatsOnRegularConnection();

for (int i = 0; i < 25; i++) {
session.executeAsync(noResponseQueryStr);
session.executeAsync(noResponseQueryStr);
MILLISECONDS.sleep(100);
}

// We should expect at least 2 heartbeats
assertThat(heartbeats.get()).isGreaterThanOrEqualTo(2);
}
*/
}

#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn should_close_connection_when_heartbeat_times_out() {
/*
try (CqlSession session = newSession()) {
Node node = session.getMetadata().getNodes().values().iterator().next();
assertThat(node.getState()).isEqualTo(NodeState.UP);

// Ensure we get some heartbeats and the node remains up.
AtomicInteger heartbeats = new AtomicInteger();
simulacronNode.registerQueryListener(
(n, l) -> heartbeats.incrementAndGet(), true, this::isOptionRequest);

await()
.pollInterval(500, TimeUnit.MILLISECONDS)
.atMost(60, TimeUnit.SECONDS)
.until(() -> heartbeats.get() >= 2);
assertThat(node.getState()).isEqualTo(NodeState.UP);

// configure node to not respond to options request, which should cause a timeout.
simulacronNode.prime(when(Options.INSTANCE).then(noResult()));
heartbeats.set(0);

// wait for heartbeat to be sent.
await()
.pollInterval(500, TimeUnit.MILLISECONDS)
.atMost(60, TimeUnit.SECONDS)
.until(() -> heartbeats.get() >= 1);
heartbeats.set(0);

// node should go down because heartbeat was unanswered.
waitForDown(node);

// clear prime so now responds to options request again.
simulacronNode.clearPrimes();

// wait for node to come up again and ensure heartbeats are successful and node remains up.
waitForUp(node);

await()
.pollInterval(500, TimeUnit.MILLISECONDS)
.atMost(60, TimeUnit.SECONDS)
.until(() -> heartbeats.get() >= 2);
assertThat(node.getState()).isEqualTo(NodeState.UP);
}
*/
}
2 changes: 2 additions & 0 deletions scylla/tests/ccm_integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ pub(crate) mod ccm;
mod test_example;
#[cfg(feature = "ssl")]
mod tls;

mod keep_alive;
Loading