Skip to content
This repository was archived by the owner on Jan 2, 2025. It is now read-only.

Send prepared statements to all nodes #320

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ pub trait GetConnection<
fn get_connection(&self) -> Option<r2d2::PooledConnection<M>>;
}

pub trait GetAllConnections<
T: CDRSTransport + Send + Sync + 'static,
M: r2d2::ManageConnection<Connection = cell::RefCell<T>, Error = error::Error>,
>
{
/// Returns all connections from the load balancer.
fn get_all_connections(&self) -> Vec<Option<r2d2::PooledConnection<M>>>;
}

/// `GetCompressor` trait provides a unified interface for Session to get a compressor
/// for further decompressing received data.
pub trait GetCompressor<'a> {
Expand Down
20 changes: 19 additions & 1 deletion src/cluster/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::cluster::NodeTcpConfig;
use crate::cluster::{new_ssl_pool, ClusterSslConfig, NodeSslConfig, SslConnectionPool};
use crate::cluster::{
new_tcp_pool, startup, CDRSSession, ClusterTcpConfig, ConnectionPool, GetCompressor,
GetConnection, TcpConnectionPool,
GetConnection, GetAllConnections, TcpConnectionPool,
};
use crate::error;
use crate::load_balancing::LoadBalancingStrategy;
Expand Down Expand Up @@ -106,6 +106,24 @@ impl<
}
}

impl<
T: CDRSTransport + Send + Sync + 'static,
M: r2d2::ManageConnection<Connection = RefCell<T>, Error = error::Error> + Sized,
LB: LoadBalancingStrategy<ConnectionPool<M>> + Sized,
> GetAllConnections<T, M> for Session<LB>
{
fn get_all_connections(&self) -> Vec<Option<r2d2::PooledConnection<M>>> {
self.load_balancing
.lock()
.ok()
.unwrap()
.get_all_nodes()
.into_iter()
.map(|pool_ref| pool_ref.get_pool().get().ok())
.collect()
}
}

impl<
'a,
T: CDRSTransport + 'static,
Expand Down
8 changes: 4 additions & 4 deletions src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub mod traits;

use crate::error;

#[derive(Debug)]
#[derive(Clone,Debug)]
pub struct Frame {
pub version: Version,
pub flags: Vec<Flag>,
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<'a> IntoBytes for Frame {
}

/// Frame's version
#[derive(Debug, PartialEq)]
#[derive(Clone,Debug, PartialEq)]
pub enum Version {
Request,
Response,
Expand Down Expand Up @@ -187,7 +187,7 @@ impl From<Vec<u8>> for Version {

/// Frame's flag
// Is not implemented functionality. Only Igonore works for now
#[derive(Debug, PartialEq)]
#[derive(Clone,Debug, PartialEq)]
pub enum Flag {
Compression,
Tracing,
Expand Down Expand Up @@ -276,7 +276,7 @@ impl From<u8> for Flag {
}
}

#[derive(Debug, PartialEq)]
#[derive(Clone,Debug, PartialEq)]
pub enum Opcode {
Error,
Startup,
Expand Down
1 change: 1 addition & 0 deletions src/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use crate::load_balancing::single_node::SingleNode;
pub trait LoadBalancingStrategy<N>: Sized {
fn init(&mut self, cluster: Vec<N>);
fn next(&self) -> Option<&N>;
fn get_all_nodes(&self) -> &Vec<N>;
fn remove_node<F>(&mut self, _filter: F)
where
F: FnMut(&N) -> bool,
Expand Down
5 changes: 5 additions & 0 deletions src/load_balancing/random.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use rand;

use super::LoadBalancingStrategy;
use std::borrow::Borrow;

pub struct Random<N> {
pub cluster: Vec<N>,
Expand Down Expand Up @@ -40,6 +41,10 @@ impl<N> LoadBalancingStrategy<N> for Random<N> {
self.cluster.get(Self::rnd_idx((0, len)))
}

fn get_all_nodes(&self) -> &Vec<N> {
self.cluster.borrow()
}

fn remove_node<F>(&mut self, filter: F)
where
F: FnMut(&N) -> bool,
Expand Down
5 changes: 5 additions & 0 deletions src/load_balancing/round_robin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cell::RefCell;

use super::LoadBalancingStrategy;
use std::borrow::Borrow;

pub struct RoundRobin<N> {
cluster: Vec<N>,
Expand Down Expand Up @@ -38,6 +39,10 @@ impl<N> LoadBalancingStrategy<N> for RoundRobin<N> {
self.cluster.get(next_idx)
}

fn get_all_nodes(&self) -> &Vec<N> {
self.cluster.borrow()
}

fn remove_node<F>(&mut self, filter: F)
where
F: FnMut(&N) -> bool,
Expand Down
5 changes: 5 additions & 0 deletions src/load_balancing/round_robin_sync.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Mutex;

use super::LoadBalancingStrategy;
use std::borrow::Borrow;

pub struct RoundRobinSync<N> {
cluster: Vec<N>,
Expand Down Expand Up @@ -42,6 +43,10 @@ impl<N> LoadBalancingStrategy<N> for RoundRobinSync<N> {
}
}

fn get_all_nodes(&self) -> &Vec<N> {
self.cluster.borrow()
}

fn remove_node<F>(&mut self, filter: F)
where
F: FnMut(&N) -> bool,
Expand Down
5 changes: 5 additions & 0 deletions src/load_balancing/single_node.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::LoadBalancingStrategy;
use std::borrow::Borrow;

pub struct SingleNode<N> {
cluster: Vec<N>,
Expand All @@ -25,6 +26,10 @@ impl<N> LoadBalancingStrategy<N> for SingleNode<N> {
fn next(&self) -> Option<&N> {
self.cluster.get(0)
}

fn get_all_nodes(&self) -> &Vec<N> {
self.cluster.borrow()
}
}

#[cfg(test)]
Expand Down
10 changes: 6 additions & 4 deletions src/query/prepare_executor.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use r2d2;
use std::cell::RefCell;

use crate::cluster::{GetCompressor, GetConnection};
use crate::cluster::{GetCompressor, GetAllConnections};
use crate::error;
use crate::frame::{Frame, IntoBytes};
use crate::transport::CDRSTransport;
use crate::types::CBytesShort;

use super::utils::{prepare_flags, send_frame};
use super::utils::prepare_flags;
use crate::query::utils::send_frame_to_all_connections;


pub type PreparedQuery = CBytesShort;

pub trait PrepareExecutor<
T: CDRSTransport + 'static,
M: r2d2::ManageConnection<Connection = RefCell<T>, Error = error::Error> + Sized,
>: GetConnection<T, M> + GetCompressor<'static>
>: GetAllConnections<T, M> + GetCompressor<'static>
{
/// It prepares a query for execution, along with query itself
/// the method takes `with_tracing` and `with_warnings` flags
Expand All @@ -32,7 +34,7 @@ pub trait PrepareExecutor<

let query_frame = Frame::new_req_prepare(query.to_string(), flags).into_cbytes();

send_frame(self, query_frame)
send_frame_to_all_connections(self, query_frame)
.and_then(|response| response.get_body())
.and_then(|body| {
Ok(body
Expand Down
54 changes: 53 additions & 1 deletion src/query/utils.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::cell::RefCell;

use crate::cluster::{GetCompressor, GetConnection};
use crate::cluster::{GetCompressor, GetConnection, GetAllConnections};
use crate::error;
use crate::frame::parser::from_connection;
use crate::frame::{Flag, Frame};
use crate::transport::CDRSTransport;
use crate::compression::Compression;
use r2d2::PooledConnection;
use std::error::Error;

pub fn prepare_flags(with_tracing: bool, with_warnings: bool) -> Vec<Flag> {
let mut flags = vec![];
Expand Down Expand Up @@ -41,6 +44,55 @@ where
.and_then(|transport_cell| from_connection(&transport_cell, compression))
}

pub fn send_frame_to_all_connections<S, T, M>(sender: &S, frame_bytes: Vec<u8>) -> error::Result<Frame>
where
S: GetAllConnections<T, M> + GetCompressor<'static> + Sized,
T: CDRSTransport + 'static,
M: r2d2::ManageConnection<Connection=RefCell<T>, Error=error::Error> + Sized,
{
let ref compression = sender.get_compressor();

let mut results: Vec<error::Result<Frame>> = Vec::new();

for connection in sender.get_all_connections() {
match connection {
Some(conn) => {
results.push(write_frame_to_connection(&frame_bytes, compression, conn));
}
None => { results.push(transform_error()); }
}
}

let result = results.iter().find(|r| r.is_err())
.unwrap_or_else( ||get_any_valid_result(&results));
Copy link
Owner

@AlexPikalov AlexPikalov Mar 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, getting any valid result may not work as expected. Each node will prepare a query with a different ID. That's why if we have a cluster with >1 nodes load balancer may select a "wrong" node which has the query prepared with a different ID.

What I had been thinking about is to have a map of prepared queries to nodes so that during preparing new entity is added to this map. Whenever exec-like call is made, the executor picks up a node where the prepared query ID is residing. However for that a LoadBalancingStrategy should be able to return a node basing on providing criteria, e.g.

load_balancer
  .find_node(|node_in_cluster| node.addr() == addr_assigned_to_prepared_id)
  .send_frame(frame_bytes)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I will work on it.


match result {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this statement can be simplified so that no clone is needed. Please consider following line to put instead the whole match.

result.map_err(|err| error::Error::General(e.description().to_string())

Ok(frame) => Result::Ok(frame.clone()),
Err(e) => Result::Err(error::Error::General(e.description().to_string()))
}
}

fn get_any_valid_result(results: &Vec<error::Result<Frame>>) -> &error::Result<Frame> {
results.iter().find(|r| r.is_ok()).unwrap()
}

fn transform_error() -> error::Result<Frame> {
Copy link
Owner

@AlexPikalov AlexPikalov Mar 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small typo. Perhaps it had to be transport_error

return error::Result::Err(error::Error::from("Unable to get transport"));
}

fn write_frame_to_connection<T, M>(frame_bytes: &Vec<u8>, compression: &Compression, conn: PooledConnection<M>) -> error::Result<Frame>
where
T: CDRSTransport + 'static,
M: r2d2::ManageConnection<Connection=RefCell<T>, Error=error::Error> + Sized {
let result = conn
.borrow_mut()
.write(frame_bytes.as_slice())
.map_err(error::Error::from);

result.map(|_| conn)
.and_then(|transport_cell| from_connection(&transport_cell, compression))
}

#[cfg(test)]
mod test {
use super::*;
Expand Down